/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.hmc;

import dr.evolution.coalescent.ConstantPopulation;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.tree.Tree;
import dr.evolution.util.Units;
import dr.evomodel.coalescent.BayesianSkylineLikelihood;
import dr.evomodel.coalescent.OldAbstractCoalescentLikelihood;
import dr.evomodel.coalescent.hmc.GMRFGradient;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.Binomial;
import dr.xml.Reportable;

public class BayesianSkylineGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    private final BayesianSkylineLikelihood likelihood;
    private final WrtParameter wrtParameter;
    private final Double tolerance;

    public BayesianSkylineGradient(BayesianSkylineLikelihood bayesianSkylineLikelihood, WrtParameter wrtParameter, Double d) {
        this.likelihood = bayesianSkylineLikelihood;
        this.wrtParameter = wrtParameter;
        this.tolerance = d;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.wrtParameter.getParameter(this.likelihood);
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.likelihood);
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return this.wrtParameter.getDiagonalHessianLogDensity(this.likelihood);
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, this.wrtParameter.getParameterLowerBound(), Double.POSITIVE_INFINITY, this.tolerance);
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "BayesianSkylineGradient check");
    }

    public static enum WrtParameter {
        NODE_HEIGHT("nodeHeight"){
            Parameter parameter;

            @Override
            Parameter getParameter(BayesianSkylineLikelihood bayesianSkylineLikelihood) {
                if (this.parameter == null) {
                    TreeModel treeModel = (TreeModel)bayesianSkylineLikelihood.getTree();
                    this.parameter = new NodeHeightProxyParameter("allInternalNode", treeModel, true);
                }
                return this.parameter;
            }

            @Override
            double[] getGradientLogDensity(BayesianSkylineLikelihood bayesianSkylineLikelihood) {
                return this.getGradientWrtNodeHeights(bayesianSkylineLikelihood);
            }

            private double[] getGradientWrtNodeHeights(BayesianSkylineLikelihood bayesianSkylineLikelihood) {
                this.getWarning(bayesianSkylineLikelihood);
                bayesianSkylineLikelihood.setupIntervals();
                double[] dArray = new double[bayesianSkylineLikelihood.getTree().getInternalNodeCount()];
                Tree tree = bayesianSkylineLikelihood.getTree();
                double[] dArray2 = new double[tree.getInternalNodeCount()];
                double[] dArray3 = new double[tree.getInternalNodeCount()];
                int[] nArray = new int[bayesianSkylineLikelihood.getTree().getInternalNodeCount()];
                GMRFGradient.WrtParameter.sortNodeHeights(tree, dArray2, dArray3, nArray);
                double d = 0.0;
                int n = 0;
                int[] nArray2 = bayesianSkylineLikelihood.getGroupSizes();
                double[] dArray4 = bayesianSkylineLikelihood.getGroupHeights();
                int n2 = 0;
                ConstantPopulation constantPopulation = new ConstantPopulation(Units.Type.YEARS);
                int n3 = 0;
                for (int i = 0; i < bayesianSkylineLikelihood.getIntervalCount(); ++i) {
                    double d2 = bayesianSkylineLikelihood.getPopSize(n, d + bayesianSkylineLikelihood.getInterval(i) / 2.0, dArray4);
                    constantPopulation.setN0(d2);
                    if (bayesianSkylineLikelihood.getIntervalType(i) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT && ++n2 >= nArray2[n]) {
                        ++n;
                        n2 = 0;
                    }
                    d += bayesianSkylineLikelihood.getInterval(i);
                    if (bayesianSkylineLikelihood.getIntervalType(i) != OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) continue;
                    double d3 = dArray2[n3];
                    while (d3 < d) {
                        d3 = dArray2[++n3];
                    }
                    dArray[n3] = this.getIntervalGradient(constantPopulation, d, bayesianSkylineLikelihood.getLineageCount(i), bayesianSkylineLikelihood.getIntervalType(i));
                    if (i + 1 < bayesianSkylineLikelihood.getIntervalCount()) {
                        constantPopulation.setN0(bayesianSkylineLikelihood.getPopSize(n, d + bayesianSkylineLikelihood.getInterval(i + 1) / 2.0, dArray4));
                        int n4 = n3;
                        dArray[n4] = dArray[n4] - this.getIntervalGradient(constantPopulation, d, bayesianSkylineLikelihood.getLineageCount(i + 1), bayesianSkylineLikelihood.getIntervalType(i));
                    }
                    ++n3;
                }
                double[] dArray5 = new double[tree.getInternalNodeCount()];
                for (int i = 0; i < tree.getInternalNodeCount(); ++i) {
                    dArray5[nArray[i]] = dArray[i];
                }
                return dArray5;
            }

            private double getIntervalGradient(DemographicFunction demographicFunction, double d, int n, OldAbstractCoalescentLikelihood.CoalescentEventType coalescentEventType) {
                double d2 = demographicFunction.getIntensityGradient(d);
                double d3 = Binomial.choose2(n);
                double d4 = -d3 * d2;
                return d4;
            }

            @Override
            double[] getDiagonalHessianLogDensity(BayesianSkylineLikelihood bayesianSkylineLikelihood) {
                throw new RuntimeException("Not yet implemented!");
            }

            @Override
            double getParameterLowerBound() {
                return 0.0;
            }

            @Override
            public void getWarning(BayesianSkylineLikelihood bayesianSkylineLikelihood) {
                if (bayesianSkylineLikelihood.getType() != 0) {
                    throw new RuntimeException("Only implemented for stepwise type of Skyline model.");
                }
            }
        };

        private final String name;

        private WrtParameter(String string2) {
            this.name = string2;
        }

        abstract Parameter getParameter(BayesianSkylineLikelihood var1);

        abstract double[] getGradientLogDensity(BayesianSkylineLikelihood var1);

        abstract double[] getDiagonalHessianLogDensity(BayesianSkylineLikelihood var1);

        abstract double getParameterLowerBound();

        public abstract void getWarning(BayesianSkylineLikelihood var1);

        public static WrtParameter factory(String string) {
            for (WrtParameter wrtParameter : WrtParameter.values()) {
                if (!string.equalsIgnoreCase(wrtParameter.name)) continue;
                return wrtParameter;
            }
            return null;
        }
    }
}

