/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportResponseHandler;

public class MLExecuteTaskRunner
extends MLTaskRunner<MLExecuteTaskRequest, MLExecuteTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLExecuteTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLInputDatasetHandler mlInputDatasetHandler;

    public MLExecuteTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService) {
        super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/execute";
    }

    @Override
    protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(ActionListener<MLExecuteTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLExecuteTaskResponse::new);
    }

    @Override
    protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
        this.threadPool.executor("OPENSEARCH_ML_TASK_THREAD_POOL").execute(() -> {
            try {
                this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
                this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
                this.mlStats.createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
                Input input = request.getInput();
                FunctionName functionName = request.getFunctionName();
                Output output = MLEngine.execute((Input)input);
                MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output);
                listener.onResponse((Object)response);
            }
            catch (Exception e) {
                this.mlStats.createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
                listener.onFailure(e);
            }
            finally {
                this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
            }
        });
    }
}

