/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class RealAdaBoost
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -7378109809933197974L;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected double m_Shrinkage = 1.0;
    protected boolean m_UseResampling;
    protected Classifier m_ZeroR;
    protected double m_SumOfWeights;

    public RealAdaBoost() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a 2-class classifier using the Real Adaboost method.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "J. Friedman and T. Hastie and R. Tibshirani");
        result.setValue(TechnicalInformation.Field.TITLE, "Additive Logistic Regression: a Statistical View of Boosting");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Annals of Statistics");
        result.setValue(TechnicalInformation.Field.VOLUME, "95");
        result.setValue(TechnicalInformation.Field.NUMBER, "2");
        result.setValue(TechnicalInformation.Field.PAGES, "337-407");
        result.setValue(TechnicalInformation.Field.YEAR, "2000");
        return result;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances data, double quantile) {
        int numInstances = data.numInstances();
        Instances trainData = new Instances(data, numInstances);
        double[] weights = new double[numInstances];
        double sumOfWeights = 0.0;
        for (int i = 0; i < numInstances; ++i) {
            weights[i] = data.instance(i).weight();
            sumOfWeights += weights[i];
        }
        double weightMassToSelect = sumOfWeights * quantile;
        int[] sortedIndices = Utils.sort(weights);
        sumOfWeights = 0.0;
        for (int i = numInstances - 1; i >= 0; --i) {
            Instance instance = (Instance)data.instance(sortedIndices[i]).copy();
            trainData.add(instance);
            if ((sumOfWeights += weights[sortedIndices[i]]) > weightMassToSelect && i > 0 && weights[sortedIndices[i]] != weights[sortedIndices[i - 1]]) break;
        }
        if (this.m_Debug) {
            System.err.println("Selected " + trainData.numInstances() + " out of " + numInstances);
        }
        return trainData;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        newVector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        newVector.addElement(new Option("\tShrinkage parameter.\n\t(default 1)", "H", 1, "-H <num>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String thresholdString = Utils.getOption('P', options);
        if (thresholdString.length() != 0) {
            this.setWeightThreshold(Integer.parseInt(thresholdString));
        } else {
            this.setWeightThreshold(100);
        }
        String shrinkageString = Utils.getOption('H', options);
        if (shrinkageString.length() != 0) {
            this.setShrinkage(new Double(shrinkageString));
        } else {
            this.setShrinkage(1.0);
        }
        this.setUseResampling(Utils.getFlag('Q', options));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getUseResampling()) {
            result.add("-Q");
        }
        result.add("-P");
        result.add("" + this.getWeightThreshold());
        result.add("-H");
        result.add("" + this.getShrinkage());
        String[] options = super.getOptions();
        for (int i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        return result.toArray(new String[result.size()]);
    }

    public String shrinkageTipText() {
        return "Shrinkage parameter (use small value like 0.1 to reduce overfitting).";
    }

    public double getShrinkage() {
        return this.m_Shrinkage;
    }

    public void setShrinkage(double newShrinkage) {
        this.m_Shrinkage = newShrinkage;
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int threshold) {
        this.m_WeightThreshold = threshold;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean r) {
        this.m_UseResampling = r;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            result.enable(Capabilities.Capability.BINARY_CLASS);
        }
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        super.buildClassifier(data);
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        this.m_SumOfWeights = data.sumOfWeights();
        if (!this.m_UseResampling && this.m_Classifier instanceof WeightedInstancesHandler) {
            this.buildClassifierWithWeights(data);
        } else {
            this.buildClassifierUsingResampling(data);
        }
    }

    protected void buildClassifierUsingResampling(Instances data) throws Exception {
        int numInstances = data.numInstances();
        Random randomInstance = new Random(this.m_Seed);
        double minLoss = Double.MAX_VALUE;
        Instances trainingWeightsNotNormalized = new Instances(data, 0, numInstances);
        this.m_NumIterationsPerformed = -1;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances training = new Instances(trainingWeightsNotNormalized);
            this.normalizeWeights(training, 1.0);
            Instances trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(training, (double)this.m_WeightThreshold / 100.0) : new Instances(training);
            double[] weights = new double[trainData.numInstances()];
            for (int i = 0; i < weights.length; ++i) {
                weights[i] = trainData.instance(i).weight();
            }
            Instances sample = trainData.resampleWithWeights(randomInstance, weights);
            if (this.m_NumIterationsPerformed == -1) {
                this.m_ZeroR = new ZeroR();
                this.m_ZeroR.buildClassifier(data);
            } else {
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(sample);
            }
            this.setWeights(trainingWeightsNotNormalized, this.m_NumIterationsPerformed);
            double loss = 0.0;
            for (Instance inst : trainingWeightsNotNormalized) {
                loss += Math.log(inst.weight());
            }
            if (this.m_Debug) {
                System.err.println("Current loss on log scale: " + loss);
            }
            if (this.m_NumIterationsPerformed > -1 && loss > minLoss) {
                if (!this.m_Debug) break;
                System.err.println("Loss has increased: bailing out.");
                break;
            }
            minLoss = loss;
            ++this.m_NumIterationsPerformed;
        }
    }

    protected void setWeights(Instances training, int iteration) throws Exception {
        for (Instance instance : training) {
            double reweight = 1.0;
            double prob = 1.0;
            double shrinkage = this.m_Shrinkage;
            if (iteration == -1) {
                prob = this.m_ZeroR.distributionForInstance(instance)[0];
                shrinkage = 1.0;
            } else {
                prob = this.m_Classifiers[iteration].distributionForInstance(instance)[0];
                prob = (this.m_SumOfWeights * prob + 1.0) / (this.m_SumOfWeights + 2.0);
            }
            reweight = instance.classValue() == 1.0 ? shrinkage * 0.5 * (Math.log(prob) - Math.log(1.0 - prob)) : shrinkage * 0.5 * (Math.log(1.0 - prob) - Math.log(prob));
            instance.setWeight(instance.weight() * Math.exp(reweight));
        }
    }

    protected void normalizeWeights(Instances training, double oldSumOfWeights) throws Exception {
        double newSumOfWeights = training.sumOfWeights();
        for (Instance instance : training) {
            instance.setWeight(instance.weight() * oldSumOfWeights / newSumOfWeights);
        }
    }

    protected void buildClassifierWithWeights(Instances data) throws Exception {
        int numInstances = data.numInstances();
        Random randomInstance = new Random(this.m_Seed);
        double minLoss = Double.MAX_VALUE;
        Instances trainingWeightsNotNormalized = new Instances(data, 0, numInstances);
        this.m_NumIterationsPerformed = -1;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances training = new Instances(trainingWeightsNotNormalized);
            this.normalizeWeights(training, this.m_SumOfWeights);
            Instances trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(training, (double)this.m_WeightThreshold / 100.0) : new Instances(training, 0, numInstances);
            if (this.m_NumIterationsPerformed == -1) {
                this.m_ZeroR = new ZeroR();
                this.m_ZeroR.buildClassifier(data);
            } else {
                if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                    ((Randomizable)((Object)this.m_Classifiers[this.m_NumIterationsPerformed])).setSeed(randomInstance.nextInt());
                }
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(trainData);
            }
            this.setWeights(trainingWeightsNotNormalized, this.m_NumIterationsPerformed);
            double loss = 0.0;
            for (Instance inst : trainingWeightsNotNormalized) {
                loss += Math.log(inst.weight());
            }
            if (this.m_Debug) {
                System.err.println("Current loss on log scale: " + loss);
            }
            if (this.m_NumIterationsPerformed > -1 && loss > minLoss) {
                if (!this.m_Debug) break;
                System.err.println("Loss has increased: bailing out.");
                break;
            }
            minLoss = loss;
            ++this.m_NumIterationsPerformed;
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] sums = new double[instance.numClasses()];
        for (int i = -1; i < this.m_NumIterationsPerformed; ++i) {
            double prob = 1.0;
            double shrinkage = this.m_Shrinkage;
            if (i == -1) {
                prob = this.m_ZeroR.distributionForInstance(instance)[0];
                shrinkage = 1.0;
            } else {
                prob = this.m_Classifiers[i].distributionForInstance(instance)[0];
                prob = (this.m_SumOfWeights * prob + 1.0) / (this.m_SumOfWeights + 2.0);
            }
            sums[0] = sums[0] + shrinkage * 0.5 * (Math.log(prob) - Math.log(1.0 - prob));
        }
        sums[1] = -sums[0];
        return Utils.logs2probs(sums);
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_ZeroR == null) {
            text.append("No model built yet.\n\n");
        } else {
            text.append("RealAdaBoost: Base classifiers: \n\n");
            text.append(this.m_ZeroR.toString() + "\n\n");
            for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
                text.append(this.m_Classifiers[i].toString() + "\n\n");
            }
            text.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 6136 $");
    }

    public static void main(String[] argv) {
        RealAdaBoost.runClassifier(new RealAdaBoost(), argv);
    }
}

