/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.exponentialfamily;

import eu.amidst.core.distribution.ConditionalLinearGaussian;
import eu.amidst.core.exponentialfamily.EF_ConditionalDistribution;
import eu.amidst.core.exponentialfamily.EF_Normal;
import eu.amidst.core.exponentialfamily.EF_Normal_Normal_Gamma;
import eu.amidst.core.exponentialfamily.MomentParameters;
import eu.amidst.core.exponentialfamily.NaturalParameters;
import eu.amidst.core.exponentialfamily.ParameterVariables;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.utils.Vector;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.DistributionType;
import eu.amidst.core.variables.Variable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;

public class EF_Normal_NormalParents
extends EF_ConditionalDistribution {
    int sizeSS;
    int nOfParents;
    double[] betas;
    double beta0;
    double variance;

    public EF_Normal_NormalParents(Variable var_, List<Variable> parents_) {
        this.var = var_;
        this.parents = parents_;
        if (!var_.isNormal()) {
            throw new UnsupportedOperationException("Creating a Normal|Normal EF distribution for a non-gaussian child variable.");
        }
        for (Variable v : this.parents) {
            if (v.isNormal()) continue;
            throw new UnsupportedOperationException("Creating a Normal|Normal EF distribution for a non-gaussian parent variable.");
        }
        this.nOfParents = this.parents.size();
        this.sizeSS = this.nOfParents * this.nOfParents + 3 * this.nOfParents + 2;
        this.var = var_;
        this.momentParameters = null;
        this.naturalParameters = null;
        this.beta0 = 0.0;
        this.variance = 0.0;
        this.betas = new double[this.parents.size()];
    }

    @Override
    public void updateNaturalFromMomentParameters() {
        CompoundVector globalMomentParam = (CompoundVector)this.momentParameters;
        double mean_X = globalMomentParam.getXYbaseMatrix().getEntry(0);
        RealVector mean_Y = globalMomentParam.getTheta_beta0BetaRV();
        double cov_XX = globalMomentParam.getcovbaseMatrix().getEntry(0, 0) - mean_X * mean_X;
        RealMatrix cov_YY = globalMomentParam.getcovbaseMatrix().getSubMatrix(1, this.nOfParents, 1, this.nOfParents).subtract(mean_Y.outerProduct(mean_Y));
        RealVector cov_XY = globalMomentParam.getcovbaseMatrix().getSubMatrix(0, 0, 1, this.nOfParents).getRowVector(0).subtract(mean_Y.mapMultiply(mean_X));
        RealMatrix cov_YYInverse = new LUDecompositionImpl(cov_YY).getSolver().getInverse();
        RealVector beta = cov_YYInverse.preMultiply(cov_XY);
        this.betas = new double[beta.getDimension()];
        for (int i = 0; i < beta.getDimension(); ++i) {
            this.betas[i] = beta.getEntry(i);
        }
        this.beta0 = mean_X - beta.dotProduct(mean_Y);
        this.variance = cov_XX - beta.dotProduct(cov_XY);
    }

    public void setBetas(double[] betas) {
        this.betas = betas;
    }

    public void setBeta0(double beta0) {
        this.beta0 = beta0;
    }

    public void setVariance(double variance) {
        this.variance = variance;
    }

    @Override
    public void updateMomentFromNaturalParameters() {
    }

    @Override
    public NaturalParameters getNaturalParameters() {
        CompoundVector naturalParametersCompound = this.createEmtpyCompoundVector();
        double theta_0 = this.beta0 / this.variance;
        naturalParametersCompound.setThetaBeta0_NatParam(theta_0);
        double variance2Inv = 1.0 / (2.0 * this.variance);
        double[] theta0_beta = Arrays.stream(this.betas).map(w -> -w * this.beta0 / this.variance).toArray();
        naturalParametersCompound.setThetaBeta0Beta_NatParam(theta0_beta);
        double theta_Minus1 = -variance2Inv;
        naturalParametersCompound.setThetaCov_NatParam(theta_Minus1, this.betas, variance2Inv);
        this.naturalParameters = naturalParametersCompound;
        return this.naturalParameters;
    }

    @Override
    public SufficientStatistics getSufficientStatistics(Assignment data) {
        CompoundVector vectorSS = this.createEmtpyCompoundVector();
        double[] Xarray = new double[]{data.getValue(this.var)};
        double[] Yarray = this.parents.stream().mapToDouble(w -> data.getValue((Variable)w)).toArray();
        ArrayRealVector XYRealVector = new ArrayRealVector(Xarray, Yarray);
        vectorSS.setXYbaseVector(XYRealVector);
        RealMatrix covRealmatrix = XYRealVector.outerProduct((RealVector)XYRealVector);
        vectorSS.setcovbaseVector(covRealmatrix);
        return vectorSS;
    }

    @Override
    public int sizeOfSufficientStatistics() {
        return this.sizeSS;
    }

    @Override
    public double computeLogBaseMeasure(Assignment dataInstance) {
        return -0.5 * Math.log(Math.PI * 2);
    }

    @Override
    public double computeLogNormalizer() {
        return this.beta0 * this.beta0 / (2.0 * this.variance) + 0.5 * Math.log(this.variance);
    }

    @Override
    public CompoundVector createZeroVector() {
        return new CompoundVector(this.nOfParents);
    }

    @Override
    public SufficientStatistics createInitSufficientStatistics() {
        CompoundVector vectorSS = this.createEmtpyCompoundVector();
        double[] Xarray = new double[]{0.0};
        double[] Yarray = this.parents.stream().mapToDouble(w -> 0.0).toArray();
        ArrayRealVector XYRealVector = new ArrayRealVector(Xarray, Yarray);
        vectorSS.setXYbaseVector(XYRealVector);
        Array2DRowRealMatrix covRealmatrix = new Array2DRowRealMatrix(Yarray.length + 1, Yarray.length + 1);
        Random rand = new Random(0L);
        for (int i = 0; i < Yarray.length + 1; ++i) {
            for (int j = 0; j < Yarray.length + 1; ++j) {
                covRealmatrix.addToEntry(i, j, rand.nextDouble() + 0.01);
            }
        }
        vectorSS.setcovbaseVector(covRealmatrix);
        return vectorSS;
    }

    @Override
    public double getExpectedLogNormalizer(Variable parent, Map<Variable, MomentParameters> momentChildCoParents) {
        throw new UnsupportedOperationException("No Implemented. This method is no really needed");
    }

    @Override
    public double getExpectedLogNormalizer(Map<Variable, MomentParameters> momentParents) {
        int nOfBetas = this.betas.length;
        double dotProductBetaY = 0.0;
        double sumSquaredMoments = 0.0;
        double sumSquaredMeanMoments = 0.0;
        for (int i = 0; i < nOfBetas; ++i) {
            dotProductBetaY += momentParents.get(this.parents.get(i)).get(0) * this.betas[i];
            sumSquaredMoments += this.betas[i] * this.betas[i] * momentParents.get(this.parents.get(i)).get(1);
            sumSquaredMeanMoments += Math.pow(this.betas[i] * momentParents.get(this.parents.get(i)).get(0), 2.0);
        }
        double beta0Squared = this.beta0 * this.beta0;
        double invVariance = 1.0 / this.variance;
        double logVar = Math.log(this.variance);
        return -0.5 * logVar + 0.5 * invVariance * (beta0Squared + dotProductBetaY * dotProductBetaY - sumSquaredMeanMoments + sumSquaredMoments + 2.0 * this.beta0 * dotProductBetaY);
    }

    @Override
    public NaturalParameters getExpectedNaturalFromParents(Map<Variable, MomentParameters> momentParents) {
        int nOfBetas = this.betas.length;
        double dotProductBetaY = 0.0;
        for (int i = 0; i < nOfBetas; ++i) {
            dotProductBetaY += momentParents.get(this.parents.get(i)).get(0) * this.betas[i];
        }
        EF_Normal.ArrayVectorParameter naturalParameters = new EF_Normal.ArrayVectorParameter(2);
        naturalParameters.set(0, this.beta0 + dotProductBetaY);
        naturalParameters.set(1, 1.0 / this.variance);
        return naturalParameters;
    }

    @Override
    public NaturalParameters getExpectedNaturalToParent(Variable parent, Map<Variable, MomentParameters> momentChildCoParents) {
        int parentID = this.parents.indexOf(parent);
        if (this.betas[parentID] == 0.0) {
            EF_Normal.ArrayVectorParameter naturalParameters = new EF_Normal.ArrayVectorParameter(2);
            naturalParameters.set(0, 0.0);
            naturalParameters.set(1, 0.0);
            return naturalParameters;
        }
        int nOfBetas = this.betas.length;
        double dotProductBetaY = 0.0;
        for (int i = 0; i < nOfBetas; ++i) {
            dotProductBetaY += momentChildCoParents.get(this.parents.get(i)).get(0) * this.betas[i];
        }
        double X = momentChildCoParents.get(this.var).get(0);
        double beta_iSquared = this.betas[parentID] * this.betas[parentID];
        double beta_i = this.betas[parentID];
        double Y_i = momentChildCoParents.get(this.parents.get(parentID)).get(0);
        double invVariance = 1.0 / this.variance;
        double factor = beta_i / beta_iSquared;
        double mean = factor * (-this.beta0 + X - (dotProductBetaY - beta_i * Y_i));
        double precision = beta_iSquared * invVariance;
        EF_Normal.ArrayVectorParameter naturalParameters = new EF_Normal.ArrayVectorParameter(2);
        naturalParameters.set(0, mean);
        naturalParameters.set(1, precision);
        return naturalParameters;
    }

    public ConditionalLinearGaussian toConditionalDistribution() {
        ConditionalLinearGaussian normal_normal = new ConditionalLinearGaussian(this.getVariable(), this.getConditioningVariables());
        normal_normal.setIntercept(this.beta0);
        normal_normal.setCoeffParents(Arrays.copyOfRange(this.betas, 0, this.betas.length));
        normal_normal.setVariance(this.variance);
        return normal_normal;
    }

    public CompoundVector createEmtpyCompoundVector() {
        return new CompoundVector(this.nOfParents);
    }

    @Override
    public List<EF_ConditionalDistribution> toExtendedLearningDistribution(ParameterVariables variables, String nameSuffix) {
        ArrayList<EF_ConditionalDistribution> conditionalDistributions = new ArrayList<EF_ConditionalDistribution>();
        Variable varGamma = variables.newGammaParameter(this.var.getName() + "_Gamma_Parameter_" + nameSuffix + "_" + variables.getNumberOfVars());
        conditionalDistributions.add((EF_ConditionalDistribution)((DistributionType)varGamma.getDistributionType()).newEFUnivariateDistribution());
        Variable normalBeta0 = variables.newGaussianParameter(this.var.getName() + "_Beta0_Parameter_" + nameSuffix + "_" + variables.getNumberOfVars());
        conditionalDistributions.add((EF_ConditionalDistribution)((DistributionType)normalBeta0.getDistributionType()).newEFUnivariateDistribution());
        ArrayList<Variable> betas = new ArrayList<Variable>();
        for (Variable variableParent : this.parents) {
            Variable normalBetai = variables.newGaussianParameter(this.var.getName() + "_Beta_" + variableParent.getName() + "_Parameter_" + nameSuffix + "_" + variables.getNumberOfVars());
            betas.add(normalBetai);
            conditionalDistributions.add((EF_ConditionalDistribution)((DistributionType)normalBetai.getDistributionType()).newEFUnivariateDistribution());
        }
        EF_Normal_Normal_Gamma condDist = new EF_Normal_Normal_Gamma(this.var, this.parents, normalBeta0, betas, varGamma);
        conditionalDistributions.add(condDist);
        return conditionalDistributions;
    }

    public static class CompoundVector
    implements SufficientStatistics,
    MomentParameters,
    NaturalParameters,
    Serializable {
        private static final long serialVersionUID = -3436599636425587512L;
        int size;
        int nOfParents;
        RealVector XYbaseVector;
        RealMatrix covbaseVector;

        public CompoundVector(int nOfParents_) {
            this.nOfParents = nOfParents_;
            this.XYbaseVector = new ArrayRealVector(this.nOfParents + 1);
            this.covbaseVector = new Array2DRowRealMatrix(this.nOfParents + 1, this.nOfParents + 1);
            this.size = this.nOfParents * this.nOfParents + 3 * this.nOfParents + 2;
        }

        public void setXYbaseVector(RealVector XYbaseVector_) {
            this.XYbaseVector = XYbaseVector_;
        }

        public void setcovbaseVector(RealMatrix covbaseVector_) {
            this.covbaseVector = covbaseVector_;
        }

        public void setMatrixByPosition(int position, RealMatrix vec) {
            switch (position) {
                case 0: {
                    this.XYbaseVector = vec.getRowVector(0);
                }
                case 1: {
                    this.covbaseVector = vec;
                }
            }
            throw new IndexOutOfBoundsException("There are only two components (indexes 0 for XY and 1 for the cov. matrix) in a normal|normal EF distribution.");
        }

        public RealVector getXYbaseMatrix() {
            return this.XYbaseVector;
        }

        public RealMatrix getcovbaseMatrix() {
            return this.covbaseVector;
        }

        public RealMatrix getMatrixByPosition(int position) {
            switch (position) {
                case 0: {
                    return new Array2DRowRealMatrix(this.XYbaseVector.getData());
                }
                case 1: {
                    return this.covbaseVector;
                }
            }
            throw new IndexOutOfBoundsException("There are only two components (indexes 0 for XY and 1 for the cov. matrix) in a normal|normal EF distribution.");
        }

        private static boolean isBetween(int x, int lower, int upper) {
            return lower <= x && x <= upper;
        }

        @Override
        public double get(int i) {
            if (i == 0) {
                return this.XYbaseVector.getEntry(0);
            }
            if (CompoundVector.isBetween(i, 1, this.nOfParents)) {
                return this.XYbaseVector.getEntry(i);
            }
            int row = (i -= this.nOfParents + 1) / (this.nOfParents + 1);
            int column = i - row * (this.nOfParents + 1);
            return this.covbaseVector.getEntry(row, column);
        }

        @Override
        public void set(int i, double val) {
            if (i == 0) {
                this.XYbaseVector.setEntry(0, val);
            } else if (CompoundVector.isBetween(i, 1, this.nOfParents)) {
                this.XYbaseVector.setEntry(i, val);
            } else {
                int row = (i -= this.nOfParents + 1) / (this.nOfParents + 1);
                int column = i - row * (this.nOfParents + 1);
                this.covbaseVector.setEntry(row, column, val);
            }
        }

        public void setThetaBeta0_NatParam(double val) {
            this.XYbaseVector.setEntry(0, val);
        }

        public void setThetaBeta0Beta_NatParam(double[] val) {
            this.XYbaseVector.setSubVector(1, val);
        }

        public void setThetaCov_NatParam(double theta_Minus1, double[] beta, double variance2Inv) {
            double[] theta_Minus1array = new double[]{theta_Minus1};
            double[] theta_beta = Arrays.stream(beta).map(w -> w * variance2Inv).toArray();
            ArrayRealVector covXY = new ArrayRealVector(theta_Minus1array, theta_beta);
            this.covbaseVector.setColumnVector(0, covXY);
            this.covbaseVector.setRowVector(0, covXY);
            ArrayRealVector betaRV = new ArrayRealVector(beta);
            RealMatrix theta_betaBeta = betaRV.outerProduct((RealVector)betaRV).scalarMultiply(-variance2Inv);
            this.covbaseVector.setSubMatrix(theta_betaBeta.getData(), 1, 1);
        }

        public double getTheta_beta0() {
            return this.XYbaseVector.getEntry(0);
        }

        public double[] getTheta_beta0Beta() {
            return this.getXYbaseMatrix().getSubVector(1, this.nOfParents).toArray();
        }

        public double getTheta_Minus1() {
            return this.getcovbaseMatrix().getEntry(0, 0);
        }

        public double[] getTheta_Beta() {
            return this.getcovbaseMatrix().getSubMatrix(0, 0, 1, this.nOfParents).getRow(0);
        }

        public double[][] getTheta_BetaBeta() {
            return this.getcovbaseMatrix().getSubMatrix(1, this.nOfParents, 1, this.nOfParents).getData();
        }

        public RealVector getTheta_beta0BetaRV() {
            return this.getXYbaseMatrix().getSubVector(1, this.nOfParents);
        }

        public RealVector getTheta_BetaRV() {
            return this.getcovbaseMatrix().getSubMatrix(0, 0, 1, this.nOfParents).getRowVector(0);
        }

        public RealMatrix getTheta_BetaBetaRM() {
            return this.getcovbaseMatrix().getSubMatrix(1, this.nOfParents, 1, this.nOfParents);
        }

        @Override
        public int size() {
            return this.size;
        }

        @Override
        public void sum(Vector vector) {
            this.sum((CompoundVector)vector);
        }

        @Override
        public void copy(Vector vector) {
            this.copy((CompoundVector)vector);
        }

        @Override
        public void divideBy(double val) {
            this.XYbaseVector.mapDivideToSelf(val);
            this.covbaseVector = this.covbaseVector.scalarMultiply(1.0 / val);
        }

        @Override
        public double dotProduct(Vector vec) {
            return this.dotProduct((CompoundVector)vec);
        }

        public double dotProduct(CompoundVector vec) {
            double result = this.getXYbaseMatrix().dotProduct(vec.getXYbaseMatrix());
            return result += IntStream.range(0, this.nOfParents + 1).mapToDouble(p -> this.getcovbaseMatrix().getRowVector(p).dotProduct(vec.getcovbaseMatrix().getRowVector(p))).sum();
        }

        public void copy(CompoundVector vector) {
            this.XYbaseVector = vector.getXYbaseMatrix().copy();
            this.covbaseVector = vector.getcovbaseMatrix().copy();
        }

        public void sum(CompoundVector vector) {
            this.XYbaseVector = this.XYbaseVector.add(vector.getXYbaseMatrix());
            this.covbaseVector = this.covbaseVector.add(vector.getcovbaseMatrix());
        }
    }
}

