/*
 * Decompiled with CFR 0.152.
 */
package org.drugis.mtc.jags;

import com.jgoodies.binding.list.ObservableList;
import edu.uci.ics.jung.graph.util.Pair;
import java.text.DecimalFormat;
import java.text.Format;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import org.apache.commons.collections15.CollectionUtils;
import org.apache.commons.collections15.PredicateUtils;
import org.apache.commons.collections15.Transformer;
import org.apache.commons.lang.StringUtils;
import org.drugis.common.CollectionHelper;
import org.drugis.mtc.Parameter;
import org.drugis.mtc.data.DataType;
import org.drugis.mtc.model.Measurement;
import org.drugis.mtc.model.Network;
import org.drugis.mtc.model.Study;
import org.drugis.mtc.model.Treatment;
import org.drugis.mtc.parameterization.BasicParameter;
import org.drugis.mtc.parameterization.InconsistencyParameter;
import org.drugis.mtc.parameterization.InconsistencyParameterization;
import org.drugis.mtc.parameterization.InconsistencyStartingValueGenerator;
import org.drugis.mtc.parameterization.NetworkModel;
import org.drugis.mtc.parameterization.NetworkParameter;
import org.drugis.mtc.parameterization.NodeSplitParameterization;
import org.drugis.mtc.parameterization.ParameterComparator;
import org.drugis.mtc.parameterization.Parameterization;
import org.drugis.mtc.parameterization.PriorGenerator;
import org.drugis.mtc.parameterization.SplitParameter;
import org.drugis.mtc.parameterization.StartingValueGenerator;
import org.mvel2.templates.CompiledTemplate;
import org.mvel2.templates.TemplateCompiler;
import org.mvel2.templates.TemplateRuntime;

public class JagsSyntaxModel {
    private static final Format s_format = new DecimalFormat("0.0##E0");
    private static final Transformer<NetworkParameter, String> s_idTrans = new Transformer<NetworkParameter, String>(){

        @Override
        public String transform(NetworkParameter input) {
            return input.getName();
        }
    };
    private static final Transformer<NetworkParameter, String> s_rTransform = new Transformer<NetworkParameter, String>(){

        @Override
        public String transform(NetworkParameter input) {
            return "x[, \"" + input.getName() + "\"]";
        }
    };
    private final Parameterization d_pmtz;
    private final boolean d_isJags;
    private final boolean d_inconsistency;
    private final boolean d_nodeSplit;
    private final Network d_network;
    private final PriorGenerator d_priorGen;

    public JagsSyntaxModel(Network network, Parameterization pmtz, boolean isJags) {
        this.d_network = network;
        this.d_pmtz = pmtz;
        this.d_isJags = isJags;
        this.d_inconsistency = pmtz instanceof InconsistencyParameterization;
        this.d_nodeSplit = pmtz instanceof NodeSplitParameterization;
        this.d_priorGen = new PriorGenerator(network);
    }

    private String rewriteNumber(String s) {
        return this.d_isJags ? s.replaceFirst("E", "*10^") : s;
    }

    private String generateDataFile(List<Pair<String>> assignments) {
        final String assign = this.d_isJags ? " <- " : " = ";
        String sep = this.d_isJags ? "\n" : ",\n";
        String head = this.d_isJags ? "" : "list(\n";
        String foot = this.d_isJags ? "\n" : "\n)\n";
        Collection<String> lines = CollectionHelper.transform(assignments, new Transformer<Pair<String>, String>(){

            @Override
            public String transform(Pair<String> input) {
                return input.getFirst() + assign + input.getSecond();
            }
        });
        return head + StringUtils.join(lines, sep) + foot;
    }

    public String initialValuesText(StartingValueGenerator generator) {
        ArrayList<Pair<String>> list = new ArrayList<Pair<String>>();
        list.addAll(this.initMetaParameters(generator));
        list.addAll(this.initBaselineEffects(generator));
        list.addAll(this.initRelativeEffects(generator));
        list.addAll(this.initVarianceParameters(generator));
        return this.generateDataFile(list);
    }

    public String analysisText(String prefix) {
        ArrayList<String> list = new ArrayList<String>();
        list.add("deriv <- list(");
        list.add(this.getDerivations());
        list.add("\t)");
        list.add("# source('mtc.R')");
        list.add("# data <- append.derived(read.mtc('" + prefix + "'), deriv)\n");
        return StringUtils.join(list, "\n");
    }

