/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.opensearch.ml.engine.utils.TribuoUtil;
import org.tribuo.MutableDataset;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.clustering.kmeans.KMeansModel;
import org.tribuo.clustering.kmeans.KMeansTrainer;

@Function(value=FunctionName.KMEANS)
public class KMeans
implements TrainAndPredictable {
    private static final KMeansParams.DistanceType DEFAULT_DISTANCE_TYPE = KMeansParams.DistanceType.EUCLIDEAN;
    private static int DEFAULT_CENTROIDS = 2;
    private static int DEFAULT_ITERATIONS = 10;
    private KMeansParams parameters;
    private int numThreads = Math.max(Runtime.getRuntime().availableProcessors() / 2, 1);
    private long seed = System.currentTimeMillis();
    private KMeansTrainer.Distance distance;

    public KMeans() {
    }

    public KMeans(MLAlgoParams parameters) {
        this.parameters = parameters == null ? KMeansParams.builder().build() : (KMeansParams)parameters;
        this.validateParameters();
        this.createDistance();
    }

    private void validateParameters() {
        if (this.parameters.getCentroids() != null && this.parameters.getCentroids() <= 0) {
            throw new IllegalArgumentException("K should be positive.");
        }
        if (this.parameters.getIterations() != null && this.parameters.getIterations() <= 0) {
            throw new IllegalArgumentException("Iterations should be positive.");
        }
    }

    private void createDistance() {
        KMeansParams.DistanceType distanceType = Optional.ofNullable(this.parameters.getDistanceType()).orElse(DEFAULT_DISTANCE_TYPE);
        switch (distanceType) {
            case COSINE: {
                this.distance = KMeansTrainer.Distance.COSINE;
                break;
            }
            case L1: {
                this.distance = KMeansTrainer.Distance.L1;
                break;
            }
            default: {
                this.distance = KMeansTrainer.Distance.EUCLIDEAN;
            }
        }
    }

    @Override
    public MLOutput predict(DataFrame dataFrame, Model model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for KMeans prediction.");
        }
        MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
        KMeansModel kMeansModel = (KMeansModel)ModelSerDeSer.deserialize(model.getContent());
        List predictions = kMeansModel.predict(predictionDataset);
        ArrayList listClusterID = new ArrayList();
        predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", ((ClusterID)e.getOutput()).getID())));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
    }

    @Override
    public Model train(DataFrame dataFrame) {
        MutableDataset trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID);
        Integer centroids = Optional.ofNullable(this.parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
        Integer iterations = Optional.ofNullable(this.parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
        KMeansTrainer trainer = new KMeansTrainer(centroids.intValue(), iterations.intValue(), this.distance, this.numThreads, this.seed);
        KMeansModel kMeansModel = trainer.train(trainDataset);
        Model model = new Model();
        model.setName(FunctionName.KMEANS.name());
        model.setVersion(1);
        model.setContent(ModelSerDeSer.serialize(kMeansModel));
        return model;
    }

    @Override
    public MLOutput trainAndPredict(DataFrame dataFrame) {
        MutableDataset trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
        Integer centroids = Optional.ofNullable(this.parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
        Integer iterations = Optional.ofNullable(this.parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
        KMeansTrainer trainer = new KMeansTrainer(centroids.intValue(), iterations.intValue(), this.distance, this.numThreads, this.seed);
        KMeansModel kMeansModel = trainer.train(trainDataset);
        List predictions = kMeansModel.predict(trainDataset);
        ArrayList listClusterID = new ArrayList();
        predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", ((ClusterID)e.getOutput()).getID())));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
    }
}

