/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sedona.common.raster;

import java.awt.image.Raster;
import java.awt.image.WritableRaster;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import javax.media.jai.RasterFactory;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.sedona.common.Functions;
import org.apache.sedona.common.raster.RasterAccessors;
import org.apache.sedona.common.raster.RasterConstructors;
import org.apache.sedona.common.raster.RasterPredicates;
import org.apache.sedona.common.utils.RasterUtils;
import org.geotools.coverage.GridSampleDimension;
import org.geotools.coverage.grid.GridCoverage2D;
import org.locationtech.jts.geom.Geometry;
import org.opengis.referencing.FactoryException;

public class RasterBandAccessors {
    public static Double getBandNoDataValue(GridCoverage2D raster, int band) {
        RasterUtils.ensureBand(raster, band);
        GridSampleDimension bandSampleDimension = raster.getSampleDimension(band - 1);
        double noDataValue = RasterUtils.getNoDataValue(bandSampleDimension);
        if (Double.isNaN(noDataValue)) {
            return null;
        }
        return noDataValue;
    }

    public static Double getBandNoDataValue(GridCoverage2D raster) {
        return RasterBandAccessors.getBandNoDataValue(raster, 1);
    }

    public static long getCount(GridCoverage2D raster, int band, boolean excludeNoDataValue) {
        Double bandNoDataValue = RasterBandAccessors.getBandNoDataValue(raster, band);
        int width = RasterAccessors.getWidth(raster);
        int height = RasterAccessors.getHeight(raster);
        if (excludeNoDataValue && bandNoDataValue != null) {
            RasterUtils.ensureBand(raster, band);
            Raster r = RasterUtils.getRaster(raster.getRenderedImage());
            double[] pixels = r.getSamples(0, 0, width, height, band - 1, (double[])null);
            long numberOfPixel = 0L;
            for (double bandValue : pixels) {
                if (Double.compare(bandValue, bandNoDataValue) == 0) continue;
                ++numberOfPixel;
            }
            return numberOfPixel;
        }
        return (long)width * (long)height;
    }

    public static long getCount(GridCoverage2D raster) {
        return RasterBandAccessors.getCount(raster, 1, true);
    }

    public static long getCount(GridCoverage2D raster, int band) {
        return RasterBandAccessors.getCount(raster, band, true);
    }

    public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException {
        List<Object> objects = RasterBandAccessors.getStatObjects(raster, roi, band, excludeNoData, lenient);
        if (objects == null) {
            return null;
        }
        DescriptiveStatistics stats = (DescriptiveStatistics)objects.get(0);
        double[] pixelData = (double[])objects.get(1);
        double[] result = new double[]{stats.getN(), stats.getSum(), stats.getMean(), stats.getPercentile(50.0), RasterBandAccessors.zonalMode(pixelData), stats.getStandardDeviation(), stats.getVariance(), stats.getMin(), stats.getMax()};
        return result;
    }