    private List<Pair<String>> initMetaParameters(StartingValueGenerator generator) {
        HashMap<BasicParameter, Double> basicValues = new HashMap<BasicParameter, Double>();
        ArrayList<Pair<String>> list = new ArrayList<Pair<String>>();
        for (NetworkParameter p : this.d_pmtz.getParameters()) {
            double relativeEffect;
            if (p instanceof BasicParameter) {
                relativeEffect = generator.getRelativeEffect((BasicParameter)p);
                basicValues.put((BasicParameter)p, relativeEffect);
            } else if (p instanceof SplitParameter) {
                SplitParameter sp = (SplitParameter)p;
                relativeEffect = generator.getRelativeEffect(new BasicParameter(sp.getBaseline(), sp.getSubject()));
            } else if (p instanceof InconsistencyParameter) {
                relativeEffect = InconsistencyStartingValueGenerator.generate((InconsistencyParameter)p, (InconsistencyParameterization)this.d_pmtz, generator, basicValues);
            } else {
                throw new IllegalStateException("Unhandled parameter " + p + " of type " + p.getClass().getCanonicalName());
            }
            list.add(new Pair<String>(p.getName(), String.valueOf(relativeEffect)));
        }
        return list;
    }

    private List<Pair<String>> initBaselineEffects(StartingValueGenerator generator) {
        Number[] baselineEffects = new Double[this.d_network.getStudies().size()];
        for (int i = 0; i < this.d_network.getStudies().size(); ++i) {
            Study study = (Study)this.d_network.getStudies().get(i);
            Treatment baseline = this.d_pmtz.getStudyBaseline(study);
            baselineEffects[i] = generator.getTreatmentEffect(study, baseline);
        }
        return Collections.singletonList(new Pair<String>("mu", JagsSyntaxModel.writeVector(baselineEffects, this.d_isJags)));
    }

    private List<Pair<String>> initRelativeEffects(final StartingValueGenerator generator) {
        Number[][] relativeEffects = this.getDoubleMatrix(new StudyTreatmentTransformer<Double>(){

            @Override
            public Double transform(Study s, Treatment t) {
                Treatment b = JagsSyntaxModel.this.d_pmtz.getStudyBaseline(s);
                if (b.equals(t)) {
                    return null;
                }
                return generator.getRelativeEffect(s, new BasicParameter(b, t));
            }
        });
        return Collections.singletonList(new Pair<String>("delta", JagsSyntaxModel.writeMatrix(relativeEffects, this.d_isJags)));
    }

    private List<Pair<String>> initVarianceParameters(StartingValueGenerator generator) {
        ArrayList<Pair<String>> list = new ArrayList<Pair<String>>();
        list.add(new Pair<String>("sd.d", String.valueOf(generator.getStandardDeviation())));
        if (this.d_inconsistency) {
            list.add(new Pair<String>("sd.w", String.valueOf(generator.getStandardDeviation())));
        }
        return list;
    }

    private String getDerivations() {
        List<String> lines = NetworkModel.transformTreatmentPairs(this.d_network, new Transformer<Pair<Treatment>, String>(){

            @Override
            public String transform(Pair<Treatment> input) {
                Treatment ti = input.getFirst();
                Treatment tj = input.getSecond();
                BasicParameter p = new BasicParameter(ti, tj);
                BasicParameter q = new BasicParameter(tj, ti);
                if (!JagsSyntaxModel.this.d_pmtz.getParameters().contains(p) && !JagsSyntaxModel.this.d_pmtz.getParameters().contains(q)) {
                    String e = JagsSyntaxModel.this.expressRelativeEffect(ti, tj, s_rTransform);
                    return "\t`" + p + "` = function(x) { " + e + " }";
                }
                return null;
            }
        });
        CollectionUtils.filter(lines, PredicateUtils.notNullPredicate());
        return StringUtils.join(lines, ",\n");
    }

    public CompiledTemplate readTemplate(String path) {
        return TemplateCompiler.compileTemplate(this.getClass().getResourceAsStream(path));
    }

    public String modelText() {
        CompiledTemplate template = this.readTemplate("modelTemplate.txt");
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("dichotomous", this.d_network.getType().equals((Object)DataType.RATE));
        map.put("nodeSplit", this.d_nodeSplit);
        if (this.d_nodeSplit) {
            map.put("indirectNode", this.getIndirectEvidenceExpression());
        }
        map.put("inconsistency", this.d_inconsistency);
        map.put("relativeEffectMatrix", this.getRelativeEffectMatrix());
        double sd = this.d_priorGen.getVagueNormalSigma();
        map.put("priorPrecision", this.rewriteNumber(s_format.format(1.0 / (sd * sd))));
        map.put("stdDevUpperLimit", this.rewriteNumber(s_format.format(this.d_priorGen.getRandomEffectsSigma())));
        map.put("parameters", this.d_pmtz.getParameters());
        map.put("inconsClass", InconsistencyParameter.class);
        return String.valueOf(TemplateRuntime.execute(template, map));
    }

