/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import scala.Tuple2;

public class QuaternarySPInstruction
extends ComputationSPInstruction {
    private CPOperand _input4 = null;
    private boolean _cacheU = false;
    private boolean _cacheV = false;

    private QuaternarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, boolean cacheU, boolean cacheV, String opcode, String str) {
        super(SPInstruction.SPType.Quaternary, op, in1, in2, in3, out, opcode, str);
        this._input4 = in4;
        this._cacheU = cacheU;
        this._cacheV = cacheV;
    }

    public static QuaternarySPInstruction parseInstruction(String str) {
        boolean cacheV;
        int addInput4;
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
            throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
        }
        if ("mapwsloss".equalsIgnoreCase(opcode) || "redwsloss".equalsIgnoreCase(opcode)) {
            boolean isRed = "redwsloss".equalsIgnoreCase(opcode);
            if (isRed) {
                InstructionUtils.checkNumFields(parts, 8);
            } else {
                InstructionUtils.checkNumFields(parts, 6);
            }
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            WeightedSquaredLoss.WeightsType wtype = WeightedSquaredLoss.WeightsType.valueOf(parts[6]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV2 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV2, opcode, str);
        }
        if ("mapwumm".equalsIgnoreCase(opcode) || "redwumm".equalsIgnoreCase(opcode)) {
            boolean isRed = "redwumm".equalsIgnoreCase(opcode);
            if (isRed) {
                InstructionUtils.checkNumFields(parts, 8);
            } else {
                InstructionUtils.checkNumFields(parts, 6);
            }
            String uopcode = parts[1];
            CPOperand in1 = new CPOperand(parts[2]);
            CPOperand in2 = new CPOperand(parts[3]);
            CPOperand in3 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            WeightedUnaryMM.WUMMType wtype = WeightedUnaryMM.WUMMType.valueOf(parts[6]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV3 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV3, opcode, str);
        }
        if ("mapwdivmm".equalsIgnoreCase(opcode) || "redwdivmm".equalsIgnoreCase(opcode)) {
            boolean isRed = opcode.startsWith("red");
            if (isRed) {
                InstructionUtils.checkNumFields(parts, 8);
            } else {
                InstructionUtils.checkNumFields(parts, 6);
            }
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV4 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            WeightedDivMM.WDivMMType wt = WeightedDivMM.WDivMMType.valueOf(parts[6]);
            QuaternaryOperator qop = wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt);
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV4, opcode, str);
        }
        boolean isRed = opcode.startsWith("red");
        int n = addInput4 = opcode.endsWith("wcemm") ? 1 : 0;
        if (isRed) {
            InstructionUtils.checkNumFields(parts, 7 + addInput4);
        } else {
            InstructionUtils.checkNumFields(parts, 5 + addInput4);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4 + addInput4]);
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean bl = cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid")) {
            return new QuaternarySPInstruction(new QuaternaryOperator(WeightedSigmoid.WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
        }
        if (opcode.endsWith("wcemm")) {
            CPOperand in4 = new CPOperand(parts[4]);
            WeightedCrossEntropy.WCeMMType wt = WeightedCrossEntropy.WCeMMType.valueOf(parts[6]);
            QuaternaryOperator qop = wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt);
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
        }
        return null;
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        PartitionedBroadcast<MatrixBlock> bc2;
        PartitionedBroadcast<MatrixBlock> bc1;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        QuaternaryOperator qop = (QuaternaryOperator)this._optr;
        ArrayList<String> rddVars = new ArrayList<String>();
        ArrayList<String> bcVars = new ArrayList<String>();
        JavaPairRDD in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD out = null;
        DataCharacteristics inMc = sec.getDataCharacteristics(this.input1.getName());
        long rlen = inMc.getRows();
        long clen = inMc.getCols();
        int blen = inMc.getBlocksize();
        if (qop.wtype1 != null || qop.wtype4 != null) {
            in = in.filter((Function)new FilterNonEmptyBlocksFunction());
        }
        if ("mapwsloss".equalsIgnoreCase(this.getOpcode()) || "mapwsigmoid".equalsIgnoreCase(this.getOpcode()) || "mapwdivmm".equalsIgnoreCase(this.getOpcode()) || "mapwcemm".equalsIgnoreCase(this.getOpcode()) || "mapwumm".equalsIgnoreCase(this.getOpcode())) {
            bc1 = sec.getBroadcastForVariable(this.input2.getName());
            bc2 = sec.getBroadcastForVariable(this.input3.getName());
            boolean noKeyChange = qop.wtype3 == null || qop.wtype3.isBasic();
            out = in.mapPartitionsToPair((PairFlatMapFunction)new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
            rddVars.add(this.input1.getName());
            bcVars.add(this.input2.getName());
            bcVars.add(this.input3.getName());
        } else {
            JavaPairRDD<MatrixIndexes, MatrixBlock> inW;
            bc1 = this._cacheU ? sec.getBroadcastForVariable(this.input2.getName()) : null;
            bc2 = this._cacheV ? sec.getBroadcastForVariable(this.input3.getName()) : null;
            JavaPairRDD inU = !this._cacheU ? sec.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName()) : null;
            JavaPairRDD inV = !this._cacheV ? sec.getBinaryMatrixBlockRDDHandleForVariable(this.input3.getName()) : null;
            JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD = inW = qop.hasFourInputs() && !this._input4.isLiteral() ? sec.getBinaryMatrixBlockRDDHandleForVariable(this._input4.getName()) : null;
            if (inU != null) {
                inU = inU.flatMapToPair((PairFlatMapFunction)new ReplicateBlockFunction(clen, blen, true));
            }
            if (inV != null) {
                inV = inV.mapToPair((PairFunction)new TransposeFactorIndexesFunction()).flatMapToPair((PairFlatMapFunction)new ReplicateBlockFunction(rlen, blen, false));
            }
            out = inU != null && inV == null && inW == null ? in.join(inU).mapToPair((PairFunction)new RDDQuaternaryFunction2(qop, bc1, bc2)) : (inU == null && inV != null && inW == null ? in.join(inV).mapToPair((PairFunction)new RDDQuaternaryFunction2(qop, bc1, bc2)) : (inU == null && inV == null && inW != null ? in.join(inW).mapToPair((PairFunction)new RDDQuaternaryFunction2(qop, bc1, bc2)) : (inU != null && inV != null && inW == null ? in.join(inU).join(inV).mapToPair((PairFunction)new RDDQuaternaryFunction3(qop, bc1, bc2)) : (inU != null && inV == null && inW != null ? in.join(inU).join(inW).mapToPair((PairFunction)new RDDQuaternaryFunction3(qop, bc1, bc2)) : (inU == null && inV != null && inW != null ? in.join(inV).join(inW).mapToPair((PairFunction)new RDDQuaternaryFunction3(qop, bc1, bc2)) : (inU == null && inV == null && inW == null ? in.mapPartitionsToPair((PairFlatMapFunction)new RDDQuaternaryFunction1(qop, bc1, bc2), false) : in.join(inU).join(inV).join(inW).mapToPair((PairFunction)new RDDQuaternaryFunction4(qop))))))));
            if (inU == null) {
                bcVars.add(this.input2.getName());
            } else {
                rddVars.add(this.input2.getName());
            }
            if (inV == null) {
                bcVars.add(this.input3.getName());
            } else {
                rddVars.add(this.input3.getName());
            }
            if (inW != null) {
                rddVars.add(this._input4.getName());
            }
        }
        if (qop.wtype1 != null || qop.wtype4 != null) {
            MatrixBlock tmp = RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out);
            DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
            sec.setVariable(this.output.getName(), ret);
        } else {
            if (qop.wtype3 != null && !qop.wtype3.isBasic()) {
                out = RDDAggregateUtils.sumByKeyStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out, false);
            }
            sec.setRDDHandleForVariable(this.output.getName(), out);
            for (String rddVar : rddVars) {
                sec.addLineageRDD(this.output.getName(), rddVar);
            }
            for (String bcVar : bcVars) {
                sec.addLineageBroadcast(this.output.getName(), bcVar);
            }
            this.updateOutputDataCharacteristics(sec, qop);
        }
    }

    private void updateOutputDataCharacteristics(SparkExecutionContext sec, QuaternaryOperator qop) {
        DataCharacteristics mcIn1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mcIn2 = sec.getDataCharacteristics(this.input2.getName());
        DataCharacteristics mcIn3 = sec.getDataCharacteristics(this.input3.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        if (qop.wtype2 != null || qop.wtype5 != null) {
            mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize());
        } else if (qop.wtype3 != null) {
            long rank = qop.wtype3.isLeft() ? mcIn3.getCols() : mcIn2.getCols();
            MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mcIn1.getRows(), mcIn1.getCols(), rank);
            mcOut.set(((DataCharacteristics)mcTmp).getRows(), ((DataCharacteristics)mcTmp).getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize());
        }
    }

    private static class TransposeFactorIndexesFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2571724736131823708L;

        private TransposeFactorIndexesFunction() {
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes ixOut = new MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex());
            MatrixBlock blkOut = new MatrixBlock(blkIn);
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class RDDQuaternaryFunction4
    extends RDDQuaternaryBaseFunction
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 7328911771600289250L;

        public RDDQuaternaryFunction4(QuaternaryOperator qop) {
            super(qop, null, null);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock>> arg0) {
            MatrixIndexes ixIn1 = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn1 = (MatrixBlock)((Tuple2)((Tuple2)((Tuple2)arg0._2())._1())._1())._1();
            MatrixBlock mbU = (MatrixBlock)((Tuple2)((Tuple2)((Tuple2)arg0._2())._1())._1())._2();
            MatrixBlock mbV = (MatrixBlock)((Tuple2)((Tuple2)arg0._2())._1())._2();
            MatrixBlock mbW = (MatrixBlock)((Tuple2)arg0._2())._2();
            MatrixBlock blkOut = new MatrixBlock();
            blkIn1.quaternaryOperations(this._qop, mbU, mbV, mbW, blkOut);
            MatrixIndexes ixOut = this.createOutputIndexes(ixIn1);
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class RDDQuaternaryFunction3
    extends RDDQuaternaryBaseFunction
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2294086455843773095L;

        public RDDQuaternaryFunction3(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            super(qop, bcU, bcV);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg0) {
            MatrixBlock mbU;
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn1 = (MatrixBlock)((Tuple2)((Tuple2)arg0._2())._1())._1();
            MatrixBlock blkIn2 = (MatrixBlock)((Tuple2)((Tuple2)arg0._2())._1())._2();
            MatrixBlock blkIn3 = (MatrixBlock)((Tuple2)arg0._2())._2();
            MatrixBlock blkOut = new MatrixBlock();
            MatrixBlock matrixBlock = mbU = this._pmU != null ? (MatrixBlock)this._pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
            MatrixBlock mbV = this._pmV != null ? (MatrixBlock)this._pmV.getBlock((int)ixIn.getColumnIndex(), 1) : (this._pmU != null ? blkIn2 : blkIn3);
            MatrixBlock mbW = this._qop.hasFourInputs() ? blkIn3 : null;
            blkIn1.quaternaryOperations(this._qop, mbU, mbV, mbW, blkOut);
            MatrixIndexes ixOut = this.createOutputIndexes(ixIn);
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class RDDQuaternaryFunction2
    extends RDDQuaternaryBaseFunction
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 7493974462943080693L;

        public RDDQuaternaryFunction2(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            super(qop, bcU, bcV);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn1 = (MatrixBlock)((Tuple2)arg0._2())._1();
            MatrixBlock blkIn2 = (MatrixBlock)((Tuple2)arg0._2())._2();
            MatrixBlock blkOut = new MatrixBlock();
            MatrixBlock mbU = this._pmU != null ? (MatrixBlock)this._pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
            MatrixBlock mbV = this._pmV != null ? (MatrixBlock)this._pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blkIn2;
            MatrixBlock mbW = this._qop.hasFourInputs() ? blkIn2 : null;
            blkIn1.quaternaryOperations(this._qop, mbU, mbV, mbW, blkOut);
            MatrixIndexes ixOut = this.createOutputIndexes(ixIn);
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class RDDQuaternaryFunction1
    extends RDDQuaternaryBaseFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8209188316939435099L;

        public RDDQuaternaryFunction1(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            super(qop, bcU, bcV);
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) {
            return new RDDQuaternaryPartitionIterator(arg);
        }

        private class RDDQuaternaryPartitionIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public RDDQuaternaryPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) {
                MatrixIndexes ixIn = (MatrixIndexes)arg._1();
                MatrixBlock blkIn = (MatrixBlock)arg._2();
                MatrixBlock blkOut = new MatrixBlock();
                MatrixBlock mbU = (MatrixBlock)RDDQuaternaryFunction1.this._pmU.getBlock((int)ixIn.getRowIndex(), 1);
                MatrixBlock mbV = (MatrixBlock)RDDQuaternaryFunction1.this._pmV.getBlock((int)ixIn.getColumnIndex(), 1);
                blkIn.quaternaryOperations(RDDQuaternaryFunction1.this._qop, mbU, mbV, null, blkOut);
                MatrixIndexes ixOut = RDDQuaternaryFunction1.this.createOutputIndexes(ixIn);
                return new Tuple2((Object)ixOut, (Object)blkOut);
            }
        }
    }

    private static abstract class RDDQuaternaryBaseFunction
    implements Serializable {
        private static final long serialVersionUID = -3175397651350954930L;
        protected QuaternaryOperator _qop = null;
        protected PartitionedBroadcast<MatrixBlock> _pmU = null;
        protected PartitionedBroadcast<MatrixBlock> _pmV = null;

        public RDDQuaternaryBaseFunction(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            this._qop = qop;
            this._pmU = bcU;
            this._pmV = bcV;
        }

        protected MatrixIndexes createOutputIndexes(MatrixIndexes in) {
            if (this._qop.wtype3 != null && !this._qop.wtype3.isBasic()) {
                boolean left = this._qop.wtype3.isLeft();
                return new MatrixIndexes(left ? in.getColumnIndex() : in.getRowIndex(), 1L);
            }
            return in;
        }
    }
}

