/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.maxent.quasinewton.ArrayMath;
import opennlp.tools.ml.maxent.quasinewton.NegLogLikelihood;
import opennlp.tools.ml.maxent.quasinewton.ParallelNegLogLikelihood;
import opennlp.tools.ml.maxent.quasinewton.QNMinimizer;
import opennlp.tools.ml.maxent.quasinewton.QNModel;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.util.TrainingParameters;

public class QNTrainer
extends AbstractEventTrainer {
    public static final String MAXENT_QN_VALUE = "MAXENT_QN";
    public static final String THREADS_PARAM = "Threads";
    public static final int THREADS_DEFAULT = 1;
    public static final String L1COST_PARAM = "L1Cost";
    public static final double L1COST_DEFAULT = 0.1;
    public static final String L2COST_PARAM = "L2Cost";
    public static final double L2COST_DEFAULT = 0.1;
    public static final String M_PARAM = "NumOfUpdates";
    public static final int M_DEFAULT = 15;
    public static final String MAX_FCT_EVAL_PARAM = "MaxFctEval";
    public static final int MAX_FCT_EVAL_DEFAULT = 30000;
    private int threads;
    private double l1Cost;
    private double l2Cost;
    private int m;
    private int maxFctEval;

    public QNTrainer(TrainingParameters parameters) {
        super(parameters);
    }

    public QNTrainer(boolean printMessages) {
        this(15, printMessages);
    }

    public QNTrainer(int m) {
        this(m, true);
    }

    public QNTrainer(int m, boolean verbose) {
        this(m, 30000, verbose);
    }

    public QNTrainer(int m, int maxFctEval, boolean printMessages) {
        this.printMessages = printMessages;
        this.m = m < 0 ? 15 : m;
        this.maxFctEval = maxFctEval < 0 ? 30000 : maxFctEval;
        this.threads = 1;
        this.l1Cost = 0.1;
        this.l2Cost = 0.1;
    }

    public QNTrainer() {
    }

    @Override
    public void init(TrainingParameters trainingParameters, Map<String, String> reportMap) {
        super.init(trainingParameters, reportMap);
        this.m = trainingParameters.getIntParameter(M_PARAM, 15);
        this.maxFctEval = trainingParameters.getIntParameter(MAX_FCT_EVAL_PARAM, 30000);
        this.threads = trainingParameters.getIntParameter(THREADS_PARAM, 1);
        this.l1Cost = trainingParameters.getDoubleParameter(L1COST_PARAM, 0.1);
        this.l2Cost = trainingParameters.getDoubleParameter(L2COST_PARAM, 0.1);
    }

    @Override
    @Deprecated
    public void init(Map<String, String> trainParams, Map<String, String> reportMap) {
        this.init(new TrainingParameters(trainParams), reportMap);
    }

    @Override
    public void validate() {
        super.validate();
        String algorithmName = this.getAlgorithm();
        if (algorithmName != null && !MAXENT_QN_VALUE.equals(algorithmName)) {
            throw new IllegalArgumentException("algorithmName must be MAXENT_QN");
        }
        if (this.m < 0) {
            throw new IllegalArgumentException("Number of Hessian updates to remember must be >= 0");
        }
        if (this.maxFctEval < 0) {
            throw new IllegalArgumentException("Maximum number of function evaluations must be >= 0");
        }
        if (this.threads < 1) {
            throw new IllegalArgumentException("Number of threads must be >= 1");
        }
        if (this.l1Cost < 0.0) {
            throw new IllegalArgumentException("Regularization costs must be >= 0");
        }
        if (this.l2Cost < 0.0) {
            throw new IllegalArgumentException("Regularization costs must be >= 0");
        }
    }

    @Override
    @Deprecated
    public boolean isValid() {
        try {
            this.validate();
            return true;
        }
        catch (IllegalArgumentException e) {
            return false;
        }
    }

    @Override
    public boolean isSortAndMerge() {
        return true;
    }

    @Override
    public AbstractModel doTrain(DataIndexer indexer) throws IOException {
        int iterations = this.getIterations();
        return this.trainModel(iterations, indexer);
    }

    public QNModel trainModel(int iterations, DataIndexer indexer) {
        NegLogLikelihood objectiveFunction;
        if (this.threads == 1) {
            System.out.println("Computing model parameters ...");
            objectiveFunction = new NegLogLikelihood(indexer);
        } else {
            System.out.println("Computing model parameters in " + this.threads + " threads ...");
            objectiveFunction = new ParallelNegLogLikelihood(indexer, this.threads);
        }
        QNMinimizer minimizer = new QNMinimizer(this.l1Cost, this.l2Cost, iterations, this.m, this.maxFctEval, this.printMessages);
        minimizer.setEvaluator(new ModelEvaluator(indexer));
        double[] parameters = minimizer.minimize(objectiveFunction);
        String[] predLabels = indexer.getPredLabels();
        int nPredLabels = predLabels.length;
        String[] outcomeNames = indexer.getOutcomeLabels();
        int nOutcomes = outcomeNames.length;
        Context[] params = new Context[nPredLabels];
        for (int ci = 0; ci < params.length; ++ci) {
            ArrayList<Integer> outcomePattern = new ArrayList<Integer>(nOutcomes);
            ArrayList<Double> alpha = new ArrayList<Double>(nOutcomes);
            for (int oi = 0; oi < nOutcomes; ++oi) {
                double val = parameters[oi * nPredLabels + ci];
                outcomePattern.add(oi);
                alpha.add(val);
            }
            params[ci] = new Context(ArrayMath.toIntArray(outcomePattern), ArrayMath.toDoubleArray(alpha));
        }
        return new QNModel(params, predLabels, outcomeNames);
    }

    private static class ModelEvaluator
    implements QNMinimizer.Evaluator {
        private DataIndexer indexer;

        public ModelEvaluator(DataIndexer indexer) {
            this.indexer = indexer;
        }

        @Override
        public double evaluate(double[] parameters) {
            int[][] contexts = this.indexer.getContexts();
            float[][] values = this.indexer.getValues();
            int[] nEventsSeen = this.indexer.getNumTimesEventsSeen();
            int[] outcomeList = this.indexer.getOutcomeList();
            int nOutcomes = this.indexer.getOutcomeLabels().length;
            int nPredLabels = this.indexer.getPredLabels().length;
            int nCorrect = 0;
            int nTotalEvents = 0;
            for (int ei = 0; ei < contexts.length; ++ei) {
                int[] context = contexts[ei];
                float[] value = values == null ? null : values[ei];
                double[] probs = new double[nOutcomes];
                QNModel.eval(context, value, probs, nOutcomes, nPredLabels, parameters);
                int outcome = ArrayMath.maxIdx(probs);
                if (outcome == outcomeList[ei]) {
                    nCorrect += nEventsSeen[ei];
                }
                nTotalEvents += nEventsSeen[ei];
            }
            return (double)nCorrect / (double)nTotalEvents;
        }
    }
}

