/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.cost;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class ComputationCostEstimator
extends ACostEstimate {
    private static final long serialVersionUID = -1205636215389161815L;
    private static final double cvThreshold = 0.2;
    private final int _scans;
    private final int _decompressions;
    private final int _dictionaryOps;
    private final int _overlappingDecompressions;
    private final int _leftMultiplications;
    private final int _rightMultiplications;
    private final int _compressedMultiplication;
    private final boolean _isDensifying;

    protected ComputationCostEstimator(InstructionTypeCounter counts) {
        this._scans = counts.scans;
        this._decompressions = counts.decompressions;
        this._overlappingDecompressions = counts.overlappingDecompressions;
        this._leftMultiplications = counts.leftMultiplications;
        this._rightMultiplications = counts.rightMultiplications;
        this._compressedMultiplication = counts.compressedMultiplications;
        this._dictionaryOps = counts.dictionaryOps;
        this._isDensifying = counts.isDensifying;
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)this);
        }
    }

    public ComputationCostEstimator(int scans, int decompressions, int overlappingDecompressions, int leftMultiplications, int rightMultiplications, int compressedMultiplication, int dictOps, boolean isDensifying) {
        this._scans = scans;
        this._decompressions = decompressions;
        this._overlappingDecompressions = overlappingDecompressions;
        this._leftMultiplications = leftMultiplications;
        this._rightMultiplications = rightMultiplications;
        this._compressedMultiplication = compressedMultiplication;
        this._dictionaryOps = dictOps;
        this._isDensifying = isDensifying;
    }

    @Override
    protected double getCostSafe(CompressedSizeInfoColGroup g) {
        int nVals = g.getNumVals();
        int nCols = g.getColumns().length;
        int nRows = g.getNumRows();
        double sparsity = nCols < 3 || this._isDensifying ? 1.0 : g.getTupleSparsity() + 1.0E-10;
        double commonFraction = g.getLargestOffInstances();
        if (g.isEmpty() && !this._isDensifying) {
            return this.getCost(nRows, 1, nCols, 1, 1.0E-5);
        }
        if (g.isEmpty() || g.isConst()) {
            return this.getCost(nRows, 1, nCols, 1, 1.0);
        }
        if (commonFraction > 0.2) {
            return this.getCost(nRows, nRows - g.getLargestOffInstances(), nCols, nVals, sparsity);
        }
        return this.getCost(nRows, nRows, nCols, nVals, sparsity);
    }

    public double getCost(int nRows, int nRowsScanned, int nCols, int nVals, double sparsity) {
        sparsity = nCols < 3 || sparsity > 0.4 ? 1.0 : sparsity;
        double cost = 0.0;
        cost += this.leftMultCost(nRowsScanned, nRows, nCols, nVals, sparsity);
        cost += this.scanCost(nRowsScanned, nCols, nVals, sparsity);
        cost += this.dictionaryOpsCost(nVals, nCols, sparsity);
        cost += this.rightMultCost(nCols, nVals, sparsity);
        cost += this.decompressionCost(nVals, nCols, nRowsScanned, sparsity);
        cost += this.overlappingDecompressionCost(nRowsScanned);
        cost += this.compressedMultiplicationCost(nRowsScanned, nRows, nVals, nCols, sparsity);
        if ((cost += 100.0) < 0.0) {
            throw new DMLCompressionException("Ivalid negative cost: " + cost);
        }
        return cost;
    }

    public boolean isDense() {
        return this._isDensifying;
    }

    @Override
    public double getCost(MatrixBlock mb) {
        double cost = 0.0;
        double nCols = mb.getNumColumns();
        double nRows = mb.getNumRows();
        double sparsity = nCols < 3.0 || this._isDensifying ? 1.0 : mb.getSparsity();
        cost += this.dictionaryOpsCost(nRows, nCols, sparsity);
        cost += this.leftMultCost(0.0, nRows * nCols * sparsity + nCols);
        cost += this.rightMultCost(nRows * nCols * sparsity, nRows * nCols);
        cost += this.scanCost(0.0, nRows, nCols, sparsity);
        if ((cost += this.compressedMultiplicationCost(0.0, 0.0, nRows, nCols, sparsity)) < 0.0) {
            throw new DMLCompressionException("Invalid negative cost : " + cost);
        }
        return cost;
    }

    @Override
    public double getCost(AColGroup cg, int nRows) {
        return cg.getCost(this, nRows);
    }

    @Override
    public boolean shouldSparsify() {
        return this._leftMultiplications > 0 || this._compressedMultiplication > 0 || this._rightMultiplications > 0;
    }

    private double dictionaryOpsCost(double nVals, double nCols, double sparsity) {
        return (double)this._dictionaryOps * sparsity * nVals * nCols * 2.0;
    }

    private double leftMultCost(double nRowsScanned, double nRows, double nCols, double nVals, double sparsity) {
        double preScalingCost = Math.max(nRowsScanned, nRows / 10.0) + nVals * 2.0;
        double postScalingCost = sparsity * nVals * nCols;
        return this.leftMultCost(preScalingCost, postScalingCost);
    }

    private double leftMultCost(double preAggregateCost, double postScalingCost) {
        return (double)this._leftMultiplications * (preAggregateCost + postScalingCost);
    }

    private double rightMultCost(double nVals, double nCols, double sparsity) {
        double preMultiplicationCost = sparsity * nCols * nVals;
        double allocationCost = nVals;
        return this.rightMultCost(preMultiplicationCost, allocationCost);
    }

    private double rightMultCost(double preMultiplicationCost, double allocationCost) {
        return (double)this._rightMultiplications * (preMultiplicationCost + allocationCost);
    }

    private double decompressionCost(double nVals, double nCols, double nRowsScanned, double sparsity) {
        return (double)this._decompressions * (nCols * nRowsScanned * sparsity);
    }

    private double overlappingDecompressionCost(double nRows) {
        return (double)this._overlappingDecompressions * nRows;
    }

    private double scanCost(double nRowsScanned, double nVals, double nCols, double sparsity) {
        return (double)this._scans * (nRowsScanned + nVals * nCols * sparsity);
    }

    private double compressedMultiplicationCost(double nRowsScanned, double nRows, double nVals, double nCols, double sparsity) {
        return (double)this._compressedMultiplication * (Math.max(nRowsScanned, nRows / 10.0) + nVals * nCols * sparsity);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" --- CostVector:[");
        sb.append(this._scans + ",");
        sb.append(this._decompressions + ",");
        sb.append(this._overlappingDecompressions + ",");
        sb.append(this._leftMultiplications + ",");
        sb.append(this._rightMultiplications + ",");
        sb.append(this._compressedMultiplication + ",");
        sb.append(this._dictionaryOps + "]");
        sb.append(" Densifying:");
        sb.append(this._isDensifying);
        return sb.toString();
    }
}

