/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.evaluate;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Trial;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.logging.Logger;

public class ConfusionMatrix {
    private static Logger logger = MalletLogger.getLogger(ConfusionMatrix.class.getName());
    int numClasses;
    ArrayList classifications;
    int[][] values;
    Trial trial;

    public ConfusionMatrix(Trial t) {
        this.trial = t;
        this.classifications = t;
        Labeling tempLabeling = ((Classification)this.classifications.get(0)).getLabeling();
        this.numClasses = tempLabeling.getLabelAlphabet().size();
        this.values = new int[this.numClasses][this.numClasses];
        for (int i = 0; i < this.classifications.size(); ++i) {
            LabelVector lv = ((Classification)this.classifications.get(i)).getLabelVector();
            Instance inst = ((Classification)this.classifications.get(i)).getInstance();
            int bestIndex = lv.getBestIndex();
            int correctIndex = inst.getLabeling().getBestIndex();
            assert (correctIndex != -1);
            int[] nArray = this.values[correctIndex];
            int n = bestIndex;
            nArray[n] = nArray[n] + 1;
        }
    }

    double value(int i, int j) {
        assert (i >= 0 && j >= 0 && i < this.numClasses && j < this.numClasses);
        return this.values[i][j];
    }

    private static void appendJustifiedInt(StringBuffer sb, int i, boolean zeroDot) {
        if (i < 100) {
            sb.append(' ');
        }
        if (i < 10) {
            sb.append(' ');
        }
        if (i == 0 && zeroDot) {
            sb.append(".");
        } else {
            sb.append("" + i);
        }
    }

    public String toString() {
        int c2;
        int i;
        StringBuffer sb = new StringBuffer();
        int maxLabelNameLength = 0;
        LabelAlphabet labelAlphabet = this.trial.getClassifier().getLabelAlphabet();
        for (i = 0; i < this.numClasses; ++i) {
            int len = labelAlphabet.lookupLabel(i).toString().length();
            if (maxLabelNameLength >= len) continue;
            maxLabelNameLength = len;
        }
        sb.append("Confusion Matrix, row=true, column=predicted  accuracy=" + this.trial.getAccuracy() + "\n");
        for (i = 0; i < maxLabelNameLength - 5 + 4; ++i) {
            sb.append(' ');
        }
        sb.append("label");
        for (c2 = 0; c2 < Math.min(10, this.numClasses); ++c2) {
            sb.append("   " + c2);
        }
        for (c2 = 10; c2 < this.numClasses; ++c2) {
            sb.append("  " + c2);
        }
        sb.append("  |total\n");
        for (int c = 0; c < this.numClasses; ++c) {
            ConfusionMatrix.appendJustifiedInt(sb, c, false);
            String labelName = labelAlphabet.lookupLabel(c).toString();
            for (int i2 = 0; i2 < maxLabelNameLength - labelName.length(); ++i2) {
                sb.append(' ');
            }
            sb.append(" " + labelName + " ");
            for (int c22 = 0; c22 < this.numClasses; ++c22) {
                ConfusionMatrix.appendJustifiedInt(sb, this.values[c][c22], true);
                sb.append(' ');
            }
            sb.append(" |" + MatrixOps.sum(this.values[c]));
            sb.append('\n');
        }
        return sb.toString();
    }

    public double getPrecision(int predictedClassIndex) {
        int total = 0;
        for (int trueClassIndex = 0; trueClassIndex < this.numClasses; ++trueClassIndex) {
            total += this.values[trueClassIndex][predictedClassIndex];
        }
        if (total == 0) {
            return 0.0;
        }
        return (double)this.values[predictedClassIndex][predictedClassIndex] / (double)total;
    }

    public double getConfusionBetween(int class1, int class2) {
        int total = 0;
        for (int trueClassIndex = 0; trueClassIndex < this.numClasses; ++trueClassIndex) {
            total += this.values[trueClassIndex][class1];
        }
        if (total == 0) {
            return 0.0;
        }
        return (double)this.values[class2][class1] / (double)total;
    }

    public double getClassPrior(int classIndex) {
        int sum = 0;
        for (int i = 0; i < this.numClasses; ++i) {
            sum += this.values[classIndex][i];
        }
        return (double)sum / (double)this.classifications.size();
    }
}

