/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.query.calcite.KylinSumSplitter;
import org.apache.kylin.query.relnode.OlapAggregateRel;
import org.apache.kylin.query.relnode.OlapJoinRel;
import org.apache.kylin.query.relnode.OlapProjectRel;
import org.apache.kylin.query.relnode.OlapRel;
import org.apache.kylin.query.util.RuleUtils;

public class OlapCountDistinctJoinRule
extends RelOptRule {
    public static final OlapCountDistinctJoinRule COUNT_DISTINCT_JOIN_ONE_SIDE_AGG = new OlapCountDistinctJoinRule(OlapCountDistinctJoinRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)OlapCountDistinctJoinRule.operand(OlapJoinRel.class, (RelOptRuleOperandChildren)OlapCountDistinctJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "OlapCountDistinctJoinRule:agg(contain-count-distinct)-join-oneSideAgg");
    public static final OlapCountDistinctJoinRule COUNT_DISTINCT_AGG_PROJECT_JOIN = new OlapCountDistinctJoinRule(OlapCountDistinctJoinRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)OlapCountDistinctJoinRule.operand(OlapProjectRel.class, (RelOptRuleOperand)OlapCountDistinctJoinRule.operand(OlapJoinRel.class, (RelOptRuleOperandChildren)OlapCountDistinctJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "OlapCountDistinctJoinRule:agg(contain-count-distinct)-agg-project-join");

    public OlapCountDistinctJoinRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall call) {
        OlapAggregateRel aggregate = (OlapAggregateRel)call.rel(0);
        OlapJoinRel join = call.rel(1) instanceof OlapJoinRel ? (OlapJoinRel)call.rel(1) : (OlapJoinRel)call.rel(2);
        return aggregate.isContainCountDistinct() && RuleUtils.isJoinOnlyOneAggChild(join);
    }

    public void onMatch(RelOptRuleCall call) {
        OlapAggregateRel aggregate = (OlapAggregateRel)call.rel(0);
        OlapRel inputRel = (OlapRel)call.rel(1);
        ImmutableList.Builder bottomAggCallsBuilder = ImmutableList.builder();
        ImmutableBitSet.Builder bottomGroupSetBuilder = ImmutableBitSet.builder();
        bottomGroupSetBuilder.addAll(aggregate.getGroupSet());
        for (AggregateCall agg : aggregate.getAggCallList()) {
            if (agg.getAggregation().getKind() == SqlKind.COUNT && agg.isDistinct()) {
                bottomGroupSetBuilder.addAll((Iterable)Lists.newArrayList((Iterable)agg.getArgList()));
                continue;
            }
            bottomAggCallsBuilder.add((Object)agg.copy((List)Lists.newArrayList((Iterable)agg.getArgList()), agg.filterArg));
        }
        ImmutableBitSet bottomGroupSetBuild = bottomGroupSetBuilder.build();
        ImmutableList bottomAggCallsBuild = bottomAggCallsBuilder.build();
        List bottomGroupSets = bottomGroupSetBuild.asList();
        RelTraitSet relTraitSet = aggregate.getTraitSet();
        aggregate.getClass();
        Aggregate bottomAggregate = aggregate.copy(relTraitSet, (RelNode)inputRel, false, bottomGroupSetBuild, null, (List)bottomAggCallsBuild);
        ImmutableBitSet.Builder topGroupSet = ImmutableBitSet.builder();
        ArrayList<Integer> topGroupSetList = new ArrayList<Integer>();
        this.setTopAggregateGroupSet(bottomAggregate, (Aggregate)aggregate, topGroupSetList, topGroupSet);
        int topAggArgsIndex = bottomGroupSets.size();
        ImmutableList.Builder topAggCalls = ImmutableList.builder();
        for (AggregateCall agg : aggregate.getAggCallList()) {
            if (agg.getAggregation().getKind() == SqlKind.COUNT && agg.isDistinct()) {
                ArrayList<Integer> aggArgsList = new ArrayList<Integer>();
                for (Integer arg : agg.getArgList()) {
                    aggArgsList.add(bottomGroupSets.indexOf(arg));
                }
                topAggCalls.add((Object)agg.copy(aggArgsList, agg.filterArg));
                continue;
            }
            if (agg.getAggregation().getKind() == SqlKind.COUNT) {
                topAggCalls.add((Object)AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM0, (boolean)false, (boolean)false, (List)Lists.newArrayList((Object[])new Integer[]{topAggArgsIndex++}), (int)-1, (RelDataType)agg.type, (String)agg.name));
                continue;
            }
            if (agg.getAggregation().getKind() == SqlKind.SUM) {
                topAggCalls.add((Object)AggregateCall.create((SqlAggFunction)KylinSumSplitter.KYLIN_SUM, (boolean)false, (boolean)false, (boolean)false, (List)Lists.newArrayList((Object[])new Integer[]{topAggArgsIndex++}), (int)-1, (ImmutableBitSet)agg.distinctKeys, (RelCollation)agg.collation, (RelDataType)agg.type, (String)agg.name));
                continue;
            }
            topAggCalls.add((Object)agg.copy((List)Lists.newArrayList((Object[])new Integer[]{topAggArgsIndex++}), agg.filterArg));
        }
        RelTraitSet relTraitSet2 = aggregate.getTraitSet();
        aggregate.getClass();
        Aggregate topAggregate = aggregate.copy(relTraitSet2, (RelNode)bottomAggregate, false, topGroupSet.build(), null, (List)topAggCalls.build());
        call.transformTo((RelNode)topAggregate);
    }

    private void setTopAggregateGroupSet(Aggregate bottomAggregate, Aggregate aggregate, List<Integer> topGroupSetList, ImmutableBitSet.Builder topGroupSet) {
        List bottomAggregateGroupIndexList = bottomAggregate.getGroupSet().asList();
        List aggregateGroupIndexList = aggregate.getGroupSet().asList();
        for (int i = 0; i < bottomAggregateGroupIndexList.size(); ++i) {
            if (!aggregateGroupIndexList.contains(bottomAggregateGroupIndexList.get(i))) continue;
            topGroupSetList.add(i);
        }
        topGroupSet.addAll(topGroupSetList);
    }
}

