/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.inference.model.GradientProvider;
import dr.math.GammaFunction;
import dr.math.distributions.MultivariateDistribution;
import dr.util.EuclideanToInfiniteNormUnitBallTransform;

public class SphericalBetaDistribution
implements MultivariateDistribution,
GradientProvider {
    private double shape;
    private int dim;
    private double logNormalizationConstant;
    public static final String TYPE = "SphericalBetaDistribution";

    public SphericalBetaDistribution(int n, double d) {
        assert (d > 0.0) : "Shape parameter of the spherical beta distribution must be positive.";
        this.shape = d;
        this.dim = n;
        this.logNormalizationConstant = this.computeLogNormalizationConstant();
    }

    SphericalBetaDistribution(int n) {
        this(n, 1.0);
    }

    private double computeLogNormalizationConstant() {
        return GammaFunction.lnGamma(this.shape + 0.5 * (double)this.dim) - 0.5 * (double)this.dim * Math.log(Math.PI) - GammaFunction.lnGamma(this.shape);
    }

    @Override
    public double logPdf(double[] dArray) {
        assert (dArray.length == this.dim);
        return this.shape == 1.0 ? this.logNormalizationConstant : this.logNormalizationConstant + (this.shape - 1.0) * Math.log(1.0 - this.squaredNorm(dArray));
    }

    private double squaredNorm(double[] dArray) {
        double d = EuclideanToInfiniteNormUnitBallTransform.squaredNorm(dArray);
        assert (d <= 1.0);
        return d;
    }

    @Override
    public double[] getGradientLogDensity(Object object) {
        return this.gradLogPdf((double[])object);
    }

    private double[] gradLogPdf(double[] dArray) {
        assert (dArray.length == this.dim);
        if (this.shape == 1.0) {
            return new double[dArray.length];
        }
        double d = -2.0 * (this.shape - 1.0) / (1.0 - this.squaredNorm(dArray));
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray2[i] = d * dArray[i];
        }
        return dArray2;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    @Override
    public int getDimension() {
        return this.dim;
    }

    @Override
    public double[][] getScaleMatrix() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] getMean() {
        throw new RuntimeException("Not yet implemented");
    }
}

