/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.BranchRates;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ContinuousBranchValueProvider;
import dr.evomodel.branchratemodel.CountableBranchCategoryProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.Transform;
import java.util.Collections;
import java.util.List;

public interface BranchSpecificFixedEffects {
    public double getEffect(Tree var1, NodeRef var2);

    public double[] getDesignVector(Tree var1, NodeRef var2);

    public Parameter getFixedEffectsParameter();

    public double[] getDifferential(double var1, Tree var3, NodeRef var4);

    public int getDimension();

    public static class Default
    extends Base
    implements BranchSpecificFixedEffects,
    Citable {
        private final Parameter coefficients;
        private final List<CountableBranchCategoryProvider> categoryProviders;
        private final List<ContinuousBranchValueProvider> valueProviders;
        private final List<BranchRates> branchRateProviders;
        private final boolean hasIntercept;
        private final int dim;
        public static Citation CITATION = new Citation(new Author[]{new Author("X", "Ji"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, Citation.Status.IN_PREPARATION);

        public Default(String string, List<CountableBranchCategoryProvider> list, List<ContinuousBranchValueProvider> list2, List<BranchRates> list3, Parameter parameter, boolean bl) {
            super(string);
            this.coefficients = parameter;
            this.categoryProviders = list;
            this.valueProviders = list2;
            this.branchRateProviders = list3;
            this.hasIntercept = bl;
            this.dim = list.size() + list2.size() + list3.size() + (bl ? 1 : 0);
            if (parameter.getDimension() != this.dim) {
                throw new IllegalArgumentException("Invalid parameter dimensions");
            }
            this.addModels(list);
            this.addModels(list2);
            this.addModels(list3);
            this.addVariable(parameter);
        }

        @Override
        public double getEffect(Tree tree, NodeRef nodeRef) {
            double[] dArray = this.getDesignVector(tree, nodeRef);
            double d = 0.0;
            for (int i = 0; i < this.dim; ++i) {
                d += dArray[i] * this.coefficients.getParameterValue(i);
            }
            return d;
        }

        @Override
        public double[] getDesignVector(Tree tree, NodeRef nodeRef) {
            double[] dArray = new double[this.dim];
            int n = 0;
            if (this.hasIntercept) {
                this.addIntercept(dArray);
                ++n;
            }
            for (CountableBranchCategoryProvider object : this.categoryProviders) {
                int n2 = object.getBranchCategory(tree, nodeRef);
                if (n2 == 0) continue;
                dArray[n2 - 1 + n] = 1.0;
            }
            n += this.categoryProviders.size();
            for (ContinuousBranchValueProvider continuousBranchValueProvider : this.valueProviders) {
                dArray[n] = continuousBranchValueProvider.getBranchValue(tree, nodeRef);
                ++n;
            }
            for (BranchRates branchRates : this.branchRateProviders) {
                dArray[n] = this.transformFromBranchRateModel(branchRates.getBranchRate(tree, nodeRef));
                ++n;
            }
            return dArray;
        }

        private double transformFromBranchRateModel(double d) {
            return Math.log(d);
        }

        private void addIntercept(double[] dArray) {
            dArray[0] = 1.0;
        }

        @Override
        public Parameter getFixedEffectsParameter() {
            return this.coefficients;
        }

        @Override
        public int getDimension() {
            return this.dim;
        }

        @Override
        protected void handleModelChangedEvent(Model model, Object object, int n) {
            this.fireModelChanged();
        }

        @Override
        protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
            if (variable != this.getFixedEffectsParameter()) {
                throw new RuntimeException("Unknown variable: " + variable.getVariableName());
            }
            this.fireModelChanged();
        }

        @Override
        protected void storeState() {
        }

        @Override
        protected void restoreState() {
        }

        @Override
        protected void acceptState() {
        }

        public double[][] getDesignMatrix(Tree tree) {
            double[][] dArrayArray = new double[tree.getNodeCount() - 1][];
            int n = 0;
            for (int i = 0; i < tree.getNodeCount(); ++i) {
                NodeRef nodeRef = tree.getNode(i);
                if (nodeRef == tree.getRoot()) continue;
                dArrayArray[n] = this.getDesignVector(tree, nodeRef);
                ++n;
            }
            return dArrayArray;
        }

        private void addModels(List list) {
            for (Object e : list) {
                if (!(e instanceof Model)) continue;
                this.addModel((Model)e);
            }
        }

        @Override
        public Citation.Category getCategory() {
            return Citation.Category.MOLECULAR_CLOCK;
        }

        @Override
        public String getDescription() {
            return "Location-scale relaxed clock";
        }

        @Override
        public List<Citation> getCitations() {
            return Collections.singletonList(CITATION);
        }
    }

    public static class Transformed
    extends Base
    implements BranchSpecificFixedEffects {
        private final BranchSpecificFixedEffects effects;
        private final Transform transform;

        public Transformed(BranchSpecificFixedEffects branchSpecificFixedEffects, Transform transform) {
            super("With transform");
            this.effects = branchSpecificFixedEffects;
            this.transform = transform;
            this.addModel((Model)((Object)branchSpecificFixedEffects));
        }

        @Override
        public double[] getDifferential(double d, Tree tree, NodeRef nodeRef) {
            double[] dArray = super.getDifferential(d, tree, nodeRef);
            double d2 = this.transform.gradient(this.getEffect(tree, nodeRef));
            int n = 0;
            while (n < dArray.length) {
                int n2 = n++;
                dArray[n2] = dArray[n2] * d2;
            }
            return dArray;
        }

        @Override
        public int getDimension() {
            return this.effects.getDimension();
        }

        @Override
        public double getEffect(Tree tree, NodeRef nodeRef) {
            double d = this.effects.getEffect(tree, nodeRef);
            return this.transform.inverse(d);
        }

        @Override
        public double[] getDesignVector(Tree tree, NodeRef nodeRef) {
            return this.effects.getDesignVector(tree, nodeRef);
        }

        @Override
        public Parameter getFixedEffectsParameter() {
            return this.effects.getFixedEffectsParameter();
        }

        @Override
        protected void handleModelChangedEvent(Model model, Object object, int n) {
            this.fireModelChanged();
        }

        @Override
        protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
            if (variable != this.effects.getFixedEffectsParameter()) {
                throw new RuntimeException("Unknown variable: " + variable.getVariableName());
            }
            this.fireModelChanged();
        }

        @Override
        protected void storeState() {
        }

        @Override
        protected void restoreState() {
        }

        @Override
        protected void acceptState() {
        }
    }

