/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.active;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import java.util.LinkedList;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.active.ALClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

public class ALUncertainty
extends AbstractClassifier
implements ALClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "drift.SingleClassifierDrift");
    public MultiChoiceOption activeLearningStrategyOption = new MultiChoiceOption("activeLearningStrategy", 'd', "Active Learning Strategy to use.", new String[]{"FixedUncertainty", "VarUncertainty", "RandVarUncertainty", "SelSampling"}, new String[]{"Fixed uncertainty strategy", "Uncertainty strategy with variable threshold", "Uncertainty strategy with randomized variable threshold", "Selective Sampling"}, 0);
    public FloatOption budgetOption = new FloatOption("budget", 'b', "Budget to use.", 0.1, 0.0, 1.0);
    public FloatOption fixedThresholdOption = new FloatOption("fixedThreshold", 'u', "Fixed threshold.", 0.9, 0.0, 1.0);
    public FloatOption stepOption = new FloatOption("step", 's', "Floating budget step.", 0.01, 0.0, 1.0);
    public FloatOption numInstancesInitOption = new FloatOption("numInstancesInit", 'n', "Number of instances at beginning without active learning.", 0.0, 0.0, 2.147483647E9);
    public Classifier classifier;
    public int lastLabelAcq = 0;
    public int costLabeling;
    public int iterationControl;
    public double newThreshold;
    public double maxPosterior;
    public double accuracyBaseLearner;
    private double outPosterior;

    @Override
    public String getPurposeString() {
        return "Active learning classifier for evolving data streams based on uncertainty";
    }

    private double getMaxPosterior(double[] incomingPrediction) {
        if (incomingPrediction.length > 1) {
            DoubleVector vote = new DoubleVector(incomingPrediction);
            if (vote.sumOfValues() > 0.0) {
                vote.normalize();
            }
            incomingPrediction = vote.getArrayRef();
            this.outPosterior = incomingPrediction[Utils.maxIndex(incomingPrediction)];
        } else {
            this.outPosterior = 0.0;
        }
        return this.outPosterior;
    }

    private void labelFixed(double incomingPosterior, Instance inst) {
        if (incomingPosterior < this.fixedThresholdOption.getValue()) {
            this.classifier.trainOnInstance(inst);
            ++this.costLabeling;
            ++this.lastLabelAcq;
        }
    }

    private void labelVar(double incomingPosterior, Instance inst) {
        if (incomingPosterior < this.newThreshold) {
            this.classifier.trainOnInstance(inst);
            ++this.costLabeling;
            ++this.lastLabelAcq;
            this.newThreshold *= 1.0 - this.stepOption.getValue();
        } else {
            this.newThreshold *= 1.0 + this.stepOption.getValue();
        }
    }

    private void labelSelSampling(double incomingPosterior, Instance inst) {
        double p = Math.abs(incomingPosterior - 1.0 / (double)inst.numClasses());
        double budget = this.budgetOption.getValue() / (this.budgetOption.getValue() + p);
        if (this.classifierRandom.nextDouble() < budget) {
            this.classifier.trainOnInstance(inst);
            ++this.costLabeling;
            ++this.lastLabelAcq;
        }
    }

    @Override
    public void resetLearningImpl() {
        this.classifier = ((Classifier)this.getPreparedClassOption(this.baseLearnerOption)).copy();
        this.classifier.resetLearning();
        this.costLabeling = 0;
        this.iterationControl = 0;
        this.newThreshold = 1.0;
        this.accuracyBaseLearner = 0.0;
        this.lastLabelAcq = 0;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        ++this.iterationControl;
        if ((double)this.iterationControl <= this.numInstancesInitOption.getValue()) {
            double costNow = 0.0;
            this.classifier.trainOnInstance(inst);
            ++this.costLabeling;
            return;
        }
        double costNow = ((double)this.costLabeling - this.numInstancesInitOption.getValue()) / ((double)this.iterationControl - this.numInstancesInitOption.getValue());
        if (costNow < this.budgetOption.getValue()) {
            switch (this.activeLearningStrategyOption.getChosenIndex()) {
                case 0: {
                    this.maxPosterior = this.getMaxPosterior(this.classifier.getVotesForInstance(inst));
                    this.labelFixed(this.maxPosterior, inst);
                    break;
                }
                case 1: {
                    this.maxPosterior = this.getMaxPosterior(this.classifier.getVotesForInstance(inst));
                    this.labelVar(this.maxPosterior, inst);
                    break;
                }
                case 2: {
                    this.maxPosterior = this.getMaxPosterior(this.classifier.getVotesForInstance(inst));
                    this.maxPosterior /= this.classifierRandom.nextGaussian() + 1.0;
                    this.labelVar(this.maxPosterior, inst);
                    break;
                }
                case 3: {
                    this.maxPosterior = this.getMaxPosterior(this.classifier.getVotesForInstance(inst));
                    this.labelSelSampling(this.maxPosterior, inst);
                }
            }
        }
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        return this.classifier.getVotesForInstance(inst);
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        ((AbstractClassifier)this.classifier).getModelDescription(out, indent);
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        LinkedList<Measurement> measurementList = new LinkedList<Measurement>();
        measurementList.add(new Measurement("labeling cost", this.costLabeling));
        measurementList.add(new Measurement("newThreshold", this.newThreshold));
        measurementList.add(new Measurement("maxPosterior", this.maxPosterior));
        measurementList.add(new Measurement("accuracyBaseLearner (percent)", 100.0 * this.accuracyBaseLearner / (double)this.costLabeling));
        Measurement[] modelMeasurements = ((AbstractClassifier)this.classifier).getModelMeasurements();
        if (modelMeasurements != null) {
            for (Measurement measurement : modelMeasurements) {
                measurementList.add(measurement);
            }
        }
        return measurementList.toArray(new Measurement[measurementList.size()]);
    }

    @Override
    public int getLastLabelAcqReport() {
        int help = this.lastLabelAcq;
        this.lastLabelAcq = 0;
        return help;
    }

    @Override
    public void setModelContext(InstancesHeader ih) {
        super.setModelContext(ih);
        this.classifier.setModelContext(ih);
    }
}

