/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.inference;

import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.inference.ImportanceSampling;
import eu.amidst.core.inference.PointEstimator;
import eu.amidst.core.io.BayesianNetworkLoader;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.models.ParentSet;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;

public class MPEInference
implements PointEstimator {
    private BayesianNetwork model;
    private List<Variable> causalOrder;
    private int sampleSize = 10;
    private int seed = 0;
    private long numberOfDiscreteVariables = 0L;
    private long numberOfDiscreteVariablesInEvidence = 0L;
    private int numberOfIterations = 100;
    private Assignment evidence = new HashMapAssignment(0);
    private Assignment MPEestimate;
    private double MPEestimateLogProbability;
    private boolean parallelMode = true;

    @Override
    public void setParallelMode(boolean parallelMode_) {
        this.parallelMode = parallelMode_;
    }

    @Override
    public void setSeed(int seed) {
        this.seed = seed;
    }

    @Override
    public void setModel(BayesianNetwork model_) {
        this.model = model_;
        this.causalOrder = Utils.getTopologicalOrder(this.model.getDAG());
        this.numberOfDiscreteVariables = this.model.getVariables().getListOfVariables().stream().filter(Variable::isMultinomial).count();
    }

    @Override
    public void setEvidence(Assignment evidence_) {
        this.evidence = evidence_;
        this.numberOfDiscreteVariablesInEvidence = this.evidence.getVariables().stream().filter(Variable::isMultinomial).count();
    }

    public void setSampleSize(int sampleSize) {
        this.sampleSize = sampleSize;
    }

    @Override
    public BayesianNetwork getOriginalModel() {
        return this.model;
    }

    @Override
    public Assignment getEstimate() {
        return this.MPEestimate;
    }

    @Override
    public double getLogProbabilityOfEstimate() {
        return this.MPEestimateLogProbability;
    }

    public void setNumberOfIterations(int numberOfIterations) {
        this.numberOfIterations = numberOfIterations;
    }

    @Override
    public void runInference() {
        this.runInference(SearchAlgorithm.HC_LOCAL);
    }

    public void runInference(SearchAlgorithm searchAlgorithm) {
        ImportanceSampling ISaux = new ImportanceSampling();
        ISaux.setModel(this.model);
        ISaux.setSamplingModel(this.model);
        ISaux.setSampleSize(this.sampleSize);
        ISaux.setParallelMode(this.parallelMode);
        ISaux.setEvidence(this.evidence);
        ISaux.setKeepDataOnMemory(true);
        Random random = new Random();
        ISaux.setSeed(random.nextInt());
        ISaux.runInference();
        Stream sample = (Stream)ISaux.getSamples().parallel();
        switch (searchAlgorithm) {
            case EXHAUSTIVE: {
                this.MPEestimate = this.sequentialSearch();
                break;
            }
            case SAMPLING: {
                this.MPEestimate = (Assignment)sample.reduce((s1, s2) -> this.model.getLogProbabiltyOf((Assignment)s1) > this.model.getLogProbabiltyOf((Assignment)s2) ? s1 : s2).get();
                break;
            }
            case SA_LOCAL: {
                this.MPEestimate = sample.map(this::simulatedAnnealingOneVar).reduce((s1, s2) -> this.model.getLogProbabiltyOf((Assignment)s1) > this.model.getLogProbabiltyOf((Assignment)s2) ? s1 : s2).get();
                break;
            }
            case SA_GLOBAL: {
                this.MPEestimate = sample.map(this::simulatedAnnealingAllVars).reduce((s1, s2) -> this.model.getLogProbabiltyOf((Assignment)s1) > this.model.getLogProbabiltyOf((Assignment)s2) ? s1 : s2).get();
                break;
            }
            case HC_GLOBAL: {
                this.MPEestimate = sample.map(this::hillClimbingAllVars).reduce((s1, s2) -> this.model.getLogProbabiltyOf((Assignment)s1) > this.model.getLogProbabiltyOf((Assignment)s2) ? s1 : s2).get();
                break;
            }
            default: {
                this.MPEestimate = sample.map(this::hillClimbingOneVar).reduce((s1, s2) -> this.model.getLogProbabiltyOf((Assignment)s1) > this.model.getLogProbabiltyOf((Assignment)s2) ? s1 : s2).get();
            }
        }
        this.MPEestimateLogProbability = this.model.getLogProbabiltyOf(this.MPEestimate);
    }

    private Assignment obtainValues(Assignment evidence, Random random) {
        Variable selectedVariable;
        int numberOfVariables = this.model.getNumberOfVars();
        HashMapAssignment result = new HashMapAssignment(evidence);
        ArrayList<Variable> contVarEvidence = new ArrayList<Variable>();
        for (int i = 0; i < numberOfVariables; ++i) {
            selectedVariable = this.causalOrder.get(i);
            ConditionalDistribution conDist = this.model.getConditionalDistributions().get(i);
            if (!selectedVariable.isMultinomial() || !Double.isNaN(evidence.getValue(selectedVariable))) continue;
            double selectedVariableNewValue = conDist.getUnivariateDistribution(result).sample(random);
            result.setValue(selectedVariable, selectedVariableNewValue);
        }
        List<Variable> modelVariables = this.model.getVariables().getListOfVariables();
        DAG graph = this.model.getDAG();
        for (int i = 0; i < numberOfVariables; ++i) {
            selectedVariable = this.causalOrder.get(i);
            if (!selectedVariable.isNormal() || Double.isNaN(evidence.getValue(selectedVariable))) continue;
            contVarEvidence.add(selectedVariable);
        }
        boolean ended = false;
        int indexCheckedVars = 0;
        while (!ended) {
            ended = true;
            while (indexCheckedVars < contVarEvidence.size()) {
                Variable currentVariable = (Variable)contVarEvidence.get(indexCheckedVars);
                ParentSet parents = graph.getParentSet(currentVariable);
                for (Variable currentParent : parents.getParents()) {
                    if (!currentParent.isNormal() || contVarEvidence.contains(currentParent)) continue;
                    ended = false;
                    contVarEvidence.add(currentParent);
                }
                ++indexCheckedVars;
            }
        }
        Collections.reverse(contVarEvidence);
        for (Variable current : contVarEvidence) {
            if (!Double.isNaN(evidence.getValue(current))) continue;
            UnivariateDistribution univariateDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(current)).getUnivariateDistribution(result);
            double newValue = univariateDistribution.sample(random);
            result.setValue(current, newValue);
        }
        for (int i = 0; i < numberOfVariables; ++i) {
            selectedVariable = this.causalOrder.get(i);
            if (!selectedVariable.isNormal() || !Double.isNaN(result.getValue(selectedVariable))) continue;
            UnivariateDistribution univariateDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(selectedVariable)).getUnivariateDistribution(result);
            double newValue = univariateDistribution.getParameters()[0];
            result.setValue(selectedVariable, newValue);
        }
        return result;
    }

    private Assignment moveDiscreteVariables(Assignment initialGuess, int numberOfMovements) {
        HashMapAssignment result = new HashMapAssignment(initialGuess);
        Random random = new Random();
        ArrayList<Integer> indicesVariablesMoved = new ArrayList<Integer>();
        if ((long)numberOfMovements > this.numberOfDiscreteVariables - this.numberOfDiscreteVariablesInEvidence) {
            numberOfMovements = (int)(this.numberOfDiscreteVariables - this.numberOfDiscreteVariablesInEvidence);
        }
        while (indicesVariablesMoved.size() < numberOfMovements) {
            int indexSelectedVariable = random.nextInt(this.model.getNumberOfVars());
            Variable selectedVariable = this.model.getVariables().getVariableById(indexSelectedVariable);
            if (indicesVariablesMoved.contains(indexSelectedVariable) || selectedVariable.isNormal() || !Double.isNaN(this.evidence.getValue(selectedVariable))) continue;
            indicesVariablesMoved.add(indexSelectedVariable);
            int newValue = random.nextInt(selectedVariable.getNumberOfStates());
            result.setValue(selectedVariable, newValue);
        }
        return result;
    }

    private Assignment assignContinuousVariables(Assignment initialGuess) {
        Variable selectedVariable;
        HashMapAssignment result = new HashMapAssignment(initialGuess);
        int numberOfVariables = this.model.getNumberOfVars();
        Random random = new Random();
        ArrayList<Variable> contVarEvidence = new ArrayList<Variable>();
        DAG graph = this.model.getDAG();
        for (int i = 0; i < numberOfVariables; ++i) {
            selectedVariable = this.causalOrder.get(i);
            if (!selectedVariable.isNormal() || Double.isNaN(this.evidence.getValue(selectedVariable))) continue;
            contVarEvidence.add(selectedVariable);
        }
        boolean ended = false;
        int indexCheckedVars = 0;
        while (!ended) {
            ended = true;
            while (indexCheckedVars < contVarEvidence.size()) {
                Variable currentVariable = (Variable)contVarEvidence.get(indexCheckedVars);
                ParentSet parents = graph.getParentSet(currentVariable);
                for (Variable currentParent : parents.getParents()) {
                    if (!currentParent.isNormal() || contVarEvidence.contains(currentParent)) continue;
                    ended = false;
                    contVarEvidence.add(currentParent);
                }
                ++indexCheckedVars;
            }
        }
        Collections.reverse(contVarEvidence);
        for (Variable current : contVarEvidence) {
            if (!Double.isNaN(this.evidence.getValue(current))) continue;
            UnivariateDistribution univariateDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(current)).getUnivariateDistribution(result);
            double newValue = univariateDistribution.sample(random);
            result.setValue(current, newValue);
        }
        for (int i = 0; i < numberOfVariables; ++i) {
            selectedVariable = this.causalOrder.get(i);
            if (!selectedVariable.isNormal() || !Double.isNaN(this.evidence.getValue(selectedVariable))) continue;
            UnivariateDistribution univariateDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(selectedVariable)).getUnivariateDistribution(result);
            double newValue = univariateDistribution.getParameters()[0];
            result.setValue(selectedVariable, newValue);
        }
        return result;
    }

    private Assignment simulatedAnnealingAllVars(Assignment initialGuess) {
        Assignment bestGuess = initialGuess;
        double R = 1000.0;
        double alpha = 0.9;
        double eps = R * Math.pow(alpha, this.numberOfIterations);
        double currentProbability = 0.0;
        Random random = new Random();
        while (R > eps) {
            Assignment newGuess = this.obtainValues(this.evidence, random);
            currentProbability = this.model.getLogProbabiltyOf(initialGuess);
            double nextProbability = this.model.getLogProbabiltyOf(newGuess);
            if (nextProbability > currentProbability) {
                bestGuess = newGuess;
            } else {
                double diff = currentProbability - nextProbability;
                double aux = random.nextDouble();
                if (aux < Math.exp(-diff / R)) {
                    bestGuess = newGuess;
                }
            }
            R = alpha * R;
        }
        return bestGuess;
    }

    private Assignment simulatedAnnealingOneVar(Assignment initialGuess) {
        double R = 1000.0;
        double alpha = 0.9;
        double eps = R * Math.pow(alpha, this.numberOfIterations);
        Assignment currentAssignment = new HashMapAssignment(initialGuess);
        double currentProbability = this.model.getLogProbabiltyOf(currentAssignment);
        Random random = new Random();
        while (R > eps) {
            Assignment nextAssignment = this.moveDiscreteVariables(initialGuess, 3);
            double nextProbability = this.model.getLogProbabiltyOf(nextAssignment = this.assignContinuousVariables(nextAssignment));
            if (nextProbability > currentProbability) {
                currentAssignment = nextAssignment;
                currentProbability = nextProbability;
            } else {
                double diff = currentProbability - nextProbability;
                double aux = random.nextDouble();
                if (aux < Math.exp(-diff / R)) {
                    currentAssignment = nextAssignment;
                    currentProbability = nextProbability;
                }
            }
            R = alpha * R;
        }
        return currentAssignment;
    }

    private Assignment hillClimbingAllVars(Assignment initialGuess) {
        double eps = 0.0;
        Assignment currentAssignment = new HashMapAssignment(initialGuess);
        double currentProbability = this.model.getLogProbabiltyOf(currentAssignment);
        Random random = new Random();
        for (double R = (double)this.numberOfIterations; R > eps; R -= 1.0) {
            Assignment nextAssignment = this.obtainValues(this.evidence, random);
            double nextProbability = this.model.getLogProbabiltyOf(nextAssignment);
            if (!(nextProbability > currentProbability)) continue;
            currentAssignment = nextAssignment;
            currentProbability = nextProbability;
        }
        return currentAssignment;
    }

    private Assignment hillClimbingOneVar(Assignment initialGuess) {
        double eps = 0.0;
        Assignment currentAssignment = new HashMapAssignment(initialGuess);
        double currentProbability = this.model.getLogProbabiltyOf(currentAssignment);
        for (double R = (double)this.numberOfIterations; R > eps; R -= 1.0) {
            Assignment nextAssignment = this.moveDiscreteVariables(currentAssignment, 3);
            double nextProbability = this.model.getLogProbabiltyOf(nextAssignment = this.assignContinuousVariables(nextAssignment));
            if (!(nextProbability > currentProbability)) continue;
            currentAssignment = nextAssignment;
            currentProbability = nextProbability;
        }
        return currentAssignment;
    }

    private Assignment bestConfig(Assignment current, int varIndex) {
        int numVars = this.model.getNumberOfVars();
        if (varIndex > numVars - 1) {
            return current;
        }
        Variable currentVariable = this.model.getVariables().getVariableById(varIndex);
        if (Double.isNaN(this.evidence.getValue(currentVariable))) {
            if (currentVariable.isMultinomial()) {
                int i;
                int numberOfStates = currentVariable.getNumberOfStates();
                ArrayList<Assignment> configs = new ArrayList<Assignment>(numberOfStates);
                for (i = 0; i < numberOfStates; ++i) {
                    configs.add(new HashMapAssignment(current));
                    ((Assignment)configs.get(i)).setValue(currentVariable, i);
                }
                if (varIndex < numVars - 1) {
                    for (i = 0; i < numberOfStates; ++i) {
                        configs.set(i, this.bestConfig((Assignment)configs.get(i), varIndex + 1));
                    }
                }
                return (Assignment)configs.stream().max((cnf1, cnf2) -> Double.compare(this.model.getLogProbabiltyOf((Assignment)cnf1), this.model.getLogProbabiltyOf((Assignment)cnf2))).get();
            }
            Assignment config0 = new HashMapAssignment(current);
            double newValue = this.model.getConditionalDistributions().get(varIndex).getUnivariateDistribution(config0).getParameters()[0];
            config0.setValue(currentVariable, newValue);
            if (varIndex < numVars - 1) {
                config0 = this.bestConfig(config0, varIndex + 1);
            }
            return config0;
        }
        if (varIndex < numVars - 1) {
            return this.bestConfig(current, varIndex + 1);
        }
        return current;
    }

    private Assignment sequentialSearch() {
        int numberOfVariables = this.model.getNumberOfVars();
        HashMapAssignment currentEstimator = new HashMapAssignment(numberOfVariables);
        for (int i = 0; i < numberOfVariables; ++i) {
            Variable selectedVariable = this.model.getVariables().getVariableById(i);
            ConditionalDistribution conDist = this.model.getConditionalDistributions().get(i);
            if (Double.isNaN(this.evidence.getValue(selectedVariable))) {
                double selectedVariableNewValue = selectedVariable.isMultinomial() ? 0.0 : conDist.getUnivariateDistribution(currentEstimator).getParameters()[0];
                currentEstimator.setValue(selectedVariable, selectedVariableNewValue);
                continue;
            }
            currentEstimator.setValue(selectedVariable, this.evidence.getValue(selectedVariable));
        }
        return this.bestConfig(currentEstimator, 0);
    }

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        BayesianNetwork bn = BayesianNetworkLoader.loadFromFile("./networks/simulated/asia.bn");
        System.out.println(bn.toString());
        MPEInference mpeInference = new MPEInference();
        mpeInference.setModel(bn);
        mpeInference.setParallelMode(true);
        System.out.println("CausalOrder: " + Arrays.toString(mpeInference.causalOrder.stream().map(v -> v.getName()).toArray()));
        System.out.println();
        List<Variable> modelVariables = Utils.getTopologicalOrder(bn.getDAG());
        int parallelSamples = 10;
        int samplingMethodSize = 1000;
        mpeInference.setSampleSize(parallelSamples);
        Variable variable1 = mpeInference.causalOrder.get(1);
        Variable variable2 = mpeInference.causalOrder.get(2);
        Variable variable3 = mpeInference.causalOrder.get(4);
        int var1value = 0;
        int var2value = 1;
        int var3value = 1;
        System.out.println("Evidence: Variable " + variable1.getName() + " = " + var1value + ", Variable " + variable2.getName() + " = " + var2value + ", " + " and Variable " + variable3.getName() + " = " + var3value);
        System.out.println();
        HashMapAssignment evidenceAssignment = new HashMapAssignment(3);
        evidenceAssignment.setValue(variable1, var1value);
        evidenceAssignment.setValue(variable2, var2value);
        evidenceAssignment.setValue(variable3, var3value);
        mpeInference.setEvidence(evidenceAssignment);
        System.out.println();
        long timeStart = System.nanoTime();
        mpeInference.runInference(SearchAlgorithm.SA_GLOBAL);
        Assignment mpeEstimate = mpeInference.getEstimate();
        System.out.println("MPE estimate (SA.All): " + mpeEstimate.outputString(modelVariables));
        System.out.println("with probability: " + Math.exp(mpeInference.getLogProbabilityOfEstimate()) + ", logProb: " + mpeInference.getLogProbabilityOfEstimate());
        long timeStop = System.nanoTime();
        double execTime = (double)(timeStop - timeStart) / 1.0E9;
        System.out.println("computed in: " + Double.toString(execTime) + " seconds");
        System.out.println();
        timeStart = System.nanoTime();
        mpeInference.runInference(SearchAlgorithm.SA_LOCAL);
        mpeEstimate = mpeInference.getEstimate();
        System.out.println("MPE estimate  (SA.Some): " + mpeEstimate.outputString(modelVariables));
        System.out.println("with probability: " + Math.exp(mpeInference.getLogProbabilityOfEstimate()) + ", logProb: " + mpeInference.getLogProbabilityOfEstimate());
        timeStop = System.nanoTime();
        execTime = (double)(timeStop - timeStart) / 1.0E9;
        System.out.println("computed in: " + Double.toString(execTime) + " seconds");
        System.out.println();
        timeStart = System.nanoTime();
        mpeInference.runInference(SearchAlgorithm.HC_GLOBAL);
        mpeEstimate = mpeInference.getEstimate();
        System.out.println("MPE estimate (HC.All): " + mpeEstimate.outputString(modelVariables));
        System.out.println("with probability: " + Math.exp(mpeInference.getLogProbabilityOfEstimate()) + ", logProb: " + mpeInference.getLogProbabilityOfEstimate());
        timeStop = System.nanoTime();
        execTime = (double)(timeStop - timeStart) / 1.0E9;
        System.out.println("computed in: " + Double.toString(execTime) + " seconds");
        System.out.println();
        timeStart = System.nanoTime();
        mpeInference.runInference(SearchAlgorithm.HC_LOCAL);
        mpeEstimate = mpeInference.getEstimate();
        System.out.println("MPE estimate  (HC.Some): " + mpeEstimate.outputString(modelVariables));
        System.out.println("with probability: " + Math.exp(mpeInference.getLogProbabilityOfEstimate()) + ", logProb: " + mpeInference.getLogProbabilityOfEstimate());
        timeStop = System.nanoTime();
        execTime = (double)(timeStop - timeStart) / 1.0E9;
        System.out.println("computed in: " + Double.toString(execTime) + " seconds");
        System.out.println();
        mpeInference.setSampleSize(samplingMethodSize);
        timeStart = System.nanoTime();
        mpeInference.runInference(SearchAlgorithm.SAMPLING);
        mpeEstimate = mpeInference.getEstimate();
        System.out.println("MPE estimate (SAMPLING): " + mpeEstimate.outputString(modelVariables));
        System.out.println("with probability: " + Math.exp(mpeInference.getLogProbabilityOfEstimate()) + ", logProb: " + mpeInference.getLogProbabilityOfEstimate());
        timeStop = System.nanoTime();
        execTime = (double)(timeStop - timeStart) / 1.0E9;
        System.out.println("computed in: " + Double.toString(execTime) + " seconds");
        System.out.println();
        timeStart = System.nanoTime();
        mpeInference.runInference(SearchAlgorithm.EXHAUSTIVE);
        mpeEstimate = mpeInference.getEstimate();
        System.out.println("MPE estimate (DETERM.): " + mpeEstimate.outputString(modelVariables));
        System.out.println("with probability: " + Math.exp(mpeInference.getLogProbabilityOfEstimate()) + ", logProb: " + mpeInference.getLogProbabilityOfEstimate());
        timeStop = System.nanoTime();
        execTime = (double)(timeStop - timeStart) / 1.0E9;
        System.out.println("computed in: " + Double.toString(execTime) + " seconds");
        System.out.println();
    }

    public static enum SearchAlgorithm {
        EXHAUSTIVE,
        SAMPLING,
        SA_LOCAL,
        SA_GLOBAL,
        HC_LOCAL,
        HC_GLOBAL;

    }
}

