/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.lazy;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.UpdateableClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class IBk
extends Classifier
implements OptionHandler,
UpdateableClassifier,
WeightedInstancesHandler {
    protected Instances m_Train;
    protected int m_NumClasses;
    protected int m_ClassType;
    protected double[] m_Min;
    protected double[] m_Max;
    protected int m_kNN;
    protected int m_kNNUpper;
    protected boolean m_kNNValid;
    protected int m_WindowSize;
    protected int m_DistanceWeighting;
    protected boolean m_CrossValidate;
    protected boolean m_MeanSquared;
    protected boolean m_DontNormalize;
    public static final int WEIGHT_NONE = 1;
    public static final int WEIGHT_INVERSE = 2;
    public static final int WEIGHT_SIMILARITY = 4;
    public static final Tag[] TAGS_WEIGHTING = new Tag[]{new Tag(1, "No distance weighting"), new Tag(2, "Weight by 1/distance"), new Tag(4, "Weight by 1-distance")};
    protected double m_NumAttributesUsed;

    public String globalInfo() {
        return "K-nearest neighbours classifier. Normalizes attributes by default. Can select appropriate value of K based on cross-validation. Can also do distance weighting. For more information, see\n\nAha, D., and D. Kibler (1991) \"Instance-based learning algorithms\", Machine Learning, vol.6, pp. 37-66.";
    }

    public IBk(int n) {
        this.init();
        this.setKNN(n);
    }

    public IBk() {
        this.init();
    }

    public String KNNTipText() {
        return "The number of neighbours to use.";
    }

    public void setKNN(int n) {
        this.m_kNN = n;
        this.m_kNNUpper = n;
        this.m_kNNValid = false;
    }

    public int getKNN() {
        return this.m_kNN;
    }

    public String windowSizeTipText() {
        return "Gets the maximum number of instances allowed in the training pool. The addition of new instances above this value will result in old instances being removed. A value of 0 signifies no limit to the number of training instances.";
    }

    public int getWindowSize() {
        return this.m_WindowSize;
    }

    public void setWindowSize(int n) {
        this.m_WindowSize = n;
    }

    public String distanceWeightingTipText() {
        return "Gets the distance weighting method used.";
    }

    public SelectedTag getDistanceWeighting() {
        return new SelectedTag(this.m_DistanceWeighting, TAGS_WEIGHTING);
    }

    public void setDistanceWeighting(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_WEIGHTING) {
            this.m_DistanceWeighting = selectedTag.getSelectedTag().getID();
        }
    }

    public String meanSquaredTipText() {
        return "Whether the mean squared error is used rather than mean absolute error when doing cross-validation for regression problems.";
    }

    public boolean getMeanSquared() {
        return this.m_MeanSquared;
    }

    public void setMeanSquared(boolean bl) {
        this.m_MeanSquared = bl;
    }

    public String crossValidateTipText() {
        return "Whether hold-one-out cross-validation will be used to select the best k value.";
    }

    public boolean getCrossValidate() {
        return this.m_CrossValidate;
    }

    public void setCrossValidate(boolean bl) {
        this.m_CrossValidate = bl;
    }

    public int getNumTraining() {
        return this.m_Train.numInstances();
    }

    public double getAttributeMin(int n) throws Exception {
        if (this.m_Min == null) {
            throw new Exception("Minimum value for attribute not available!");
        }
        return this.m_Min[n];
    }

    public double getAttributeMax(int n) throws Exception {
        if (this.m_Max == null) {
            throw new Exception("Maximum value for attribute not available!");
        }
        return this.m_Max[n];
    }

    public String noNormalizationTipText() {
        return "Whether attribute normalization is turned off.";
    }

    public boolean getNoNormalization() {
        return this.m_DontNormalize;
    }

    public void setNoNormalization(boolean bl) {
        this.m_DontNormalize = bl;
    }

    public void buildClassifier(Instances instances) throws Exception {
        if (instances.classIndex() < 0) {
            throw new Exception("No class attribute assigned to instances");
        }
        if (instances.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
        }
        try {
            this.m_NumClasses = instances.numClasses();
            this.m_ClassType = instances.classAttribute().type();
        }
        catch (Exception exception) {
            throw new Error("This should never be reached");
        }
        this.m_Train = new Instances(instances, 0, instances.numInstances());
        this.m_Train.deleteWithMissingClass();
        if (this.m_WindowSize > 0 && instances.numInstances() > this.m_WindowSize) {
            this.m_Train = new Instances(this.m_Train, this.m_Train.numInstances() - this.m_WindowSize, this.m_WindowSize);
        }
        if (this.m_DontNormalize) {
            this.m_Min = null;
            this.m_Max = null;
        } else {
            this.m_Min = new double[this.m_Train.numAttributes()];
            this.m_Max = new double[this.m_Train.numAttributes()];
            for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
                this.m_Max[i] = Double.NaN;
                this.m_Min[i] = Double.NaN;
            }
            Enumeration enumeration = this.m_Train.enumerateInstances();
            while (enumeration.hasMoreElements()) {
                this.updateMinMax((Instance)enumeration.nextElement());
            }
        }
        this.m_NumAttributesUsed = 0.0;
        for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
            if (i == this.m_Train.classIndex() || !this.m_Train.attribute(i).isNominal() && !this.m_Train.attribute(i).isNumeric()) continue;
            this.m_NumAttributesUsed += 1.0;
        }
        this.m_kNNValid = false;
    }

    public void updateClassifier(Instance instance) throws Exception {
        if (!this.m_Train.equalHeaders(instance.dataset())) {
            throw new Exception("Incompatible instance types");
        }
        if (instance.classIsMissing()) {
            return;
        }
        if (!this.m_DontNormalize) {
            this.updateMinMax(instance);
        }
        this.m_Train.add(instance);
        this.m_kNNValid = false;
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_Train.numInstances() == 0) {
            throw new Exception("No training instances!");
        }
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            this.m_kNNValid = false;
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
            }
        }
        if (!this.m_kNNValid && this.m_CrossValidate && this.m_kNNUpper >= 1) {
            this.crossValidate();
        }
        if (!this.m_DontNormalize) {
            this.updateMinMax(instance);
        }
        NeighborList neighborList = this.findNeighbors(instance);
        return this.makeDistribution(neighborList);
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(8);
        vector.addElement(new Option("\tWeight neighbours by the inverse of their distance\n\t(use when k > 1)", "I", 0, "-I"));
        vector.addElement(new Option("\tWeight neighbours by 1 - their distance\n\t(use when k > 1)", "F", 0, "-F"));
        vector.addElement(new Option("\tNumber of nearest neighbours (k) used in classification.\n\t(Default = 1)", "K", 1, "-K <number of neighbors>"));
        vector.addElement(new Option("\tMinimise mean squared error rather than mean absolute\n\terror when using -X option with numeric prediction.", "E", 0, "-E"));
        vector.addElement(new Option("\tMaximum number of training instances maintained.\n\tTraining instances are dropped FIFO. (Default = no window)", "W", 1, "-W <window size>"));
        vector.addElement(new Option("\tSelect the number of nearest neighbours between 1\n\tand the k value specified using hold-one-out evaluation\n\ton the training data (use when k > 1)", "X", 0, "-X"));
        vector.addElement(new Option("\tDon't normalize the data.\n", "N", 0, "-N"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('K', stringArray);
        if (string.length() != 0) {
            this.setKNN(Integer.parseInt(string));
        } else {
            this.setKNN(1);
        }
        String string2 = Utils.getOption('W', stringArray);
        if (string2.length() != 0) {
            this.setWindowSize(Integer.parseInt(string2));
        } else {
            this.setWindowSize(0);
        }
        if (Utils.getFlag('I', stringArray)) {
            this.setDistanceWeighting(new SelectedTag(2, TAGS_WEIGHTING));
        } else if (Utils.getFlag('F', stringArray)) {
            this.setDistanceWeighting(new SelectedTag(4, TAGS_WEIGHTING));
        } else {
            this.setDistanceWeighting(new SelectedTag(1, TAGS_WEIGHTING));
        }
        this.setCrossValidate(Utils.getFlag('X', stringArray));
        this.setMeanSquared(Utils.getFlag('E', stringArray));
        this.setNoNormalization(Utils.getFlag('N', stringArray));
        Utils.checkForRemainingOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = new String[11];
        int n = 0;
        stringArray[n++] = "-K";
        stringArray[n++] = "" + this.getKNN();
        stringArray[n++] = "-W";
        stringArray[n++] = "" + this.m_WindowSize;
        if (this.getCrossValidate()) {
            stringArray[n++] = "-X";
        }
        if (this.getMeanSquared()) {
            stringArray[n++] = "-E";
        }
        if (this.m_DistanceWeighting == 2) {
            stringArray[n++] = "-I";
        } else if (this.m_DistanceWeighting == 4) {
            stringArray[n++] = "-F";
        }
        if (this.m_DontNormalize) {
            stringArray[n++] = "-N";
        }
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    public String toString() {
        if (this.m_Train == null) {
            return "IBk: No model built yet.";
        }
        if (!this.m_kNNValid && this.m_CrossValidate) {
            this.crossValidate();
        }
        String string = "IB1 instance-based classifier\nusing " + this.m_kNN;
        switch (this.m_DistanceWeighting) {
            case 2: {
                string = string + " inverse-distance-weighted";
                break;
            }
            case 4: {
                string = string + " similarity-weighted";
            }
        }
        string = string + " nearest neighbour(s) for classification\n";
        if (this.m_WindowSize != 0) {
            string = string + "using a maximum of " + this.m_WindowSize + " (windowed) training instances\n";
        }
        return string;
    }

    protected void init() {
        this.setKNN(1);
        this.m_WindowSize = 0;
        this.m_DistanceWeighting = 1;
        this.m_CrossValidate = false;
        this.m_MeanSquared = false;
        this.m_DontNormalize = false;
    }

    protected double distance(Instance instance, Instance instance2) {
        double d = 0.0;
        int n = 0;
        int n2 = 0;
        while (n < instance.numValues() || n2 < instance2.numValues()) {
            double d2;
            int n3 = n >= instance.numValues() ? this.m_Train.numAttributes() : instance.index(n);
            int n4 = n2 >= instance2.numValues() ? this.m_Train.numAttributes() : instance2.index(n2);
            if (n3 == this.m_Train.classIndex()) {
                ++n;
                continue;
            }
            if (n4 == this.m_Train.classIndex()) {
                ++n2;
                continue;
            }
            if (n3 == n4) {
                d2 = this.difference(n3, instance.valueSparse(n), instance2.valueSparse(n2));
                ++n;
                ++n2;
            } else if (n3 > n4) {
                d2 = this.difference(n4, 0.0, instance2.valueSparse(n2));
                ++n2;
            } else {
                d2 = this.difference(n3, instance.valueSparse(n), 0.0);
                ++n;
            }
            d += d2 * d2;
        }
        return Math.sqrt(d / this.m_NumAttributesUsed);
    }

    protected double difference(int n, double d, double d2) {
        switch (this.m_Train.attribute(n).type()) {
            case 1: {
                if (Instance.isMissingValue(d) || Instance.isMissingValue(d2) || (int)d != (int)d2) {
                    return 1.0;
                }
                return 0.0;
            }
            case 0: {
                if (Instance.isMissingValue(d) || Instance.isMissingValue(d2)) {
                    if (Instance.isMissingValue(d) && Instance.isMissingValue(d2)) {
                        return 1.0;
                    }
                    double d3 = Instance.isMissingValue(d2) ? this.norm(d, n) : this.norm(d2, n);
                    if (d3 < 0.5) {
                        d3 = 1.0 - d3;
                    }
                    return d3;
                }
                return this.norm(d, n) - this.norm(d2, n);
            }
        }
        return 0.0;
    }

    protected double norm(double d, int n) {
        if (this.m_DontNormalize) {
            return d;
        }
        if (Double.isNaN(this.m_Min[n]) || Utils.eq(this.m_Max[n], this.m_Min[n])) {
            return 0.0;
        }
        return (d - this.m_Min[n]) / (this.m_Max[n] - this.m_Min[n]);
    }

    protected void updateMinMax(Instance instance) {
        for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
            if (instance.isMissing(i)) continue;
            if (Double.isNaN(this.m_Min[i])) {
                this.m_Min[i] = instance.value(i);
                this.m_Max[i] = instance.value(i);
                continue;
            }
            if (instance.value(i) < this.m_Min[i]) {
                this.m_Min[i] = instance.value(i);
                continue;
            }
            if (!(instance.value(i) > this.m_Max[i])) continue;
            this.m_Max[i] = instance.value(i);
        }
    }

    protected NeighborList findNeighbors(Instance instance) {
        NeighborList neighborList = new NeighborList(this.m_kNN);
        Enumeration enumeration = this.m_Train.enumerateInstances();
        int n = 0;
        while (enumeration.hasMoreElements()) {
            Instance instance2 = (Instance)enumeration.nextElement();
            if (instance == instance2) continue;
            double d = this.distance(instance, instance2);
            if (neighborList.isEmpty() || n < this.m_kNN || d <= neighborList.m_Last.m_Distance) {
                neighborList.insertSorted(d, instance2);
            }
            ++n;
        }
        return neighborList;
    }

    protected double[] makeDistribution(NeighborList neighborList) throws Exception {
        double d = 0.0;
        double[] dArray = new double[this.m_NumClasses];
        if (this.m_ClassType == 1) {
            for (int i = 0; i < this.m_NumClasses; ++i) {
                dArray[i] = 1.0 / (double)Math.max(1, this.m_Train.numInstances());
            }
            d = (double)this.m_NumClasses / (double)Math.max(1, this.m_Train.numInstances());
        }
        if (!neighborList.isEmpty()) {
            NeighborNode neighborNode = neighborList.m_First;
            while (neighborNode != null) {
                double d2;
                switch (this.m_DistanceWeighting) {
                    case 2: {
                        d2 = 1.0 / (neighborNode.m_Distance + 0.001);
                        break;
                    }
                    case 4: {
                        d2 = 1.0 - neighborNode.m_Distance;
                        break;
                    }
                    default: {
                        d2 = 1.0;
                    }
                }
                d2 *= neighborNode.m_Instance.weight();
                try {
                    switch (this.m_ClassType) {
                        case 1: {
                            int n = (int)neighborNode.m_Instance.classValue();
                            dArray[n] = dArray[n] + d2;
                            break;
                        }
                        case 0: {
                            dArray[0] = dArray[0] + neighborNode.m_Instance.classValue() * d2;
                        }
                    }
                }
                catch (Exception exception) {
                    throw new Error("Data has no class attribute!");
                }
                d += d2;
                neighborNode = neighborNode.m_Next;
            }
        }
        if (d > 0.0) {
            Utils.normalize(dArray, d);
        }
        return dArray;
    }

    protected void crossValidate() {
        try {
            int n;
            double[] dArray = new double[this.m_kNNUpper];
            double[] dArray2 = new double[this.m_kNNUpper];
            for (int i = 0; i < this.m_kNNUpper; ++i) {
                dArray[i] = 0.0;
                dArray2[i] = 0.0;
            }
            this.m_kNN = this.m_kNNUpper;
            for (n = 0; n < this.m_Train.numInstances(); ++n) {
                if (this.m_Debug && n % 50 == 0) {
                    System.err.print("Cross validating " + n + "/" + this.m_Train.numInstances() + "\r");
                }
                Instance instance = this.m_Train.instance(n);
                NeighborList neighborList = this.findNeighbors(instance);
                for (int i = this.m_kNNUpper - 1; i >= 0; --i) {
                    double[] dArray3 = this.makeDistribution(neighborList);
                    double d = Utils.maxIndex(dArray3);
                    if (this.m_Train.classAttribute().isNumeric()) {
                        d = dArray3[0];
                        double d2 = d - instance.classValue();
                        int n2 = i;
                        dArray2[n2] = dArray2[n2] + d2 * d2;
                        int n3 = i;
                        dArray[n3] = dArray[n3] + Math.abs(d2);
                    } else if (d != instance.classValue()) {
                        int n4 = i;
                        dArray[n4] = dArray[n4] + 1.0;
                    }
                    if (i < 1) continue;
                    neighborList.pruneToK(i);
                }
            }
            for (n = 0; n < this.m_kNNUpper; ++n) {
                if (this.m_Debug) {
                    System.err.print("Hold-one-out performance of " + (n + 1) + " neighbors ");
                }
                if (this.m_Train.classAttribute().isNumeric()) {
                    if (!this.m_Debug) continue;
                    if (this.m_MeanSquared) {
                        System.err.println("(RMSE) = " + Math.sqrt(dArray2[n] / (double)this.m_Train.numInstances()));
                        continue;
                    }
                    System.err.println("(MAE) = " + dArray[n] / (double)this.m_Train.numInstances());
                    continue;
                }
                if (!this.m_Debug) continue;
                System.err.println("(%ERR) = " + 100.0 * dArray[n] / (double)this.m_Train.numInstances());
            }
            double[] dArray4 = dArray;
            if (this.m_Train.classAttribute().isNumeric() && this.m_MeanSquared) {
                dArray4 = dArray2;
            }
            double d = Double.NaN;
            int n5 = 1;
            for (int i = 0; i < this.m_kNNUpper; ++i) {
                if (!Double.isNaN(d) && !(d > dArray4[i])) continue;
                d = dArray4[i];
                n5 = i + 1;
            }
            this.m_kNN = n5;
            if (this.m_Debug) {
                System.err.println("Selected k = " + n5);
            }
            this.m_kNNValid = true;
        }
        catch (Exception exception) {
            throw new Error("Couldn't optimize by cross-validation: " + exception.getMessage());
        }
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(Evaluation.evaluateModel(new IBk(), stringArray));
        }
        catch (Exception exception) {
            exception.printStackTrace();
            System.err.println(exception.getMessage());
        }
    }

    protected class NeighborList {
        protected NeighborNode m_First;
        protected NeighborNode m_Last;
        protected int m_Length = 1;

        public NeighborList(int n) {
            this.m_Length = n;
        }

        public boolean isEmpty() {
            return this.m_First == null;
        }

        public int currentLength() {
            int n = 0;
            NeighborNode neighborNode = this.m_First;
            while (neighborNode != null) {
                ++n;
                neighborNode = neighborNode.m_Next;
            }
            return n;
        }

        public void insertSorted(double d, Instance instance) {
            if (this.isEmpty()) {
                this.m_First = this.m_Last = new NeighborNode(d, instance);
            } else {
                NeighborNode neighborNode = this.m_First;
                if (d < this.m_First.m_Distance) {
                    this.m_First = new NeighborNode(d, instance, this.m_First);
                } else {
                    while (neighborNode.m_Next != null && neighborNode.m_Next.m_Distance < d) {
                        neighborNode = neighborNode.m_Next;
                    }
                    neighborNode.m_Next = new NeighborNode(d, instance, neighborNode.m_Next);
                    if (neighborNode.equals(this.m_Last)) {
                        this.m_Last = neighborNode.m_Next;
                    }
                }
                int n = 0;
                neighborNode = this.m_First;
                while (neighborNode.m_Next != null) {
                    if (++n >= this.m_Length && neighborNode.m_Distance != neighborNode.m_Next.m_Distance) {
                        this.m_Last = neighborNode;
                        neighborNode.m_Next = null;
                        break;
                    }
                    neighborNode = neighborNode.m_Next;
                }
            }
        }

        public void pruneToK(int n) {
            if (this.isEmpty()) {
                return;
            }
            if (n < 1) {
                n = 1;
            }
            int n2 = 0;
            double d = this.m_First.m_Distance;
            NeighborNode neighborNode = this.m_First;
            while (neighborNode.m_Next != null) {
                d = neighborNode.m_Distance;
                if (++n2 >= n && d != neighborNode.m_Next.m_Distance) {
                    this.m_Last = neighborNode;
                    neighborNode.m_Next = null;
                    break;
                }
                neighborNode = neighborNode.m_Next;
            }
        }

        public void printList() {
            if (this.isEmpty()) {
                System.out.println("Empty list");
            } else {
                NeighborNode neighborNode = this.m_First;
                while (neighborNode != null) {
                    System.out.println("Node: instance " + neighborNode.m_Instance + ", distance " + neighborNode.m_Distance);
                    neighborNode = neighborNode.m_Next;
                }
                System.out.println();
            }
        }
    }

    protected class NeighborNode {
        protected Instance m_Instance;
        protected double m_Distance;
        protected NeighborNode m_Next;

        public NeighborNode(double d, Instance instance, NeighborNode neighborNode) {
            this.m_Distance = d;
            this.m_Instance = instance;
            this.m_Next = neighborNode;
        }

        public NeighborNode(double d, Instance instance) {
            this(d, instance, null);
        }
    }
}

