/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes.net.estimate;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
import weka.classifiers.bayes.net.estimate.DiscreteEstimatorFullBayes;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.classifiers.bayes.net.search.local.K2;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Statistics;
import weka.core.Utils;
import weka.estimators.Estimator;

public class MultiNomialBMAEstimator
extends BayesNetEstimator {
    protected boolean m_bUseK2Prior = true;

    public void estimateCPTs(BayesNet bayesNet) throws Exception {
        int n;
        Serializable serializable;
        Serializable serializable2;
        int n2;
        this.initCPTs(bayesNet);
        for (int i = 0; i < bayesNet.m_Instances.numAttributes(); ++i) {
            if (bayesNet.getParentSet(i).getNrOfParents() <= 1) continue;
            throw new Exception("Cannot handle networks with nodes with more than 1 parent (yet).");
        }
        Instances instances = new Instances(bayesNet.m_Instances);
        while (instances.numInstances() > 0) {
            instances.delete(0);
        }
        for (n2 = instances.numAttributes() - 1; n2 >= 0; --n2) {
            if (n2 == instances.classIndex()) continue;
            serializable2 = new FastVector();
            ((FastVector)serializable2).addElement("0");
            ((FastVector)serializable2).addElement("1");
            serializable = new Attribute(instances.attribute(n2).name(), (FastVector)serializable2);
            instances.deleteAttributeAt(n2);
            instances.insertAttributeAt((Attribute)serializable, n2);
        }
        for (n2 = 0; n2 < bayesNet.m_Instances.numInstances(); ++n2) {
            serializable2 = bayesNet.m_Instances.instance(n2);
            serializable = new Instance(instances.numAttributes());
            for (n = 0; n < instances.numAttributes(); ++n) {
                if (n != instances.classIndex()) {
                    if (!(((Instance)serializable2).value(n) > 0.0)) continue;
                    ((Instance)serializable).setValue(n, 1.0);
                    continue;
                }
                ((Instance)serializable).setValue(n, ((Instance)serializable2).value(n));
            }
        }
        BayesNet bayesNet2 = new BayesNet();
        serializable2 = new K2();
        ((K2)serializable2).setInitAsNaiveBayes(false);
        ((K2)serializable2).setMaxNrOfParents(0);
        bayesNet2.setSearchAlgorithm((SearchAlgorithm)serializable2);
        bayesNet2.buildClassifier(instances);
        serializable = new BayesNet();
        ((K2)serializable2).setInitAsNaiveBayes(true);
        ((K2)serializable2).setMaxNrOfParents(1);
        ((BayesNet)serializable).setSearchAlgorithm((SearchAlgorithm)serializable2);
        ((BayesNet)serializable).buildClassifier(instances);
        for (n = 0; n < instances.numAttributes(); ++n) {
            int n3;
            int n4;
            int n5;
            if (n == instances.classIndex()) continue;
            double d = 0.0;
            double d2 = 0.0;
            int n6 = instances.attribute(n).numValues();
            if (this.m_bUseK2Prior) {
                for (n5 = 0; n5 < n6; ++n5) {
                    d += Statistics.lnGamma(1.0 + ((DiscreteEstimatorBayes)bayesNet2.m_Distributions[n][0]).getCount(n5)) - Statistics.lnGamma(1.0);
                }
                d += Statistics.lnGamma(n6) - Statistics.lnGamma(n6 + instances.numInstances());
                for (n5 = 0; n5 < bayesNet.getParentSet(n).getCardinalityOfParents(); ++n5) {
                    n4 = 0;
                    for (n3 = 0; n3 < n6; ++n3) {
                        double d3 = ((DiscreteEstimatorBayes)((BayesNet)serializable).m_Distributions[n][n5]).getCount(n3);
                        d2 += Statistics.lnGamma(1.0 + d3) - Statistics.lnGamma(1.0);
                        n4 = (int)((double)n4 + d3);
                    }
                    d2 += Statistics.lnGamma(n6) - Statistics.lnGamma(n6 + n4);
                }
            } else {
                for (n5 = 0; n5 < n6; ++n5) {
                    d += Statistics.lnGamma(1.0 / (double)n6 + ((DiscreteEstimatorBayes)bayesNet2.m_Distributions[n][0]).getCount(n5)) - Statistics.lnGamma(1.0 / (double)n6);
                }
                d += Statistics.lnGamma(1.0) - Statistics.lnGamma(1 + instances.numInstances());
                n5 = bayesNet.getParentSet(n).getCardinalityOfParents();
                for (n4 = 0; n4 < n5; ++n4) {
                    n3 = 0;
                    for (int i = 0; i < n6; ++i) {
                        double d4 = ((DiscreteEstimatorBayes)((BayesNet)serializable).m_Distributions[n][n4]).getCount(i);
                        d2 += Statistics.lnGamma(1.0 / (double)(n6 * n5) + d4) - Statistics.lnGamma(1.0 / (double)(n6 * n5));
                        n3 = (int)((double)n3 + d4);
                    }
                    d2 += Statistics.lnGamma(1.0) - Statistics.lnGamma(1 + n3);
                }
            }
            if (d < d2) {
                d2 -= d;
                d = 0.0;
                d = 1.0 / (1.0 + Math.exp(d2));
                d2 = Math.exp(d2) / (1.0 + Math.exp(d2));
            } else {
                d -= d2;
                d2 = 0.0;
                d2 = 1.0 / (1.0 + Math.exp(d));
                d = Math.exp(d) / (1.0 + Math.exp(d));
            }
            for (n5 = 0; n5 < bayesNet.getParentSet(n).getCardinalityOfParents(); ++n5) {
                bayesNet.m_Distributions[n][n5] = new DiscreteEstimatorFullBayes(instances.attribute(n).numValues(), d, d2, (DiscreteEstimatorBayes)bayesNet2.m_Distributions[n][0], (DiscreteEstimatorBayes)((BayesNet)serializable).m_Distributions[n][n5], this.m_fAlpha);
            }
        }
        n = instances.classIndex();
        bayesNet.m_Distributions[n][0] = bayesNet2.m_Distributions[n][0];
    }

    public void updateClassifier(BayesNet bayesNet, Instance instance) throws Exception {
        throw new Exception("updateClassifier does not apply to BMA estimator");
    }

    public void initCPTs(BayesNet bayesNet) throws Exception {
        bayesNet.m_Distributions = new Estimator[bayesNet.m_Instances.numAttributes()][2];
    }

    public boolean isUseK2Prior() {
        return this.m_bUseK2Prior;
    }

    public void setUseK2Prior(boolean bl) {
        this.m_bUseK2Prior = bl;
    }

    public double[] distributionForInstance(BayesNet bayesNet, Instance instance) throws Exception {
        int n;
        int n2;
        Instances instances = bayesNet.m_Instances;
        int n3 = instances.numClasses();
        double[] dArray = new double[n3];
        for (n2 = 0; n2 < n3; ++n2) {
            dArray[n2] = 1.0;
        }
        n2 = 0;
        while (n2 < n3) {
            double d = 0.0;
            for (int i = 0; i < instances.numAttributes(); ++i) {
                double d2 = 0.0;
                for (int j = 0; j < bayesNet.getParentSet(i).getNrOfParents(); ++j) {
                    int n4 = bayesNet.getParentSet(i).getParent(j);
                    d2 = n4 == instances.classIndex() ? d2 * (double)n3 + (double)n2 : d2 * (double)instances.attribute(n4).numValues() + instance.value(n4);
                }
                if (i == instances.classIndex()) {
                    d += Math.log(bayesNet.m_Distributions[i][(int)d2].getProbability(n2));
                    continue;
                }
                d += instance.value(i) * Math.log(bayesNet.m_Distributions[i][(int)d2].getProbability(instance.value(1)));
            }
            int n5 = n2++;
            dArray[n5] = dArray[n5] + d;
        }
        double d = dArray[0];
        for (n = 0; n < n3; ++n) {
            if (!(dArray[n] > d)) continue;
            d = dArray[n];
        }
        for (n = 0; n < n3; ++n) {
            dArray[n] = Math.exp(dArray[n] - d);
        }
        Utils.normalize(dArray);
        return dArray;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(1);
        vector.addElement(new Option("\tWhether to use K2 prior.\n", "k2", 0, "-k2"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setUseK2Prior(Utils.getFlag("k2", stringArray));
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = super.getOptions();
        String[] stringArray2 = new String[1 + stringArray.length];
        int n = 0;
        if (this.isUseK2Prior()) {
            stringArray2[n++] = "-k2";
        }
        for (int i = 0; i < stringArray.length; ++i) {
            stringArray2[n++] = stringArray[i];
        }
        while (n < stringArray2.length) {
            stringArray2[n++] = "";
        }
        return stringArray2;
    }
}

