/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Util;
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableRewriteMinusAllRule;
import org.apache.flink.table.planner.plan.utils.SetOpRewriteUtil;
import org.immutables.value.Value;

@Value.Enclosing
public class RewriteMinusAllRule
extends RelRule<RewriteMinusAllRuleConfig> {
    public static final RewriteMinusAllRule INSTANCE = RewriteMinusAllRuleConfig.DEFAULT.toRule();

    protected RewriteMinusAllRule(RewriteMinusAllRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Minus minus = (Minus)call.rel(0);
        return minus.all && minus.getInputs().size() == 2;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Minus minus = (Minus)call.rel(0);
        RelNode left = minus.getInput(0);
        RelNode right = minus.getInput(1);
        List<Integer> fields = Util.range(minus.getRowType().getFieldCount());
        FlinkRelBuilder flinkRelBuilder = FlinkRelBuilder.of(call.rel(0).getCluster(), null);
        RelBuilder leftBuilder = flinkRelBuilder.transform(u -> u.withConvertCorrelateToJoin(false));
        RelNode leftWithAddedVirtualCols = leftBuilder.push(left).project(Stream.concat(leftBuilder.fields(fields).stream(), Stream.of(leftBuilder.alias(leftBuilder.cast(leftBuilder.literal(1L), SqlTypeName.BIGINT), "vcol_marker"))).collect(Collectors.toList())).build();
        RelBuilder rightBuilder = flinkRelBuilder.transform(u -> u.withConvertCorrelateToJoin(false));
        RelNode rightWithAddedVirtualCols = rightBuilder.push(right).project(Stream.concat(rightBuilder.fields(fields).stream(), Stream.of(rightBuilder.alias(leftBuilder.cast(leftBuilder.literal(-1L), SqlTypeName.BIGINT), "vcol_marker"))).collect(Collectors.toList())).build();
        RelBuilder builder = flinkRelBuilder.transform(u -> u.withConvertCorrelateToJoin(false));
        builder.push(leftWithAddedVirtualCols).push(rightWithAddedVirtualCols).union(true).aggregate(builder.groupKey(builder.fields(fields)), builder.sum(false, "sum_vcol_marker", builder.field("vcol_marker"))).filter(builder.call(FlinkSqlOperatorTable.GREATER_THAN, builder.field("sum_vcol_marker"), builder.literal(0))).project(Stream.concat(Stream.of(builder.field("sum_vcol_marker")), builder.fields(fields).stream()).collect(Collectors.toList()));
        RelNode output = SetOpRewriteUtil.replicateRows(builder, minus.getRowType(), fields);
        call.transformTo(output);
    }

    @Value.Immutable(singleton=false)
    public static interface RewriteMinusAllRuleConfig
    extends RelRule.Config {
        public static final RewriteMinusAllRuleConfig DEFAULT = ImmutableRewriteMinusAllRule.RewriteMinusAllRuleConfig.builder().operandSupplier(b0 -> b0.operand(Minus.class).anyInputs()).relBuilderFactory(RelFactories.LOGICAL_BUILDER).description("RewriteMinusAllRule").build();

        @Override
        default public RewriteMinusAllRule toRule() {
            return new RewriteMinusAllRule(this);
        }
    }
}

