/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableJoinConditionTypeCoerceRule;
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
import org.immutables.value.Value;

@Value.Enclosing
public class JoinConditionTypeCoerceRule
extends RelRule<JoinConditionTypeCoerceRuleConfig> {
    public static final JoinConditionTypeCoerceRule INSTANCE = JoinConditionTypeCoerceRuleConfig.DEFAULT.toRule();

    private JoinConditionTypeCoerceRule(JoinConditionTypeCoerceRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Join join = (Join)call.rel(0);
        if (join.getCondition().isAlwaysTrue()) {
            return false;
        }
        RelDataTypeFactory typeFactory = call.builder().getTypeFactory();
        return this.hasEqualsRefsOfDifferentTypes(typeFactory, join.getCondition());
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Join join = (Join)call.rel(0);
        RelBuilder builder = call.builder();
        RexBuilder rexBuilder = builder.getRexBuilder();
        RelDataTypeFactory typeFactory = builder.getTypeFactory();
        List<RexNode> joinFilters = RelOptUtil.conjunctions(join.getCondition());
        List newJoinFilters = joinFilters.stream().map(filter -> {
            RexCall c;
            if (filter instanceof RexCall && (c = (RexCall)filter).getKind() == SqlKind.EQUALS) {
                RexNode leftOp = c.getOperands().get(0);
                RexNode rightOp = c.getOperands().get(1);
                if (leftOp instanceof RexInputRef && rightOp instanceof RexInputRef) {
                    RexInputRef ref1 = (RexInputRef)leftOp;
                    RexInputRef ref2 = (RexInputRef)rightOp;
                    if (!SqlTypeUtil.equalSansNullability(typeFactory, ref1.getType(), ref2.getType())) {
                        List<RelDataType> refTypes = Arrays.asList(ref1.getType(), ref2.getType());
                        RelDataType targetType = typeFactory.leastRestrictive(refTypes);
                        if (targetType == null) {
                            throw new TableException("implicit type conversion between " + ref1.getType() + " and " + ref2.getType() + " is not supported on join's condition now");
                        }
                        return builder.equals(rexBuilder.ensureType(targetType, ref1, true), rexBuilder.ensureType(targetType, ref2, true));
                    }
                }
            }
            return filter;
        }).collect(Collectors.toList());
        RexNode newCondExp = builder.and(FlinkRexUtil.simplify(rexBuilder, builder.and(newJoinFilters), join.getCluster().getPlanner().getExecutor()));
        Join newJoin = join.copy(join.getTraitSet(), newCondExp, join.getLeft(), join.getRight(), join.getJoinType(), join.isSemiJoinDone());
        call.transformTo(newJoin);
    }

    private boolean hasEqualsRefsOfDifferentTypes(RelDataTypeFactory typeFactory, RexNode predicate) {
        List<RexNode> conjunctions = RelOptUtil.conjunctions(predicate);
        return conjunctions.stream().filter(node -> node instanceof RexCall && node.getKind() == SqlKind.EQUALS).anyMatch(c -> {
            RexCall call = (RexCall)c;
            RexNode ref1 = call.getOperands().get(0);
            RexNode ref2 = call.getOperands().get(1);
            return ref1 instanceof RexInputRef && ref2 instanceof RexInputRef && !SqlTypeUtil.equalSansNullability(typeFactory, ref1.getType(), ref2.getType());
        });
    }

    @Value.Immutable(singleton=false)
    public static interface JoinConditionTypeCoerceRuleConfig
    extends RelRule.Config {
        public static final JoinConditionTypeCoerceRuleConfig DEFAULT = ImmutableJoinConditionTypeCoerceRule.JoinConditionTypeCoerceRuleConfig.builder().build().withOperandSupplier(b0 -> b0.operand(Join.class).anyInputs()).withDescription("JoinConditionTypeCoerceRule");

        @Override
        default public JoinConditionTypeCoerceRule toRule() {
            return new JoinConditionTypeCoerceRule(this);
        }
    }
}

