/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.smooth;

import dr.evolution.coalescent.IntervalType;
import dr.evolution.tree.Tree;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.MachineAccuracy;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;

public class SkyGlideLikelihood
extends AbstractModelLikelihood
implements Reportable {
    private final List<TreeModel> trees;
    private final List<BigFastTreeIntervals> intervals;
    private final Parameter logPopSizeParameter;
    private final Parameter gridPointParameter;
    private boolean likelihoodKnown = false;
    private double logLikelihood;

    public SkyGlideLikelihood(String string, List<TreeModel> list, Parameter parameter, Parameter parameter2) {
        super(string);
        this.trees = list;
        this.logPopSizeParameter = parameter;
        this.gridPointParameter = parameter2;
        this.intervals = new ArrayList<BigFastTreeIntervals>();
        for (int i = 0; i < list.size(); ++i) {
            BigFastTreeIntervals bigFastTreeIntervals = new BigFastTreeIntervals(list.get(i));
            this.intervals.add(bigFastTreeIntervals);
            this.addModel(bigFastTreeIntervals);
        }
        this.addVariable(parameter);
    }

    public List<TreeModel> getTrees() {
        return this.trees;
    }

    public BigFastTreeIntervals getIntervals(int n) {
        return this.intervals.get(n);
    }

    public TreeModel getTree(int n) {
        return this.trees.get(n);
    }

    @Override
    public String getReport() {
        return "skyGlideLikelihood(" + this.getLogLikelihood() + ")";
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
        this.likelihoodKnown = false;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            double d = 0.0;
            for (int i = 0; i < this.trees.size(); ++i) {
                d += this.getSingleTreeLogLikelihood(i);
            }
            this.logLikelihood = d;
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    public Parameter getLogPopSizeParameter() {
        return this.logPopSizeParameter;
    }

    public double[] getGradientWrtLogPopulationSize() {
        double[] dArray = new double[this.logPopSizeParameter.getDimension()];
        for (int i = 0; i < this.trees.size(); ++i) {
            BigFastTreeIntervals bigFastTreeIntervals = this.intervals.get(i);
            Tree tree = this.trees.get(i);
            int n = 0;
            for (int j = 0; j < bigFastTreeIntervals.getIntervalCount(); ++j) {
                int n2;
                double d;
                int n3 = bigFastTreeIntervals.getLineageCount(j);
                int[] nArray = bigFastTreeIntervals.getNodeNumbersForInterval(j);
                double d2 = tree.getNodeHeight(tree.getNode(nArray[0]));
                if (d2 == (d = tree.getNodeHeight(tree.getNode(nArray[1])))) continue;
                int[] nArray2 = this.getGridPoints(n, d2, d);
                int n4 = nArray2[0];
                if (n4 == (n2 = nArray2[1])) {
                    this.updateIntervalGradientWrtLogPopSize(d2, d, n4, n3, dArray);
                } else {
                    this.updateIntervalGradientWrtLogPopSize(d2, this.gridPointParameter.getParameterValue(n4), n4, n3, dArray);
                    n = n4;
                    while (n + 1 < n2) {
                        this.updateIntervalGradientWrtLogPopSize(this.gridPointParameter.getParameterValue(n), this.gridPointParameter.getParameterValue(n + 1), n + 1, n3, dArray);
                        ++n;
                    }
                    this.updateIntervalGradientWrtLogPopSize(this.gridPointParameter.getParameterValue(n), d, n + 1, n3, dArray);
                }
                n = n2;
            }
            this.updateSingleTreePopulationInverseGradientWrtLogPopSize(i, dArray);
        }
        return dArray;
    }

    public double[] getDiagonalHessianLogDensityWrtLogPopSize() {
        double[] dArray = new double[this.logPopSizeParameter.getDimension()];
        for (int i = 0; i < this.trees.size(); ++i) {
            BigFastTreeIntervals bigFastTreeIntervals = this.intervals.get(i);
            Tree tree = this.trees.get(i);
            int n = 0;
            for (int j = 0; j < bigFastTreeIntervals.getIntervalCount(); ++j) {
                int n2;
                double d;
                int n3 = bigFastTreeIntervals.getLineageCount(j);
                int[] nArray = bigFastTreeIntervals.getNodeNumbersForInterval(j);
                double d2 = tree.getNodeHeight(tree.getNode(nArray[0]));
                if (d2 == (d = tree.getNodeHeight(tree.getNode(nArray[1])))) continue;
                int[] nArray2 = this.getGridPoints(n, d2, d);
                int n4 = nArray2[0];
                if (n4 == (n2 = nArray2[1])) {
                    this.updateIntervalDiagonalHessianWrtLogPopSize(d2, d, n4, n3, dArray);
                } else {
                    this.updateIntervalDiagonalHessianWrtLogPopSize(d2, this.gridPointParameter.getParameterValue(n4), n4, n3, dArray);
                    n = n4;
                    while (n + 1 < n2) {
                        this.updateIntervalDiagonalHessianWrtLogPopSize(this.gridPointParameter.getParameterValue(n), this.gridPointParameter.getParameterValue(n + 1), n + 1, n3, dArray);
                        ++n;
                    }
                    this.updateIntervalDiagonalHessianWrtLogPopSize(this.gridPointParameter.getParameterValue(n), d, n + 1, n3, dArray);
                }
                n = n2;
            }
        }
        return dArray;
    }

    private void updateIntervalDiagonalHessianWrtLogPopSize(double d, double d2, int n, int n2, double[] dArray) {
        double d3 = this.getGridSlope(n);
        double d4 = this.getGridIntercept(n);
        double d5 = -0.5 * (double)n2 * (double)(n2 - 1);
        assert (d3 != 0.0 || d4 != 0.0);
        double d6 = this.getMagicUnderFlowBound(d3);
        if (d != d2) {
            double d7 = Math.exp(-d3 * d);
            double d8 = Math.exp(-d3 * d2);
            double d9 = Math.exp(-d4);
            double d10 = n < this.gridPointParameter.getDimension() ? this.gridPointParameter.getParameterValue(n) : 0.0;
            double d11 = n == 0 ? 0.0 : this.gridPointParameter.getParameterValue(n - 1);
            double d12 = this.getLinearInverseIntegral(d, d2, n);
            double d13 = Math.abs(d3) < d6 ? d9 * (d2 * d2 * d2 - d * d * d) / 3.0 : d9 * (-2.0 / d3 / d3 * (d2 * d8 - d * d7) + (d * d * d7 - d2 * d2 * d8) / d3 + 2.0 / d3 / d3 / d3 * (d7 - d8));
            double d14 = Math.abs(d3) < d6 ? d9 * (d * d - d2 * d2) / 2.0 : d9 * ((d2 * d8 - d * d7) / d3 - (d7 - d8) / d3 / d3);
            double d15 = -d14;
            double d16 = d10 > 0.0 ? d10 / (d10 - d11) : 1.0;
            double d17 = d10 > 0.0 ? -d11 / (d10 - d11) : 0.0;
            double d18 = d10 > 0.0 ? -1.0 / (d10 - d11) : 0.0;
            double d19 = d10 > 0.0 ? 1.0 / (d10 - d11) : 0.0;
            int n3 = n;
            dArray[n3] = dArray[n3] + d5 * (d12 * d16 * d16 + 2.0 * d15 * d18 * d16 + d13 * d18 * d18);
            if (n < this.gridPointParameter.getDimension()) {
                int n4 = n + 1;
                dArray[n4] = dArray[n4] + d5 * (d12 * d17 * d17 + 2.0 * d15 * d19 * d17 + d13 * d19 * d19);
            }
        }
    }

    public double[] getGradientWrtNodeHeight(int n) {
        return this.getDerivativeWrtNodeHeight(n, NodeHeightDerivativeType.GRADIENT);
    }

    public double[] getDiagonalHessianWrtNodeHeight(int n) {
        return this.getDerivativeWrtNodeHeight(n, NodeHeightDerivativeType.DIAGONAL_HESSIAN);
    }

    public double[] getDerivativeWrtNodeHeight(int n, NodeHeightDerivativeType nodeHeightDerivativeType) {
        int n2;
        int n3;
        BigFastTreeIntervals bigFastTreeIntervals = this.intervals.get(n);
        Tree tree = this.trees.get(n);
        double[] dArray = new double[tree.getInternalNodeCount()];
        int n4 = 0;
        double d = 0.0;
        double d2 = 1.0;
        for (n3 = 0; n3 < bigFastTreeIntervals.getIntervalCount(); ++n3) {
            n2 = bigFastTreeIntervals.getLineageCount(n3);
            int[] nArray = bigFastTreeIntervals.getNodeNumbersForInterval(n3);
            double d3 = tree.getNodeHeight(tree.getNode(nArray[0]));
            double d4 = tree.getNodeHeight(tree.getNode(nArray[1]));
            if (tree.isExternal(tree.getNode(nArray[0])) && tree.isExternal(tree.getNode(nArray[1]))) continue;
            int[] nArray2 = this.getGridPoints(n4, d3, d4);
            int n5 = nArray2[0];
            int n6 = nArray2[1];
            if (d3 == d4) {
                if (bigFastTreeIntervals.getIntervalType(n3) == IntervalType.COALESCENT) {
                    d2 += 1.0;
                }
            } else {
                double d5 = this.getGridSlope(n5);
                double d6 = this.getGridIntercept(n5);
                double d7 = this.getGridSlope(n6);
                double d8 = this.getGridIntercept(n6);
                double d9 = 0.5 * (double)n2 * (double)(n2 - 1);
                if (!tree.isExternal(tree.getNode(nArray[0]))) {
                    d += nodeHeightDerivativeType.getNodeHeightDerivative(d6, d5, d3, d9);
                }
                int n7 = 0;
                int n8 = 0;
                while (d2 - (double)n7 > 0.0 && d != 0.0) {
                    int n9 = bigFastTreeIntervals.getNodeNumbersForInterval(n3 - n8)[0];
                    if (!tree.isExternal(tree.getNode(n9))) {
                        ++n7;
                        int n10 = n9 - tree.getExternalNodeCount();
                        dArray[n10] = dArray[n10] + d / d2;
                    }
                    ++n8;
                }
                d2 = 1.0;
                d = 0.0;
                if (bigFastTreeIntervals.getIntervalType(n3) == IntervalType.COALESCENT) {
                    d = -nodeHeightDerivativeType.getNodeHeightDerivative(d8, d7, d4, d9);
                }
            }
            n4 = n6;
        }
        n3 = 0;
        n2 = 0;
        while (d2 - (double)n3 > 0.0 && d != 0.0) {
            int n11 = bigFastTreeIntervals.getNodeNumbersForInterval(bigFastTreeIntervals.getIntervalCount() - 1 - n2)[1];
            if (!tree.isExternal(tree.getNode(n11))) {
                ++n3;
                int n12 = n11 - tree.getExternalNodeCount();
                dArray[n12] = dArray[n12] + d / d2;
            }
            ++n2;
        }
        nodeHeightDerivativeType.updateSingleTreePopulationInverseGradientWrtNodeHeight(this, n, dArray);
        return dArray;
    }

    public double getSingleTreeLogLikelihood(int n) {
        BigFastTreeIntervals bigFastTreeIntervals = this.intervals.get(n);
        Tree tree = this.trees.get(n);
        int n2 = 0;
        double d = 0.0;
        for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); ++i) {
            double d2;
            int n3 = bigFastTreeIntervals.getLineageCount(i);
            int[] nArray = bigFastTreeIntervals.getNodeNumbersForInterval(i);
            double d3 = tree.getNodeHeight(tree.getNode(nArray[0]));
            if (d3 == (d2 = tree.getNodeHeight(tree.getNode(nArray[1])))) continue;
            int[] nArray2 = this.getGridPoints(n2, d3, d2);
            int n4 = nArray2[0];
            int n5 = nArray2[1];
            double d4 = 0.0;
            if (n4 == n5) {
                d4 += this.getLinearInverseIntegral(d3, d2, n4);
            } else {
                d4 += this.getLinearInverseIntegral(d3, this.gridPointParameter.getParameterValue(n4), n4);
                n2 = n4;
                while (n2 + 1 < n5) {
                    d4 += this.getLinearInverseIntegral(this.gridPointParameter.getParameterValue(n2), this.gridPointParameter.getParameterValue(n2 + 1), n2 + 1);
                    ++n2;
                }
                d4 += this.getLinearInverseIntegral(this.gridPointParameter.getParameterValue(n2), d2, n2 + 1);
            }
            n2 = n5;
            d -= 0.5 * (double)n3 * (double)(n3 - 1) * d4;
        }
        return d += this.getSingleTreePopulationInverseLogLikelihood(n);
    }

    private double getSingleTreePopulationInverseLogLikelihood(int n) {
        int n2 = 0;
        double d = 0.0;
        BigFastTreeIntervals bigFastTreeIntervals = this.intervals.get(n);
        for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); ++i) {
            if (bigFastTreeIntervals.getIntervalType(i) != IntervalType.COALESCENT) continue;
            double d2 = bigFastTreeIntervals.getIntervalTime(i + 1);
            n2 = this.getGridIndex(d2, n2);
            d -= this.getLogPopulationSize(d2, n2);
        }
        return d;
    }

    private void updateSingleTreePopulationInverseGradientWrtLogPopSize(int n, double[] dArray) {
        int n2 = 0;
        BigFastTreeIntervals bigFastTreeIntervals = this.intervals.get(n);
        for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); ++i) {
            if (bigFastTreeIntervals.getIntervalType(i) != IntervalType.COALESCENT) continue;
            double d = bigFastTreeIntervals.getIntervalTime(i + 1);
            n2 = this.getGridIndex(d, n2);
            this.updateLogPopSizeDerivative(d, n2, dArray);
        }
    }

    private double getLogPopulationSize(double d, int n) {
        double d2 = this.getGridSlope(n);
        double d3 = this.getGridIntercept(n);
        return d3 + d2 * d;
    }

    private void updateLogPopSizeDerivative(double d, int n, double[] dArray) {
        this.updateGridSlopeDerivativeWrtLogPopSize(n, dArray, -d);
        this.updateGridInterceptDerivativeWrtLogPopSize(n, dArray, -1.0);
    }

    private int getGridIndex(double d, int n) {
        int n2;
        for (n2 = n; n2 < this.gridPointParameter.getDimension() && this.gridPointParameter.getParameterValue(n2) < d; ++n2) {
        }
        return n2;
    }

    private double getLinearInverseIntegral(double d, double d2, int n) {
        double d3 = this.getGridSlope(n);
        double d4 = this.getGridIntercept(n);
        assert (d3 != 0.0 || d4 != 0.0);
        if (d == d2) {
            return 0.0;
        }
        double d5 = this.getMagicUnderFlowBound(d3);
        if (Math.abs(d3) < d5) {
            return Math.exp(-d4) * (d2 - d);
        }
        return Math.exp(-d4) * (Math.exp(-d3 * d) - Math.exp(-d3 * d2)) / d3;
    }

    private double getMagicUnderFlowBound(double d) {
        return MachineAccuracy.SQRT_EPSILON * (Math.abs(d) + 1.0);
    }

    private void updateIntervalGradientWrtLogPopSize(double d, double d2, int n, int n2, double[] dArray) {
        double d3 = this.getGridSlope(n);
        double d4 = this.getGridIntercept(n);
        double d5 = -0.5 * (double)n2 * (double)(n2 - 1);
        assert (d3 != 0.0 || d4 != 0.0);
        double d6 = this.getMagicUnderFlowBound(d3);
        if (d != d2) {
            double d7 = Math.abs(d3) < d6 ? 0.0 : Math.exp(-d4) * (-d * Math.exp(-d3 * d) + d2 * Math.exp(-d3 * d2) - (Math.exp(-d3 * d) - Math.exp(-d3 * d2)) / d3) / d3;
            double d8 = Math.abs(d3) < d6 ? (d2 - d) * -Math.exp(-d4) : Math.exp(-d4) * (-Math.exp(-d3 * d) + Math.exp(-d3 * d2)) / d3;
            this.updateGridInterceptDerivativeWrtLogPopSize(n, dArray, d5 * d8);
            this.updateGridSlopeDerivativeWrtLogPopSize(n, dArray, d5 * d7);
        }
    }

    private double getGridSlope(int n) {
        if (n == this.gridPointParameter.getDimension()) {
            return 0.0;
        }
        double d = this.gridPointParameter.getParameterValue(n);
        double d2 = n == 0 ? 0.0 : this.gridPointParameter.getParameterValue(n - 1);
        return (this.logPopSizeParameter.getParameterValue(n + 1) - this.logPopSizeParameter.getParameterValue(n)) / (d - d2);
    }

    private void updateGridSlopeDerivativeWrtLogPopSize(int n, double[] dArray, double d) {
        if (n != this.gridPointParameter.getDimension()) {
            double d2 = this.gridPointParameter.getParameterValue(n);
            double d3 = n == 0 ? 0.0 : this.gridPointParameter.getParameterValue(n - 1);
            int n2 = n + 1;
            dArray[n2] = dArray[n2] + d / (d2 - d3);
            int n3 = n;
            dArray[n3] = dArray[n3] - d / (d2 - d3);
        }
    }

    private double getGridIntercept(int n) {
        if (n == this.gridPointParameter.getDimension() || n == 0) {
            return this.logPopSizeParameter.getParameterValue(n);
        }
        double d = this.gridPointParameter.getParameterValue(n);
        double d2 = this.gridPointParameter.getParameterValue(n - 1);
        return (d * this.logPopSizeParameter.getParameterValue(n) - d2 * this.logPopSizeParameter.getParameterValue(n + 1)) / (d - d2);
    }

    private void updateGridInterceptDerivativeWrtLogPopSize(int n, double[] dArray, double d) {
        if (n == this.gridPointParameter.getDimension() || n == 0) {
            int n2 = n;
            dArray[n2] = dArray[n2] + d;
        } else {
            double d2 = this.gridPointParameter.getParameterValue(n);
            double d3 = this.gridPointParameter.getParameterValue(n - 1);
            double d4 = d2 / (d2 - d3) * d;
            double d5 = -d3 / (d2 - d3) * d;
            int n3 = n;
            dArray[n3] = dArray[n3] + d4;
            int n4 = n + 1;
            dArray[n4] = dArray[n4] + d5;
        }
    }

    private int[] getGridPoints(int n, double d, double d2) {
        int n2;
        for (n2 = n; n2 < this.gridPointParameter.getDimension() && this.gridPointParameter.getParameterValue(n2) < d; ++n2) {
        }
        int n3 = n2;
        while (n2 < this.gridPointParameter.getDimension() && this.gridPointParameter.getParameterValue(n2) < d2) {
            ++n2;
        }
        int n4 = n2;
        return new int[]{n3, n4};
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
    }

    public static enum NodeHeightDerivativeType {
        GRADIENT{

            @Override
            double getNodeHeightDerivative(double d, double d2, double d3, double d4) {
                return d4 * Math.exp(-d - d2 * d3);
            }

            @Override
            void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood skyGlideLikelihood, int n, double[] dArray) {
                int n2 = 0;
                BigFastTreeIntervals bigFastTreeIntervals = skyGlideLikelihood.getIntervals(n);
                TreeModel treeModel = skyGlideLikelihood.getTree(n);
                for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); ++i) {
                    if (bigFastTreeIntervals.getIntervalType(i) != IntervalType.COALESCENT) continue;
                    double d = bigFastTreeIntervals.getIntervalTime(i + 1);
                    int n3 = bigFastTreeIntervals.getNodeNumbersForInterval(i)[1];
                    n2 = skyGlideLikelihood.getGridIndex(d, n2);
                    double d2 = skyGlideLikelihood.getGridSlope(n2);
                    int n4 = n3 - treeModel.getExternalNodeCount();
                    dArray[n4] = dArray[n4] - d2;
                }
            }
        }
        ,
        DIAGONAL_HESSIAN{

            @Override
            double getNodeHeightDerivative(double d, double d2, double d3, double d4) {
                return -d4 * Math.exp(-d - d2 * d3) * d2;
            }

            @Override
            void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood skyGlideLikelihood, int n, double[] dArray) {
            }
        };


        abstract double getNodeHeightDerivative(double var1, double var3, double var5, double var7);

        abstract void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood var1, int var2, double[] var3);
    }
}

