/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.learning.parametric.bayesian.utils;

import eu.amidst.core.exponentialfamily.NaturalParameters;
import eu.amidst.core.inference.messagepassing.Message;
import eu.amidst.core.inference.messagepassing.Node;
import eu.amidst.core.inference.messagepassing.VMP;
import eu.amidst.core.learning.parametric.bayesian.utils.PlateuStructure;
import eu.amidst.core.utils.CompoundVector;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VMPLocalUpdates
extends VMP {
    static Logger logger = LoggerFactory.getLogger(VMPLocalUpdates.class);
    PlateuStructure plateuStructure;

    public VMPLocalUpdates(PlateuStructure plateuStructure) {
        this.plateuStructure = plateuStructure;
    }

    @Override
    public void runInference() {
        this.nIter = 0;
        boolean convergence = false;
        this.probOfEvidence = Double.NEGATIVE_INFINITY;
        this.local_elbo = Double.NEGATIVE_INFINITY;
        this.local_iter = 0;
        int global_iter = 0;
        this.testConvergence();
        while (!convergence && this.local_iter++ < this.maxIter) {
            boolean done = true;
            for (Node node : this.nodes) {
                if (!node.isActive() || node.isObserved() || this.plateuStructure.isNonReplicatedVar(node.getMainVariable())) continue;
                Message<NaturalParameters> selfMessage = this.newSelfMessage(node);
                Optional<Message> message = node.getChildren().stream().filter(children -> children.isActive()).map(children -> this.newMessageToParent((Node)children, node)).reduce(Message::combineNonStateless);
                if (message.isPresent()) {
                    selfMessage.combine(message.get());
                }
                this.updateCombinedMessage(node, selfMessage);
                done &= node.isDone();
            }
            convergence = this.testConvergence();
            if (!done) continue;
            convergence = true;
        }
        CompoundVector posteriorOLD = this.plateuStructure.getPlateauNaturalParameterPosterior();
        CompoundVector posteriorNew = this.plateuStructure.getPlateauNaturalParameterPosterior();
        int count = 0;
        for (Node node : this.nodes) {
            if (node.isObserved() || this.plateuStructure.isReplicatedVar(node.getMainVariable())) continue;
            if (!node.isActive() && this.plateuStructure.isNonReplicatedVar(node.getMainVariable())) {
                ++count;
                continue;
            }
            Message<NaturalParameters> selfMessage = this.newSelfMessage(node);
            Optional<Message> message = node.getChildren().stream().filter(children -> children.isActive()).map(children -> this.newMessageToParent((Node)children, node)).reduce(Message::combineNonStateless);
            if (message.isPresent()) {
                selfMessage.combine(message.get());
            }
            this.updateCombinedMessage(node, selfMessage);
            posteriorNew.setVectorByPosition(count, node.getQDist().getNaturalParameters());
            node.getQDist().setNaturalParameters((NaturalParameters)posteriorOLD.getVectorByPosition(count));
            node.getQDist().fixNumericalInstability();
            node.getQDist().updateMomentFromNaturalParameters();
            ++count;
        }
        this.plateuStructure.updateNaturalParameterPosteriors(posteriorNew);
        this.probOfEvidence = this.local_elbo;
        this.probOfEvidence = this.local_elbo;
        if (this.output) {
            System.out.println("N Iter: " + global_iter + " " + this.local_iter + ", elbo:" + this.local_elbo);
            logger.info("N Iter: {}, {}, elbo: {}", global_iter, this.local_iter, this.local_elbo);
        }
        this.nIter = this.local_iter;
    }
}

