/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import OpenSourceExtensions.UnorderedPair;
import bartMachine.Classifier;
import bartMachine.StatToolbox;
import bartMachine.Tools;
import bartMachine.bartMachineClassification;
import bartMachine.bartMachineRegression;
import bartMachine.bartMachineTreeNode;
import bartMachine.bartMachine_b_hyperparams;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class bartMachineRegressionMultThread
extends Classifier
implements Serializable {
    protected int num_cores = 1;
    protected int num_trees = 50;
    protected bartMachineRegression[] bart_gibbs_chain_threads;
    protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees_after_burn_in;
    private Double sample_var_y;
    protected int num_gibbs_burn_in = 250;
    protected int num_gibbs_total_iterations = 1250;
    protected int total_iterations_multithreaded;
    protected double[] cov_split_prior;
    protected Double alpha = 0.95;
    protected Double beta = 2.0;
    protected Double hyper_k = 2.0;
    protected Double hyper_q = 0.9;
    protected Double hyper_nu = 3.0;
    protected Double prob_grow = 0.2777777777777778;
    protected Double prob_prune = 0.2777777777777778;
    protected boolean verbose = true;
    protected boolean mem_cache_for_speed = true;
    protected boolean flush_indices_to_save_ram = true;
    private boolean tree_illust;
    private HashMap<Integer, IntOpenHashSet> interaction_constraints;

    public bartMachineRegressionMultThread() {
        this.setNumGibbsTotalIterations(this.num_gibbs_total_iterations);
    }

    public void setNumGibbsTotalIterations(int n) {
        this.num_gibbs_total_iterations = n;
        this.total_iterations_multithreaded = this.num_gibbs_burn_in + (int)Math.ceil((double)(n - this.num_gibbs_burn_in) / (double)this.num_cores);
    }

    public int numSamplesAfterBurning() {
        return this.num_gibbs_total_iterations - this.num_gibbs_burn_in;
    }

    protected void SetupBARTModels() {
        this.bart_gibbs_chain_threads = new bartMachineRegression[this.num_cores];
        for (int i = 0; i < this.num_cores; ++i) {
            bartMachineRegression bartMachineRegression2 = new bartMachineRegression();
            this.SetupBartModel(bartMachineRegression2, i);
        }
    }

    protected void SetupBartModel(bartMachineRegression bartMachineRegression2, int n) {
        bartMachineRegression2.setVerbose(this.verbose);
        bartMachineRegression2.num_trees = this.num_trees;
        bartMachineRegression2.num_gibbs_total_iterations = this.total_iterations_multithreaded;
        bartMachineRegression2.num_gibbs_burn_in = this.num_gibbs_burn_in;
        bartMachineRegression2.sample_var_y = this.sample_var_y;
        bartMachineRegression2.setAlpha(this.alpha);
        bartMachineRegression2.setBeta(this.beta);
        bartMachineRegression2.setK(this.hyper_k);
        bartMachineRegression2.setProbGrow(this.prob_grow);
        bartMachineRegression2.setProbPrune(this.prob_prune);
        bartMachineRegression2.setThreadNum(n);
        bartMachineRegression2.setTotalNumThreads(this.num_cores);
        bartMachineRegression2.setMemCacheForSpeed(this.mem_cache_for_speed);
        bartMachineRegression2.setFlushIndicesToSaveRAM(this.flush_indices_to_save_ram);
        if (this.cov_split_prior != null) {
            bartMachineRegression2.setCovSplitPrior(this.cov_split_prior);
        }
        if (this.interaction_constraints != null) {
            bartMachineRegression2.setInteractionConstraints(this.interaction_constraints);
        }
        if (!(bartMachineRegression2 instanceof bartMachineClassification)) {
            bartMachineRegression2.setNu(this.hyper_nu);
            bartMachineRegression2.setQ(this.hyper_q);
        }
        bartMachineRegression2.setData(this.X_y);
        bartMachineRegression2.tree_illust = this.tree_illust;
        this.bart_gibbs_chain_threads[n] = bartMachineRegression2;
    }

    public void setNormSamples(double[] dArray) {
        bartMachine_b_hyperparams.samps_std_normal = dArray;
        bartMachine_b_hyperparams.samps_std_normal_length = dArray.length;
    }

    public void setGammaSamples(double[] dArray) {
        bartMachine_b_hyperparams.samps_chi_sq_df_eq_nu_plus_n = dArray;
        bartMachine_b_hyperparams.samps_chi_sq_df_eq_nu_plus_n_length = dArray.length;
    }

    @Override
    public void Build() {
        this.SetupBARTModels();
        long l = System.currentTimeMillis();
        if (this.verbose) {
            System.out.println("building BART " + (this.mem_cache_for_speed ? "with" : "without") + " mem-cache speedup...");
        }
        this.BuildOnAllThreads();
        long l2 = System.currentTimeMillis();
        if (this.verbose) {
            System.out.println("done building BART in " + (double)(l2 - l) / 1000.0 + " sec \n");
        }
        this.ConstructBurnedChainForTreesAndOtherInformation();
    }

    protected void ConstructBurnedChainForTreesAndOtherInformation() {
        this.gibbs_samples_of_bart_trees_after_burn_in = new bartMachineTreeNode[this.numSamplesAfterBurning()][this.num_trees];
        if (this.verbose) {
            System.out.print("burning and aggregating chains from all threads... ");
        }
        for (int i = 0; i < this.num_cores; ++i) {
            int n;
            int n2;
            bartMachineRegression bartMachineRegression2 = this.bart_gibbs_chain_threads[i];
            for (int j = this.num_gibbs_burn_in; j < this.total_iterations_multithreaded && (n2 = (n = i * (this.total_iterations_multithreaded - this.num_gibbs_burn_in)) + (j - this.num_gibbs_burn_in)) < this.numSamplesAfterBurning(); ++j) {
                this.gibbs_samples_of_bart_trees_after_burn_in[n2] = bartMachineRegression2.gibbs_samples_of_bart_trees[j];
            }
        }
        if (this.verbose) {
            System.out.print("done\n");
        }
    }

    private void BuildOnAllThreads() {
        ExecutorService executorService = Executors.newFixedThreadPool(this.num_cores);
        int n = 0;
        while (n < this.num_cores) {
            final int n2 = n++;
            executorService.execute(new Runnable(){

                @Override
                public void run() {
                    bartMachineRegressionMultThread.this.bart_gibbs_chain_threads[n2].Build();
                }
            });
        }
        executorService.shutdown();
        try {
            executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
    }

    protected boolean[][][][] getNodePredictionTrainingIndicies(double[][] dArray) {
        int n;
        if (dArray == null) {
            dArray = new double[this.n][this.p];
            for (n = 0; n < this.n; ++n) {
                dArray[n] = (double[])this.X_y.get(n);
            }
        }
        n = dArray.length;
        int n2 = this.numSamplesAfterBurning();
        boolean[][][][] blArray = new boolean[n][n2][this.num_trees][this.n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n2; ++j) {
                bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[j];
                for (int k = 0; k < this.num_trees; ++k) {
                    for (int n3 : bartMachineTreeNodeArray[k].EvaluateNode((double[])dArray[i]).indicies) {
                        blArray[i][j][k][n3] = true;
                    }
                }
            }
        }
        return blArray;
    }

    protected double[][] getProjectionWeights(double[][] dArray) {
        int n;
        if (dArray == null) {
            dArray = new double[this.n][this.p];
            for (n = 0; n < this.n; ++n) {
                dArray[n] = (double[])this.X_y.get(n);
            }
        }
        n = dArray.length;
        int n2 = this.numSamplesAfterBurning();
        boolean[][][][] blArray = this.getNodePredictionTrainingIndicies(dArray);
        double[][] dArrayArray = new double[n][];
        for (int i = 0; i < n; ++i) {
            int n3;
            double[] dArray2 = new double[this.n];
            for (n3 = 0; n3 < n2; ++n3) {
                for (int j = 0; j < this.num_trees; ++j) {
                    boolean[] blArray2 = blArray[i][n3][j];
                    int n4 = Tools.sum_array(blArray2);
                    for (int k = 0; k < this.n; ++k) {
                        int n5 = k;
                        dArray2[n5] = dArray2[n5] + (double)(blArray2[k] ? 1 : 0) / ((double)n4 * (double)this.num_trees);
                    }
                }
            }
            n3 = 0;
            while (n3 < this.n) {
                int n6 = n3++;
                dArray2[n6] = dArray2[n6] * (1.0 / (double)n2);
            }
            dArrayArray[i] = dArray2;
        }
        return dArrayArray;
    }

    protected double[][] getGibbsSamplesForPrediction(final double[][] dArray, final int n) {
        final int n2 = this.numSamplesAfterBurning();
        final bartMachineRegression bartMachineRegression2 = this.bart_gibbs_chain_threads[0];
        final int n3 = dArray.length;
        final double[][] dArray2 = new double[n3][dArray[0].length];
        if (n == 1) {
            for (int i = 0; i < n3; ++i) {
                double[] dArray3 = new double[n2];
                for (int j = 0; j < n2; ++j) {
                    bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[j];
                    double d = 0.0;
                    for (int k = 0; k < this.num_trees; ++k) {
                        d += bartMachineTreeNodeArray[k].Evaluate(dArray[i]);
                    }
                    dArray3[j] = bartMachineRegression2.un_transform_y(d);
                }
                dArray2[i] = dArray3;
            }
        } else {
            int n4;
            Thread[] threadArray = new Thread[n];
            for (n4 = 0; n4 < n; ++n4) {
                final int n5 = n4;
                Thread thread = new Thread(){

                    @Override
                    public void run() {
                        for (int i = 0; i < n3; ++i) {
                            if (i % n != n5) continue;
                            double[] dArray3 = new double[n2];
                            for (int j = 0; j < n2; ++j) {
                                bartMachineTreeNode[] bartMachineTreeNodeArray = bartMachineRegressionMultThread.this.gibbs_samples_of_bart_trees_after_burn_in[j];
                                double d = 0.0;
                                for (int k = 0; k < bartMachineRegressionMultThread.this.num_trees; ++k) {
                                    d += bartMachineTreeNodeArray[k].Evaluate(dArray[i]);
                                }
                                dArray3[j] = bartMachineRegression2.un_transform_y(d);
                            }
                            dArray2[i] = dArray3;
                        }
                    }
                };
                thread.start();
                threadArray[n4] = thread;
            }
            for (n4 = 0; n4 < n; ++n4) {
                try {
                    threadArray[n4].join();
                    continue;
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
            }
        }
        return dArray2;
    }

    protected double[] getPostPredictiveIntervalForPrediction(double[] dArray, double d, int n) {
        double[][] dArray2 = new double[1][dArray.length];
        dArray2[0] = dArray;
        double[][] dArray3 = this.getGibbsSamplesForPrediction(dArray2, n);
        double[] dArray4 = dArray3[0];
        Arrays.sort(dArray4);
        int n2 = (int)Math.round((1.0 - d) / 2.0 * (double)dArray4.length) - 1;
        int n3 = (int)Math.round(((1.0 - d) / 2.0 + d) * (double)dArray4.length) - 1;
        double[] dArray5 = new double[]{dArray4[n2], dArray4[n3]};
        return dArray5;
    }

    protected double[] get95PctPostPredictiveIntervalForPrediction(double[] dArray, int n) {
        return this.getPostPredictiveIntervalForPrediction(dArray, 0.95, n);
    }

    public double[] getGibbsSamplesSigsqs() {
        DoubleArrayList doubleArrayList = new DoubleArrayList(this.num_gibbs_total_iterations);
        for (int i = 0; i < this.num_cores; ++i) {
            DoubleArrayList doubleArrayList2 = new DoubleArrayList(this.bart_gibbs_chain_threads[i].getGibbsSamplesSigsqs());
            if (i == 0) {
                doubleArrayList.addAll((DoubleList)doubleArrayList2);
                continue;
            }
            doubleArrayList.addAll(doubleArrayList2.subList(this.num_gibbs_burn_in, this.total_iterations_multithreaded));
        }
        return doubleArrayList.elements();
    }

    public boolean[][] getAcceptRejectMHsBurnin() {
        boolean[][] blArray = this.bart_gibbs_chain_threads[0].getAcceptRejectMH();
        boolean[][] blArray2 = new boolean[this.num_gibbs_burn_in][this.num_trees];
        for (int i = 1; i < this.num_gibbs_burn_in + 1; ++i) {
            blArray2[i - 1] = blArray[i];
        }
        return blArray2;
    }

    public boolean[][] getAcceptRejectMHsAfterBurnIn(int n) {
        boolean[][] blArray = this.bart_gibbs_chain_threads[n - 1].getAcceptRejectMH();
        boolean[][] blArray2 = new boolean[this.total_iterations_multithreaded - this.num_gibbs_burn_in][this.num_trees];
        for (int i = this.num_gibbs_burn_in; i < this.total_iterations_multithreaded; ++i) {
            blArray2[i - this.num_gibbs_burn_in] = blArray[i];
        }
        return blArray2;
    }

    public int[][] getCountsForAllAttribute(String string) {
        int[][] nArray = new int[this.num_gibbs_total_iterations - this.num_gibbs_burn_in][this.p];
        for (int i = 0; i < this.num_gibbs_total_iterations - this.num_gibbs_burn_in; ++i) {
            bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
            int[] nArray2 = new int[this.p];
            for (bartMachineTreeNode bartMachineTreeNode2 : bartMachineTreeNodeArray) {
                if (string.equals("splits")) {
                    nArray2 = Tools.add_arrays(nArray2, bartMachineTreeNode2.attributeSplitCounts());
                    continue;
                }
                if (!string.equals("trees")) continue;
                nArray2 = Tools.binary_add_arrays(nArray2, bartMachineTreeNode2.attributeSplitCounts());
            }
            nArray[i] = nArray2;
        }
        return nArray;
    }

    public double[] getAttributeProps(String string) {
        int[][] nArray = this.getCountsForAllAttribute(string);
        double[] dArray = new double[this.p];
        for (int i = 0; i < this.num_gibbs_total_iterations - this.num_gibbs_burn_in; ++i) {
            dArray = Tools.add_arrays(dArray, nArray[i]);
        }
        Tools.normalize_array(dArray);
        return dArray;
    }

    public int[][] getInteractionCounts() {
        int[][] nArray = new int[this.p][this.p];
        for (int i = 0; i < this.gibbs_samples_of_bart_trees_after_burn_in.length; ++i) {
            bartMachineTreeNode[] bartMachineTreeNodeArray;
            for (bartMachineTreeNode bartMachineTreeNode2 : bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i]) {
                HashSet<UnorderedPair<Integer>> hashSet = new HashSet<UnorderedPair<Integer>>(this.p * this.p);
                bartMachineTreeNode2.findInteractions(hashSet);
                for (UnorderedPair<Integer> unorderedPair : hashSet) {
                    int[] nArray2 = nArray[unorderedPair.getFirst()];
                    int n = unorderedPair.getSecond();
                    nArray2[n] = nArray2[n] + 1;
                }
            }
        }
        return nArray;
    }

    @Override
    protected void FlushData() {
        for (int i = 0; i < this.num_cores; ++i) {
            this.bart_gibbs_chain_threads[i].FlushData();
        }
    }

    @Override
    public double Evaluate(double[] dArray) {
        return this.EvaluateViaSampAvg(dArray, 1);
    }

    @Override
    public double Evaluate(double[] dArray, int n) {
        return this.EvaluateViaSampAvg(dArray, n);
    }

    public double EvaluateViaSampAvg(double[] dArray, int n) {
        double[][] dArray2 = new double[1][dArray.length];
        dArray2[0] = dArray;
        double[][] dArray3 = this.getGibbsSamplesForPrediction(dArray2, n);
        return StatToolbox.sample_average(dArray3[0]);
    }

    public double EvaluateViaSampMed(double[] dArray, int n) {
        double[][] dArray2 = new double[1][dArray.length];
        dArray2[0] = dArray;
        double[][] dArray3 = this.getGibbsSamplesForPrediction(dArray2, n);
        return StatToolbox.sample_median(dArray3[0]);
    }

    public int[][] getDepthsForTreesInGibbsSampAfterBurnIn(int n) {
        return this.bart_gibbs_chain_threads[n - 1].getDepthsForTrees(this.num_gibbs_burn_in, this.total_iterations_multithreaded);
    }

    public int[][] getNumNodesAndLeavesForTreesInGibbsSampAfterBurnIn(int n) {
        return this.bart_gibbs_chain_threads[n - 1].getNumNodesAndLeavesForTrees(this.num_gibbs_burn_in, this.total_iterations_multithreaded);
    }

    @Override
    public void setData(ArrayList<double[]> arrayList) {
        this.X_y = arrayList;
        this.n = arrayList.size();
        this.p = arrayList.get(0).length - 1;
    }

    public void printTreeIllustations() {
        this.tree_illust = true;
    }

    public void setCovSplitPrior(double[] dArray) {
        this.cov_split_prior = dArray;
    }

    public void intializeInteractionConstraints(int n) {
        this.interaction_constraints = new HashMap(n);
    }

    public void addInteractionConstraint(int n, int[] nArray) {
        if (this.interaction_constraints.get(n) == null) {
            this.interaction_constraints.put(n, new IntOpenHashSet());
        }
        IntOpenHashSet intOpenHashSet = this.interaction_constraints.get(n);
        for (int n2 : nArray) {
            intOpenHashSet.add(n2);
        }
    }

    public void setNumGibbsBurnIn(int n) {
        this.num_gibbs_burn_in = n;
    }

    public void setNumTrees(int n) {
        this.num_trees = n;
    }

    public void setSampleVarY(double d) {
        this.sample_var_y = d;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setBeta(double d) {
        this.beta = d;
    }

    public void setK(double d) {
        this.hyper_k = d;
    }

    public void setQ(double d) {
        this.hyper_q = d;
    }

    public void setNU(double d) {
        this.hyper_nu = d;
    }

    public void setProbGrow(double d) {
        this.prob_grow = d;
    }

    public void setProbPrune(double d) {
        this.prob_prune = d;
    }

    public void setVerbose(boolean bl) {
        this.verbose = bl;
    }

    public void setSeed(int n) {
        StatToolbox.setSeed(n);
    }

    public void setNumCores(int n) {
        this.num_cores = n;
    }

    public void setMemCacheForSpeed(boolean bl) {
        this.mem_cache_for_speed = bl;
    }

    public void setFlushIndicesToSaveRAM(boolean bl) {
        this.flush_indices_to_save_ram = bl;
    }

    @Override
    public void StopBuilding() {
    }

    public bartMachineTreeNode[] extractRawNodeInformation(int n) {
        bartMachineTreeNode[] bartMachineTreeNodeArray = new bartMachineTreeNode[this.num_trees];
        for (int i = 0; i < this.num_trees; ++i) {
            bartMachineTreeNodeArray[i] = this.gibbs_samples_of_bart_trees_after_burn_in[n][i];
        }
        return bartMachineTreeNodeArray;
    }
}