    private String getIndirectEvidenceExpression() {
        NodeSplitParameterization pmtz = (NodeSplitParameterization)this.d_pmtz;
        return pmtz.getIndirectParameter().toString() + " <- " + JagsSyntaxModel.writeExpression(pmtz.parameterizeIndirect(), s_idTrans);
    }

    private String expressRelativeEffect(Treatment t1, Treatment t2, Transformer<NetworkParameter, String> transform) {
        if (t1.equals(t2)) {
            return "0";
        }
        return JagsSyntaxModel.writeExpression(this.d_pmtz.parameterize(t1, t2), transform);
    }

    private String getRelativeEffectMatrix() {
        ArrayList<String> lines = new ArrayList<String>();
        ObservableList<Treatment> treatments = this.d_network.getTreatments();
        for (int i = 0; i < treatments.size(); ++i) {
            for (int j = 0; j < treatments.size(); ++j) {
                lines.add("\td[" + (i + 1) + "," + (j + 1) + "] <- " + this.expressRelativeEffect((Treatment)treatments.get(i), (Treatment)treatments.get(j), s_idTrans));
            }
        }
        return StringUtils.join(lines, "\n");
    }

    public String scriptText(String prefix, int nchains, int tuning, int simulation) {
        CompiledTemplate template = this.d_isJags ? this.readTemplate("jagsScriptTemplate.txt") : this.readTemplate("bugsScriptTemplate.txt");
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("prefix", prefix);
        map.put("nchains", nchains);
        ArrayList<Integer> chains = new ArrayList<Integer>();
        for (int i = 0; i < nchains; ++i) {
            chains.add(i + 1);
        }
        map.put("chains", chains);
        map.put("tuning", tuning);
        map.put("simulation", simulation);
        map.put("inconsistency", this.d_inconsistency);
        ArrayList<NetworkParameter> parameters = new ArrayList<NetworkParameter>(this.d_pmtz.getParameters());
        if (this.d_nodeSplit) {
            NodeSplitParameterization pmtz = (NodeSplitParameterization)this.d_pmtz;
            parameters.add(pmtz.getIndirectParameter());
        }
        map.put("parameters", parameters);
        return String.valueOf(TemplateRuntime.execute(template, map));
    }

    public String dataText() {
        ArrayList<Pair<String>> list = new ArrayList<Pair<String>>();
        if (this.d_nodeSplit) {
            Treatment baseline = ((NodeSplitParameterization)this.d_pmtz).getDirectParameter().getBaseline();
            Treatment subject = ((NodeSplitParameterization)this.d_pmtz).getDirectParameter().getSubject();
            Number[] split = new Integer[]{this.d_network.getTreatments().indexOf(baseline) + 1, this.d_network.getTreatments().indexOf(subject) + 1};
            list.add(new Pair<String>("split", JagsSyntaxModel.writeVector(split, this.d_isJags)));
        }
        list.add(new Pair<String>("ns", JagsSyntaxModel.writeNumber(this.d_network.getStudies().size(), this.d_isJags)));
        list.add(new Pair<String>("na", JagsSyntaxModel.writeVector(this.getArmCounts(), this.d_isJags)));
        list.add(new Pair<String>("t", JagsSyntaxModel.writeMatrix(this.getTreatmentMatrix(), this.d_isJags)));
        switch (this.d_network.getType()) {
            case RATE: {
                list.add(new Pair<String>("r", JagsSyntaxModel.writeMatrix(this.getResponderMatrix(), this.d_isJags)));
                list.add(new Pair<String>("n", JagsSyntaxModel.writeMatrix(this.getSampleSizeMatrix(), this.d_isJags)));
                break;
            }
            case CONTINUOUS: {
                list.add(new Pair<String>("m", JagsSyntaxModel.writeMatrix(this.getMeanMatrix(), this.d_isJags)));
                list.add(new Pair<String>("e", JagsSyntaxModel.writeMatrix(this.getStdErrMatrix(), this.d_isJags)));
                break;
            }
            default: {
                throw new IllegalArgumentException("Don't know how to generate starting values for " + (Object)((Object)this.d_network.getType()) + " data");
            }
        }
        return this.generateDataFile(list);
    }

    private Integer[] getArmCounts() {
        Integer[] count = new Integer[this.d_network.getStudies().size()];
        for (int i = 0; i < count.length; ++i) {
            count[i] = ((Study)this.d_network.getStudies().get(i)).getMeasurements().size();
        }
        return count;
    }

    public int getMaxArmCount() {
        int max = 0;
        for (int i = 0; i < this.d_network.getStudies().size(); ++i) {
            max = Math.max(max, ((Study)this.d_network.getStudies().get(i)).getMeasurements().size());
        }
        return max;
    }

