/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class HNB
extends AbstractClassifier
implements TechnicalInformationHandler {
    static final long serialVersionUID = -4503874444306113214L;
    private double[] m_ClassCounts;
    private double[][][] m_ClassAttAttCounts;
    private int[] m_NumAttValues;
    private int m_TotalAttValues;
    private int m_NumClasses;
    private int m_NumAttributes;
    private int m_NumInstances;
    private int m_ClassIndex;
    private int[] m_StartAttIndex;
    private double[][] m_condiMutualInfo;

    public String globalInfo() {
        return "Contructs Hidden Naive Bayes classification model with high classification accuracy and AUC.\n\nFor more information refer to:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "H. Zhang and L. Jiang and J. Su");
        result.setValue(TechnicalInformation.Field.TITLE, "Hidden Naive Bayes");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Twentieth National Conference on Artificial Intelligence");
        result.setValue(TechnicalInformation.Field.YEAR, "2005");
        result.setValue(TechnicalInformation.Field.PAGES, "919-924");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "AAAI Press");
        return result;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_NumClasses = instances.numClasses();
        this.m_ClassIndex = instances.classIndex();
        this.m_NumAttributes = instances.numAttributes();
        this.m_NumInstances = instances.numInstances();
        this.m_TotalAttValues = 0;
        this.m_StartAttIndex = new int[this.m_NumAttributes];
        this.m_NumAttValues = new int[this.m_NumAttributes];
        for (int i = 0; i < this.m_NumAttributes; ++i) {
            if (i != this.m_ClassIndex) {
                this.m_StartAttIndex[i] = this.m_TotalAttValues;
                this.m_NumAttValues[i] = instances.attribute(i).numValues();
                this.m_TotalAttValues += this.m_NumAttValues[i];
                continue;
            }
            this.m_StartAttIndex[i] = -1;
            this.m_NumAttValues[i] = this.m_NumClasses;
        }
        this.m_ClassCounts = new double[this.m_NumClasses];
        this.m_ClassAttAttCounts = new double[this.m_NumClasses][this.m_TotalAttValues][this.m_TotalAttValues];
        for (int k = 0; k < this.m_NumInstances; ++k) {
            int classVal;
            int n = classVal = (int)instances.instance(k).classValue();
            this.m_ClassCounts[n] = this.m_ClassCounts[n] + 1.0;
            int[] attIndex = new int[this.m_NumAttributes];
            for (int i = 0; i < this.m_NumAttributes; ++i) {
                attIndex[i] = i == this.m_ClassIndex ? -1 : this.m_StartAttIndex[i] + (int)instances.instance(k).value(i);
            }
            for (int Att1 = 0; Att1 < this.m_NumAttributes; ++Att1) {
                if (attIndex[Att1] == -1) continue;
                for (int Att2 = 0; Att2 < this.m_NumAttributes; ++Att2) {
                    if (attIndex[Att2] == -1) continue;
                    double[] dArray = this.m_ClassAttAttCounts[classVal][attIndex[Att1]];
                    int n2 = attIndex[Att2];
                    dArray[n2] = dArray[n2] + 1.0;
                }
            }
        }
        this.m_condiMutualInfo = new double[this.m_NumAttributes][this.m_NumAttributes];
        for (int son = 0; son < this.m_NumAttributes; ++son) {
            if (son == this.m_ClassIndex) continue;
            for (int parent = 0; parent < this.m_NumAttributes; ++parent) {
                if (parent == this.m_ClassIndex || son == parent) continue;
                this.m_condiMutualInfo[son][parent] = this.conditionalMutualInfo(son, parent);
            }
        }
    }

    private double conditionalMutualInfo(int son, int parent) throws Exception {
        int k;
        int j;
        int i;
        double CondiMutualInfo = 0.0;
        int sIndex = this.m_StartAttIndex[son];
        int pIndex = this.m_StartAttIndex[parent];
        double[] PriorsClass = new double[this.m_NumClasses];
        double[][] PriorsClassSon = new double[this.m_NumClasses][this.m_NumAttValues[son]];
        double[][] PriorsClassParent = new double[this.m_NumClasses][this.m_NumAttValues[parent]];
        double[][][] PriorsClassParentSon = new double[this.m_NumClasses][this.m_NumAttValues[parent]][this.m_NumAttValues[son]];
        for (i = 0; i < this.m_NumClasses; ++i) {
            PriorsClass[i] = this.m_ClassCounts[i] / (double)this.m_NumInstances;
        }
        for (i = 0; i < this.m_NumClasses; ++i) {
            for (j = 0; j < this.m_NumAttValues[son]; ++j) {
                PriorsClassSon[i][j] = this.m_ClassAttAttCounts[i][sIndex + j][sIndex + j] / (double)this.m_NumInstances;
            }
        }
        for (i = 0; i < this.m_NumClasses; ++i) {
            for (j = 0; j < this.m_NumAttValues[parent]; ++j) {
                PriorsClassParent[i][j] = this.m_ClassAttAttCounts[i][pIndex + j][pIndex + j] / (double)this.m_NumInstances;
            }
        }
        for (i = 0; i < this.m_NumClasses; ++i) {
            for (j = 0; j < this.m_NumAttValues[parent]; ++j) {
                for (k = 0; k < this.m_NumAttValues[son]; ++k) {
                    PriorsClassParentSon[i][j][k] = this.m_ClassAttAttCounts[i][pIndex + j][sIndex + k] / (double)this.m_NumInstances;
                }
            }
        }
        for (i = 0; i < this.m_NumClasses; ++i) {
            for (j = 0; j < this.m_NumAttValues[parent]; ++j) {
                for (k = 0; k < this.m_NumAttValues[son]; ++k) {
                    CondiMutualInfo += PriorsClassParentSon[i][j][k] * this.log2(PriorsClassParentSon[i][j][k] * PriorsClass[i], PriorsClassParent[i][j] * PriorsClassSon[i][k]);
                }
            }
        }
        return CondiMutualInfo;
    }

    private double log2(double x, double y) {
        if (x < 1.0E-6 || y < 1.0E-6) {
            return 0.0;
        }
        return Math.log(x / y) / Math.log(2.0);
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] probs = new double[this.m_NumClasses];
        int[] attIndex = new int[this.m_NumAttributes];
        for (int att = 0; att < this.m_NumAttributes; ++att) {
            attIndex[att] = att == this.m_ClassIndex ? -1 : this.m_StartAttIndex[att] + (int)instance.value(att);
        }
        for (int classVal = 0; classVal < this.m_NumClasses; ++classVal) {
            probs[classVal] = (this.m_ClassCounts[classVal] + 1.0 / (double)this.m_NumClasses) / ((double)this.m_NumInstances + 1.0);
            for (int son = 0; son < this.m_NumAttributes; ++son) {
                if (attIndex[son] == -1) continue;
                int sIndex = attIndex[son];
                attIndex[son] = -1;
                double prob = 0.0;
                double condiMutualInfoSum = 0.0;
                for (int parent = 0; parent < this.m_NumAttributes; ++parent) {
                    if (attIndex[parent] == -1) continue;
                    condiMutualInfoSum += this.m_condiMutualInfo[son][parent];
                    prob += this.m_condiMutualInfo[son][parent] * (this.m_ClassAttAttCounts[classVal][attIndex[parent]][sIndex] + 1.0 / (double)this.m_NumAttValues[son]) / (this.m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0);
                }
                if (condiMutualInfoSum > 0.0) {
                    int n = classVal;
                    probs[n] = probs[n] * (prob /= condiMutualInfoSum);
                } else {
                    prob = (this.m_ClassAttAttCounts[classVal][sIndex][sIndex] + 1.0 / (double)this.m_NumAttValues[son]) / (this.m_ClassCounts[classVal] + 1.0);
                    int n = classVal;
                    probs[n] = probs[n] * prob;
                }
                attIndex[son] = sIndex;
            }
        }
        Utils.normalize(probs);
        return probs;
    }

    public String toString() {
        return "HNB (Hidden Naive Bayes)";
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5928 $");
    }

    public static void main(String[] args) {
        HNB.runClassifier(new HNB(), args);
    }
}

