/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.trees.adtree.ReferenceInstances;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.ContingencyTables;
import weka.core.DenseInstance;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class LADTree
extends AbstractClassifier
implements Drawable,
AdditionalMeasureProducer,
TechnicalInformationHandler {
    private static final long serialVersionUID = -4940716114518300302L;
    protected double Z_MAX = 4.0;
    protected int m_numOfClasses;
    protected ReferenceInstances m_trainInstances;
    protected PredictionNode m_root = null;
    protected int m_lastAddedSplitNum = 0;
    protected int[] m_numericAttIndices;
    protected double m_search_smallestLeastSquares;
    protected PredictionNode m_search_bestInsertionNode;
    protected Splitter m_search_bestSplitter;
    protected Instances m_search_bestPathInstances;
    protected FastVector m_staticPotentialSplitters2way;
    protected int m_nodesExpanded = 0;
    protected int m_examplesCounted = 0;
    protected int m_boostingIterations = 10;

    public String globalInfo() {
        return "Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall");
        result.setValue(TechnicalInformation.Field.TITLE, "Multiclass alternating decision trees");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "ECML");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        result.setValue(TechnicalInformation.Field.PAGES, "161-172");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "Springer");
        return result;
    }

    public void initClassifier(Instances instances) throws Exception {
        this.m_nodesExpanded = 0;
        this.m_examplesCounted = 0;
        this.m_lastAddedSplitNum = 0;
        this.m_numOfClasses = instances.numClasses();
        if (instances.checkForStringAttributes()) {
            throw new Exception("Can't handle string attributes!");
        }
        if (!instances.classAttribute().isNominal()) {
            throw new Exception("Class must be nominal!");
        }
        this.m_trainInstances = new ReferenceInstances(instances, instances.numInstances());
        Enumeration e = instances.enumerateInstances();
        while (e.hasMoreElements()) {
            Instance inst = (Instance)e.nextElement();
            if (inst.classIsMissing()) continue;
            LADInstance adtInst = new LADInstance(inst);
            this.m_trainInstances.addReference(adtInst);
            adtInst.setDataset(this.m_trainInstances);
        }
        this.m_root = new PredictionNode(new double[this.m_numOfClasses]);
        this.generateStaticPotentialSplittersAndNumericIndices();
    }

    public void next(int iteration) throws Exception {
        this.boost();
    }

    public void done() throws Exception {
    }

    private void boost() throws Exception {
        if (this.m_trainInstances == null) {
            throw new Exception("Trying to boost with no training data");
        }
        this.searchForBestTest();
        if (this.m_Debug) {
            System.out.println("Best split found: " + this.m_search_bestSplitter.getNumOfBranches() + "-way split on " + this.m_search_bestSplitter.attributeString() + "\nBestGain = " + this.m_search_smallestLeastSquares);
        }
        if (this.m_search_bestSplitter == null) {
            return;
        }
        for (int i = 0; i < this.m_search_bestSplitter.getNumOfBranches(); ++i) {
            Instances applicableInstances = this.m_search_bestSplitter.instancesDownBranch(i, this.m_search_bestPathInstances);
            double[] predictionValues = this.calcPredictionValues(applicableInstances);
            PredictionNode newPredictor = new PredictionNode(predictionValues);
            this.updateWeights(applicableInstances, predictionValues);
            this.m_search_bestSplitter.setChildForBranch(i, newPredictor);
        }
        this.m_search_bestInsertionNode.addChild(this.m_search_bestSplitter);
        if (this.m_Debug) {
            System.out.println("Tree is now:\n" + this.toString(this.m_root, 1) + "\n");
        }
        this.m_search_bestPathInstances = null;
    }

    private void updateWeights(Instances instances, double[] newPredictionValues) {
        for (int i = 0; i < instances.numInstances(); ++i) {
            ((LADInstance)instances.instance(i)).updateWeights(newPredictionValues);
        }
    }

    private void generateStaticPotentialSplittersAndNumericIndices() {
        int i;
        this.m_staticPotentialSplitters2way = new FastVector();
        FastVector<Integer> numericIndices = new FastVector<Integer>();
        for (i = 0; i < this.m_trainInstances.numAttributes(); ++i) {
            if (i == this.m_trainInstances.classIndex()) continue;
            if (this.m_trainInstances.attribute(i).isNumeric()) {
                numericIndices.addElement(new Integer(i));
                continue;
            }
            int numValues = this.m_trainInstances.attribute(i).numValues();
            if (numValues == 2) {
                this.m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, 0));
                continue;
            }
            for (int j = 0; j < numValues; ++j) {
                this.m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, j));
            }
        }
        this.m_numericAttIndices = new int[numericIndices.size()];
        for (i = 0; i < numericIndices.size(); ++i) {
            this.m_numericAttIndices[i] = (Integer)numericIndices.elementAt(i);
        }
    }

    private void searchForBestTest() throws Exception {
        if (this.m_Debug) {
            System.out.println("Searching for best split...");
        }
        this.m_search_smallestLeastSquares = 0.0;
        this.searchForBestTest(this.m_root, this.m_trainInstances);
    }

    private void searchForBestTest(PredictionNode currentNode, Instances instances) throws Exception {
        ++this.m_nodesExpanded;
        this.m_examplesCounted += instances.numInstances();
        Enumeration e = this.m_staticPotentialSplitters2way.elements();
        while (e.hasMoreElements()) {
            this.evaluateSplitter((Splitter)e.nextElement(), currentNode, instances);
        }
        if (this.m_Debug) {
            // empty if block
        }
        for (int i = 0; i < this.m_numericAttIndices.length; ++i) {
            this.evaluateNumericSplit(currentNode, instances, this.m_numericAttIndices[i]);
        }
        if (currentNode.getChildren().size() == 0) {
            return;
        }
        this.goDownAllPaths(currentNode, instances);
    }

    private void goDownAllPaths(PredictionNode currentNode, Instances instances) throws Exception {
        Enumeration e = currentNode.children();
        while (e.hasMoreElements()) {
            Splitter split = (Splitter)e.nextElement();
            for (int i = 0; i < split.getNumOfBranches(); ++i) {
                this.searchForBestTest(split.getChildForBranch(i), split.instancesDownBranch(i, instances));
            }
        }
    }

    private void evaluateSplitter(Splitter split, PredictionNode currentNode, Instances instances) throws Exception {
        double leastSquares = this.leastSquaresNonMissing(instances, split.attIndex);
        for (int i = 0; i < split.getNumOfBranches(); ++i) {
            leastSquares -= this.leastSquares(split.instancesDownBranch(i, instances));
        }
        if (this.m_Debug) {
            System.out.print(split.getNumOfBranches() + "-way split on " + split.attributeString() + " has leastSquares value of " + Utils.doubleToString(leastSquares, 3));
        }
        if (leastSquares > this.m_search_smallestLeastSquares) {
            if (this.m_Debug) {
                System.out.print(" (best so far)");
            }
            this.m_search_smallestLeastSquares = leastSquares;
            this.m_search_bestInsertionNode = currentNode;
            this.m_search_bestSplitter = split;
            this.m_search_bestPathInstances = instances;
        }
        if (this.m_Debug) {
            System.out.print("\n");
        }
    }

    private void evaluateNumericSplit(PredictionNode currentNode, Instances instances, int attIndex) {
        double[] splitAndLS = this.findNumericSplitpointAndLS(instances, attIndex);
        double gain = this.leastSquaresNonMissing(instances, attIndex) - splitAndLS[1];
        if (this.m_Debug) {
            System.out.print("Numeric split on " + instances.attribute(attIndex).name() + " has leastSquares value of " + Utils.doubleToString(gain, 3));
        }
        if (gain > this.m_search_smallestLeastSquares) {
            if (this.m_Debug) {
                System.out.print(" (best so far)");
            }
            this.m_search_smallestLeastSquares = gain;
            this.m_search_bestInsertionNode = currentNode;
            this.m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndLS[0]);
            this.m_search_bestPathInstances = instances;
        }
        if (this.m_Debug) {
            System.out.print("\n");
        }
    }

    private double[] findNumericSplitpointAndLS(Instances instances, int attIndex) {
        double temp1;
        double allLS = this.leastSquares(instances);
        double[] term1L = new double[this.m_numOfClasses];
        double[] term2L = new double[this.m_numOfClasses];
        double[] term3L = new double[this.m_numOfClasses];
        double[] meanNumL = new double[this.m_numOfClasses];
        double[] meanDenL = new double[this.m_numOfClasses];
        double[] term1R = new double[this.m_numOfClasses];
        double[] term2R = new double[this.m_numOfClasses];
        double[] term3R = new double[this.m_numOfClasses];
        double[] meanNumR = new double[this.m_numOfClasses];
        double[] meanDenR = new double[this.m_numOfClasses];
        double[] classMeans = new double[this.m_numOfClasses];
        double[] classTotals = new double[this.m_numOfClasses];
        for (int j = 0; j < this.m_numOfClasses; ++j) {
            for (int i = 0; i < instances.numInstances(); ++i) {
                LADInstance inst = (LADInstance)instances.instance(i);
                temp1 = inst.wVector[j] * inst.zVector[j];
                int n = j;
                term1R[n] = term1R[n] + temp1 * inst.zVector[j];
                int n2 = j;
                term2R[n2] = term2R[n2] + temp1;
                int n3 = j;
                term3R[n3] = term3R[n3] + inst.wVector[j];
                int n4 = j;
                meanNumR[n4] = meanNumR[n4] + inst.wVector[j] * inst.zVector[j];
            }
        }
        double smallestLeastSquares = Double.POSITIVE_INFINITY;
        double bestSplit = 0.0;
        instances.sort(attIndex);
        for (int i = 0; i < instances.numInstances() - 1 && !instances.instance(i + 1).isMissing(attIndex); ++i) {
            boolean newSplit = instances.instance(i + 1).value(attIndex) > instances.instance(i).value(attIndex);
            LADInstance inst = (LADInstance)instances.instance(i);
            double leastSquares = 0.0;
            for (int j = 0; j < this.m_numOfClasses; ++j) {
                temp1 = inst.wVector[j] * inst.zVector[j];
                double temp2 = temp1 * inst.zVector[j];
                double temp3 = inst.wVector[j] * inst.zVector[j];
                int n = j;
                term1L[n] = term1L[n] + temp2;
                int n5 = j;
                term2L[n5] = term2L[n5] + temp1;
                int n6 = j;
                term3L[n6] = term3L[n6] + inst.wVector[j];
                int n7 = j;
                term1R[n7] = term1R[n7] - temp2;
                int n8 = j;
                term2R[n8] = term2R[n8] - temp1;
                int n9 = j;
                term3R[n9] = term3R[n9] - inst.wVector[j];
                int n10 = j;
                meanNumL[n10] = meanNumL[n10] + temp3;
                int n11 = j;
                meanNumR[n11] = meanNumR[n11] - temp3;
                if (!newSplit) continue;
                double meanL = meanNumL[j] / term3L[j];
                double meanR = meanNumR[j] / term3R[j];
                leastSquares += term1L[j] - 2.0 * meanL * term2L[j] + meanL * meanL * term3L[j];
                leastSquares += term1R[j] - 2.0 * meanR * term2R[j] + meanR * meanR * term3R[j];
            }
            if (this.m_Debug && newSplit) {
                System.out.println(attIndex + "/" + (instances.instance(i).value(attIndex) + instances.instance(i + 1).value(attIndex)) / 2.0 + " = " + (allLS - leastSquares));
            }
            if (!newSplit || !(leastSquares < smallestLeastSquares)) continue;
            bestSplit = (instances.instance(i).value(attIndex) + instances.instance(i + 1).value(attIndex)) / 2.0;
            smallestLeastSquares = leastSquares;
        }
        double[] result = new double[]{bestSplit, smallestLeastSquares > 0.0 ? smallestLeastSquares : 0.0};
        return result;
    }

    private double leastSquares(Instances instances) {
        int j;
        double numerator = 0.0;
        double denominator = 0.0;
        double[] classMeans = new double[this.m_numOfClasses];
        double[] classTotals = new double[this.m_numOfClasses];
        for (int i = 0; i < instances.numInstances(); ++i) {
            LADInstance inst = (LADInstance)instances.instance(i);
            for (j = 0; j < this.m_numOfClasses; ++j) {
                int n = j;
                classMeans[n] = classMeans[n] + inst.zVector[j] * inst.wVector[j];
                int n2 = j;
                classTotals[n2] = classTotals[n2] + inst.wVector[j];
            }
        }
        double numInstances = instances.numInstances();
        for (j = 0; j < this.m_numOfClasses; ++j) {
            if (classTotals[j] == 0.0) continue;
            int n = j;
            classMeans[n] = classMeans[n] / classTotals[j];
        }
        for (int i = 0; i < instances.numInstances(); ++i) {
            for (int j2 = 0; j2 < this.m_numOfClasses; ++j2) {
                LADInstance inst = (LADInstance)instances.instance(i);
                double w = inst.wVector[j2];
                double t = inst.zVector[j2] - classMeans[j2];
                numerator += w * (t * t);
                denominator += w;
            }
        }
        return numerator > 0.0 ? numerator : 0.0;
    }

    private double leastSquaresNonMissing(Instances instances, int attIndex) {
        int j;
        double numerator = 0.0;
        double denominator = 0.0;
        double[] classMeans = new double[this.m_numOfClasses];
        double[] classTotals = new double[this.m_numOfClasses];
        for (int i = 0; i < instances.numInstances(); ++i) {
            LADInstance inst = (LADInstance)instances.instance(i);
            for (j = 0; j < this.m_numOfClasses; ++j) {
                int n = j;
                classMeans[n] = classMeans[n] + inst.zVector[j] * inst.wVector[j];
                int n2 = j;
                classTotals[n2] = classTotals[n2] + inst.wVector[j];
            }
        }
        double numInstances = instances.numInstances();
        for (j = 0; j < this.m_numOfClasses; ++j) {
            if (classTotals[j] == 0.0) continue;
            int n = j;
            classMeans[n] = classMeans[n] / classTotals[j];
        }
        for (int i = 0; i < instances.numInstances(); ++i) {
            for (int j2 = 0; j2 < this.m_numOfClasses; ++j2) {
                LADInstance inst = (LADInstance)instances.instance(i);
                if (inst.isMissing(attIndex)) continue;
                double w = inst.wVector[j2];
                double t = inst.zVector[j2] - classMeans[j2];
                numerator += w * (t * t);
                denominator += w;
            }
        }
        return numerator > 0.0 ? numerator : 0.0;
    }

    private double[] calcPredictionValues(Instances instances) {
        int j;
        double[] classMeans = new double[this.m_numOfClasses];
        double meansSum = 0.0;
        double multiplier = (double)(this.m_numOfClasses - 1) / (double)this.m_numOfClasses;
        double[] classTotals = new double[this.m_numOfClasses];
        for (int i = 0; i < instances.numInstances(); ++i) {
            LADInstance inst = (LADInstance)instances.instance(i);
            for (j = 0; j < this.m_numOfClasses; ++j) {
                int n = j;
                classMeans[n] = classMeans[n] + inst.zVector[j] * inst.wVector[j];
                int n2 = j;
                classTotals[n2] = classTotals[n2] + inst.wVector[j];
            }
        }
        double numInstances = instances.numInstances();
        for (j = 0; j < this.m_numOfClasses; ++j) {
            if (classTotals[j] != 0.0) {
                int n = j;
                classMeans[n] = classMeans[n] / classTotals[j];
            }
            meansSum += classMeans[j];
        }
        meansSum /= (double)this.m_numOfClasses;
        for (j = 0; j < this.m_numOfClasses; ++j) {
            classMeans[j] = multiplier * (classMeans[j] - meansSum);
        }
        return classMeans;
    }

    @Override
    public double[] distributionForInstance(Instance instance) {
        double[] predValues = new double[this.m_numOfClasses];
        for (int i = 0; i < this.m_numOfClasses; ++i) {
            predValues[i] = 0.0;
        }
        double[] distribution = this.predictionValuesForInstance(instance, this.m_root, predValues);
        double max = distribution[Utils.maxIndex(distribution)];
        for (int i = 0; i < this.m_numOfClasses; ++i) {
            distribution[i] = Math.exp(distribution[i] - max);
        }
        double sum = Utils.sum(distribution);
        if (sum > 0.0) {
            Utils.normalize(distribution, sum);
        }
        return distribution;
    }

    private double[] predictionValuesForInstance(Instance inst, PredictionNode currentNode, double[] currentValues) {
        double[] predValues = currentNode.getValues();
        for (int i = 0; i < this.m_numOfClasses; ++i) {
            int n = i;
            currentValues[n] = currentValues[n] + predValues[i];
        }
        Enumeration e = currentNode.children();
        while (e.hasMoreElements()) {
            Splitter split = (Splitter)e.nextElement();
            int branch = split.branchInstanceGoesDown(inst);
            if (branch < 0) continue;
            currentValues = this.predictionValuesForInstance(inst, split.getChildForBranch(branch), currentValues);
        }
        return currentValues;
    }

    public String toString() {
        String className = this.getClass().getName();
        if (this.m_root == null) {
            return className + " not built yet";
        }
        return className + ":\n\n" + this.toString(this.m_root, 1) + "\nLegend: " + this.legend() + "\n#Tree size (total): " + this.numOfAllNodes(this.m_root) + "\n#Tree size (number of predictor nodes): " + this.numOfPredictionNodes(this.m_root) + "\n#Leaves (number of predictor nodes): " + this.numOfLeafNodes(this.m_root) + "\n#Expanded nodes: " + this.m_nodesExpanded + "\n#Processed examples: " + this.m_examplesCounted + "\n#Ratio e/n: " + (double)this.m_examplesCounted / (double)this.m_nodesExpanded;
    }

    private String toString(PredictionNode currentNode, int level) {
        StringBuffer text = new StringBuffer();
        text.append(": ");
        double[] predValues = currentNode.getValues();
        for (int i = 0; i < this.m_numOfClasses; ++i) {
            text.append(Utils.doubleToString(predValues[i], 3));
            if (i >= this.m_numOfClasses - 1) continue;
            text.append(",");
        }
        Enumeration e = currentNode.children();
        while (e.hasMoreElements()) {
            Splitter split = (Splitter)e.nextElement();
            for (int j = 0; j < split.getNumOfBranches(); ++j) {
                PredictionNode child = split.getChildForBranch(j);
                if (child == null) continue;
                text.append("\n");
                for (int k = 0; k < level; ++k) {
                    text.append("|  ");
                }
                text.append("(" + split.orderAdded + ")");
                text.append(split.attributeString() + " " + split.comparisonString(j));
                text.append(this.toString(child, level + 1));
            }
        }
        return text.toString();
    }

    @Override
    public String graph() throws Exception {
        StringBuffer text = new StringBuffer();
        text.append("digraph ADTree {\n");
        this.graphTraverse(this.m_root, text, 0, 0);
        return text.toString() + "}\n";
    }

    protected void graphTraverse(PredictionNode currentNode, StringBuffer text, int splitOrder, int predOrder) throws Exception {
        text.append("S" + splitOrder + "P" + predOrder + " [label=\"");
        double[] predValues = currentNode.getValues();
        for (int i = 0; i < this.m_numOfClasses; ++i) {
            text.append(Utils.doubleToString(predValues[i], 3));
            if (i >= this.m_numOfClasses - 1) continue;
            text.append(",");
        }
        if (splitOrder == 0) {
            text.append(" (" + this.legend() + ")");
        }
        text.append("\" shape=box style=filled]\n");
        Enumeration e = currentNode.children();
        while (e.hasMoreElements()) {
            Splitter split = (Splitter)e.nextElement();
            text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded + " [style=dotted]\n");
            text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " + split.attributeString() + "\"]\n");
            for (int i = 0; i < split.getNumOfBranches(); ++i) {
                PredictionNode child = split.getChildForBranch(i);
                if (child == null) continue;
                text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i + " [label=\"" + split.comparisonString(i) + "\"]\n");
                this.graphTraverse(child, text, split.orderAdded, i);
            }
        }
    }

    public String legend() {
        Attribute classAttribute = null;
        if (this.m_trainInstances == null) {
            return "";
        }
        try {
            classAttribute = this.m_trainInstances.classAttribute();
        }
        catch (Exception x) {
            // empty catch block
        }
        if (this.m_numOfClasses == 1) {
            return "-ve = " + classAttribute.value(0) + ", +ve = " + classAttribute.value(1);
        }
        StringBuffer text = new StringBuffer();
        for (int i = 0; i < this.m_numOfClasses; ++i) {
            if (i > 0) {
                text.append(", ");
            }
            text.append(classAttribute.value(i));
        }
        return text.toString();
    }

    public String numOfBoostingIterationsTipText() {
        return "The number of boosting iterations to use, which determines the size of the tree.";
    }

    public int getNumOfBoostingIterations() {
        return this.m_boostingIterations;
    }

    public void setNumOfBoostingIterations(int b) {
        this.m_boostingIterations = b;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(1);
        newVector.addElement(new Option("\tNumber of boosting iterations.\n\t(Default = 10)", "B", 1, "-B <number of boosting iterations>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String bString = Utils.getOption('B', options);
        if (bString.length() != 0) {
            this.setNumOfBoostingIterations(Integer.parseInt(bString));
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] options = new String[2 + super.getOptions().length];
        int current = 0;
        options[current++] = "-B";
        options[current++] = "" + this.getNumOfBoostingIterations();
        System.arraycopy(super.getOptions(), 0, options, current, super.getOptions().length);
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public double measureTreeSize() {
        return this.numOfAllNodes(this.m_root);
    }

    public double measureNumLeaves() {
        return this.numOfPredictionNodes(this.m_root);
    }

    public double measureNumPredictionLeaves() {
        return this.numOfLeafNodes(this.m_root);
    }

    public double measureNodesExpanded() {
        return this.m_nodesExpanded;
    }

    public double measureExamplesCounted() {
        return this.m_examplesCounted;
    }

    @Override
    public Enumeration enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(5);
        newVector.addElement("measureTreeSize");
        newVector.addElement("measureNumLeaves");
        newVector.addElement("measureNumPredictionLeaves");
        newVector.addElement("measureNodesExpanded");
        newVector.addElement("measureExamplesCounted");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equals("measureTreeSize")) {
            return this.measureTreeSize();
        }
        if (additionalMeasureName.equals("measureNodesExpanded")) {
            return this.measureNodesExpanded();
        }
        if (additionalMeasureName.equals("measureNumLeaves")) {
            return this.measureNumLeaves();
        }
        if (additionalMeasureName.equals("measureNumPredictionLeaves")) {
            return this.measureNumPredictionLeaves();
        }
        if (additionalMeasureName.equals("measureExamplesCounted")) {
            return this.measureExamplesCounted();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (ADTree)");
    }

    protected int numOfPredictionNodes(PredictionNode root) {
        int numSoFar = 0;
        if (root != null) {
            ++numSoFar;
            Enumeration e = root.children();
            while (e.hasMoreElements()) {
                Splitter split = (Splitter)e.nextElement();
                for (int i = 0; i < split.getNumOfBranches(); ++i) {
                    numSoFar += this.numOfPredictionNodes(split.getChildForBranch(i));
                }
            }
        }
        return numSoFar;
    }

    protected int numOfLeafNodes(PredictionNode root) {
        int numSoFar = 0;
        if (root.getChildren().size() > 0) {
            Enumeration e = root.children();
            while (e.hasMoreElements()) {
                Splitter split = (Splitter)e.nextElement();
                for (int i = 0; i < split.getNumOfBranches(); ++i) {
                    numSoFar += this.numOfLeafNodes(split.getChildForBranch(i));
                }
            }
        } else {
            numSoFar = 1;
        }
        return numSoFar;
    }

    protected int numOfAllNodes(PredictionNode root) {
        int numSoFar = 0;
        if (root != null) {
            ++numSoFar;
            Enumeration e = root.children();
            while (e.hasMoreElements()) {
                ++numSoFar;
                Splitter split = (Splitter)e.nextElement();
                for (int i = 0; i < split.getNumOfBranches(); ++i) {
                    numSoFar += this.numOfAllNodes(split.getChildForBranch(i));
                }
            }
        }
        return numSoFar;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        this.initClassifier(instances);
        for (int T = 0; T < this.m_boostingIterations; ++T) {
            this.boost();
        }
    }

    public int predictiveError(Instances test) {
        int error = 0;
        for (int i = test.numInstances() - 1; i >= 0; --i) {
            Instance inst = test.instance(i);
            try {
                if (this.classifyInstance(inst) == inst.classValue()) continue;
                ++error;
                continue;
            }
            catch (Exception e) {
                ++error;
            }
        }
        return error;
    }

    public void merge(LADTree mergeWith) throws Exception {
        if (this.m_root == null || mergeWith.m_root == null) {
            throw new Exception("Trying to merge an uninitialized tree");
        }
        if (this.m_numOfClasses != mergeWith.m_numOfClasses) {
            throw new Exception("Trees not suitable for merge - different sized prediction nodes");
        }
        this.m_root.merge(mergeWith.m_root);
    }

    @Override
    public int graphType() {
        return 1;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 6035 $");
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public static void main(String[] argv) {
        LADTree.runClassifier(new LADTree(), argv);
    }

    protected class TwoWayNumericSplit
    extends Splitter
    implements Cloneable {
        private double splitPoint;
        private PredictionNode[] children;

        public TwoWayNumericSplit(int _attIndex, double _splitPoint) {
            this.attIndex = _attIndex;
            this.splitPoint = _splitPoint;
            this.children = new PredictionNode[2];
        }

        public TwoWayNumericSplit(int _attIndex, Instances instances) throws Exception {
            this.attIndex = _attIndex;
            this.splitPoint = this.findSplit(instances, this.attIndex);
            this.children = new PredictionNode[2];
        }

        @Override
        public int getNumOfBranches() {
            return 2;
        }

        @Override
        public int branchInstanceGoesDown(Instance inst) {
            if (inst.isMissing(this.attIndex)) {
                return -1;
            }
            if (inst.value(this.attIndex) < this.splitPoint) {
                return 0;
            }
            return 1;
        }

        @Override
        public Instances instancesDownBranch(int branch, Instances instances) {
            ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1);
            if (branch == -1) {
                Enumeration e = instances.enumerateInstances();
                while (e.hasMoreElements()) {
                    Instance inst = (Instance)e.nextElement();
                    if (!inst.isMissing(this.attIndex)) continue;
                    filteredInstances.addReference(inst);
                }
            } else if (branch == 0) {
                Enumeration e = instances.enumerateInstances();
                while (e.hasMoreElements()) {
                    Instance inst = (Instance)e.nextElement();
                    if (inst.isMissing(this.attIndex) || !(inst.value(this.attIndex) < this.splitPoint)) continue;
                    filteredInstances.addReference(inst);
                }
            } else {
                Enumeration e = instances.enumerateInstances();
                while (e.hasMoreElements()) {
                    Instance inst = (Instance)e.nextElement();
                    if (inst.isMissing(this.attIndex) || !(inst.value(this.attIndex) >= this.splitPoint)) continue;
                    filteredInstances.addReference(inst);
                }
            }
            return filteredInstances;
        }

        @Override
        public String attributeString() {
            return LADTree.this.m_trainInstances.attribute(this.attIndex).name();
        }

        @Override
        public String comparisonString(int branchNum) {
            return (branchNum == 0 ? "< " : ">= ") + Utils.doubleToString(this.splitPoint, 3);
        }

        @Override
        public boolean equalTo(Splitter compare) {
            if (compare instanceof TwoWayNumericSplit) {
                TwoWayNumericSplit compareSame = (TwoWayNumericSplit)compare;
                return this.attIndex == compareSame.attIndex && this.splitPoint == compareSame.splitPoint;
            }
            return false;
        }

        @Override
        public void setChildForBranch(int branchNum, PredictionNode childPredictor) {
            this.children[branchNum] = childPredictor;
        }

        @Override
        public PredictionNode getChildForBranch(int branchNum) {
            return this.children[branchNum];
        }

        @Override
        public Object clone() {
            TwoWayNumericSplit clone = new TwoWayNumericSplit(this.attIndex, this.splitPoint);
            if (this.children[0] != null) {
                clone.setChildForBranch(0, (PredictionNode)this.children[0].clone());
            }
            if (this.children[1] != null) {
                clone.setChildForBranch(1, (PredictionNode)this.children[1].clone());
            }
            return clone;
        }

        private double findSplit(Instances instances, int index) throws Exception {
            Instance inst;
            int i;
            double splitPoint = 0.0;
            double bestVal = Double.MAX_VALUE;
            int numMissing = 0;
            double[][] distribution = new double[3][instances.numClasses()];
            for (i = 0; i < instances.numInstances(); ++i) {
                inst = instances.instance(i);
                if (!inst.isMissing(index)) {
                    double[] dArray = distribution[1];
                    int n = (int)inst.classValue();
                    dArray[n] = dArray[n] + 1.0;
                    continue;
                }
                double[] dArray = distribution[2];
                int n = (int)inst.classValue();
                dArray[n] = dArray[n] + 1.0;
                ++numMissing;
            }
            instances.sort(index);
            for (i = 0; i < instances.numInstances() - (numMissing + 1); ++i) {
                inst = instances.instance(i);
                Instance instPlusOne = instances.instance(i + 1);
                double[] dArray = distribution[0];
                int n = (int)inst.classValue();
                dArray[n] = dArray[n] + inst.weight();
                double[] dArray2 = distribution[1];
                int n2 = (int)inst.classValue();
                dArray2[n2] = dArray2[n2] - inst.weight();
                if (!Utils.sm(inst.value(index), instPlusOne.value(index))) continue;
                double currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
                double currVal = ContingencyTables.entropyConditionedOnRows(distribution);
                if (!Utils.sm(currVal, bestVal)) continue;
                splitPoint = currCutPoint;
                bestVal = currVal;
            }
            return splitPoint;
        }
    }

    protected class TwoWayNominalSplit
    extends Splitter {
        private int trueSplitValue;
        private PredictionNode[] children;

        public TwoWayNominalSplit(int _attIndex, int _trueSplitValue) {
            this.attIndex = _attIndex;
            this.trueSplitValue = _trueSplitValue;
            this.children = new PredictionNode[2];
        }

        @Override
        public int getNumOfBranches() {
            return 2;
        }

        @Override
        public int branchInstanceGoesDown(Instance inst) {
            if (inst.isMissing(this.attIndex)) {
                return -1;
            }
            if (inst.value(this.attIndex) == (double)this.trueSplitValue) {
                return 0;
            }
            return 1;
        }

        @Override
        public Instances instancesDownBranch(int branch, Instances instances) {
            ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1);
            if (branch == -1) {
                Enumeration e = instances.enumerateInstances();
                while (e.hasMoreElements()) {
                    Instance inst = (Instance)e.nextElement();
                    if (!inst.isMissing(this.attIndex)) continue;
                    filteredInstances.addReference(inst);
                }
            } else if (branch == 0) {
                Enumeration e = instances.enumerateInstances();
                while (e.hasMoreElements()) {
                    Instance inst = (Instance)e.nextElement();
                    if (inst.isMissing(this.attIndex) || inst.value(this.attIndex) != (double)this.trueSplitValue) continue;
                    filteredInstances.addReference(inst);
                }
            } else {
                Enumeration e = instances.enumerateInstances();
                while (e.hasMoreElements()) {
                    Instance inst = (Instance)e.nextElement();
                    if (inst.isMissing(this.attIndex) || inst.value(this.attIndex) == (double)this.trueSplitValue) continue;
                    filteredInstances.addReference(inst);
                }
            }
            return filteredInstances;
        }

        @Override
        public String attributeString() {
            return LADTree.this.m_trainInstances.attribute(this.attIndex).name();
        }

        @Override
        public String comparisonString(int branchNum) {
            Attribute att = LADTree.this.m_trainInstances.attribute(this.attIndex);
            if (att.numValues() != 2) {
                return (branchNum == 0 ? "= " : "!= ") + att.value(this.trueSplitValue);
            }
            return "= " + (branchNum == 0 ? att.value(this.trueSplitValue) : att.value(this.trueSplitValue == 0 ? 1 : 0));
        }

        @Override
        public boolean equalTo(Splitter compare) {
            if (compare instanceof TwoWayNominalSplit) {
                TwoWayNominalSplit compareSame = (TwoWayNominalSplit)compare;
                return this.attIndex == compareSame.attIndex && this.trueSplitValue == compareSame.trueSplitValue;
            }
            return false;
        }

        @Override
        public void setChildForBranch(int branchNum, PredictionNode childPredictor) {
            this.children[branchNum] = childPredictor;
        }

        @Override
        public PredictionNode getChildForBranch(int branchNum) {
            return this.children[branchNum];
        }

        @Override
        public Object clone() {
            TwoWayNominalSplit clone = new TwoWayNominalSplit(this.attIndex, this.trueSplitValue);
            if (this.children[0] != null) {
                clone.setChildForBranch(0, (PredictionNode)this.children[0].clone());
            }
            if (this.children[1] != null) {
                clone.setChildForBranch(1, (PredictionNode)this.children[1].clone());
            }
            return clone;
        }
    }

    protected abstract class Splitter
    implements Serializable,
    Cloneable {
        protected int attIndex;
        public int orderAdded;

        protected Splitter() {
        }

        public abstract int getNumOfBranches();

        public abstract int branchInstanceGoesDown(Instance var1);

        public abstract Instances instancesDownBranch(int var1, Instances var2);

        public abstract String attributeString();

        public abstract String comparisonString(int var1);

        public abstract boolean equalTo(Splitter var1);

        public abstract void setChildForBranch(int var1, PredictionNode var2);

        public abstract PredictionNode getChildForBranch(int var1);

        public abstract Object clone();
    }

    protected class PredictionNode
    implements Serializable,
    Cloneable {
        private double[] values;
        private FastVector children;

        public PredictionNode(double[] newValues) {
            this.values = new double[LADTree.this.m_numOfClasses];
            this.setValues(newValues);
            this.children = new FastVector();
        }

        public void setValues(double[] newValues) {
            System.arraycopy(newValues, 0, this.values, 0, LADTree.this.m_numOfClasses);
        }

        public double[] getValues() {
            return this.values;
        }

        public FastVector getChildren() {
            return this.children;
        }

        public Enumeration children() {
            return this.children.elements();
        }

        public void addChild(Splitter newChild) {
            Splitter oldEqual = null;
            Enumeration e = this.children();
            while (e.hasMoreElements()) {
                Splitter split = (Splitter)e.nextElement();
                if (!newChild.equalTo(split)) continue;
                oldEqual = split;
                break;
            }
            if (oldEqual == null) {
                Splitter addChild = (Splitter)newChild.clone();
                addChild.orderAdded = ++LADTree.this.m_lastAddedSplitNum;
                this.children.addElement(addChild);
            } else {
                for (int i = 0; i < newChild.getNumOfBranches(); ++i) {
                    PredictionNode oldPred = oldEqual.getChildForBranch(i);
                    PredictionNode newPred = newChild.getChildForBranch(i);
                    if (oldPred == null || newPred == null) continue;
                    oldPred.merge(newPred);
                }
            }
        }

        public Object clone() {
            PredictionNode clone = new PredictionNode(this.values);
            Enumeration e = this.children.elements();
            while (e.hasMoreElements()) {
                clone.children.addElement((Splitter)((Splitter)e.nextElement()).clone());
            }
            return clone;
        }

        public void merge(PredictionNode merger) {
            for (int i = 0; i < LADTree.this.m_numOfClasses; ++i) {
                int n = i;
                this.values[n] = this.values[n] + merger.values[i];
            }
            Enumeration e = merger.children();
            while (e.hasMoreElements()) {
                this.addChild((Splitter)e.nextElement());
            }
        }
    }

    protected class LADInstance
    extends DenseInstance {
        public double[] fVector;
        public double[] wVector;
        public double[] pVector;
        public double[] zVector;

        public LADInstance(Instance instance) {
            super(instance);
            this.setDataset(instance.dataset());
            this.fVector = new double[LADTree.this.m_numOfClasses];
            this.wVector = new double[LADTree.this.m_numOfClasses];
            this.pVector = new double[LADTree.this.m_numOfClasses];
            this.zVector = new double[LADTree.this.m_numOfClasses];
            double initProb = 1.0 / (double)LADTree.this.m_numOfClasses;
            for (int i = 0; i < LADTree.this.m_numOfClasses; ++i) {
                this.pVector[i] = initProb;
            }
            this.updateZVector();
            this.updateWVector();
        }

        public void updateWeights(double[] fVectorIncrement) {
            for (int i = 0; i < this.fVector.length; ++i) {
                int n = i;
                this.fVector[n] = this.fVector[n] + fVectorIncrement[i];
            }
            this.updateVectors(this.fVector);
        }

        public void updateVectors(double[] newFVector) {
            this.updatePVector(newFVector);
            this.updateZVector();
            this.updateWVector();
        }

        public void updatePVector(double[] newFVector) {
            double max = newFVector[Utils.maxIndex(newFVector)];
            for (int i = 0; i < this.pVector.length; ++i) {
                this.pVector[i] = Math.exp(newFVector[i] - max);
            }
            Utils.normalize(this.pVector);
        }

        public void updateWVector() {
            for (int i = 0; i < this.wVector.length; ++i) {
                this.wVector[i] = (this.yVector(i) - this.pVector[i]) / this.zVector[i];
            }
        }

        public void updateZVector() {
            for (int i = 0; i < this.zVector.length; ++i) {
                if (this.yVector(i) == 1.0) {
                    this.zVector[i] = 1.0 / this.pVector[i];
                    if (!(this.zVector[i] > LADTree.this.Z_MAX)) continue;
                    this.zVector[i] = LADTree.this.Z_MAX;
                    continue;
                }
                this.zVector[i] = -1.0 / (1.0 - this.pVector[i]);
                if (!(this.zVector[i] < -LADTree.this.Z_MAX)) continue;
                this.zVector[i] = -LADTree.this.Z_MAX;
            }
        }

        public double yVector(int index) {
            return index == (int)this.classValue() ? 1.0 : 0.0;
        }

        @Override
        public Object copy() {
            LADInstance copy = new LADInstance((Instance)super.copy());
            System.arraycopy(this.fVector, 0, copy.fVector, 0, this.fVector.length);
            System.arraycopy(this.wVector, 0, copy.wVector, 0, this.wVector.length);
            System.arraycopy(this.pVector, 0, copy.pVector, 0, this.pVector.length);
            System.arraycopy(this.zVector, 0, copy.zVector, 0, this.zVector.length);
            return copy;
        }

        @Override
        public String toString() {
            int i;
            StringBuffer text = new StringBuffer();
            text.append(" * F(");
            for (i = 0; i < this.fVector.length; ++i) {
                text.append(Utils.doubleToString(this.fVector[i], 3));
                if (i >= this.fVector.length - 1) continue;
                text.append(",");
            }
            text.append(") P(");
            for (i = 0; i < this.pVector.length; ++i) {
                text.append(Utils.doubleToString(this.pVector[i], 3));
                if (i >= this.pVector.length - 1) continue;
                text.append(",");
            }
            text.append(") W(");
            for (i = 0; i < this.wVector.length; ++i) {
                text.append(Utils.doubleToString(this.wVector[i], 3));
                if (i >= this.wVector.length - 1) continue;
                text.append(",");
            }
            text.append(")");
            return super.toString() + text.toString();
        }
    }
}

