/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.Context;
import org.apache.hadoop.hive.ql.exec.ConditionalTask;
import org.apache.hadoop.hive.ql.exec.DefaultBucketMatcher;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.TaskFactory;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.BucketMapJoinContext;
import org.apache.hadoop.hive.ql.plan.ConditionalResolver;
import org.apache.hadoop.hive.ql.plan.ConditionalResolverSkewJoin;
import org.apache.hadoop.hive.ql.plan.ConditionalWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.MapredLocalWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.SparkBucketMapJoinContext;
import org.apache.hadoop.hive.ql.plan.SparkHashTableSinkDesc;
import org.apache.hadoop.hive.ql.plan.SparkWork;

public class SparkMapJoinResolver
implements PhysicalPlanResolver {
    private final Set<Task<? extends Serializable>> visitedTasks = new HashSet<Task<? extends Serializable>>();

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        SparkMapJoinTaskDispatcher dispatcher = new SparkMapJoinTaskDispatcher(pctx);
        TaskGraphWalker graphWalker = new TaskGraphWalker(dispatcher);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        graphWalker.startWalking(topNodes, null);
        return pctx;
    }

    private boolean containsOp(BaseWork work, Class<?> clazz) {
        Set<Operator<?>> matchingOps = SparkMapJoinResolver.getOp(work, clazz);
        return matchingOps != null && !matchingOps.isEmpty();
    }

    private boolean containsOp(SparkWork sparkWork, Class<?> clazz) {
        for (BaseWork work : sparkWork.getAllWorkUnsorted()) {
            if (!this.containsOp(work, clazz)) continue;
            return true;
        }
        return false;
    }

    public static Set<Operator<?>> getOp(BaseWork work, Class<?> clazz) {
        HashSet ops = new HashSet();
        if (work instanceof MapWork) {
            Collection<Operator<? extends OperatorDesc>> opSet = ((MapWork)work).getAliasToWork().values();
            Stack<Operator<? extends OperatorDesc>> opStack = new Stack<Operator<? extends OperatorDesc>>();
            opStack.addAll(opSet);
            while (!opStack.empty()) {
                Operator operator = (Operator)opStack.pop();
                ops.add(operator);
                if (operator.getChildOperators() == null) continue;
                opStack.addAll(operator.getChildOperators());
            }
        } else {
            ops.addAll(work.getAllOperators());
        }
        HashSet matchingOps = new HashSet();
        for (Operator operator : ops) {
            if (!clazz.isInstance(operator)) continue;
            matchingOps.add(operator);
        }
        return matchingOps;
    }

    class SparkMapJoinTaskDispatcher
    implements Dispatcher {
        private final PhysicalContext physicalContext;
        private final Map<BaseWork, SparkWork> sparkWorkMap;
        private final Map<SparkWork, List<SparkWork>> dependencyGraph;

        public SparkMapJoinTaskDispatcher(PhysicalContext pc) {
            this.physicalContext = pc;
            this.sparkWorkMap = new LinkedHashMap<BaseWork, SparkWork>();
            this.dependencyGraph = new LinkedHashMap<SparkWork, List<SparkWork>>();
        }

        private void moveWork(SparkWork sparkWork, BaseWork work, SparkWork targetWork) {
            List<BaseWork> parentWorks = sparkWork.getParents(work);
            if (sparkWork != targetWork) {
                targetWork.add(work);
                for (BaseWork childWork : sparkWork.getChildren(work)) {
                    if (!targetWork.contains(childWork)) continue;
                    targetWork.connect(work, childWork, sparkWork.getEdgeProperty(work, childWork));
                }
            }
            if (!SparkMapJoinResolver.this.containsOp(work, MapJoinOperator.class)) {
                for (BaseWork parent : parentWorks) {
                    this.moveWork(sparkWork, parent, targetWork);
                }
            } else {
                SparkWork parentWork = new SparkWork(this.physicalContext.conf.getVar(HiveConf.ConfVars.HIVEQUERYID));
                parentWork.setCloneToWork(sparkWork.getCloneToWork());
                this.dependencyGraph.get(targetWork).add(parentWork);
                this.dependencyGraph.put(parentWork, new ArrayList());
                this.sparkWorkMap.put(work, parentWork);
                for (BaseWork parent : parentWorks) {
                    if (SparkMapJoinResolver.this.containsOp(parent, SparkHashTableSinkOperator.class)) {
                        this.moveWork(sparkWork, parent, parentWork);
                        continue;
                    }
                    this.moveWork(sparkWork, parent, targetWork);
                }
            }
        }

        private void generateLocalWork(SparkTask originalTask) {
            SparkWork originalWork = (SparkWork)originalTask.getWork();
            List<BaseWork> allBaseWorks = originalWork.getAllWork();
            Context ctx = this.physicalContext.getContext();
            for (BaseWork work : allBaseWorks) {
                Set<Operator<?>> ops;
                if (work.getMapRedLocalWork() != null) continue;
                if (SparkMapJoinResolver.this.containsOp(work, SparkHashTableSinkOperator.class) || SparkMapJoinResolver.this.containsOp(work, MapJoinOperator.class)) {
                    work.setMapRedLocalWork(new MapredLocalWork());
                }
                if ((ops = SparkMapJoinResolver.getOp(work, MapJoinOperator.class)) == null || ops.isEmpty()) continue;
                Path tmpPath = Utilities.generateTmpPath(ctx.getMRTmpPath(), originalTask.getId());
                MapredLocalWork bigTableLocalWork = work.getMapRedLocalWork();
                ArrayList<Operator<? extends OperatorDesc>> dummyOps = new ArrayList<Operator<? extends OperatorDesc>>(work.getDummyOps());
                bigTableLocalWork.setDummyParentOp(dummyOps);
                bigTableLocalWork.setTmpPath(tmpPath);
                SparkBucketMapJoinContext bucketMJCxt = null;
                for (Operator<?> op : ops) {
                    MapJoinOperator mapJoinOp = (MapJoinOperator)op;
                    MapJoinDesc mapJoinDesc = (MapJoinDesc)mapJoinOp.getConf();
                    if (!mapJoinDesc.isBucketMapJoin()) continue;
                    bucketMJCxt = new SparkBucketMapJoinContext(mapJoinDesc);
                    bucketMJCxt.setBucketMatcherClass(DefaultBucketMatcher.class);
                    bucketMJCxt.setPosToAliasMap(mapJoinOp.getPosToAliasMap());
                    ((MapWork)work).setUseBucketizedHiveInputFormat(true);
                    bigTableLocalWork.setBucketMapjoinContext(bucketMJCxt);
                    bigTableLocalWork.setInputFileChangeSensitive(true);
                    break;
                }
                block2: for (BaseWork parentWork : originalWork.getParents(work)) {
                    Set<Operator<?>> hashTableSinkOps = SparkMapJoinResolver.getOp(parentWork, SparkHashTableSinkOperator.class);
                    if (hashTableSinkOps == null || hashTableSinkOps.isEmpty()) continue;
                    MapredLocalWork parentLocalWork = parentWork.getMapRedLocalWork();
                    parentLocalWork.setTmpHDFSPath(tmpPath);
                    if (bucketMJCxt == null) continue;
                    for (Operator<?> op : hashTableSinkOps) {
                        SparkHashTableSinkOperator hashTableSinkOp = (SparkHashTableSinkOperator)op;
                        SparkHashTableSinkDesc hashTableSinkDesc = (SparkHashTableSinkDesc)hashTableSinkOp.getConf();
                        BucketMapJoinContext original = hashTableSinkDesc.getBucketMapjoinContext();
                        if (original == null || original.getBucketFileNameMapping() != bucketMJCxt.getBucketFileNameMapping()) continue;
                        ((MapWork)parentWork).setUseBucketizedHiveInputFormat(true);
                        parentLocalWork.setBucketMapjoinContext(bucketMJCxt);
                        parentLocalWork.setInputFileChangeSensitive(true);
                        continue block2;
                    }
                }
            }
        }

        private SparkTask createSparkTask(SparkTask originalTask, SparkWork sparkWork, Map<SparkWork, SparkTask> createdTaskMap, ConditionalTask conditionalTask) {
            SparkTask resultTask;
            if (createdTaskMap.containsKey(sparkWork)) {
                return createdTaskMap.get(sparkWork);
            }
            SparkTask sparkTask = resultTask = originalTask.getWork() == sparkWork ? originalTask : (SparkTask)TaskFactory.get(sparkWork, this.physicalContext.conf, new Task[0]);
            if (!this.dependencyGraph.get(sparkWork).isEmpty()) {
                for (SparkWork parentWork : this.dependencyGraph.get(sparkWork)) {
                    SparkTask parentTask = this.createSparkTask(originalTask, parentWork, createdTaskMap, conditionalTask);
                    parentTask.addDependentTask(resultTask);
                }
            } else if (originalTask != resultTask) {
                List<Task<Serializable>> parentTasks = originalTask.getParentTasks();
                if (parentTasks != null && parentTasks.size() > 0) {
                    originalTask.setParentTasks(new ArrayList<Task<? extends Serializable>>());
                    for (Task<Serializable> parentTask : parentTasks) {
                        parentTask.addDependentTask(resultTask);
                        parentTask.removeDependentTask(originalTask);
                    }
                } else if (conditionalTask == null) {
                    this.physicalContext.addToRootTask(resultTask);
                    this.physicalContext.removeFromRootTask(originalTask);
                } else {
                    this.updateConditionalTask(conditionalTask, originalTask, resultTask);
                }
            }
            createdTaskMap.put(sparkWork, resultTask);
            return resultTask;
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nos) throws SemanticException {
            Task currentTask = (Task)nd;
            if (currentTask.isMapRedTask()) {
                if (currentTask instanceof ConditionalTask) {
                    List<Task<? extends Serializable>> taskList = ((ConditionalTask)currentTask).getListTasks();
                    for (Task<? extends Serializable> tsk : taskList) {
                        if (!(tsk instanceof SparkTask)) continue;
                        this.processCurrentTask((SparkTask)tsk, (ConditionalTask)currentTask);
                        SparkMapJoinResolver.this.visitedTasks.add(tsk);
                    }
                } else if (currentTask instanceof SparkTask) {
                    this.processCurrentTask((SparkTask)currentTask, null);
                    SparkMapJoinResolver.this.visitedTasks.add(currentTask);
                }
            }
            return null;
        }

        private void processCurrentTask(SparkTask sparkTask, ConditionalTask conditionalTask) {
            SparkWork sparkWork = (SparkWork)sparkTask.getWork();
            if (!SparkMapJoinResolver.this.visitedTasks.contains(sparkTask)) {
                this.dependencyGraph.clear();
                this.sparkWorkMap.clear();
                this.generateLocalWork(sparkTask);
                this.dependencyGraph.put(sparkWork, new ArrayList());
                Set<BaseWork> leaves = sparkWork.getLeaves();
                for (BaseWork leaf : leaves) {
                    this.moveWork(sparkWork, leaf, sparkWork);
                }
                for (SparkWork newSparkWork : this.sparkWorkMap.values()) {
                    for (BaseWork work : newSparkWork.getAllWorkUnsorted()) {
                        sparkWork.remove(work);
                    }
                }
                LinkedHashMap<SparkWork, SparkTask> createdTaskMap = new LinkedHashMap<SparkWork, SparkTask>();
                for (SparkWork work : this.dependencyGraph.keySet()) {
                    this.createSparkTask(sparkTask, work, createdTaskMap, conditionalTask);
                }
            } else if (conditionalTask != null && sparkTask.getParentTasks() != null && sparkTask.getParentTasks().size() == 1 && sparkTask.getParentTasks().get(0) instanceof SparkTask) {
                SparkTask parent = (SparkTask)sparkTask.getParentTasks().get(0);
                if (SparkMapJoinResolver.this.containsOp(sparkWork, MapJoinOperator.class) && SparkMapJoinResolver.this.containsOp((SparkWork)parent.getWork(), SparkHashTableSinkOperator.class)) {
                    this.updateConditionalTask(conditionalTask, sparkTask, parent);
                }
            }
        }

        private void updateConditionalTask(ConditionalTask conditionalTask, SparkTask originalTask, SparkTask newTask) {
            ConditionalWork conditionalWork = (ConditionalWork)conditionalTask.getWork();
            SparkWork originWork = (SparkWork)originalTask.getWork();
            SparkWork newWork = (SparkWork)newTask.getWork();
            List<Task<? extends Serializable>> listTask = conditionalTask.getListTasks();
            List<? extends Serializable> listWork = conditionalWork.getListWorks();
            int taskIndex = listTask.indexOf(originalTask);
            int workIndex = listWork.indexOf(originWork);
            if (taskIndex < 0 || workIndex < 0) {
                return;
            }
            listTask.set(taskIndex, newTask);
            listWork.set(workIndex, newWork);
            ConditionalResolver resolver = conditionalTask.getResolver();
            if (resolver instanceof ConditionalResolverSkewJoin) {
                ConditionalResolverSkewJoin.ConditionalResolverSkewJoinCtx context = (ConditionalResolverSkewJoin.ConditionalResolverSkewJoinCtx)conditionalTask.getResolverCtx();
                HashMap<Path, Task<? extends Serializable>> bigKeysDirToTaskMap = context.getDirToTaskMap();
                HashMap<Path, Task<? extends Serializable>> newbigKeysDirToTaskMap = new HashMap<Path, Task<? extends Serializable>>();
                for (Map.Entry<Path, Task<? extends Serializable>> entry : bigKeysDirToTaskMap.entrySet()) {
                    Task<? extends Serializable> task = entry.getValue();
                    Path bigKeyDir = entry.getKey();
                    if (task.equals(originalTask)) {
                        newbigKeysDirToTaskMap.put(bigKeyDir, newTask);
                        continue;
                    }
                    newbigKeysDirToTaskMap.put(bigKeyDir, task);
                }
                context.setDirToTaskMap(newbigKeysDirToTaskMap);
                if (context.getNoSkewTask() != null && context.getNoSkewTask().equals(originalTask)) {
                    context.setNoSkewTask(newTask);
                }
            }
        }
    }
}

