/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.distribution;

import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.distribution.Distribution;
import eu.amidst.core.distribution.Multinomial;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.Variable;
import java.util.List;
import java.util.Random;

public class Multinomial_LogisticParents
extends ConditionalDistribution {
    private double[] intercept;
    private double[][] coeffParents;

    public Multinomial_LogisticParents(Variable var1, List<Variable> parents1) {
        if (parents1.size() == 0) {
            throw new UnsupportedOperationException("A multinomial logistic distribution can not be created from a empty set of parents.");
        }
        this.var = var1;
        this.parents = parents1;
        this.intercept = new double[this.var.getNumberOfStates() - 1];
        this.coeffParents = new double[this.var.getNumberOfStates() - 1][this.parents.size()];
        for (int k = 0; k < this.var.getNumberOfStates() - 1; ++k) {
            this.intercept[k] = 0.0;
            this.coeffParents[k] = new double[this.parents.size() + 1];
            for (int i = 0; i < this.parents.size(); ++i) {
                this.coeffParents[k][i] = 1.0;
            }
        }
    }

    @Override
    public double[] getParameters() {
        double[] param = new double[this.getNumberOfParameters()];
        System.arraycopy(this.intercept, 0, param, 0, this.intercept.length);
        int count = this.intercept.length;
        for (int i = 0; i < this.coeffParents.length; ++i) {
            System.arraycopy(this.coeffParents[i], 0, param, count, this.coeffParents[i].length);
            count += this.coeffParents[i].length;
        }
        return new double[0];
    }

    @Override
    public int getNumberOfParameters() {
        int n = 0;
        for (int i = 0; i < this.coeffParents.length; ++i) {
            n += this.getCoeffParents(i).length;
        }
        return n + this.intercept.length;
    }

    public double getIntercept(int state) {
        return this.intercept[state];
    }

    public void setIntercept(int state, double intercept) {
        this.intercept[state] = intercept;
    }

    public double[] getCoeffParents(int state) {
        return this.coeffParents[state];
    }

    public void setCoeffParents(int state, double[] coeffParents) {
        this.coeffParents[state] = coeffParents;
    }

    public Multinomial getMultinomial(Assignment parentsAssignment) {
        double[] probs = new double[this.var.getNumberOfStates()];
        for (int i = 0; i < this.var.getNumberOfStates() - 1; ++i) {
            probs[i] = this.intercept[i];
            int cont = 0;
            for (Variable v : this.parents) {
                int n = i;
                probs[n] = probs[n] + this.coeffParents[i][cont] * parentsAssignment.getValue(v);
                ++cont;
            }
        }
        probs = Utils.logs2probs(probs);
        Multinomial multinomial = new Multinomial(this.var);
        multinomial.setProbabilities(probs);
        return multinomial;
    }

    @Override
    public double getLogConditionalProbability(Assignment assignment) {
        double value = assignment.getValue(this.var);
        return this.getMultinomial(assignment).getLogProbability(value);
    }

    @Override
    public UnivariateDistribution getUnivariateDistribution(Assignment assignment) {
        return this.getMultinomial(assignment);
    }

    @Override
    public String label() {
        return "Multinomial Logistic";
    }

    @Override
    public void randomInitialization(Random random) {
        for (int i = 0; i < this.coeffParents.length; ++i) {
            this.intercept[i] = random.nextGaussian();
            for (int j = 0; j < this.coeffParents[i].length; ++j) {
                this.coeffParents[i][j] = random.nextGaussian();
            }
        }
    }

    @Override
    public String toString() {
        StringBuilder str = new StringBuilder();
        str.append("");
        for (int i = 0; i < this.var.getNumberOfStates() - 1; ++i) {
            str.append("[ alpha = " + this.getIntercept(i));
            for (int j = 0; j < this.getCoeffParents(i).length; ++j) {
                str.append(", beta = " + this.getCoeffParents(i)[j]);
            }
            str.append("]\n");
        }
        return str.toString();
    }

    @Override
    public boolean equalDist(Distribution dist, double threshold) {
        if (dist instanceof Multinomial_LogisticParents) {
            return this.equalDist((Multinomial_LogisticParents)dist, threshold);
        }
        return false;
    }

    public boolean equalDist(Multinomial_LogisticParents dist, double threshold) {
        int i;
        boolean equals = true;
        for (i = 0; i < this.intercept.length; ++i) {
            equals = equals && Math.abs(this.getIntercept(i) - dist.getIntercept(i)) <= threshold;
        }
        if (equals) {
            for (i = 0; i < this.coeffParents.length; ++i) {
                for (int j = 0; j < this.coeffParents[i].length; ++j) {
                    equals = equals && Math.abs(this.coeffParents[i][j] - dist.coeffParents[i][j]) <= threshold;
                }
            }
        }
        return equals;
    }
}

