/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sedona.core.joinJudgement;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sedona.core.enums.DistanceMetric;
import org.apache.sedona.core.joinJudgement.JudgementBase;
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
import org.apache.sedona.core.knnJudgement.SpheroidDistance;
import org.apache.sedona.core.wrapper.UniqueGeometry;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.index.SpatialIndex;
import org.locationtech.jts.index.strtree.GeometryItemDistance;
import org.locationtech.jts.index.strtree.ItemDistance;
import org.locationtech.jts.index.strtree.STRtree;

public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry>
extends JudgementBase<T, U>
implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, T>>,
Serializable {
    private final int k;
    private final Double searchRadius;
    private final DistanceMetric distanceMetric;
    private final boolean includeTies;
    private final Broadcast<List> broadcastQueryObjects;
    private final Broadcast<STRtree> broadcastObjectsTreeIndex;

    public KnnJoinIndexJudgement(int k, Double searchRadius, DistanceMetric distanceMetric, boolean includeTies, Broadcast<List> broadcastQueryObjects, Broadcast<STRtree> broadcastObjectsTreeIndex, LongAccumulator buildCount, LongAccumulator streamCount, LongAccumulator resultCount, LongAccumulator candidateCount) {
        super(null, buildCount, streamCount, resultCount, candidateCount);
        this.k = k;
        this.searchRadius = searchRadius;
        this.distanceMetric = distanceMetric;
        this.includeTies = includeTies;
        this.broadcastQueryObjects = broadcastQueryObjects;
        this.broadcastObjectsTreeIndex = broadcastObjectsTreeIndex;
    }

    public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex> treeIndexes) throws Exception {
        STRtree strTree;
        if (!treeIndexes.hasNext() || streamShapes != null && !streamShapes.hasNext()) {
            this.buildCount.add(0L);
            this.streamCount.add(0L);
            this.resultCount.add(0L);
            this.candidateCount.add(0L);
            return Collections.emptyIterator();
        }
        if (this.broadcastObjectsTreeIndex != null) {
            strTree = (STRtree)this.broadcastObjectsTreeIndex.getValue();
        } else {
            SpatialIndex treeIndex = treeIndexes.next();
            if (!(treeIndex instanceof STRtree)) {
                throw new Exception("[KnnJoinIndexJudgement][Call] Only STRtree index supports KNN search.");
            }
            strTree = (STRtree)treeIndex;
        }
        ArrayList<Pair> result = new ArrayList<Pair>();
        if (this.broadcastQueryObjects != null) {
            List queryItems = (List)this.broadcastQueryObjects.getValue();
            for (Object item : queryItems) {
                Geometry queryGeom = item instanceof UniqueGeometry ? (Geometry)((UniqueGeometry)item).getOriginalGeometry() : (Geometry)item;
                this.streamCount.add(1L);
                Object[] localK = strTree.nearestNeighbour(queryGeom.getEnvelopeInternal(), queryGeom, this.getItemDistance(), this.k);
                if (this.includeTies) {
                    localK = this.getUpdatedLocalKWithTies(queryGeom, localK, strTree);
                }
                if (this.searchRadius != null) {
                    localK = this.getInSearchRadius(localK, queryGeom);
                }
                for (Object obj : localK) {
                    Geometry candidate = (Geometry)obj;
                    Pair pair = Pair.of((Object)((Geometry)item), (Object)candidate);
                    result.add(pair);
                    this.resultCount.add(1L);
                }
            }
            return result.iterator();
        }
        while (streamShapes.hasNext()) {
            Geometry streamShape = (Geometry)streamShapes.next();
            this.streamCount.add(1L);
            Object[] localK = strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), streamShape, this.getItemDistance(), this.k);
            if (this.includeTies) {
                localK = this.getUpdatedLocalKWithTies(streamShape, localK, strTree);
            }
            if (this.searchRadius != null) {
                localK = this.getInSearchRadius(localK, streamShape);
            }
            for (Object obj : localK) {
                Geometry candidate = (Geometry)obj;
                Pair pair = Pair.of((Object)streamShape, (Object)candidate);
                result.add(pair);
                this.resultCount.add(1L);
            }
        }
        return result.iterator();
    }

    private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
        localK = Arrays.stream(localK).filter(candidate -> {
            Geometry candidateGeom = (Geometry)candidate;
            return KnnJoinIndexJudgement.distanceByMetric(queryGeom, candidateGeom, this.distanceMetric) <= this.searchRadius;
        }).toArray();
        return localK;
    }

    public static double distanceByMetric(Geometry queryGeom, Geometry candidateGeom, DistanceMetric distanceMetric) {
        switch (distanceMetric) {
            case EUCLIDEAN: {
                EuclideanItemDistance euclideanItemDistance = new EuclideanItemDistance();
                return euclideanItemDistance.distance(queryGeom, candidateGeom);
            }
            case HAVERSINE: {
                HaversineItemDistance haversineItemDistance = new HaversineItemDistance();
                return haversineItemDistance.distance(queryGeom, candidateGeom);
            }
            case SPHEROID: {
                SpheroidDistance spheroidDistance = new SpheroidDistance();
                return spheroidDistance.distance(queryGeom, candidateGeom);
            }
        }
        return queryGeom.distance(candidateGeom);
    }

    private ItemDistance getItemDistance() {
        ItemDistance itemDistance = KnnJoinIndexJudgement.getItemDistanceByMetric(this.distanceMetric);
        return itemDistance;
    }

    public static ItemDistance getItemDistanceByMetric(DistanceMetric distanceMetric) {
        ItemDistance itemDistance;
        switch (distanceMetric) {
            case EUCLIDEAN: {
                itemDistance = new EuclideanItemDistance();
                break;
            }
            case HAVERSINE: {
                itemDistance = new HaversineItemDistance();
                break;
            }
            case SPHEROID: {
                itemDistance = new SpheroidDistance();
                break;
            }
            default: {
                itemDistance = new GeometryItemDistance();
            }
        }
        return itemDistance;
    }

    private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) {
        Envelope searchEnvelope = ((Geometry)streamShape).getEnvelopeInternal();
        double maxDistance = 0.0;
        LinkedHashSet<Geometry> uniqueCandidates = new LinkedHashSet<Geometry>();
        for (Object obj : localK) {
            Geometry candidate = (Geometry)obj;
            uniqueCandidates.add(candidate);
            double distance = ((Geometry)streamShape).distance(candidate);
            if (!(distance > maxDistance)) continue;
            maxDistance = distance;
        }
        searchEnvelope.expandBy(maxDistance);
        List candidates = strTree.query(searchEnvelope);
        if (!candidates.isEmpty()) {
            ArrayList<Geometry> tiedResults = new ArrayList<Geometry>();
            Collections.addAll(tiedResults, localK);
            for (Geometry candidate : candidates) {
                double distance = ((Geometry)streamShape).distance(candidate);
                if (distance != maxDistance || uniqueCandidates.contains(candidate)) continue;
                tiedResults.add(candidate);
            }
            localK = tiedResults.toArray();
        }
        return localK;
    }

    public static <U extends Geometry, T extends Geometry> double distance(U key, T value, DistanceMetric distanceMetric) {
        switch (distanceMetric) {
            case EUCLIDEAN: {
                return new EuclideanItemDistance().distance(key, value);
            }
            case HAVERSINE: {
                return new HaversineItemDistance().distance(key, value);
            }
            case SPHEROID: {
                return new SpheroidDistance().distance(key, value);
            }
        }
        return new EuclideanItemDistance().distance(key, value);
    }
}

