/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.datatype.DataType;
import dr.evolution.distance.DistanceMatrix;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.bigfasttree.thorney.MutationBranchMap;
import dr.evomodel.bigfasttree.thorney.MutationList;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class RzhetskyNeiBranchLengthProvider
extends AbstractModel
implements MutationBranchMap {
    public static String RZHETSKY_NEI_BRANCH_LENGTH_PROVIDER = "RzhetskyNeiBranchLengthProvider";
    private final DistanceMatrix distanceMatrix;
    private final Set<Integer> allTaxonSet;
    private boolean branchLengthsKnown;
    private double[] distanceSums;
    private double[] storedDistanceSums;
    private double[] branchLengths;
    private double[] stroredBranchLengths;
    private boolean[] updateNode;
    private boolean[] storedUpdatedNodes;
    private final TreeModel tree;
    private final Map<Integer, Set<Integer>> taxonSetMap = new HashMap<Integer, Set<Integer>>();

    public RzhetskyNeiBranchLengthProvider(DistanceMatrix distanceMatrix, TreeModel treeModel) {
        super(RZHETSKY_NEI_BRANCH_LENGTH_PROVIDER);
        this.distanceMatrix = distanceMatrix;
        this.tree = treeModel;
        this.branchLengths = new double[treeModel.getNodeCount()];
        this.stroredBranchLengths = new double[treeModel.getNodeCount()];
        this.distanceSums = new double[treeModel.getNodeCount()];
        this.storedDistanceSums = new double[treeModel.getNodeCount()];
        this.updateNode = new boolean[treeModel.getNodeCount()];
        this.storedUpdatedNodes = new boolean[treeModel.getNodeCount()];
        this.allTaxonSet = new HashSet<Integer>(this.getTaxonSets(this.tree, this.tree.getRoot()));
    }

    private Set<Integer> getTaxonSets(Tree tree, NodeRef nodeRef) {
        HashSet<Integer> hashSet = new HashSet<Integer>();
        if (tree.isExternal(nodeRef)) {
            hashSet.add(nodeRef.getNumber());
        } else {
            assert (tree.getChildCount(nodeRef) == 2) : "Must be a strictly bifurcating tree";
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                hashSet.addAll(this.getTaxonSets(tree, tree.getChild(nodeRef, i)));
            }
        }
        this.taxonSetMap.put(nodeRef.getNumber(), hashSet);
        return hashSet;
    }

    private void calculateBranchLengths(NodeRef nodeRef, NodeRef nodeRef2) {
        double d = -1.0;
        if (this.tree.isExternal(nodeRef)) {
            Set<Integer> set = this.taxonSetMap.get(nodeRef.getNumber());
            Set<Integer> set2 = this.taxonSetMap.get(nodeRef2.getNumber());
            HashSet<Integer> hashSet = new HashSet<Integer>(this.allTaxonSet);
            hashSet.removeAll(set);
            hashSet.removeAll(set2);
            double d2 = hashSet.size();
            double d3 = set2.size();
            double d4 = this.getSumOfDistances(set, hashSet);
            double d5 = this.getSumOfDistances(set, set2);
            double d6 = this.getSumOfDistances(hashSet, set2);
            d = 0.5 * (d4 / d2 + d5 / d3 - d6 / (d2 * d3));
        } else {
            NodeRef nodeRef3 = this.tree.getChild(nodeRef, 0);
            NodeRef nodeRef4 = this.tree.getChild(nodeRef, 1);
            this.calculateBranchLengths(nodeRef3, nodeRef4);
            this.calculateBranchLengths(nodeRef4, nodeRef3);
            if (nodeRef != this.tree.getRoot()) {
                Set<Integer> set = this.taxonSetMap.get(nodeRef3.getNumber());
                Set<Integer> set3 = this.taxonSetMap.get(nodeRef4.getNumber());
                Set<Integer> set4 = this.taxonSetMap.get(nodeRef2.getNumber());
                HashSet<Integer> hashSet = new HashSet<Integer>(this.allTaxonSet);
                hashSet.removeAll(set);
                hashSet.removeAll(set3);
                hashSet.removeAll(set4);
                double d7 = hashSet.size();
                double d8 = set4.size();
                double d9 = set.size();
                double d10 = set3.size();
                double d11 = (d8 * d9 + d7 * d10) / ((d7 + d8) * (d9 + d10));
                double d12 = this.getSumOfDistances(hashSet, set);
                double d13 = this.getSumOfDistances(set4, set3);
                double d14 = this.getSumOfDistances(set4, set);
                double d15 = this.getSumOfDistances(hashSet, set3);
                double d16 = this.getSumOfDistances(hashSet, set4);
                double d17 = this.getSumOfDistances(set, set3);
                d = 0.5 * (d11 * (d12 / d7 * d9 + d13 / d8 * d10) + (1.0 - d11) * (d14 / d8 * d9 + d15 / d7 * d10) - d16 / d7 * d8 - d17 / d9 * d10);
            }
        }
        this.branchLengths[nodeRef.getNumber()] = d;
    }

    private double getSumOfDistances(Set<Integer> set, Set<Integer> set2) {
        double d = 0.0;
        for (int n : set) {
            for (int n2 : set2) {
                d += this.distanceMatrix.getElement(n, n2);
            }
        }
        return d;
    }

    public double getBranchLength(NodeRef nodeRef) {
        this.calculateBranchLengths(this.tree.getRoot(), null);
        return this.branchLengths[nodeRef.getNumber()];
    }

    protected void updateNode(NodeRef nodeRef) {
        this.updateNode[nodeRef.getNumber()] = true;
        NodeRef nodeRef2 = this.tree.getParent(nodeRef);
        if (nodeRef2 != null && !this.updateNode[nodeRef2.getNumber()]) {
            this.updateNode(nodeRef2);
        }
        this.branchLengthsKnown = false;
    }

    protected void updateNodeAndChildren(NodeRef nodeRef) {
        this.updateNode(nodeRef);
        for (int i = 0; i < this.tree.getChildCount(nodeRef); ++i) {
            NodeRef nodeRef2 = this.tree.getChild(nodeRef, i);
            this.updateNode(nodeRef2);
        }
        this.branchLengthsKnown = false;
    }

    protected void updateAllNodes() {
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            this.updateNode[i] = true;
        }
        this.branchLengthsKnown = false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.fireModelChanged();
        if (model == this.tree && object instanceof TreeChangedEvent) {
            if (((TreeChangedEvent)object).isNodeChanged()) {
                NodeRef nodeRef = ((TreeChangedEvent)object).getNode();
                this.updateNodeAndChildren(nodeRef);
            } else if (((TreeChangedEvent)object).isTreeChanged()) {
                this.updateAllNodes();
            }
        }
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
    }

    @Override
    protected void acceptState() {
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    public DataType getDataType() {
        throw new UnsupportedOperationException("Unimplemented method 'getDataType'");
    }

    @Override
    public MutationList getMutations(NodeRef nodeRef) {
        MutationList.SimpleMutationList simpleMutationList = new MutationList.SimpleMutationList(this.getBranchLength(nodeRef));
        return simpleMutationList;
    }
}

