/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.speciation;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxa;
import dr.evomodel.speciation.CalibrationLineagesIterator;
import dr.inference.model.Statistic;
import dr.math.distributions.Distribution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class CalibrationPoints {
    private final CorrectionType correctionType;
    private final int[][] clades;
    private final Distribution[] densities;
    private final boolean[] forParent;
    private final int[][] taxaPartialOrder;
    private final int[] freeHeights;
    private final boolean rootCorrection;
    private final Statistic calibrationLogPDF;
    private final double lg2 = Math.log(2.0);
    private double[] lc2;
    private double[] lNR;
    private double[] lfactorials;
    private CalibrationLineagesIterator linsIter = null;
    double lastLam = Double.NEGATIVE_INFINITY;
    double[] lastHeights;
    double lastValue = Double.NEGATIVE_INFINITY;

    public CalibrationPoints(Tree tree, boolean bl, List<Distribution> list, List<Taxa> list2, List<Boolean> list3, Statistic statistic, CorrectionType correctionType) {
        int n;
        int n2;
        Taxa taxa;
        int n3;
        this.densities = new Distribution[list.size()];
        this.clades = new int[list2.size()][];
        this.forParent = new boolean[list2.size()];
        for (int i = 0; i < list2.size(); ++i) {
            Taxa taxa2 = list2.get(i);
            for (n3 = i + 1; n3 < list2.size(); ++n3) {
                taxa = list2.get(n3);
                if (!taxa.containsAny(taxa2) || taxa.containsAll(taxa2) || taxa2.containsAll(taxa)) continue;
                throw new IllegalArgumentException("Overlapping clades??");
            }
        }
        Taxa[] taxaArray = new Taxa[list2.size()];
        for (int i = list2.size() - 1; i >= 0; --i) {
            for (n3 = 0; n3 < list2.size() && !this.isMaximal(list2, n3); ++n3) {
            }
            this.densities[i] = list.remove(n3);
            this.forParent[i] = list3.remove(n3);
            taxa = list2.get(n3);
            n2 = taxa.getTaxonCount();
            this.clades[i] = new int[n2];
            for (int j = 0; j < n2; ++j) {
                int n4;
                this.clades[i][j] = n4 = tree.getTaxonIndex(taxa.getTaxon(j));
                if (n4 >= 0) continue;
                throw new IllegalArgumentException("Taxon not found in tree: " + taxa.getTaxon(j));
            }
            taxaArray[i] = taxa;
            list2.remove(n3);
        }
        List[] listArray = new List[taxaArray.length];
        for (n3 = 0; n3 < taxaArray.length; ++n3) {
            listArray[n3] = new ArrayList();
        }
        block6: for (n3 = 0; n3 < taxaArray.length; ++n3) {
            for (int i = n3 + 1; i < taxaArray.length; ++i) {
                if (!taxaArray[i].containsAll(taxaArray[n3])) continue;
                listArray[i].add(n3);
                continue block6;
            }
        }
        this.taxaPartialOrder = new int[taxaArray.length][];
        for (n3 = 0; n3 < taxaArray.length; ++n3) {
            List list4 = listArray[n3];
            this.taxaPartialOrder[n3] = new int[list4.size()];
            for (n2 = 0; n2 < list4.size(); ++n2) {
                this.taxaPartialOrder[n3][n2] = (Integer)list4.get(n2);
            }
        }
        this.freeHeights = new int[this.clades.length];
        for (n3 = 0; n3 < this.clades.length; ++n3) {
            int n5 = 0;
            for (int n6 : this.taxaPartialOrder[n3]) {
                n5 += this.clades[n6].length - (this.forParent[n6] ? 0 : 1);
            }
            this.freeHeights[n3] = this.clades[n3].length - (this.forParent[n3] ? 1 : 2) - n5;
            assert (this.freeHeights[n3] >= 0);
        }
        boolean[] blArray = new boolean[this.clades.length];
        for (n = 0; n < this.clades.length; ++n) {
            blArray[n] = true;
        }
        for (n = 0; n < this.clades.length; ++n) {
            for (int n4 : this.taxaPartialOrder[n]) {
                blArray[n4] = false;
            }
        }
        this.rootCorrection = this.clades[this.clades.length - 1].length < tree.getExternalNodeCount();
        this.calibrationLogPDF = statistic;
        this.correctionType = correctionType;
        if (statistic == null) {
            if (!bl) {
                throw new IllegalArgumentException("Sorry, not implemented: conditional calibration prior for this non Yule models.");
            }
            if (correctionType == CorrectionType.EXACT) {
                if (this.densities.length != 1) {
                    n = 0;
                    for (boolean bl2 : this.forParent) {
                        if (!bl2) continue;
                        n = 1;
                    }
                    if (n != 0) {
                        throw new IllegalArgumentException("Sorry, not implemented: calibration on parent for more than one clade.");
                    }
                    if (this.densities.length != 2 || !taxaArray[1].containsAll(taxaArray[0])) {
                        this.setUpTables(tree);
                        this.linsIter = new CalibrationLineagesIterator(this.clades, this.taxaPartialOrder, blArray, tree.getExternalNodeCount());
                        this.lastHeights = new double[this.clades.length];
                    }
                }
            } else if (correctionType == CorrectionType.PEXACT) {
                this.setUpTables(tree);
            }
        }
    }

    private void setUpTables(Tree tree) {
        int n;
        int n2 = tree.getExternalNodeCount() + 1;
        double[] dArray = new double[n2];
        this.lc2 = new double[n2];
        this.lfactorials = new double[n2];
        this.lNR = new double[n2];
        dArray[0] = Double.NEGATIVE_INFINITY;
        dArray[1] = 0.0;
        for (n = 2; n < n2; ++n) {
            dArray[n] = Math.log(n);
        }
        this.lc2[1] = Double.NEGATIVE_INFINITY;
        this.lc2[0] = Double.NEGATIVE_INFINITY;
        for (n = 2; n < n2; ++n) {
            this.lc2[n] = dArray[n] + dArray[n - 1] - this.lg2;
        }
        this.lfactorials[0] = 0.0;
        for (n = 1; n < n2; ++n) {
            this.lfactorials[n] = this.lfactorials[n - 1] + dArray[n];
        }
        this.lNR[0] = Double.NEGATIVE_INFINITY;
        this.lNR[1] = 0.0;
        for (n = 2; n < n2; ++n) {
            this.lNR[n] = this.lNR[n - 1] + this.lc2[n];
        }
    }

    private boolean isMaximal(List<Taxa> list, int n) {
        Taxa taxa = list.get(n);
        for (int i = 0; i < list.size(); ++i) {
            Taxa taxa2;
            if (i == n || !(taxa2 = list.get(i)).containsAll(taxa)) continue;
            return false;
        }
        return true;
    }

    public double getCorrection(Tree tree, double d) {
        double d2;
        Object object;
        int n;
        double d3 = 0.0;
        int n2 = this.densities.length;
        double[] dArray = new double[n2];
        for (n = 0; n < n2; ++n) {
            int[] nArray = this.clades[n];
            if (nArray.length > 1) {
                object = TreeUtils.getCommonAncestor(tree, nArray);
                if (TreeUtils.getLeafCount(tree, (NodeRef)object) != nArray.length) {
                    return Double.NEGATIVE_INFINITY;
                }
            } else {
                object = tree.getNode(nArray[0]);
                assert (this.forParent[n]);
            }
            if (this.forParent[n]) {
                object = tree.getParent((NodeRef)object);
            }
            d2 = tree.getNodeHeight((NodeRef)object);
            d3 += this.densities[n].logPdf(d2);
            dArray[n] = d2;
        }
        if (Double.isInfinite(d3)) {
            return d3;
        }
        if (this.correctionType == CorrectionType.NONE) {
            return d3;
        }
        if (this.calibrationLogPDF == null) {
            switch (this.correctionType) {
                case EXACT: {
                    if (n2 == 1) {
                        d3 -= this.logMarginalDensity(d, tree.getExternalNodeCount(), dArray[0], this.clades[0].length, this.forParent[0]);
                        break;
                    }
                    if (n2 == 2 && this.taxaPartialOrder[1].length == 1) {
                        assert (!this.forParent[0] && !this.forParent[1]);
                        d3 -= this.logMarginalDensity(d, tree.getExternalNodeCount(), dArray[0], this.clades[0].length, dArray[1], this.clades[1].length);
                        break;
                    }
                    if (this.lastLam == d) {
                        for (n = 0; n < dArray.length && dArray[n] == this.lastHeights[n]; ++n) {
                        }
                        if (n == dArray.length) {
                            return this.lastValue;
                        }
                    }
                    double[] dArray2 = new double[dArray.length];
                    object = new int[dArray.length];
                    for (int i = 0; i < dArray.length; ++i) {
                        int n3 = 0;
                        for (double d4 : dArray) {
                            n3 += d4 < dArray[i] ? 1 : 0;
                        }
                        object[i] = n3 + 1;
                        dArray2[n3] = dArray[i];
                    }
                    this.lastLam = d;
                    System.arraycopy(dArray, 0, this.lastHeights, 0, this.lastHeights.length);
                    this.lastValue = d3 -= this.logMarginalDensity(d, dArray2, (int[])object, this.linsIter);
                    break;
                }
                case APPROXIMATED: {
                    double d5 = Math.log(d);
                    int n4 = 0;
                    for (int i = 0; i < n2; ++i) {
                        double d6 = -d * dArray[i];
                        if (this.freeHeights[i] > 0) {
                            d3 -= Math.log1p(-Math.exp(d6)) * (double)this.freeHeights[i];
                        }
                        d3 -= d6 + d5;
                        if (!(dArray[i] > dArray[n4])) continue;
                        n4 = i;
                    }
                    if (!this.rootCorrection) {
                        // empty if block
                    }
                    if (!Double.isNaN(d3 -= (double)(-(this.forParent[n4] ? 0 : 1)) * d * dArray[n4])) break;
                    d3 = Double.NEGATIVE_INFINITY;
                    break;
                }
                case PEXACT: {
                    int n5;
                    Arrays.sort(dArray);
                    int[] nArray = new int[n2 + 1];
                    int n6 = tree.getInternalNodeCount();
                    for (int i = 0; i < n6; ++i) {
                        d2 = tree.getNodeHeight(tree.getInternalNode(i));
                        for (n5 = 0; n5 < dArray.length && !(dArray[n5] >= d2); ++n5) {
                        }
                        if (n5 == dArray.length) {
                            int n7 = n5;
                            nArray[n7] = nArray[n7] + 1;
                            continue;
                        }
                        if (!(d2 < dArray[n5])) continue;
                        int n8 = n5;
                        nArray[n8] = nArray[n8] + 1;
                    }
                    double d7 = 0.0;
                    d7 += (double)nArray[0] * Math.log1p(-Math.exp(-d * dArray[0])) - d * dArray[0] - this.lfactorials[nArray[0]];
                    for (int i = 1; i < nArray.length - 1; ++i) {
                        n5 = nArray[i];
                        d7 += (double)n5 * (Math.log1p(-Math.exp(-d * (dArray[i] - dArray[i - 1]))) - d * dArray[i - 1]);
                        d7 += -d * dArray[i] - this.lfactorials[n5];
                    }
                    d7 += -d * (double)(nArray[n2] + 1) * dArray[n2 - 1] - this.lfactorials[nArray[n2] + 1];
                    d3 -= (d7 += Math.log(d) * (double)n2);
                    break;
                }
            }
        } else {
            double d8 = this.calibrationLogPDF.getStatisticValue(0);
            d3 = Double.isNaN(d8) || Double.isInfinite(d8) ? Double.NEGATIVE_INFINITY : (d3 -= d8);
        }
        return d3;
    }

    private double logMarginalDensity(double d, int n, double d2, int n2, boolean bl) {
        double d3;
        double d4 = d * d2;
        if (bl) {
            d3 = -2.0 * d4 + Math.log(d);
            if (n2 > 1) {
                d3 += (double)(n2 - 1) * Math.log(1.0 - Math.exp(-d4));
            }
        } else {
            assert (n2 > 1);
            d3 = -3.0 * d4 + (double)(n2 - 2) * Math.log(1.0 - Math.exp(-d4)) + Math.log(d);
            if (n == n2) {
                d3 += d4;
            }
        }
        return d3;
    }

    private double logMarginalDensity(double d, int n, double d2, int n2, double d3, int n3) {
        assert (d2 <= d3 && n2 < n3);
        int n4 = n3 - n2;
        double d4 = Math.exp(-d * d2);
        double d5 = Math.exp(-d * d3);
        double d6 = 2.0 * Math.log(d);
        d6 += (double)(n2 - 2) * Math.log(1.0 - d4);
        d6 += (double)(n4 - 3) * Math.log(1.0 - d5);
        d6 += Math.log(1.0 - (double)(2 * n4) * d5 + (double)(2 * (n4 - 1)) * d4 - (double)(n4 * (n4 - 1)) * d5 * d4 + (double)(n4 * (n4 + 1)) / 2.0 * d5 * d5 + (double)((n4 - 1) * (n4 - 2)) / 2.0 * d4 * d4);
        d6 = n3 < n ? (d6 -= d * (d2 + 3.0 * d3)) : (d6 -= d * (d2 + 2.0 * d3));
        return d6;
    }

    private double logMarginalDensity(double d, double[] dArray, int[] nArray, CalibrationLineagesIterator calibrationLineagesIterator) {
        int n;
        double d2;
        int[][] nArray2;
        int n2;
        int n3 = calibrationLineagesIterator.setup(nArray);
        int n4 = dArray.length;
        double[] dArray2 = new double[n4 + 1];
        dArray2[0] = 0.0;
        for (n2 = 1; n2 < dArray2.length; ++n2) {
            dArray2[n2] = -d * dArray[n2 - 1];
        }
        n2 = n3 == dArray2.length ? 1 : 0;
        int n5 = n4 + (n2 != 0 ? 1 : 0);
        double[] dArray3 = new double[n5];
        for (int i = 0; i < n4; ++i) {
            dArray3[i] = dArray2[i] + Math.log1p(-Math.exp(dArray2[i + 1] - dArray2[i]));
        }
        if (n2 != 0) {
            dArray3[n4] = dArray2[n4];
        }
        int[] nArray3 = new int[n5];
        int[][] nArray4 = calibrationLineagesIterator.allJoiners();
        double d3 = 0.0;
        boolean bl = true;
        int n6 = 0;
        while ((nArray2 = calibrationLineagesIterator.next()) != null) {
            ++n6;
            d2 = this.countRankedTrees(n5, nArray2, nArray4, nArray3);
            if (n2 != 0) {
                nArray3[n5 - 1] = n = nArray3[n5 - 1] + 2;
                d2 -= this.lc2[n] + this.lg2;
            }
            for (n = 0; n < n5; ++n) {
                d2 += (double)nArray3[n] * dArray3[n];
            }
            if (bl) {
                d3 = d2;
                bl = false;
                continue;
            }
            if (d3 > d2) {
                d3 += Math.log1p(Math.exp(d2 - d3));
                continue;
            }
            d3 = d2 + Math.log1p(Math.exp(d3 - d2));
        }
        d2 = 0.0;
        n = 0;
        for (int i = 0; i < n3; ++i) {
            int n7 = calibrationLineagesIterator.nStart(i);
            if (n7 <= 0) continue;
            d2 += this.lNR[n7];
            n += n7;
        }
        double d4 = this.lfactorials[n];
        double d5 = (double)n4 * Math.log(d);
        for (int i = 1; i < n4 + 1; ++i) {
            d5 += dArray2[i];
        }
        if (n2 == 0) {
            d5 += 1.0 * dArray2[n4];
        }
        return d3 += d2 + d4 + d5;
    }

    private double countRankedTrees(int n, int[][] nArray, int[][] nArray2, int[] nArray3) {
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2 = 0;
            for (int j = i; j < n; ++j) {
                int[] nArray4 = nArray[j];
                int n3 = nArray4[i];
                if (nArray2[j][i] > 0 && ++n3 > 1) {
                    d += this.lc2[n3];
                }
                int n4 = n3 - nArray4[i + 1];
                d -= this.lfactorials[n4];
                n2 += n4;
            }
            nArray3[i] = n2;
        }
        return d;
    }

    public static enum CorrectionType {
        EXACT("exact"),
        APPROXIMATED("approximated"),
        PEXACT("pexact"),
        NONE("none");

        private final String name;

        private CorrectionType(String string2) {
            this.name = string2;
        }

        public String toString() {
            return this.name;
        }
    }
}