    public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
        return RasterBandAccessors.getZonalStatsAll(raster, roi, band, excludeNoData, true);
    }

    public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band) throws FactoryException {
        return RasterBandAccessors.getZonalStatsAll(raster, roi, band, true);
    }

    public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi) throws FactoryException {
        return RasterBandAccessors.getZonalStatsAll(raster, roi, 1, true);
    }

    public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData, boolean lenient) throws FactoryException {
        List<Object> objects = RasterBandAccessors.getStatObjects(raster, roi, band, excludeNoData, lenient);
        if (objects == null) {
            return null;
        }
        DescriptiveStatistics stats = (DescriptiveStatistics)objects.get(0);
        double[] pixelData = (double[])objects.get(1);
        switch (statType.toLowerCase()) {
            case "sum": {
                return stats.getSum();
            }
            case "average": 
            case "avg": 
            case "mean": {
                return stats.getMean();
            }
            case "count": {
                return stats.getN();
            }
            case "max": {
                return stats.getMax();
            }
            case "min": {
                return stats.getMin();
            }
            case "stddev": 
            case "sd": {
                return stats.getStandardDeviation();
            }
            case "median": {
                return stats.getPercentile(50.0);
            }
            case "mode": {
                return RasterBandAccessors.zonalMode(pixelData);
            }
            case "variance": {
                return stats.getVariance();
            }
        }
        throw new IllegalArgumentException("Please select from the accepted options. Some of the valid options are sum, mean, stddev, etc.");
    }

    public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException {
        return RasterBandAccessors.getZonalStats(raster, roi, band, statType, excludeNoData, true);
    }

    public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException {
        return RasterBandAccessors.getZonalStats(raster, roi, band, statType, true);
    }

    public static Double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException {
        return RasterBandAccessors.getZonalStats(raster, roi, 1, statType, true);
    }

    private static double zonalMode(double[] pixelData) {
        double[] modes = StatUtils.mode(pixelData);
        return modes[modes.length - 1];
    }

    private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException {
        RasterUtils.ensureBand(raster, band);
        if (RasterAccessors.srid(raster) != roi.getSRID()) {
            roi = RasterUtils.convertCRSIfNeeded(roi, raster.getCoordinateReferenceSystem());
            roi = Functions.setSRID(roi, RasterAccessors.srid(raster));
        }
        if (!RasterPredicates.rsIntersects(raster, roi)) {
            if (lenient) {
                return null;
            }
            throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent.");
        }
        Raster rasterData = RasterUtils.getRaster(raster.getRenderedImage());
        String datatype = RasterBandAccessors.getBandType(raster, band);
        Double noDataValue = RasterBandAccessors.getBandNoDataValue(raster, band);
        GridCoverage2D rasterizedGeom = RasterConstructors.asRasterWithRasterExtent(roi, raster, datatype, 150.0, null);
        Raster rasterziedData = RasterUtils.getRaster(rasterizedGeom.getRenderedImage());
        int width = RasterAccessors.getWidth(rasterizedGeom);
        int height = RasterAccessors.getHeight(rasterizedGeom);
        double[] rasterizedPixelData = rasterziedData.getSamples(0, 0, width, height, 0, (double[])null);
        double[] rasterPixelData = rasterData.getSamples(0, 0, width, height, band - 1, (double[])null);
        ArrayList<Double> pixelData = new ArrayList<Double>();
        for (int k = 0; k < rasterPixelData.length; ++k) {
            if (rasterizedPixelData[k] == 0.0 || excludeNoData && noDataValue != null && rasterPixelData[k] == noDataValue) continue;
            pixelData.add(rasterPixelData[k]);
        }
        double[] pixelsArray = pixelData.stream().mapToDouble(d -> d).toArray();
        DescriptiveStatistics stats = new DescriptiveStatistics(pixelsArray);
        ArrayList<Object> statObjects = new ArrayList<Object>();
        statObjects.add(stats);
        statObjects.add(pixelsArray);
        return statObjects;
    }

    public static double[] getSummaryStats(GridCoverage2D rasterGeom, int band, boolean excludeNoDataValue) {
        RasterUtils.ensureBand(rasterGeom, band);
        Raster raster = RasterUtils.getRaster(rasterGeom.getRenderedImage());
        int height = RasterAccessors.getHeight(rasterGeom);
        int width = RasterAccessors.getWidth(rasterGeom);
        double[] pixels = raster.getSamples(0, 0, width, height, band - 1, (double[])null);
        ArrayList<Double> pixelData = null;
        if (excludeNoDataValue) {
            pixelData = new ArrayList<Double>();
            Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, band);
            for (double pixel : pixels) {
                if (noDataValue != null && pixel == noDataValue) continue;
                pixelData.add(pixel);
            }
        }
        DescriptiveStatistics stats = null;
        if (pixelData == null) {
            stats = new DescriptiveStatistics(pixels);
        } else {
            pixels = pixelData.stream().mapToDouble(d -> d).toArray();
            stats = new DescriptiveStatistics(pixels);
        }
        StandardDeviation sd = new StandardDeviation(false);
        double count = stats.getN();
        double sum = stats.getSum();
        double mean = stats.getMean();
        double stddev = sd.evaluate(pixels, mean);
        double min2 = stats.getMin();
        double max = stats.getMax();
        return new double[]{count, sum, mean, stddev, min2, max};
    }

    public static double[] getSummaryStats(GridCoverage2D raster, int band) {
        return RasterBandAccessors.getSummaryStats(raster, band, true);
    }

    public static double[] getSummaryStats(GridCoverage2D raster) {
        return RasterBandAccessors.getSummaryStats(raster, 1, true);
    }

    public static GridCoverage2D getBand(GridCoverage2D rasterGeom, int[] bandIndexes) throws FactoryException {
        double[] metadata = RasterAccessors.metadata(rasterGeom);
        int width = (int)metadata[2];
        int height = (int)metadata[3];
        GridCoverage2D resultRaster = RasterConstructors.makeEmptyRaster(bandIndexes.length, width, height, metadata[0], metadata[1], metadata[4], metadata[5], metadata[6], metadata[7], (int)metadata[8]);
        Raster raster = RasterUtils.getRaster(rasterGeom.getRenderedImage());
        int dataTypeCode = raster.getDataBuffer().getDataType();
        boolean isDataTypeIntegral = RasterUtils.isDataTypeIntegral(dataTypeCode);
        int[] bandsDistinct = Arrays.stream(bandIndexes).distinct().toArray();
        HashMap<Integer, double[]> bandData = new HashMap<Integer, double[]>();
        for (int curBand : bandsDistinct) {
            RasterUtils.ensureBand(rasterGeom, curBand);
            if (isDataTypeIntegral) {
                bandData.put(curBand - 1, raster.getSamples(0, 0, width, height, curBand - 1, (int[])null));
                continue;
            }
            bandData.put(curBand - 1, raster.getSamples(0, 0, width, height, curBand - 1, (double[])null));
        }
        WritableRaster wr = RasterFactory.createBandedRaster((int)dataTypeCode, (int)width, (int)height, (int)bandIndexes.length, null);
        GridSampleDimension[] sampleDimensionsOg = rasterGeom.getSampleDimensions();
        GridSampleDimension[] sampleDimensionsResult = resultRaster.getSampleDimensions();
        for (int i = 0; i < bandIndexes.length; ++i) {
            sampleDimensionsResult[i] = sampleDimensionsOg[bandIndexes[i] - 1];
            if (isDataTypeIntegral) {
                wr.setSamples(0, 0, width, height, i, (int[])bandData.get(bandIndexes[i] - 1));
            } else {
                wr.setSamples(0, 0, width, height, i, (double[])bandData.get(bandIndexes[i] - 1));
            }
            Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, bandIndexes[i]);
            GridSampleDimension sampleDimension = sampleDimensionsResult[i];
            if (noDataValue == null) continue;
            sampleDimensionsResult[i] = RasterUtils.createSampleDimensionWithNoDataValue(sampleDimension, (double)noDataValue);
        }
        return RasterUtils.clone(wr, sampleDimensionsResult, rasterGeom, null, false);
    }

    public static String getBandType(GridCoverage2D raster, int band) {
        RasterUtils.ensureBand(raster, band);
        GridSampleDimension bandSampleDimension = raster.getSampleDimension(band - 1);
        return bandSampleDimension.getSampleDimensionType().name();
    }

    public static String getBandType(GridCoverage2D raster) {
        return RasterBandAccessors.getBandType(raster, 1);
    }

    public static boolean bandIsNoData(GridCoverage2D raster, int band) {
        double[] pixels;
        RasterUtils.ensureBand(raster, band);
        Raster rasterData = RasterUtils.getRaster(raster.getRenderedImage());
        int width = rasterData.getWidth();
        int height = rasterData.getHeight();
        double noDataValue = RasterUtils.getNoDataValue(raster.getSampleDimension(band - 1));
        if (Double.isNaN(noDataValue)) {
            return false;
        }
        for (double pixel : pixels = rasterData.getSamples(0, 0, width, height, band - 1, (double[])null)) {
            if (Double.compare(pixel, noDataValue) == 0) continue;
            return false;
        }
        return true;
    }

    public static boolean bandIsNoData(GridCoverage2D raster) {
        return RasterBandAccessors.bandIsNoData(raster, 1);
    }
}

