/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RBooleanVector;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.XLevelsFormulaContext;

public class MultinomConverter
extends ModelConverter<RGenericVector> {
    public MultinomConverter(RGenericVector multinom) {
        super(multinom);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector multinom = (RGenericVector)this.getObject();
        RStringVector lev = multinom.getStringElement("lev");
        RExp terms = (RExp)multinom.getElement("terms");
        RGenericVector xlevels = multinom.getGenericElement("xlevels");
        RStringVector vcoefnames = multinom.getStringElement("vcoefnames");
        XLevelsFormulaContext context = new XLevelsFormulaContext(xlevels);
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        FormulaUtil.setLabel(formula, terms, lev, encoder);
        List<String> names = FormulaUtil.removeSpecialSymbol(vcoefnames.getValues(), "(Intercept)", 0);
        FormulaUtil.addFeatures(formula, names, true, encoder);
    }

    public RegressionModel encodeModel(Schema schema) {
        RGenericVector multinom = (RGenericVector)this.getObject();
        RDoubleVector n = multinom.getDoubleElement("n");
        RBooleanVector softmax = multinom.getBooleanElement("softmax");
        RBooleanVector censored = multinom.getBooleanElement("censored");
        RDoubleVector wts = multinom.getDoubleElement("wts");
        if (n.size() != 3) {
            throw new IllegalArgumentException();
        }
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        if (categoricalLabel.size() == 2) {
            SchemaUtil.checkSize((int)(wts.size() - 2), (List)features);
            int offset = 1;
            List<Double> coefficients = wts.getValues().subList(offset + 1, offset + 1 + features.size());
            Double intercept = wts.getValue(offset);
            return RegressionModelUtil.createBinaryLogisticClassification((List)features, coefficients, (Number)intercept, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema);
        }
        if (categoricalLabel.size() > 2) {
            SchemaUtil.checkSize((int)(wts.size() - 2 * categoricalLabel.size()), (CategoricalLabel)categoricalLabel, (List)features);
            if (softmax != null && ((Boolean)softmax.asScalar()).booleanValue() && censored != null && ((Boolean)censored.asScalar()).booleanValue()) {
                throw new IllegalArgumentException();
            }
            ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
            RegressionTable regressionTable = new RegressionTable((Number)0.0).setTargetCategory(categoricalLabel.getValue(0));
            regressionTables.add(regressionTable);
            for (int i = 1; i < categoricalLabel.size(); ++i) {
                List categoryWts = CMatrixUtil.getRow(wts.getValues(), (int)categoricalLabel.size(), (int)(1 + (features.size() + 1)), (int)i);
                List coefficients = categoryWts.subList(2, 2 + features.size());
                Double intercept = (Double)categoryWts.get(1);
                RegressionTable regressionTable2 = RegressionModelUtil.createRegressionTable((List)features, coefficients, (Number)intercept).setTargetCategory(categoricalLabel.getValue(i));
                regressionTables.add(regressionTable2);
            }
            RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
            return regressionModel;
        }
        throw new IllegalArgumentException();
    }
}

