/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.branchratemodel.NodeRateMap;
import dr.evomodel.branchratemodel.RandomLocalClockModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Citable;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
import org.ejml.data.DenseMatrix64F;

public class ScaledByTreeTimeBranchRateModel
extends AbstractBranchRateModel
implements DifferentiableBranchRates,
Citable {
    private final TreeModel treeModel;
    private final BranchRateModel branchRateModel;
    private final DifferentiableBranchRates differentiableBranchRateModel;
    private final Parameter meanRateParameter;
    private boolean scaleFactorKnown;
    private boolean storedScaleFactorKnown;
    private double scaleFactor;
    private double storedScaleFactor;
    private double branchTotal;
    private double storedBranchTotal;
    private double timeTotal;
    private double storedTimeTotal;
    private double meanRateParameterValue;
    private double storedMeanRate;
    private DenseMatrix64F Jacobian;
    private static final boolean USE_GENERIC = true;

    public ScaledByTreeTimeBranchRateModel(TreeModel treeModel, BranchRateModel branchRateModel, Parameter parameter) {
        super("scaledByTreeTimeBranchRates");
        this.treeModel = treeModel;
        this.branchRateModel = branchRateModel;
        this.differentiableBranchRateModel = branchRateModel instanceof DifferentiableBranchRates ? (DifferentiableBranchRates)((Object)branchRateModel) : null;
        this.meanRateParameter = parameter;
        this.meanRateParameterValue = 1.0;
        this.addModel(treeModel);
        this.addModel(branchRateModel);
        if (parameter != null) {
            this.addVariable(parameter);
        }
        this.scaleFactorKnown = false;
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        this.scaleFactorKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.scaleFactorKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
        this.storedScaleFactor = this.scaleFactor;
        this.storedScaleFactorKnown = this.scaleFactorKnown;
        this.storedBranchTotal = this.branchTotal;
        this.storedTimeTotal = this.timeTotal;
        this.storedMeanRate = this.meanRateParameterValue;
    }

    @Override
    protected void restoreState() {
        this.scaleFactor = this.storedScaleFactor;
        this.scaleFactorKnown = this.storedScaleFactorKnown;
        this.branchTotal = this.storedBranchTotal;
        this.timeTotal = this.storedTimeTotal;
        this.meanRateParameterValue = this.storedMeanRate;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public double getBranchRateDifferential(Tree tree, NodeRef nodeRef) {
        this.checkDifferentiability();
        return this.differentiableBranchRateModel.getBranchRateDifferential(tree, nodeRef);
    }

    @Override
    public double getBranchRateSecondDifferential(Tree tree, NodeRef nodeRef) {
        this.checkDifferentiability();
        return this.differentiableBranchRateModel.getBranchRateSecondDifferential(tree, nodeRef);
    }

    @Override
    public Parameter getRateParameter() {
        this.checkDifferentiability();
        return this.differentiableBranchRateModel.getRateParameter();
    }

    @Override
    public int getParameterIndexFromNode(NodeRef nodeRef) {
        this.checkDifferentiability();
        return this.differentiableBranchRateModel.getParameterIndexFromNode(nodeRef);
    }

    public double getPriorRateAsIncrement(Tree tree) {
        return 0.0;
    }

    private void checkDifferentiability() {
        if (this.differentiableBranchRateModel == null) {
            throw new RuntimeException("Non-differentiable base BranchRateModel");
        }
    }

    @Override
    public ArbitraryBranchRates.BranchRateTransform getTransform() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] updateGradientLogDensity(double[] dArray, double[] dArray2, int n, int n3) {
        if (!this.scaleFactorKnown) {
            this.calculateScaleFactor();
            this.scaleFactorKnown = true;
        }
        double[] dArray3 = new double[this.treeModel.getNodeCount() - 1];
        if (this.meanRateParameter != null) {
            this.meanRateParameterValue = this.meanRateParameter.getParameterValue(0);
        }
        this.forEachOverRates((n2, nodeRef2, d3) -> {
            double d4 = this.scaleFactor * this.scaleFactor / this.timeTotal * this.treeModel.getBranchLength(nodeRef2);
            dArray[n2] = this.mapReduceOverRates((n, nodeRef, d2) -> d4 * d2 * this.meanRateParameterValue * dArray[n], (d, d2) -> d - d2, this.scaleFactor * this.meanRateParameterValue * dArray[n2]);
            return 0.0;
        });
        return dArray3;
    }

    @Override
    public double[] updateDiagonalHessianLogDensity(double[] dArray, double[] dArray2, double[] dArray3, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double mapReduceOverRates(NodeRateMap nodeRateMap, DoubleBinaryOperator doubleBinaryOperator, double d) {
        this.checkDifferentiability();
        return this.differentiableBranchRateModel.mapReduceOverRates(nodeRateMap, doubleBinaryOperator, d);
    }

    @Override
    public void forEachOverRates(NodeRateMap nodeRateMap) {
        this.checkDifferentiability();
        this.differentiableBranchRateModel.forEachOverRates(nodeRateMap);
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        assert (tree == this.treeModel);
        if (!this.scaleFactorKnown) {
            this.calculateScaleFactor();
            this.scaleFactorKnown = true;
        }
        return this.meanRateParameterValue * this.scaleFactor * this.branchRateModel.getBranchRate(tree, nodeRef);
    }

    private void calculateScaleFactor() {
        double d = 0.0;
        double d2 = 0.0;
        for (int i = 0; i < this.treeModel.getNodeCount(); ++i) {
            NodeRef nodeRef = this.treeModel.getNode(i);
            if (this.treeModel.isRoot(nodeRef)) continue;
            double d3 = this.treeModel.getBranchLength(nodeRef);
            double d4 = d3 * this.branchRateModel.getBranchRate(this.treeModel, nodeRef);
            d += d3;
            d2 += d4;
        }
        double d5 = d / d2;
        if (this.meanRateParameter != null) {
            this.meanRateParameterValue = this.meanRateParameter.getParameterValue(0);
        }
        this.scaleFactor = d5;
        this.branchTotal = d2;
        this.timeTotal = d;
    }

    private double getTempTotal(NodeRef nodeRef, NodeRef nodeRef2) {
        double d = -this.branchRateModel.getBranchRate(this.treeModel, nodeRef) * this.treeModel.getBranchLength(nodeRef2) * this.scaleFactor;
        return d /= this.branchTotal;
    }

    @Override
    public Tree getTree() {
        return this.treeModel;
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.MOLECULAR_CLOCK;
    }

    @Override
    public String getDescription() {
        String string = this.branchRateModel instanceof Citable ? ((Citable)((Object)this.branchRateModel)).getDescription() : "Unknown clock model";
        string = string + " with scaling-by-tree-time";
        return string;
    }

    @Override
    public List<Citation> getCitations() {
        ArrayList<Citation> arrayList = this.branchRateModel instanceof Citable ? new ArrayList<Citation>(((Citable)((Object)this.branchRateModel)).getCitations()) : new ArrayList();
        arrayList.add(RandomLocalClockModel.CITATION);
        return arrayList;
    }
}

