/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.gpu;

import jcuda.runtime.JCuda;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.GPUInstructionParser;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.operators.Operator;

public abstract class GPUInstruction
extends Instruction
implements LineageTraceable {
    private static final Log LOG = LogFactory.getLog((String)GPUInstruction.class.getName());
    public final CPOperand _output;
    public final CPOperand _input1;
    public final CPOperand _input2;
    public static final String MISC_TIMER_HOST_TO_DEVICE = "H2D";
    public static final String MISC_TIMER_DEVICE_TO_HOST = "D2H";
    public static final String MISC_TIMER_DEVICE_TO_DEVICE = "D2D";
    public static final String MISC_TIMER_SPARSE_TO_DENSE = "s2d";
    public static final String MISC_TIMER_DENSE_TO_SPARSE = "d2s";
    public static final String MISC_TIMER_ROW_TO_COLUMN_MAJOR = "r2c";
    public static final String MISC_TIMER_COLUMN_TO_ROW_MAJOR = "c2r";
    public static final String MISC_TIMER_OBJECT_CLONE = "clone";
    public static final String MISC_TIMER_CUDA_SYNC = "sync";
    public static final String MISC_TIMER_CUDA_FREE = "f";
    public static final String MISC_TIMER_ALLOCATE = "a";
    public static final String MISC_TIMER_EVICT = "evict";
    public static final String MISC_TIMER_ALLOCATE_DENSE_OUTPUT = "ad";
    public static final String MISC_TIMER_ALLOCATE_SPARSE_OUTPUT = "as";
    public static final String MISC_TIMER_SET_ZERO = "az";
    public static final String MISC_TIMER_REUSE = "r";
    public static final String MISC_TIMER_SPARSE_ALLOCATE_LIB = "Msao";
    public static final String MISC_TIMER_DENSE_DOT_LIB = "Mddot";
    public static final String MISC_TIMER_DENSE_VECTOR_DENSE_MATRIX_LIB = "Mdvdm";
    public static final String MISC_TIMER_DENSE_MATRIX_DENSE_VECTOR_LIB = "Mdmdv";
    public static final String MISC_TIMER_DENSE_MATRIX_DENSE_MATRIX_LIB = "Mdmdm";
    public static final String MISC_TIMER_SPARSE_MATRIX_DENSE_VECTOR_LIB = "Msmdv";
    public static final String MISC_TIMER_SPARSE_MATRIX_SPARSE_MATRIX_LIB = "Msmsm";
    public static final String MISC_TIMER_SPARSE_MATRIX_DENSE_MATRIX_LIB = "Msmdm";
    public static final String MISC_TIMER_SYRK_LIB = "Msyrk";
    public static final String MISC_TIMER_DAXPY_LIB = "daxpy";
    public static final String MISC_TIMER_QR_BUFFER = "qr_buffer";
    public static final String MISC_TIMER_QR = "qr";
    public static final String MISC_TIMER_ORMQR = "ormqr";
    public static final String MISC_TIMER_TRSM = "trsm";
    public static final String MISC_TIMER_SPARSE_DGEAM_LIB = "sdgeaml";
    public static final String MISC_TIMER_DENSE_DGEAM_LIB = "ddgeaml";
    public static final String MISC_TIMER_TRANSPOSE_LIB = "dtl";
    public static final String MISC_TIMER_MATRIX_MATRIX_CELLWISE_OP_KERNEL = "mmck";
    public static final String MISC_TIMER_COMPARE_AND_SET_KERNEL = "cask";
    public static final String MISC_TIMER_EXP_KERNEL = "expk";
    public static final String MISC_TIMER_SQRT_KERNEL = "sqrtk";
    public static final String MISC_TIMER_ROUND_KERNEL = "roundk";
    public static final String MISC_TIMER_ABS_KERNEL = "absk";
    public static final String MISC_TIMER_LOG_KERNEL = "logk";
    public static final String MISC_TIMER_FLOOR_KERNEL = "floork";
    public static final String MISC_TIMER_CEIL_KERNEL = "ceilk";
    public static final String MISC_TIMER_SIN_KERNEL = "sink";
    public static final String MISC_TIMER_COS_KERNEL = "cosk";
    public static final String MISC_TIMER_TAN_KERNEL = "tank";
    public static final String MISC_TIMER_SINH_KERNEL = "sinhk";
    public static final String MISC_TIMER_COSH_KERNEL = "coshk";
    public static final String MISC_TIMER_TANH_KERNEL = "tanhk";
    public static final String MISC_TIMER_ASIN_KERNEL = "asink";
    public static final String MISC_TIMER_ACOS_KERNEL = "acosk";
    public static final String MISC_TIMER_ATAN_KERNEL = "atank";
    public static final String MISC_TIMER_SIGN_KERNEL = "signk";
    public static final String MISC_TIMER_SIGMOID_KERNEL = "sigmk";
    public static final String MISC_TIMER_CBIND_KERNEL = "cbindk";
    public static final String MISC_TIMER_RBIND_KERNEL = "rbindk";
    public static final String MISC_TIMER_DAXPY_MV_KERNEL = "daxpymv";
    public static final String MISC_TIMER_UPPER_TO_LOWER_TRIANGLE_KERNEL = "u2lk";
    public static final String MISC_TIMER_FILL_KERNEL = "fillk";
    public static final String MISC_TIMER_MATRIX_SCALAR_OP_KERNEL = "msk";
    public static final String MISC_TIMER_REDUCE_ALL_KERNEL = "rallk";
    public static final String MISC_TIMER_REDUCE_ROW_KERNEL = "rrowk";
    public static final String MISC_TIMER_REDUCE_COL_KERNEL = "rcolk";
    public static final String MISC_TIMER_RIX_DENSE_OP = "drix";
    public static final String MISC_TIMER_RIX_SPARSE_DENSE_OP_ROWWISE = "sdrixr";
    public static final String MISC_TIMER_RIX_SPARSE_DENSE_OP_NNZ = "sdrixn";
    public static final String MISC_TIMER_ACTIVATION_FORWARD_LIB = "nnaf";
    public static final String MISC_TIMER_CONVOLUTION_FORWARD_LIB = "nncf";
    public static final String MISC_TIMER_CONVOLUTION_BACKWARD_FILTER_LIB = "nncbf";
    public static final String MISC_TIMER_CONVOLUTION_BACKWARD_DATA_LIB = "nncbd";
    public static final String MISC_TIMER_MAXPOOLING_FORWARD_LIB = "nnmf";
    public static final String MISC_TIMER_MAXPOOLING_BACKWARD_LIB = "nnmb";
    public static final String MISC_TIMER_BIAS_ADD_LIB = "nnba";
    public static final String MISC_TIMER_RELU_BACKWARD_KERNEL = "nnrbk";
    public static final String MISC_TIMER_RELU_KERNEL = "nnrk";
    public static final String MISC_TIMER_CUDNN_INIT = "nni";
    public static final String MISC_TIMER_CUDNN_CLEANUP = "nnc";
    public static final String MISC_TIMER_DENSE_IM2COL_KERNEL = "nndim2c";
    public static final String MISC_TIMER_SPARSE_IM2COL_KERNEL = "nnsim2c";
    public static final String MISC_TIMER_DENSE_REORG_KNPQ_KERNEL = "nndrknpq";
    public static final String MISC_TIMER_CUMULATIVE_SCAN_KERNEL = "cumk";
    public static final String MISC_TIMER_CUMULATIVE_SUMPROD_KERNEL = "cumSumProdk";
    protected GPUINSTRUCTION_TYPE _gputype;
    protected boolean _requiresLabelUpdate = false;

    protected GPUInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(op);
        this._input1 = in1;
        this._input2 = in2;
        this._output = out;
        this.instString = istr;
        this.instOpcode = opcode;
        this._requiresLabelUpdate = super.requiresLabelUpdate();
    }

    protected GPUInstruction(Operator op, String opcode, String istr) {
        super(op);
        this._input1 = null;
        this._input2 = null;
        this._output = null;
        this.instString = istr;
        this.instOpcode = opcode;
        this._requiresLabelUpdate = super.requiresLabelUpdate();
    }

    @Override
    public Instruction.IType getType() {
        return Instruction.IType.GPU;
    }

    public GPUINSTRUCTION_TYPE getGPUInstructionType() {
        return this._gputype;
    }

    @Override
    public boolean requiresLabelUpdate() {
        return this._requiresLabelUpdate;
    }

    @Override
    public String getGraphString() {
        return this.getOpcode();
    }

    @Override
    public Instruction preprocessInstruction(ExecutionContext ec) {
        Instruction tmp = super.preprocessInstruction(ec);
        if (tmp.requiresLabelUpdate()) {
            String updInst = CPInstruction.updateLabels(tmp.toString(), ec.getVariables());
            tmp = GPUInstructionParser.parseSingleInstruction(updInst);
        }
        return tmp;
    }

    @Override
    public abstract void processInstruction(ExecutionContext var1);

    @Override
    public void postprocessInstruction(ExecutionContext ec) {
        if (DMLScript.SYNCHRONIZE_GPU) {
            JCuda.cudaDeviceSynchronize();
        }
        if (LOG.isDebugEnabled()) {
            for (GPUContext gpuCtx : ec.getGPUContexts()) {
                if (gpuCtx == null) continue;
                gpuCtx.printMemoryInfo(this.getOpcode());
            }
        }
        super.postprocessInstruction(ec);
    }

    protected MatrixObject getMatrixInputForGPUInstruction(ExecutionContext ec, String name) {
        return ec.getMatrixInputForGPUInstruction(name, this.getExtendedOpcode());
    }

    protected MatrixObject getDenseMatrixOutputForGPUInstruction(ExecutionContext ec, String name, long numRows, long numCols) {
        return this.getDenseMatrixOutputForGPUInstruction(ec, name, numRows, numCols, true);
    }

    protected MatrixObject getDenseMatrixOutputForGPUInstruction(ExecutionContext ec, String name, long numRows, long numCols, boolean initialize) {
        return ec.getDenseMatrixOutputForGPUInstruction(name, numRows, numCols, initialize).getKey();
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        return Pair.of((Object)this._output.getName(), (Object)new LineageItem(this.getOpcode(), LineageItemUtils.getLineage(ec, this._input1, this._input2)));
    }

    public static enum GPUINSTRUCTION_TYPE {
        AggregateUnary,
        AggregateBinary,
        RelationalBinary,
        Dnn,
        MMTSJ,
        Reorg,
        MatrixReshape,
        Append,
        ArithmeticBinary,
        BuiltinUnary,
        BuiltinBinary,
        Builtin,
        MatrixIndexing,
        SpoofFused;

    }
}

