/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.DerivativeWrtParameterProvider;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.hmc.JointGradient;
import dr.inference.hmc.ParallelGradientExecutor;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.DerivativeOrder;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;

public class CompoundGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
DerivativeWrtParameterProvider,
Reportable {
    protected final int dimension;
    final List<GradientWrtParameterProvider> derivativeList;
    private final Likelihood likelihood;
    private final Parameter parameter;
    private final List<DerivativeWrtParameterProvider> newDerivativeList;
    private final DerivativeOrder highestOrder;
    private final ParallelGradientExecutor parallelExecutor;

    public CompoundGradient(List<GradientWrtParameterProvider> list) {
        this(list, 0);
    }

    public CompoundGradient(List<GradientWrtParameterProvider> list, int n) {
        this.derivativeList = list;
        if (list.size() == 1) {
            this.likelihood = list.get(0).getLikelihood();
            this.parameter = list.get(0).getParameter();
            this.dimension = this.parameter.getDimension();
        } else {
            ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
            CompoundParameter object = new CompoundParameter("hmc"){

                @Override
                public void fireParameterChangedEvent() {
                    this.doNotPropagateChangeUp = true;
                    for (Parameter parameter : this.uniqueParameters) {
                        parameter.fireParameterChangedEvent();
                    }
                    this.doNotPropagateChangeUp = false;
                    this.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
                }
            };
            int n2 = 0;
            for (GradientWrtParameterProvider gradientWrtParameterProvider : list) {
                for (Likelihood likelihood : gradientWrtParameterProvider.getLikelihood().getLikelihoodSet()) {
                    if (arrayList.contains(likelihood)) continue;
                    arrayList.add(likelihood);
                }
                Parameter parameter = gradientWrtParameterProvider.getParameter();
                object.addParameter(parameter);
                n2 += parameter.getDimension();
            }
            this.likelihood = new CompoundLikelihood(arrayList);
            this.parameter = object;
            this.dimension = n2;
        }
        this.newDerivativeList = new ArrayList<DerivativeWrtParameterProvider>();
        for (GradientWrtParameterProvider gradientWrtParameterProvider : list) {
            if (!(gradientWrtParameterProvider instanceof DerivativeWrtParameterProvider)) continue;
            DerivativeWrtParameterProvider derivativeWrtParameterProvider = (DerivativeWrtParameterProvider)((Object)gradientWrtParameterProvider);
            this.newDerivativeList.add(derivativeWrtParameterProvider);
        }
        this.highestOrder = DerivativeWrtParameterProvider.getHighestOrder(this.newDerivativeList);
        this.parallelExecutor = n > 1 || n < 0 ? new ParallelGradientExecutor(n, list) : null;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension(DerivativeOrder derivativeOrder) {
        return derivativeOrder.getDerivativeDimension(this.dimension);
    }

    @Override
    public int getDimension() {
        return this.dimension;
    }

    @Override
    public double[] getDerivativeLogDensity(DerivativeOrder derivativeOrder) {
        assert (this.highestOrder.getValue() >= derivativeOrder.getValue());
        double[] dArray = new double[this.dimension];
        int n = 0;
        for (DerivativeWrtParameterProvider derivativeWrtParameterProvider : this.newDerivativeList) {
            double[] dArray2 = derivativeWrtParameterProvider.getDerivativeLogDensity(derivativeOrder);
            System.arraycopy(dArray2, 0, dArray, n, dArray2.length);
            n += dArray2.length;
        }
        return dArray;
    }

    @Override
    public DerivativeOrder getHighestOrder() {
        return this.highestOrder;
    }

    @Override
    public double[] getGradientLogDensity() {
        if (this.parallelExecutor != null) {
            return this.getDerivativeLogDensityParallelImpl(JointGradient.DerivativeType.GRADIENT);
        }
        return this.getDerivativeLogDensitySerialImpl(JointGradient.DerivativeType.GRADIENT);
    }

    private double[] getDerivativeLogDensityParallelImpl(JointGradient.DerivativeType derivativeType) {
        return this.parallelExecutor.getDerivativeLogDensityInParallel(derivativeType, (list, n) -> {
            double[] dArray = new double[n];
            int n2 = 0;
            for (Future future : list) {
                double[] dArray2 = (double[])future.get();
                System.arraycopy(dArray2, 0, dArray, n2, dArray2.length);
                n2 += dArray2.length;
            }
            return dArray;
        }, this.dimension);
    }

    private double[] getDerivativeLogDensitySerialImpl(JointGradient.DerivativeType derivativeType) {
        double[] dArray = new double[this.dimension];
        int n = 0;
        for (GradientWrtParameterProvider gradientWrtParameterProvider : this.derivativeList) {
            double[] dArray2 = derivativeType.getDerivativeLogDensity(gradientWrtParameterProvider);
            System.arraycopy(dArray2, 0, dArray, n, gradientWrtParameterProvider.getDimension());
            n += gradientWrtParameterProvider.getDimension();
        }
        return dArray;
    }

    @Override
    public String getReport() {
        return "compoundGradient." + this.parameter.getParameterName() + "\n" + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, GradientWrtParameterProvider.TOLERANCE);
    }

    public List<GradientWrtParameterProvider> getDerivativeList() {
        return this.derivativeList;
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        if (this.parallelExecutor != null) {
            return this.getDerivativeLogDensityParallelImpl(JointGradient.DerivativeType.DIAGONAL_HESSIAN);
        }
        return this.getDerivativeLogDensitySerialImpl(JointGradient.DerivativeType.DIAGONAL_HESSIAN);
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not implemented yet");
    }
}

