/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.ad;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.ad.AnomalyDetectionLibSVMParams;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.Trainable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.opensearch.ml.engine.utils.TribuoUtil;
import org.tribuo.MutableDataset;
import org.tribuo.anomaly.AnomalyFactory;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.libsvm.LibSVMAnomalyModel;
import org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer;
import org.tribuo.anomaly.libsvm.SVMAnomalyType;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.SVMParameters;
import org.tribuo.common.libsvm.SVMType;

@Function(value=FunctionName.AD_LIBSVM)
public class AnomalyDetectionLibSVM
implements Trainable,
Predictable {
    public static final String VERSION = "1.0.0";
    private static double DEFAULT_GAMMA = 1.0;
    private static double DEFAULT_NU = 0.1;
    private static KernelType DEFAULT_KERNEL_TYPE = KernelType.RBF;
    private AnomalyDetectionLibSVMParams parameters;
    private LibSVMModel libSVMAnomalyModel = null;

    public AnomalyDetectionLibSVM() {
    }

    public AnomalyDetectionLibSVM(MLAlgoParams parameters) {
        this.parameters = parameters == null ? AnomalyDetectionLibSVMParams.builder().build() : (AnomalyDetectionLibSVMParams)parameters;
        this.validateParameters();
    }

    private void validateParameters() {
        if (this.parameters.getGamma() != null && this.parameters.getGamma() <= 0.0) {
            throw new IllegalArgumentException("gamma should be positive.");
        }
        if (this.parameters.getNu() != null && this.parameters.getNu() <= 0.0) {
            throw new IllegalArgumentException("nu should be positive.");
        }
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        this.libSVMAnomalyModel = (LibSVMModel)ModelSerDeSer.deserialize(model);
    }

    @Override
    public void close() {
        this.libSVMAnomalyModel = null;
    }

    @Override
    public boolean isModelReady() {
        return this.libSVMAnomalyModel != null;
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        MLInputDataset inputDataset = mlInput.getInputDataset();
        DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame();
        if (this.libSVMAnomalyModel == null) {
            throw new IllegalArgumentException("model not deployed");
        }
        MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
        List predictions = this.libSVMAnomalyModel.predict(predictionDataset);
        ArrayList adResults = new ArrayList();
        predictions.forEach(e -> {
            HashMap<String, Object> result = new HashMap<String, Object>();
            result.put("score", ((Event)e.getOutput()).getScore());
            result.put("anomaly_type", ((Event)e.getOutput()).getType().name());
            adResults.add(result);
        });
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
    }

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for KMeans prediction.");
        }
        this.libSVMAnomalyModel = (LibSVMModel)ModelSerDeSer.deserialize(model);
        return this.predict(mlInput);
    }

    @Override
    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        KernelType kernelType = this.parseKernelType();
        SVMParameters params = new SVMParameters((SVMType)new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
        Double gamma = Optional.ofNullable(this.parameters.getGamma()).orElse(DEFAULT_GAMMA);
        Double nu = Optional.ofNullable(this.parameters.getNu()).orElse(DEFAULT_NU);
        params.setGamma(gamma.doubleValue());
        params.setNu(nu.doubleValue());
        if (this.parameters.getCost() != null) {
            params.setCost(this.parameters.getCost().doubleValue());
        }
        if (this.parameters.getCoeff() != null) {
            params.setCoeff(this.parameters.getCoeff().doubleValue());
        }
        if (this.parameters.getEpsilon() != null) {
            params.setEpsilon(this.parameters.getEpsilon().doubleValue());
        }
        if (this.parameters.getDegree() != null) {
            params.setDegree(this.parameters.getDegree().intValue());
        }
        MutableDataset data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
        LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);
        LibSVMModel libSVMModel = trainer.train(data);
        ((LibSVMAnomalyModel)libSVMModel).getNumberOfSupportVectors();
        MLModel model = MLModel.builder().name(FunctionName.AD_LIBSVM.name()).algorithm(FunctionName.AD_LIBSVM).version(VERSION).content(ModelSerDeSer.serializeToBase64(libSVMModel)).modelState(MLModelState.TRAINED).build();
        return model;
    }

    private KernelType parseKernelType() {
        KernelType kernelType = DEFAULT_KERNEL_TYPE;
        if (this.parameters.getKernelType() == null) {
            return kernelType;
        }
        switch (this.parameters.getKernelType()) {
            case LINEAR: {
                kernelType = KernelType.LINEAR;
                break;
            }
            case POLY: {
                kernelType = KernelType.POLY;
                break;
            }
            case RBF: {
                kernelType = KernelType.RBF;
                break;
            }
            case SIGMOID: {
                kernelType = KernelType.SIGMOID;
                break;
            }
        }
        return kernelType;
    }
}