    public static class None
    extends Base
    implements BranchSpecificFixedEffects {
        private final Parameter location;
        private static final double[] design = new double[]{1.0};

        public None(Parameter parameter) {
            super("No effects");
            this.location = parameter;
            this.addVariable(parameter);
        }

        @Override
        protected void handleModelChangedEvent(Model model, Object object, int n) {
        }

        @Override
        protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
            if (variable != this.location) {
                throw new RuntimeException("Unknown variable: " + variable.getVariableName());
            }
            this.fireModelChanged();
        }

        @Override
        protected void storeState() {
        }

        @Override
        protected void restoreState() {
        }

        @Override
        protected void acceptState() {
        }

        @Override
        public double getEffect(Tree tree, NodeRef nodeRef) {
            return this.location.getParameterValue(0);
        }

        @Override
        public double[] getDesignVector(Tree tree, NodeRef nodeRef) {
            return (double[])design.clone();
        }

        @Override
        public Parameter getFixedEffectsParameter() {
            return this.location;
        }

        @Override
        public int getDimension() {
            return 1;
        }
    }

    public static abstract class Base
    extends AbstractModel
    implements BranchSpecificFixedEffects {
        public Base(String string) {
            super(string);
        }

        @Override
        public double[] getDifferential(double d, Tree tree, NodeRef nodeRef) {
            double[] dArray = this.getDesignVector(tree, nodeRef);
            double d2 = d / this.getEffect(tree, nodeRef);
            int n = 0;
            while (n < dArray.length) {
                int n2 = n++;
                dArray[n2] = dArray[n2] * d2;
            }
            return dArray;
        }
    }
}

