/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.query.calcite.KylinSumSplitter;
import org.apache.kylin.query.relnode.OlapAggregateRel;
import org.apache.kylin.query.relnode.OlapJoinRel;
import org.apache.kylin.query.util.RuleUtils;
import org.jetbrains.annotations.NotNull;

public class OlapAggJoinTransposeRule
extends RelOptRule {
    private static final String STAR_TOKEN = "*";
    public static final OlapAggJoinTransposeRule INSTANCE_JOIN_RIGHT_AGG = new OlapAggJoinTransposeRule(OlapAggJoinTransposeRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)OlapAggJoinTransposeRule.operand(OlapJoinRel.class, (RelOptRuleOperandChildren)OlapAggJoinTransposeRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "OlapAggJoinTransposeRule:agg-join-rightAgg");

    public OlapAggJoinTransposeRule(RelOptRuleOperand operand) {
        super(operand);
    }

    public OlapAggJoinTransposeRule(RelOptRuleOperand operand, String description) {
        super(operand, description);
    }

    public OlapAggJoinTransposeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        OlapAggregateRel aggregate = (OlapAggregateRel)call.rel(0);
        OlapJoinRel joinRel = (OlapJoinRel)call.rel(1);
        return !aggregate.isContainCountDistinct() && RuleUtils.isJoinOnlyOneAggChild(joinRel);
    }

    public void onMatch(RelOptRuleCall call) {
        OlapAggregateRel aggregate = (OlapAggregateRel)call.rel(0);
        OlapJoinRel join = (OlapJoinRel)call.rel(1);
        RelBuilder relBuilder = call.builder();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) != null && aggregateCall.filterArg < 0) continue;
            return;
        }
        if (join.getJoinType() != JoinRelType.INNER && join.getJoinType() != JoinRelType.LEFT) {
            return;
        }
        ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        RelMetadataQuery mq = call.getMetadataQuery();
        ImmutableBitSet keyColumns = OlapAggJoinTransposeRule.keyColumns(aggregateColumns, (ImmutableList<RexNode>)mq.getPulledUpPredicates((RelNode)join).pulledUpPredicates);
        ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits((RexNode)join.getCondition());
        boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
        ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
        ArrayList leftKeys = Lists.newArrayList();
        ArrayList rightKeys = Lists.newArrayList();
        ArrayList filterNulls = Lists.newArrayList();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition((RelNode)join.getLeft(), (RelNode)join.getRight(), (RexNode)join.getCondition(), (List)leftKeys, (List)rightKeys, (List)filterNulls);
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        boolean aggPushDown = this.aggPushDown(aggregate, join, belowAggregateColumns, mq, relBuilder, allColumnsInAggregate);
        if (aggPushDown) {
            call.transformTo(relBuilder.build());
        }
    }

    private boolean aggPushDown(OlapAggregateRel aggregate, OlapJoinRel join, ImmutableBitSet belowAggregateColumns, RelMetadataQuery mq, RelBuilder relBuilder, boolean allColumnsInAggregate) {
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        ArrayList<Side> sides = new ArrayList<Side>();
        int uniqueCount = 0;
        int offset = 0;
        int belowOffset = 0;
        for (int s = 0; s < 2; ++s) {
            boolean unique;
            Side side = new Side();
            RelNode joinInput = join.getInput(s);
            int fieldCount = joinInput.getRowType().getFieldCount();
            ImmutableBitSet fieldSet = ImmutableBitSet.range((int)offset, (int)(offset + fieldCount));
            ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
            for (Ord c : Ord.zip((Iterable)belowAggregateKeyNotShifted)) {
                map.put((Integer)c.e, belowOffset + c.i);
            }
            Mappings.IdentityMapping mapping = s == 0 ? Mappings.createIdentity((int)fieldCount) : Mappings.createShiftMapping((int)(fieldCount + offset), (int[])new int[]{0, offset, fieldCount});
            ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
            Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            boolean bl = unique = unique0 != null && unique0 != false;
            if (unique) {
                ++uniqueCount;
                this.processUnique(side, relBuilder, joinInput, aggregate, fieldSet, (Mappings.TargetMapping)mapping, belowAggregateKey);
            } else {
                this.processUnUnique(side, aggregate, relBuilder, joinInput, fieldSet, (Mappings.TargetMapping)mapping, belowAggregateKey);
            }
            offset += fieldCount;
            belowOffset += side.newInput.getRowType().getFieldCount();
            sides.add(side);
        }
        if (uniqueCount == 2 || this.isJoinInputNotChanged(join, sides)) {
            return false;
        }
        OlapAggJoinTransposeRule.updateCondition(sides, map, aggregate, join, belowOffset, relBuilder, allColumnsInAggregate);
        return true;
    }

    private boolean isJoinInputNotChanged(OlapJoinRel join, List<Side> sides) {
        List newInputs = sides.stream().map(side -> side.newInput).collect(Collectors.toList());
        for (int i = 0; i < newInputs.size(); ++i) {
            if (Objects.equals(newInputs.get(i), join.getInput(i))) continue;
            return false;
        }
        return true;
    }

    private void processUnique(Side side, RelBuilder relBuilder, RelNode joinInput, OlapAggregateRel aggregate, ImmutableBitSet fieldSet, Mappings.TargetMapping mapping, ImmutableBitSet belowAggregateKey) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        side.aggregate = false;
        relBuilder.push(joinInput);
        ArrayList<Object> projects = new ArrayList<Object>();
        for (Integer i : belowAggregateKey) {
            projects.add(relBuilder.field(i.intValue()));
        }
        for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
            if (((AggregateCall)aggCall.e).getArgList().isEmpty() || !fieldSet.contains(ImmutableBitSet.of((Iterable)((AggregateCall)aggCall.e).getArgList()))) continue;
            RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), ((AggregateCall)aggCall.e).transform(mapping));
            if (singleton instanceof RexInputRef) {
                int index = ((RexInputRef)singleton).getIndex();
                if (!belowAggregateKey.get(index)) {
                    projects.add(singleton);
                    side.split.put(aggCall.i, projects.size() - 1);
                    continue;
                }
                side.split.put(aggCall.i, index);
                continue;
            }
            projects.add(singleton);
            side.split.put(aggCall.i, projects.size() - 1);
        }
        relBuilder.project(projects);
        side.newInput = relBuilder.build();
    }

    private void processUnUnique(Side side, OlapAggregateRel aggregate, RelBuilder relBuilder, RelNode joinInput, ImmutableBitSet fieldSet, Mappings.TargetMapping mapping, ImmutableBitSet belowAggregateKey) {
        side.aggregate = true;
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList belowAggCalls = new ArrayList();
        SqlSplittableAggFunction.Registry belowAggCallRegistry = OlapAggJoinTransposeRule.registry(belowAggCalls);
        int oldGroupKeyCount = aggregate.getGroupCount();
        int newGroupKeyCount = belowAggregateKey.cardinality();
        for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
            AggregateCall call1;
            SqlAggFunction aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
            if (fieldSet.contains(ImmutableBitSet.of((Iterable)((AggregateCall)aggCall.e).getArgList()))) {
                AggregateCall splitCall = splitter.split((AggregateCall)aggCall.e, mapping);
                call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
            } else {
                call1 = splitter.other(rexBuilder.getTypeFactory(), (AggregateCall)aggCall.e);
            }
            if (call1 == null) continue;
            side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register((Object)call1));
        }
        side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey), belowAggCalls).build();
    }

    private static void updateCondition(List<Side> sides, Map<Integer, Integer> map, OlapAggregateRel aggregate, OlapJoinRel join, int belowOffset, RelBuilder relBuilder, boolean allColumnsInAggregate) {
        Object aggregation;
        Mapping mapping = (Mapping)Mappings.target(map::get, (int)join.getRowType().getFieldCount(), (int)belowOffset);
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RexNode newCondition = RexUtil.apply((Mappings.TargetMapping)mapping, (RexNode)join.getCondition());
        relBuilder.push(sides.get((int)0).newInput).push(sides.get((int)1).newInput).join(join.getJoinType(), newCondition);
        ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
        int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        int newLeftWidth = sides.get((int)0).newInput.getRowType().getFieldCount();
        List<RexNode> projects = new ArrayList<RexNode>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
            aggregation = ((AggregateCall)aggCall.e).getAggregation();
            SqlSplittableAggFunction splitter = OlapAggJoinTransposeRule.getSqlSplittableAggFunction((SqlAggFunction)aggregation);
            Integer leftSubTotal = sides.get((int)0).split.get(aggCall.i);
            Integer rightSubTotal = sides.get((int)1).split.get(aggCall.i);
            newAggCalls.add(splitter.topSplit(rexBuilder, OlapAggJoinTransposeRule.registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), (AggregateCall)aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
        }
        if (join.getJoinType() == JoinRelType.LEFT) {
            projects = OlapAggJoinTransposeRule.createNewProjects(rexBuilder, projects);
        }
        relBuilder.project(projects);
        boolean aggConvertedToProjects = false;
        if (allColumnsInAggregate) {
            ArrayList<Object> projects2 = new ArrayList<Object>();
            aggregation = Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()).iterator();
            while (aggregation.hasNext()) {
                int key = (Integer)aggregation.next();
                projects2.add(relBuilder.field(key));
            }
            for (AggregateCall newAggCall : newAggCalls) {
                SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (splitter == null) continue;
                RelDataType rowType = relBuilder.peek().getRowType();
                projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                relBuilder.project(projects2);
                aggConvertedToProjects = true;
            }
        }
        if (!aggConvertedToProjects) {
            ImmutableBitSet groupSet = Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet());
            ImmutableList groupSets = Mappings.apply2((Mapping)mapping, (Iterable)aggregate.getGroupSets());
            if (groupSets == null) {
                relBuilder.aggregate(relBuilder.groupKey(groupSet), newAggCalls);
            } else {
                relBuilder.aggregate(relBuilder.groupKey(groupSet, (Iterable)groupSets), newAggCalls);
            }
        }
    }

    @NotNull
    private static SqlSplittableAggFunction getSqlSplittableAggFunction(SqlAggFunction aggregation) {
        SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Objects.requireNonNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        if (splitter.equals(SqlSplittableAggFunction.SumSplitter.INSTANCE)) {
            splitter = KylinSumSplitter.INSTANCE;
        }
        return splitter;
    }

    private static List<RexNode> createNewProjects(RexBuilder rexBuilder, List<RexNode> projects) {
        ArrayList<RexNode> converted = new ArrayList<RexNode>();
        HashMap<Integer, RexInputRef> rexInpufRefMap = new HashMap<Integer, RexInputRef>();
        for (RexNode rexNode : projects) {
            if (rexNode instanceof RexInputRef) {
                RexInputRef inputRef = (RexInputRef)rexNode;
                converted.add((RexNode)inputRef);
                rexInpufRefMap.put(inputRef.getIndex(), inputRef);
                continue;
            }
            OriginInputRefReplacer visitor = new OriginInputRefReplacer(rexInpufRefMap);
            RexNode newRexNode = (RexNode)OlapAggJoinTransposeRule.rewriteRexNode(rexNode, rexBuilder).accept((RexVisitor)visitor);
            converted.add(newRexNode);
        }
        return converted;
    }

    private static RexNode rewriteRexNode(RexNode rexNode, RexBuilder rexBuilder) {
        if (rexNode instanceof RexCall) {
            RexCall rexCall = (RexCall)rexNode;
            SqlOperator sqlOperator = rexCall.getOperator();
            if (OlapAggJoinTransposeRule.isMultiplicationRexCall(rexCall)) {
                List<RexNode> rewriteRexNodeList = OlapAggJoinTransposeRule.rewriteRexNodeList(rexCall, rexBuilder);
                return rexBuilder.makeCall(rexCall.type, (SqlOperator)SqlStdOperatorTable.MULTIPLY, rewriteRexNodeList);
            }
            if (sqlOperator.getKind() == SqlKind.CAST && rexCall.getOperands().size() == 1 && rexCall.getOperands().get(0) instanceof RexCall && OlapAggJoinTransposeRule.isMultiplicationRexCall((RexCall)rexCall.getOperands().get(0))) {
                RexCall innerRexCall = (RexCall)rexCall.getOperands().get(0);
                List<RexNode> rewriteRexNodeList = OlapAggJoinTransposeRule.rewriteRexNodeList(innerRexCall, rexBuilder);
                RexNode rewriteInnerRexCall = rexBuilder.makeCall(innerRexCall.type, (SqlOperator)SqlStdOperatorTable.MULTIPLY, rewriteRexNodeList);
                return rexBuilder.makeCast(rexCall.type, rewriteInnerRexCall);
            }
        }
        return rexNode;
    }

    private static List<RexNode> rewriteRexNodeList(RexCall rexCall, RexBuilder rexBuilder) {
        ArrayList<RexNode> rewriteRexNodeList = new ArrayList<RexNode>();
        RelDataType dataType = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.INTEGER);
        for (RexNode rexNode : rexCall.getOperands()) {
            rewriteRexNodeList.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.COALESCE, new RexNode[]{rexNode, rexBuilder.makeLiteral((Object)1, dataType, false)}));
        }
        return rewriteRexNodeList;
    }

    private static boolean isMultiplicationRexCall(RexCall rexCall) {
        return rexCall.getOperator().getName().equals(STAR_TOKEN) && rexCall.getOperands().size() == 2;
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
        TreeMap<Integer, BitSet> equivalence = new TreeMap<Integer, BitSet>();
        for (RexNode predicate : predicates) {
            OlapAggJoinTransposeRule.populateEquivalences(equivalence, predicate);
        }
        ImmutableBitSet keyColumns = aggregateColumns;
        for (Integer aggregateColumn : aggregateColumns) {
            BitSet bitSet = (BitSet)equivalence.get(aggregateColumn);
            if (bitSet == null) continue;
            keyColumns = keyColumns.union(bitSet);
        }
        return keyColumns;
    }

    private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
        if (predicate.getKind() != SqlKind.EQUALS) {
            return;
        }
        RexCall call = (RexCall)predicate;
        List operands = call.getOperands();
        if (operands.get(0) instanceof RexInputRef) {
            RexInputRef ref0 = (RexInputRef)operands.get(0);
            if (operands.get(1) instanceof RexInputRef) {
                RexInputRef ref1 = (RexInputRef)operands.get(1);
                OlapAggJoinTransposeRule.populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
                OlapAggJoinTransposeRule.populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
        BitSet bitSet = equivalence.computeIfAbsent(i0, bitset -> new BitSet());
        bitSet.set(i1);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(List<E> list) {
        return e -> {
            int i = list.indexOf(e);
            if (i < 0) {
                i = list.size();
                list.add(e);
            }
            return i;
        };
    }

    private static class Side {
        final Map<Integer, Integer> split = new HashMap<Integer, Integer>();
        RelNode newInput;
        boolean aggregate;

        private Side() {
        }
    }

    private static class OriginInputRefReplacer
    extends RexVisitorImpl<RexNode> {
        final Map<Integer, RexInputRef> rexInpufRefMap;

        protected OriginInputRefReplacer(Map<Integer, RexInputRef> rexInpufRefMap) {
            super(true);
            this.rexInpufRefMap = rexInpufRefMap;
        }

        public RexNode visitInputRef(RexInputRef inputRef) {
            RexNode rexNode = (RexNode)this.rexInpufRefMap.get(inputRef.getIndex());
            return rexNode == null ? inputRef : rexNode;
        }

        public RexNode visitCall(RexCall call) {
            List rexNodes = call.getOperands();
            List converted = rexNodes.stream().map(rex -> (RexNode)rex.accept((RexVisitor)this)).collect(Collectors.toList());
            return call.clone(call.getType(), converted);
        }

        public RexNode visitLiteral(RexLiteral literal) {
            return literal;
        }
    }
}