    public Integer[][] getTreatmentMatrix() {
        return this.getIntegerMatrix(new StudyTreatmentTransformer<Integer>(){

            @Override
            public Integer transform(Study s, Treatment t) {
                return JagsSyntaxModel.this.d_network.getTreatments().indexOf(t) + 1;
            }
        });
    }

    public Integer[][] getResponderMatrix() {
        return this.getIntegerMatrix(new StudyTreatmentTransformer<Integer>(){

            @Override
            public Integer transform(Study s, Treatment t) {
                return NetworkModel.findMeasurement(s, t).getResponders();
            }
        });
    }

    public Double[][] getMeanMatrix() {
        return this.getDoubleMatrix(new StudyTreatmentTransformer<Double>(){

            @Override
            public Double transform(Study s, Treatment t) {
                return NetworkModel.findMeasurement(s, t).getMean();
            }
        });
    }

    public Double[][] getStdErrMatrix() {
        return this.getDoubleMatrix(new StudyTreatmentTransformer<Double>(){

            @Override
            public Double transform(Study s, Treatment t) {
                Measurement m = NetworkModel.findMeasurement(s, t);
                return m.getStdDev() / Math.sqrt(m.getSampleSize().intValue());
            }
        });
    }

    public Integer[][] getSampleSizeMatrix() {
        return this.getIntegerMatrix(new StudyTreatmentTransformer<Integer>(){

            @Override
            public Integer transform(Study s, Treatment t) {
                return NetworkModel.findMeasurement(s, t).getSampleSize();
            }
        });
    }

    public Double[][] getDoubleMatrix(StudyTreatmentTransformer<Double> transformer) {
        Number[][] m = new Double[this.d_network.getStudies().size()][this.getMaxArmCount()];
        this.getMatrix(m, transformer);
        return m;
    }

    public Integer[][] getIntegerMatrix(StudyTreatmentTransformer<Integer> transformer) {
        Number[][] m = new Integer[this.d_network.getStudies().size()][this.getMaxArmCount()];
        this.getMatrix(m, transformer);
        return m;
    }

    public <N extends Number> void getMatrix(N[][] m, StudyTreatmentTransformer<N> transformer) {
        ObservableList<Study> studies = this.d_network.getStudies();
        for (int i = 0; i < studies.size(); ++i) {
            List<Treatment> treatments = this.getTreatments((Study)studies.get(i));
            for (int j = 0; j < treatments.size(); ++j) {
                m[i][j] = (Number)transformer.transform((Study)studies.get(i), treatments.get(j));
            }
        }
    }

    private List<Treatment> getTreatments(Study study) {
        List<Treatment> treatments = NetworkModel.getTreatments(study);
        Treatment baseline = this.d_pmtz.getStudyBaseline(study);
        treatments.remove(baseline);
        treatments.add(0, baseline);
        return treatments;
    }

    public static String writeNumber(Number x, boolean jags) {
        if (x == null) {
            return "NA";
        }
        String suffix = jags && JagsSyntaxModel.isInteger(x) ? "L" : "";
        return String.valueOf(x) + suffix;
    }

    private static boolean isInteger(Number x) {
        return x instanceof Integer || x instanceof Long || x instanceof Short || x instanceof Byte;
    }

    public static String writeMatrix(Number[][] m, boolean jags) {
        int rows = m.length;
        int cols = m[0].length;
        Object[] cells = new String[rows * cols];
        for (int i = 0; i < cells.length; ++i) {
            cells[i] = JagsSyntaxModel.writeNumber(jags ? (Number)m[i % rows][i / rows] : (Number)m[i / cols][i % cols], jags);
        }
        return "structure(" + (jags ? "" : ".Data = ") + "c(" + StringUtils.join(cells, ", ") + "), .Dim = c(" + JagsSyntaxModel.writeNumber(rows, jags) + ", " + JagsSyntaxModel.writeNumber(cols, jags) + "))";
    }

    public static String writeVector(Number[] v, boolean jags) {
        Object[] cells = new String[v.length];
        for (int i = 0; i < cells.length; ++i) {
            cells[i] = JagsSyntaxModel.writeNumber(v[i], jags);
        }
        return "c(" + StringUtils.join(cells, ", ") + ")";
    }

    public static String writeExpression(Map<NetworkParameter, Integer> pmtz, Transformer<NetworkParameter, String> transform) {
        ArrayList<String> terms = new ArrayList<String>();
        TreeSet<Parameter> keys = new TreeSet<Parameter>(new ParameterComparator());
        keys.addAll(pmtz.keySet());
        for (NetworkParameter networkParameter : keys) {
            terms.add((pmtz.get(networkParameter) == -1 ? "-" : "") + transform.transform(networkParameter));
        }
        return StringUtils.join(terms, " + ");
    }

    private static interface StudyTreatmentTransformer<O> {
        public O transform(Study var1, Treatment var2);
    }
}

