/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.basta;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.GeneralDataType;
import dr.evolution.datatype.HiddenCodons;
import dr.evolution.datatype.HiddenDataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.bigfasttree.BestSignalsFromBigFastTreeIntervals;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.evomodel.coalescent.basta.BastaLikelihoodDelegate;
import dr.evomodel.coalescent.basta.CoalescentIntervalTraversal;
import dr.evomodel.coalescent.basta.ProcessOnCoalescentIntervalDelegate;
import dr.evomodel.substmodel.SVSComplexSubstitutionModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Profileable;
import dr.inference.model.Variable;
import dr.util.Citable;
import dr.util.Citation;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.logging.Logger;

public class BastaLikelihood
extends AbstractModelLikelihood
implements TreeTraitProvider,
Citable,
Profileable,
Reportable {
    private static final boolean COUNT_TOTAL_OPERATIONS = true;
    private final BastaLikelihoodDelegate likelihoodDelegate;
    private final Tree tree;
    private final PatternList patternList;
    private final SubstitutionModel substitutionModel;
    private final Parameter popSizeParameter;
    private final BranchRateModel branchRateModel;
    private final int stateCount;
    private final TreeTraitProvider.Helper treeTraits = new TreeTraitProvider.Helper();
    private final CoalescentIntervalTraversal treeTraversalDelegate;
    private final BestSignalsFromBigFastTreeIntervals treeIntervals;
    private double logLikelihood;
    private double storedLogLikelihood;
    protected boolean likelihoodKnown;
    private boolean populationSizesKnown;
    private boolean treeIntervalsKnown;
    private boolean transitionMatricesKnown;
    private int[][] reconstructedStates;
    private int[][] storedReconstructedStates;
    protected boolean areStatesRedrawn = false;
    protected boolean storedAreStatesRedrawn = false;
    private final CodeFormatter formatter;
    private final DataType dataType;
    private final List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> NO_OPT = new ArrayList<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation>();
    private int totalPropagationCount = 0;
    private int totalMatrixUpdateCount = 0;
    private int totalIntervalReductionCount = 0;
    private int totalGetLogLikelihoodCount = 0;
    private int totalModelChangedCount = 0;
    private int totalMakeDirtyCount = 0;
    private int totalCalculateLikelihoodCount = 0;
    private int totalRateUpdateAllCount = 0;
    private int totalRateUpdateSingleCount = 0;
    private long totalLikelihoodTime = 0L;

    public BastaLikelihood(String string, Tree tree, PatternList patternList, SubstitutionModel substitutionModel, Parameter parameter, BranchRateModel branchRateModel, BastaLikelihoodDelegate bastaLikelihoodDelegate, DataType dataType, final String string2, boolean bl, int n, boolean bl2) {
        super(string);
        assert (bastaLikelihoodDelegate != null);
        assert (tree != null);
        assert (branchRateModel != null);
        assert (patternList.getPatternCount() == 1);
        assert (bl2);
        if (!(branchRateModel instanceof StrictClockBranchRates)) {
            throw new RuntimeException("Not yet implemented");
        }
        Logger logger = Logger.getLogger("dr.evomodel");
        logger.info("\nUsing BastaLikelihood");
        this.dataType = dataType;
        this.patternList = patternList;
        this.likelihoodDelegate = bastaLikelihoodDelegate;
        this.addModel(bastaLikelihoodDelegate);
        this.tree = tree;
        this.branchRateModel = branchRateModel;
        this.addModel(branchRateModel);
        this.substitutionModel = substitutionModel;
        this.addModel(substitutionModel);
        this.popSizeParameter = parameter;
        this.addVariable(parameter);
        this.stateCount = substitutionModel.getDataType().getStateCount();
        if (!(this.tree instanceof TreeModel)) {
            throw new RuntimeException("Not yet implemented");
        }
        this.treeIntervals = new BestSignalsFromBigFastTreeIntervals((TreeModel)tree);
        this.addModel(this.treeIntervals);
        this.treeTraversalDelegate = new CoalescentIntervalTraversal(tree, this.treeIntervals, branchRateModel, n);
        this.setTipData();
        boolean bl3 = false;
        this.formatter = new CodeFormatter(dataType, bl3);
        this.treeTraits.addTrait(new TreeTrait.IA(){

            @Override
            public String getTraitName() {
                return string2;
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            @Override
            public Class getTraitClass() {
                return int[].class;
            }

            @Override
            public int[] getTrait(Tree tree, NodeRef nodeRef) {
                return BastaLikelihood.this.getStatesForNode(tree, nodeRef);
            }

            @Override
            public String getTraitString(Tree tree, NodeRef nodeRef) {
                return BastaLikelihood.this.formattedState(BastaLikelihood.this.getStatesForNode(tree, nodeRef), BastaLikelihood.this.formatter);
            }
        });
        this.likelihoodKnown = false;
        this.populationSizesKnown = false;
        this.treeIntervalsKnown = false;
        this.transitionMatricesKnown = false;
    }

    private void setTipData() {
        int[] nArray = this.patternList.getPattern(0);
        for (int i = 0; i < this.tree.getExternalNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getExternalNode(i);
            int n = this.patternList.getTaxonIndex(this.tree.getNodeTaxon(nodeRef).getId());
            int n2 = nArray[n];
            if (n2 >= this.stateCount) {
                throw new RuntimeException("Not yet implemented");
            }
            double[] dArray = new double[this.stateCount];
            dArray[n2] = 1.0;
            this.likelihoodDelegate.setPartials(nodeRef.getNumber(), dArray);
        }
    }

    @Override
    public final Model getModel() {
        return this;
    }

    @Override
    public final double getLogLikelihood() {
        ++this.totalGetLogLikelihoodCount;
        if (!this.likelihoodKnown) {
            ++this.totalCalculateLikelihoodCount;
            long l = System.nanoTime();
            this.logLikelihood = this.calculateLogLikelihood();
            long l2 = System.nanoTime();
            this.totalLikelihoodTime += (l2 - l) / 1000000L;
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override
    public final void makeDirty() {
        ++this.totalMakeDirtyCount;
        this.likelihoodKnown = false;
        this.treeIntervalsKnown = false;
        this.populationSizesKnown = false;
        this.transitionMatricesKnown = false;
        this.areStatesRedrawn = false;
        this.likelihoodDelegate.makeDirty();
        this.updateAllNodes();
    }

    private void redrawAncestralStates() {
        this.logLikelihood = 0.0;
        this.areStatesRedrawn = true;
    }

    private String formattedState(int[] nArray, CodeFormatter codeFormatter) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("\"");
        codeFormatter.reset();
        for (int n : nArray) {
            stringBuffer.append(codeFormatter.getCodeString(n));
        }
        stringBuffer.append("\"");
        return stringBuffer.toString();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable != this.popSizeParameter) {
            throw new RuntimeException("Not yet implemented");
        }
        this.populationSizesKnown = false;
        this.likelihoodKnown = false;
    }

    @Override
    protected final void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.treeIntervals) {
            this.treeIntervalsKnown = false;
            this.transitionMatricesKnown = false;
        } else if (model == this.branchRateModel) {
            this.treeIntervalsKnown = false;
            this.transitionMatricesKnown = false;
        } else if (model == this.substitutionModel) {
            this.transitionMatricesKnown = false;
        } else {
            throw new RuntimeException("Not yet implemented");
        }
        ++this.totalModelChangedCount;
        this.likelihoodKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected final void storeState() {
        assert (this.likelihoodKnown) : "the likelihood should always be known at this point in the cycle";
        assert (this.populationSizesKnown);
        assert (this.treeIntervalsKnown);
        assert (this.transitionMatricesKnown);
        if (this.areStatesRedrawn) {
            for (int i = 0; i < this.reconstructedStates.length; ++i) {
                System.arraycopy(this.reconstructedStates[i], 0, this.storedReconstructedStates[i], 0, this.reconstructedStates[i].length);
            }
        }
        this.storedAreStatesRedrawn = this.areStatesRedrawn;
        this.storedLogLikelihood = this.logLikelihood;
    }

    @Override
    protected final void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = true;
        this.populationSizesKnown = true;
        this.treeIntervalsKnown = true;
        this.transitionMatricesKnown = true;
        int[][] nArray = this.reconstructedStates;
        this.reconstructedStates = this.storedReconstructedStates;
        this.storedReconstructedStates = nArray;
        this.areStatesRedrawn = this.storedAreStatesRedrawn;
    }

    @Override
    protected void acceptState() {
    }

    private double calculateLogLikelihood() {
        this.areStatesRedrawn = false;
        if (!this.transitionMatricesKnown) {
            this.likelihoodDelegate.updateEigenDecomposition(0, this.substitutionModel.getEigenDecomposition(), false);
        }
        if (!this.populationSizesKnown) {
            this.likelihoodDelegate.updatePopulationSizes(0, this.popSizeParameter.getParameterValues(), false);
        }
        if (!this.treeIntervalsKnown) {
            this.treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations();
        }
        List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list = this.treeTraversalDelegate.getBranchIntervalOperations();
        List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list2 = this.transitionMatricesKnown ? this.NO_OPT : this.treeTraversalDelegate.getMatrixOperations();
        List<Integer> list3 = this.treeTraversalDelegate.getIntervalStarts();
        this.totalPropagationCount += list.size();
        this.totalMatrixUpdateCount += list2.size();
        this.totalIntervalReductionCount += this.treeTraversalDelegate.getCoalescentIntervalCount();
        NodeRef nodeRef = this.tree.getRoot();
        double d = this.likelihoodDelegate.calculateLikelihood(list, list2, list3, nodeRef.getNumber());
        this.setAllNodesUpdated();
        this.treeIntervalsKnown = true;
        this.populationSizesKnown = true;
        this.transitionMatricesKnown = true;
        this.redrawAncestralStates();
        return d;
    }

    public double[] getGradientLogDensity() {
        int n;
        int n2;
        assert (this.substitutionModel instanceof SVSComplexSubstitutionModel);
        SVSComplexSubstitutionModel sVSComplexSubstitutionModel = (SVSComplexSubstitutionModel)this.substitutionModel;
        Parameter parameter = sVSComplexSubstitutionModel.getRatesParameter();
        List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list = this.treeTraversalDelegate.getBranchIntervalOperations();
        List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list2 = this.transitionMatricesKnown ? this.NO_OPT : this.treeTraversalDelegate.getMatrixOperations();
        List<Integer> list3 = this.treeTraversalDelegate.getIntervalStarts();
        NodeRef nodeRef = this.tree.getRoot();
        this.calculateLogLikelihood();
        double[][] dArray = this.likelihoodDelegate.calculateGradient(list, list2, list3, nodeRef.getNumber());
        double[] dArray2 = new double[this.stateCount * (this.stateCount - 1)];
        int n3 = 0;
        for (n2 = 0; n2 < this.stateCount; ++n2) {
            for (n = n2 + 1; n < this.stateCount; ++n) {
                dArray2[n3] = (dArray[n2][n] - dArray[n2][n2]) * this.substitutionModel.getFrequencyModel().getFrequency(n);
                ++n3;
            }
        }
        for (n2 = 0; n2 < this.stateCount; ++n2) {
            for (n = n2 + 1; n < this.stateCount; ++n) {
                dArray2[n3] = (dArray[n][n2] - dArray[n][n]) * this.substitutionModel.getFrequencyModel().getFrequency(n2);
                ++n3;
            }
        }
        return dArray2;
    }

    public double[] getPopSizeGradientLogDensity() {
        Parameter parameter = this.popSizeParameter;
        List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> list = this.treeTraversalDelegate.getBranchIntervalOperations();
        List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> list2 = this.transitionMatricesKnown ? this.NO_OPT : this.treeTraversalDelegate.getMatrixOperations();
        List<Integer> list3 = this.treeTraversalDelegate.getIntervalStarts();
        NodeRef nodeRef = this.tree.getRoot();
        this.calculateLogLikelihood();
        double[] dArray = this.likelihoodDelegate.calculateGradientPopSize(list, list2, list3, nodeRef.getNumber());
        double[] dArray2 = new double[this.stateCount];
        for (int i = 0; i < this.stateCount; ++i) {
            dArray2[i] = -dArray[i] * Math.pow(parameter.getParameterValue(i), -2.0);
        }
        return dArray2;
    }

    private void setAllNodesUpdated() {
        this.treeTraversalDelegate.setAllNodesUpdated();
    }

    protected void updateNode(NodeRef nodeRef) {
        ++this.totalRateUpdateSingleCount;
        this.treeTraversalDelegate.updateNode(nodeRef);
        this.likelihoodKnown = false;
    }

    protected void updateAllNodes() {
        ++this.totalRateUpdateAllCount;
        this.treeTraversalDelegate.updateAllNodes();
        this.likelihoodKnown = false;
    }

    private int[] getStatesForNode(Tree tree, NodeRef nodeRef) {
        if (!this.likelihoodKnown) {
            this.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        if (!this.areStatesRedrawn) {
            this.redrawAncestralStates();
        }
        return this.reconstructedStates[nodeRef.getNumber()];
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder();
        double d = this.getLogLikelihood();
        String string = this.likelihoodDelegate.getReport();
        if (string != null) {
            stringBuilder.append(string);
        }
        stringBuilder.append(this.getClass().getName()).append("(").append(d).append(")");
        stringBuilder.append("\n  propagation operations = ").append(this.totalPropagationCount).append("\n  matrix updates = ").append(this.totalMatrixUpdateCount).append("\n  interval operations = ").append(this.totalIntervalReductionCount).append("\n  model changes = ").append(this.totalModelChangedCount).append("\n  make dirties = ").append(this.totalMakeDirtyCount).append("\n  calculate likelihoods = ").append(this.totalCalculateLikelihoodCount).append("\n  get likelihoods = ").append(this.totalGetLogLikelihoodCount).append("\n  all rate updates = ").append(this.totalRateUpdateAllCount).append("\n  partial rate updates = ").append(this.totalRateUpdateSingleCount).append("\n  average likelihood time = ").append(this.totalLikelihoodTime / (long)this.totalCalculateLikelihoodCount);
        return stringBuilder.toString();
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return this.treeTraits.getTreeTraits();
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        return this.treeTraits.getTreeTrait(string);
    }

    public void addTrait(TreeTrait treeTrait) {
        this.treeTraits.addTrait(treeTrait);
    }

    public void addTraits(TreeTrait[] treeTraitArray) {
        this.treeTraits.addTraits(treeTraitArray);
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override
    public String getDescription() {
        if (this.likelihoodDelegate instanceof Citable) {
            return ((Citable)((Object)this.likelihoodDelegate)).getDescription();
        }
        return null;
    }

    public BastaLikelihoodDelegate getLikelihoodDelegate() {
        return this.likelihoodDelegate;
    }

    @Override
    public List<Citation> getCitations() {
        if (this.likelihoodDelegate instanceof Citable) {
            return ((Citable)((Object)this.likelihoodDelegate)).getCitations();
        }
        return new ArrayList<Citation>();
    }

    @Override
    public long getTotalCalculationCount() {
        return this.likelihoodDelegate.getTotalCalculationCount();
    }

    public Parameter getPopSizes() {
        return this.popSizeParameter;
    }

    private class CodeFormatter {
        private final DataType dataType;
        private final Function<String, String> appender;
        private final Function<Integer, String> getter;
        private boolean first = true;

        CodeFormatter(DataType dataType, boolean bl) {
            this.dataType = dataType;
            Function<String, String> function = this.appender = dataType instanceof GeneralDataType ? string -> string + " " : Function.identity();
            this.getter = dataType instanceof HiddenCodons ? (bl ? ((HiddenCodons)dataType)::getTripletWithoutHiddenCode : dataType::getTriplet) : (dataType instanceof HiddenDataType && bl ? ((HiddenDataType)((Object)dataType))::getCodeWithoutHiddenState : dataType::getCode);
        }

        String getCodeString(int n) {
            String string = this.getter.apply(n);
            if (this.first) {
                this.first = false;
            } else {
                string = this.appender.apply(string);
            }
            return string;
        }

        void reset() {
            this.first = true;
        }
    }
}

