/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.runtime.matrix.data.Pair;

public class RewriteSplitDagDataDependentOperators
extends StatementBlockRewriteRule {
    @Override
    public boolean createsSplitDag() {
        return true;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
        Compression.CompressConfig compress = Compression.CompressConfig.valueOf(ConfigurationManager.getDMLConfig().getTextValue("sysds.compressed.linalg").toUpperCase());
        if (DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE && compress == Compression.CompressConfig.FALSE || !HopRewriteUtils.isLastLevelStatementBlock(sb)) {
            return Arrays.asList(sb);
        }
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        ArrayList<Hop> cand = new ArrayList<Hop>();
        this.collectDataDependentOperators(sb.getHops(), cand);
        Hop.resetVisitStatus(sb.getHops());
        if (!cand.isEmpty()) {
            HashSet<Hop> candChilds = new HashSet<Hop>();
            this.collectCandidateChildOperators(cand, candChilds);
            candChilds.addAll(cand);
            try {
                StatementBlock sb1 = new StatementBlock();
                sb1.setDMLProg(sb.getDMLProg());
                sb1.setParseInfo(sb);
                sb1.setLiveIn(new VariableSet());
                sb1.setLiveOut(new VariableSet());
                ArrayList<Hop> sb1hops = new ArrayList<Hop>();
                for (Hop c : cand) {
                    boolean hasTWrites = RewriteSplitDagDataDependentOperators.hasTransientWriteParents(c);
                    boolean moveTWrite = hasTWrites ? HopRewriteUtils.rHasSimpleReadChain(c, RewriteSplitDagDataDependentOperators.getFirstTransientWriteParent(c).getName()) : false;
                    String varname = null;
                    long rlen = c.getDim1();
                    long clen = c.getDim2();
                    int blen = c.getBlocksize();
                    if (hasTWrites && moveTWrite) {
                        Hop twrite = RewriteSplitDagDataDependentOperators.getFirstTransientWriteParent(c);
                        varname = twrite.getName();
                        DataOp tread = HopRewriteUtils.createTransientRead(varname, c);
                        ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
                        for (int i = 0; i < parents.size(); ++i) {
                            Hop parent = parents.get(i);
                            if (candChilds.contains(parent)) continue;
                            if (parent != twrite) {
                                HopRewriteUtils.replaceChildReference(parent, c, tread);
                                continue;
                            }
                            sb.getHops().remove(parent);
                        }
                        sb1hops.add(twrite);
                    } else {
                        varname = RewriteSplitDagDataDependentOperators.createCutVarName(false);
                        DataOp tread = HopRewriteUtils.createTransientRead(varname, c);
                        ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
                        for (int i = 0; i < parents.size(); ++i) {
                            Hop parent = parents.get(i);
                            if (candChilds.contains(parent)) continue;
                            HopRewriteUtils.replaceChildReference(parent, c, tread);
                        }
                        DataOp twrite = HopRewriteUtils.createTransientWrite(varname, c);
                        sb1hops.add(twrite);
                    }
                    DataIdentifier diVar = new DataIdentifier(varname);
                    diVar.setDimensions(rlen, clen);
                    diVar.setBlocksize(blen);
                    diVar.setDataType(c.getDataType());
                    diVar.setValueType(c.getValueType());
                    sb1.liveOut().addVariable(varname, new DataIdentifier(diVar));
                    sb.liveIn().addVariable(varname, new DataIdentifier(diVar));
                    sb.variablesRead().addVariable(varname, new DataIdentifier(diVar));
                }
                this.handleReplicatedOperators(sb1hops, sb.getHops(), sb1.liveOut(), sb.liveIn());
                sb1.setHops(Recompiler.deepCopyHopsDag(sb1hops));
                sb1.updateRecompilationFlag();
                sb1.setSplitDag(true);
                List<StatementBlock> tmp = this.rewriteStatementBlock(sb1, state);
                ret.addAll(tmp);
                ret.add(sb);
                sb.setSplitDag(true);
            }
            catch (Exception ex) {
                throw new HopsException("Failed to split hops dag for data dependent operators with unknown size.", ex);
            }
            LOG.debug((Object)("Applied splitDagDataDependentOperators (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ")."));
        } else {
            ret.add(sb);
        }
        return ret;
    }

    private void collectDataDependentOperators(ArrayList<Hop> roots, ArrayList<Hop> cand) {
        if (roots == null) {
            return;
        }
        Hop.resetVisitStatus(roots);
        for (Hop root : roots) {
            this.rCollectDataDependentOperators(root, cand);
        }
    }

    private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> cand) {
        if (hop.isVisited()) {
            return;
        }
        boolean noSplitRequired = HopRewriteUtils.hasOnlyWriteParents(hop, true, false) || hop.dimsKnown() || DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE;
        boolean investigateChilds = true;
        if (!(!(hop instanceof ParameterizedBuiltinOp) || ((ParameterizedBuiltinOp)hop).getOp() != Types.ParamBuiltinOp.RMEMPTY || noSplitRequired || hop.getParent().size() == 1 && hop.getParent().get(0) instanceof TernaryOp && ((TernaryOp)hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
            ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp)hop;
            cand.add(pbhop);
            investigateChilds = false;
            boolean noEmptyBlocks = true;
            boolean onlyPMM = true;
            boolean diagInput = pbhop.isTargetDiagInput();
            for (Hop p : hop.getParent()) {
                noEmptyBlocks &= p instanceof AggBinaryOp && hop == p.getInput().get(0) || HopRewriteUtils.isUnary(p, Types.OpOp1.NROW);
                onlyPMM &= p instanceof AggBinaryOp && hop == p.getInput().get(0);
            }
            pbhop.setOutputEmptyBlocks(!noEmptyBlocks);
            if (onlyPMM && diagInput) {
                if (ConfigurationManager.isDynamicRecompilation()) {
                    pbhop.setOutputPermutationMatrix(true);
                }
                for (Hop p : hop.getParent()) {
                    ((AggBinaryOp)p).setHasLeftPMInput(true);
                }
            }
        }
        if (HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE) && hop.getInput().size() < 4 && !noSplitRequired) {
            cand.add(hop);
            investigateChilds = false;
            boolean onlyPMM = true;
            for (Hop p : hop.getParent()) {
                onlyPMM &= p instanceof AggBinaryOp && hop == p.getInput().get(0);
            }
            if (onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0))) {
                hop.setOutputEmptyBlocks(false);
            }
        }
        if (HopRewriteUtils.isReorg(hop, Types.ReOrgOp.SORT)) {
            for (int i = 2; i <= 3; ++i) {
                Hop c = hop.getInput().get(i);
                if (c instanceof LiteralOp || c instanceof DataOp) continue;
                cand.add(c);
                c.setVisited();
                investigateChilds = false;
            }
        }
        if (RewriteSplitDagDataDependentOperators.isBasicDataDependentOperator(hop, noSplitRequired)) {
            cand.add(hop);
            investigateChilds = false;
        }
        if (investigateChilds && hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                this.rCollectDataDependentOperators(c, cand);
            }
        }
        hop.setVisited();
    }

    private static boolean isBasicDataDependentOperator(Hop hop, boolean noSplitRequired) {
        return HopRewriteUtils.isNary(hop, Types.OpOpN.EVAL) & !noSplitRequired || HopRewriteUtils.isData(hop, Types.OpOpData.SQLREAD) & !noSplitRequired || HopRewriteUtils.isParameterBuiltinOp(hop, Types.ParamBuiltinOp.GROUPEDAGG) && !((ParameterizedBuiltinOp)hop).isKnownNGroups() && !noSplitRequired || (HopRewriteUtils.isUnary(hop, Types.OpOp1.COMPRESS) || hop.requiresCompression()) && !HopRewriteUtils.hasOnlyWriteParents(hop, true, true);
    }

    private static boolean hasTransientWriteParents(Hop hop) {
        for (Hop p : hop.getParent()) {
            if (!(p instanceof DataOp) || ((DataOp)p).getOp() != Types.OpOpData.TRANSIENTWRITE) continue;
            return true;
        }
        return false;
    }

    private static Hop getFirstTransientWriteParent(Hop hop) {
        for (Hop p : hop.getParent()) {
            if (!(p instanceof DataOp) || ((DataOp)p).getOp() != Types.OpOpData.TRANSIENTWRITE) continue;
            return p;
        }
        return null;
    }

    private void handleReplicatedOperators(ArrayList<Hop> rootsSB1, ArrayList<Hop> rootsSB2, VariableSet sb1out, VariableSet sb2in) {
        HashSet<Hop> probeSet = new HashSet<Hop>();
        Hop.resetVisitStatus(rootsSB1);
        for (Hop hop : rootsSB1) {
            this.rAddHopsToProbeSet(hop, probeSet);
        }
        HashSet<Pair<Hop, Hop>> candSet = new HashSet<Pair<Hop, Hop>>();
        Hop.resetVisitStatus(rootsSB2);
        for (Hop h : rootsSB2) {
            this.rProbeAndAddHopsToCandidateSet(h, probeSet, candSet);
        }
        HashMap<Long, DataOp> hashMap = new HashMap<Long, DataOp>();
        for (Pair pair : candSet) {
            Hop hop = (Hop)pair.getKey();
            Hop c = (Hop)pair.getValue();
            DataOp tread = (DataOp)hashMap.get(c.getHopID());
            if (tread == null) {
                String varname = RewriteSplitDagDataDependentOperators.createCutVarName(false);
                tread = HopRewriteUtils.createTransientRead(varname, c);
                hashMap.put(c.getHopID(), tread);
                DataOp twrite = HopRewriteUtils.createTransientWrite(varname, c);
                DataIdentifier diVar = new DataIdentifier(varname);
                diVar.setDimensions(c.getDim1(), c.getDim2());
                diVar.setBlocksize(c.getBlocksize());
                diVar.setDataType(c.getDataType());
                diVar.setValueType(c.getValueType());
                sb1out.addVariable(varname, new DataIdentifier(diVar));
                sb2in.addVariable(varname, new DataIdentifier(diVar));
                rootsSB1.add(twrite);
            }
            int pos = HopRewriteUtils.getChildReferencePos(hop, c);
            HopRewriteUtils.removeChildReferenceByPos(hop, c, pos);
            HopRewriteUtils.addChildReference(hop, tread, pos);
        }
    }

    private void rAddHopsToProbeSet(Hop hop, HashSet<Hop> probeSet) {
        if (hop.isVisited()) {
            return;
        }
        if (!(hop instanceof DataOp && !((DataOp)hop).isPersistentReadWrite() || hop instanceof LiteralOp)) {
            probeSet.add(hop);
        }
        if (hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                this.rAddHopsToProbeSet(c, probeSet);
            }
        }
        hop.setVisited();
    }

    private void rProbeAndAddHopsToCandidateSet(Hop hop, HashSet<Hop> probeSet, HashSet<Pair<Hop, Hop>> candSet) {
        if (hop.isVisited()) {
            return;
        }
        if (hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                if (!probeSet.contains(c)) {
                    this.rProbeAndAddHopsToCandidateSet(c, probeSet, candSet);
                    continue;
                }
                candSet.add(new Pair<Hop, Hop>(hop, c));
            }
        }
        hop.setVisited();
    }

    private void collectCandidateChildOperators(ArrayList<Hop> cand, HashSet<Hop> candChilds) {
        Hop.resetVisitStatus(cand);
        if (cand != null) {
            for (Hop root : cand) {
                this.rCollectCandidateChildOperators(root, cand, candChilds, false);
            }
        }
        Hop.resetVisitStatus(cand);
    }

    private void rCollectCandidateChildOperators(Hop hop, ArrayList<Hop> cand, HashSet<Hop> candChilds, boolean collect) {
        if (hop.isVisited()) {
            return;
        }
        if (collect) {
            candChilds.add(hop);
        }
        boolean passedFlag = collect;
        if (cand.contains(hop)) {
            passedFlag = true;
        }
        if (hop.getInput() != null) {
            for (Hop c : hop.getInput()) {
                this.rCollectCandidateChildOperators(c, cand, candChilds, passedFlag);
            }
        }
        hop.setVisited();
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
        return sbs;
    }
}

