/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.utils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.flink.calcite.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor;

public class MultiJoinUtil {
    public static Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> createJoinAttributeMap(List<RelNode> joinInputs, List<? extends RexNode> joinConditions) {
        HashMap<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> joinAttributeMap = new HashMap<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>>();
        List<Integer> inputFieldCounts = joinInputs.stream().map(input -> input.getRowType().getFieldCount()).collect(Collectors.toList());
        ArrayList<Integer> inputOffsets = new ArrayList<Integer>();
        int currentOffset = 0;
        for (Integer n : inputFieldCounts) {
            inputOffsets.add(currentOffset);
            currentOffset += n.intValue();
        }
        for (RexNode rexNode : joinConditions) {
            MultiJoinUtil.extractEqualityConditions(rexNode, inputOffsets, inputFieldCounts, joinAttributeMap);
        }
        return joinAttributeMap;
    }

    private static void extractEqualityConditions(RexNode condition, List<Integer> inputOffsets, List<Integer> inputFieldCounts, Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> joinAttributeMap) {
        InputRef rightRef;
        InputRef leftRef;
        if (!(condition instanceof RexCall)) {
            return;
        }
        RexCall call = (RexCall)condition;
        SqlKind kind = call.getOperator().getKind();
        if (kind != SqlKind.EQUALS) {
            for (RexNode operand : call.getOperands()) {
                MultiJoinUtil.extractEqualityConditions(operand, inputOffsets, inputFieldCounts, joinAttributeMap);
            }
            return;
        }
        if (call.getOperands().size() != 2) {
            return;
        }
        RexNode op1 = call.getOperands().get(0);
        RexNode op2 = call.getOperands().get(1);
        if (!(op1 instanceof RexInputRef) || !(op2 instanceof RexInputRef)) {
            return;
        }
        InputRef inputRef1 = MultiJoinUtil.findInputRef(((RexInputRef)op1).getIndex(), inputOffsets, inputFieldCounts);
        InputRef inputRef2 = MultiJoinUtil.findInputRef(((RexInputRef)op2).getIndex(), inputOffsets, inputFieldCounts);
        if (inputRef1 == null || inputRef2 == null) {
            return;
        }
        if (inputRef1.inputIndex < inputRef2.inputIndex) {
            leftRef = inputRef1;
            rightRef = inputRef2;
        } else {
            leftRef = inputRef2;
            rightRef = inputRef1;
        }
        if (leftRef.inputIndex == 0) {
            AttributeBasedJoinKeyExtractor.ConditionAttributeRef firstAttrRef = new AttributeBasedJoinKeyExtractor.ConditionAttributeRef(-1, -1, leftRef.inputIndex, leftRef.attributeIndex);
            joinAttributeMap.computeIfAbsent(leftRef.inputIndex, k -> new ArrayList()).add(firstAttrRef);
        }
        AttributeBasedJoinKeyExtractor.ConditionAttributeRef attrRef = new AttributeBasedJoinKeyExtractor.ConditionAttributeRef(leftRef.inputIndex, leftRef.attributeIndex, rightRef.inputIndex, rightRef.attributeIndex);
        joinAttributeMap.computeIfAbsent(rightRef.inputIndex, k -> new ArrayList()).add(attrRef);
    }

    private static @Nullable InputRef findInputRef(int fieldIndex, List<Integer> inputOffsets, List<Integer> inputFieldCounts) {
        for (int i = 0; i < inputOffsets.size(); ++i) {
            int offset = inputOffsets.get(i);
            if (fieldIndex < offset || fieldIndex >= offset + inputFieldCounts.get(i)) continue;
            return new InputRef(i, fieldIndex - offset);
        }
        return null;
    }

    private static final class InputRef {
        private final int inputIndex;
        private final int attributeIndex;

        private InputRef(int inputIndex, int attributeIndex) {
            this.inputIndex = inputIndex;
            this.attributeIndex = attributeIndex;
        }
    }
}

