/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.Config;
import java.util.logging.Logger;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.LabelObjective;
import org.tribuo.classification.sgd.fm.FMClassificationModel;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.common.sgd.AbstractFMTrainer;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.common.sgd.SGDObjective;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.provenance.ModelProvenance;

public class FMClassificationTrainer
extends AbstractFMTrainer<Label, Integer, FMClassificationModel> {
    private static final Logger logger = Logger.getLogger(FMClassificationTrainer.class.getName());
    @Config(description="The classification objective function to use.")
    private LabelObjective objective = new LogMulticlass();

    public FMClassificationTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) {
        super(optimiser, epochs, loggingInterval, minibatchSize, seed, factorizedDimSize, variance);
        this.objective = objective;
    }

    public FMClassificationTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed, int factorizedDimSize, double variance) {
        this(objective, optimiser, epochs, loggingInterval, 1, seed, factorizedDimSize, variance);
    }

    public FMClassificationTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed, int factorizedDimSize, double variance) {
        this(objective, optimiser, epochs, 1000, 1, seed, factorizedDimSize, variance);
    }

    private FMClassificationTrainer() {
    }

    protected Integer getTarget(ImmutableOutputInfo<Label> outputInfo, Label output) {
        return outputInfo.getID((Output)output);
    }

    protected SGDObjective<Integer> getObjective() {
        return this.objective;
    }

    protected FMClassificationModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, FMParameters parameters) {
        return new FMClassificationModel(name, provenance, featureMap, outputInfo, parameters, this.objective.getNormalizer(), this.objective.isProbabilistic());
    }

    protected String getModelClassName() {
        return FMClassificationModel.class.getName();
    }

    public String toString() {
        return "FMClassificationTrainer(objective=" + this.objective.toString() + ",optimiser=" + this.optimiser.toString() + ",epochs=" + this.epochs + ",minibatchSize=" + this.minibatchSize + ",seed=" + this.seed + ",factorizedDimSize=" + this.factorizedDimSize + ",variance=" + this.variance + ")";
    }
}

