/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.GammaDistributionModel;
import dr.inference.distribution.LogNormalDistributionModel;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.repeatedMeasures.GammaGibbsProvider;
import dr.math.MathUtils;
import dr.math.distributions.Distribution;
import dr.math.distributions.GammaDistribution;
import dr.math.matrixAlgebra.Vector;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import dr.xml.XORRule;

public class NormalGammaPrecisionGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator,
Reportable {
    public static final String OPERATOR_NAME = "normalGammaPrecisionGibbsOperator";
    public static final String LIKELIHOOD = "likelihood";
    private static final String NORMAL_EXTENSION = "normalExtension";
    public static final String PRIOR = "prior";
    private static final String WORKING = "workingDistribution";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), new XORRule(new ElementRule("likelihood", new XMLSyntaxRule[]{new ElementRule(DistributionLikelihood.class)}), new ElementRule(GammaGibbsProvider.class)), new ElementRule("prior", new XMLSyntaxRule[]{new ElementRule(DistributionLikelihood.class)}), new ElementRule("workingDistribution", new XMLSyntaxRule[]{new ElementRule(DistributionLikelihood.class)}, true)};

        @Override
        public String getParserName() {
            return NormalGammaPrecisionGibbsOperator.OPERATOR_NAME;
        }

        private void checkGammaDistribution(DistributionLikelihood distributionLikelihood) throws XMLParseException {
            if (!(distributionLikelihood.getDistribution() instanceof GammaDistribution) && !(distributionLikelihood.getDistribution() instanceof GammaDistributionModel)) {
                throw new XMLParseException("Gibbs operator assumes normal-gamma model");
            }
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            GammaGibbsProvider gammaGibbsProvider;
            double d = xMLObject.getDoubleAttribute("weight");
            DistributionLikelihood distributionLikelihood = (DistributionLikelihood)xMLObject.getElementFirstChild(NormalGammaPrecisionGibbsOperator.PRIOR);
            this.checkGammaDistribution(distributionLikelihood);
            DistributionLikelihood distributionLikelihood2 = xMLObject.hasChildNamed(NormalGammaPrecisionGibbsOperator.WORKING) ? (DistributionLikelihood)xMLObject.getElementFirstChild(NormalGammaPrecisionGibbsOperator.WORKING) : null;
            Distribution distribution = null;
            if (distributionLikelihood2 != null) {
                this.checkGammaDistribution(distributionLikelihood2);
                distribution = distributionLikelihood2.getDistribution();
            }
            if (xMLObject.hasChildNamed(NormalGammaPrecisionGibbsOperator.LIKELIHOOD)) {
                DistributionLikelihood distributionLikelihood3 = (DistributionLikelihood)xMLObject.getElementFirstChild(NormalGammaPrecisionGibbsOperator.LIKELIHOOD);
                if (!(distributionLikelihood3.getDistribution() instanceof NormalDistributionModel) && !(distributionLikelihood3.getDistribution() instanceof LogNormalDistributionModel)) {
                    throw new XMLParseException("Gibbs operator assumes normal-gamma model");
                }
                gammaGibbsProvider = new GammaGibbsProvider.Default(distributionLikelihood3);
            } else {
                gammaGibbsProvider = (GammaGibbsProvider)xMLObject.getChild(GammaGibbsProvider.class);
            }
            return new NormalGammaPrecisionGibbsOperator(gammaGibbsProvider, distributionLikelihood.getDistribution(), distribution, d);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a operator on the precision parameter of a normal model with gamma prior.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private final GammaGibbsProvider gammaGibbsProvider;
    private final Parameter precisionParameter;
    private final GammaParametrization priorParametrization;
    private final GammaParametrization workingParametrization;
    private double pathParameter = 1.0;

    public NormalGammaPrecisionGibbsOperator(GammaGibbsProvider gammaGibbsProvider, Distribution distribution, double d) {
        this(gammaGibbsProvider, distribution, null, d);
    }

    public NormalGammaPrecisionGibbsOperator(GammaGibbsProvider gammaGibbsProvider, Distribution distribution, Distribution distribution2, double d) {
        this.gammaGibbsProvider = gammaGibbsProvider;
        this.precisionParameter = gammaGibbsProvider.getPrecisionParameter();
        this.priorParametrization = new GammaParametrization(distribution.mean(), distribution.variance());
        this.workingParametrization = distribution2 != null ? new GammaParametrization(distribution2.mean(), distribution2.variance()) : null;
        this.setWeight(d);
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return OPERATOR_NAME;
    }

    @Override
    public String getReport() {
        int n = this.precisionParameter.getDimension();
        double[] dArray = new double[n];
        double[] dArray2 = new double[n];
        this.gammaGibbsProvider.drawValues();
        for (int i = 0; i < n; ++i) {
            GammaGibbsProvider.SufficientStatistics sufficientStatistics = this.gammaGibbsProvider.getSufficientStatistics(i);
            dArray[i] = sufficientStatistics.observationCount;
            dArray2[i] = sufficientStatistics.sumOfSquaredErrors;
        }
        StringBuilder stringBuilder = new StringBuilder("normalGammaPrecisionGibbsOperator report:\n");
        stringBuilder.append("Observation counts:\t");
        stringBuilder.append(new Vector(dArray));
        stringBuilder.append("\n");
        stringBuilder.append("Sum of squared errors:\t");
        stringBuilder.append(new Vector(dArray2));
        return stringBuilder.toString();
    }

    private double weigh(double d, double d2) {
        return (1.0 - this.pathParameter) * d + this.pathParameter * d2;
    }

    @Override
    public double doOperation() {
        this.gammaGibbsProvider.drawValues();
        for (int i = 0; i < this.precisionParameter.getDimension(); ++i) {
            GammaGibbsProvider.SufficientStatistics sufficientStatistics = this.gammaGibbsProvider.getSufficientStatistics(i);
            double d = this.pathParameter * (double)sufficientStatistics.observationCount / 2.0;
            double d2 = this.pathParameter * sufficientStatistics.sumOfSquaredErrors / 2.0;
            if (this.workingParametrization == null) {
                d += this.priorParametrization.getShape();
                d2 += this.priorParametrization.getRate();
            } else {
                d += this.weigh(this.priorParametrization.getShape(), this.priorParametrization.getShape());
                d2 += this.weigh(this.priorParametrization.getRate(), this.priorParametrization.getShape());
            }
            double d3 = MathUtils.nextGamma(d, d2);
            this.precisionParameter.setParameterValue(i, d3);
        }
        return 0.0;
    }

    @Override
    public void setPathParameter(double d) {
        if (d < 0.0 || d > 1.0) {
            throw new IllegalArgumentException("Invalid pathParameter value");
        }
        this.pathParameter = d;
    }

    public int getStepCount() {
        return 1;
    }

    static class GammaParametrization {
        private final double rate;
        private final double shape;

        GammaParametrization(double d, double d2) {
            if (d == 0.0) {
                this.rate = 0.0;
                this.shape = -0.5;
            } else {
                this.rate = d / d2;
                this.shape = d * this.rate;
            }
        }

        double getRate() {
            return this.rate;
        }

        double getShape() {
            return this.shape;
        }
    }
}

