/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.ml.tksvmlight.model;

import com.google.common.annotations.Beta;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang.NotImplementedException;
import org.cleartk.ml.tksvmlight.TreeFeatureVector;
import org.cleartk.ml.tksvmlight.model.SimpleDepTree;
import org.cleartk.util.treebank.TopTreebankNode;
import org.cleartk.util.treebank.TreebankFormatParser;
import org.cleartk.util.treebank.TreebankNode;

@Beta
public class TreeKernel {
    public static final int MAX_CHILDREN = 5;
    public static final double LAMBDA_DEFAULT = 0.4;
    private double lambda = 0.4;
    private double lambdaSquared = this.lambda * this.lambda;
    private double[] lambdaPowers = new double[10];
    public static final double MU_DEFAULT = 0.4;
    private double mu = 0.4;
    private boolean normalize = false;
    private boolean useCache = true;
    HashMap<SimpleDepTree, HashMap<SimpleDepTree, Double>> cache = new HashMap();
    private ConcurrentHashMap<String, Double> normalizers = new ConcurrentHashMap();
    private ForestSumMethod sumMethod = ForestSumMethod.SEQUENTIAL;
    private KernelType kernelType;
    HashMap<String, TopTreebankNode> trees = null;
    HashMap<String, SimpleDepTree> depTrees = null;

    public TreeKernel(double lambda, ForestSumMethod sumMethod, KernelType kernelType, boolean normalize) {
        this.lambda = lambda;
        this.lambdaSquared = lambda * lambda;
        this.sumMethod = sumMethod;
        this.kernelType = kernelType;
        this.normalize = normalize;
        this.trees = new HashMap();
        this.depTrees = new HashMap();
        this.initExponents();
    }

    public void initExponents() {
        for (int i = 0; i < this.lambdaPowers.length; ++i) {
            this.lambdaPowers[i] = Math.pow(this.lambda, i);
        }
    }

    public double evaluate(TreeFeatureVector fv1, TreeFeatureVector fv2) {
        double sim = 0.0;
        if (this.sumMethod == ForestSumMethod.SEQUENTIAL) {
            ArrayList<String> fv1Trees = new ArrayList<String>(fv1.getTrees().values());
            ArrayList<String> fv2Trees = new ArrayList<String>(fv2.getTrees().values());
            for (int i = 0; i < fv1Trees.size(); ++i) {
                String tree1Str = (String)fv1Trees.get(i);
                String tree2Str = (String)fv2Trees.get(i);
                if (this.kernelType == KernelType.SUBSET) {
                    sim += this.sst(tree1Str, tree2Str);
                    continue;
                }
                if (this.kernelType == KernelType.PARTIAL) {
                    sim += this.ptk(tree1Str, tree2Str);
                    continue;
                }
                throw new NotImplementedException("The only kernel types implemented are SST and PTK!");
            }
        } else {
            throw new NotImplementedException("The only summation method implemented is Sequential!");
        }
        return sim;
    }

    private double sst(String tree1Str, String tree2Str) {
        TopTreebankNode node1 = null;
        if (!this.trees.containsKey(tree1Str)) {
            node1 = TreebankFormatParser.parse((String)tree1Str);
            this.trees.put(tree1Str, node1);
        } else {
            node1 = this.trees.get(tree1Str);
        }
        TopTreebankNode node2 = null;
        if (!this.trees.containsKey(tree2Str)) {
            node2 = TreebankFormatParser.parse((String)tree2Str);
            this.trees.put(tree2Str, node2);
        } else {
            node2 = this.trees.get(tree2Str);
        }
        double norm1 = 0.0;
        double norm2 = 0.0;
        if (this.normalize) {
            double norm;
            if (!this.normalizers.containsKey(tree1Str)) {
                norm = this.sim((TreebankNode)node1, (TreebankNode)node1);
                this.normalizers.put(tree1Str, norm);
            }
            if (!this.normalizers.containsKey(tree2Str)) {
                norm = this.sim((TreebankNode)node2, (TreebankNode)node2);
                this.normalizers.put(tree2Str, norm);
            }
            norm1 = this.normalizers.get(tree1Str);
            norm2 = this.normalizers.get(tree2Str);
        }
        if (this.normalize) {
            return this.sim((TreebankNode)node1, (TreebankNode)node2) / Math.sqrt(norm1 * norm2);
        }
        return this.sim((TreebankNode)node1, (TreebankNode)node2);
    }

    private double sim(TreebankNode node1, TreebankNode node2) {
        double sim = 0.0;
        List<TreebankNode> N1 = TreeKernel.getNodeList(node1);
        List<TreebankNode> N2 = TreeKernel.getNodeList(node2);
        for (TreebankNode n1 : N1) {
            for (TreebankNode n2 : N2) {
                sim += this.numCommonSubtrees(n1, n2);
            }
        }
        return sim;
    }

    private double numCommonSubtrees(TreebankNode n1, TreebankNode n2) {
        int c2size;
        double retVal = 1.0;
        List children1 = n1.getChildren();
        List children2 = n2.getChildren();
        int c1size = children1.size();
        if (c1size != (c2size = children2.size())) {
            retVal = 0.0;
        } else if (!n1.getType().equals(n2.getType())) {
            retVal = 0.0;
        } else if (n1.isLeaf() && n2.isLeaf()) {
            retVal = n1.getValue().equals(n2.getValue()) ? this.lambda : 0.0;
        } else {
            int i;
            boolean sameProd = true;
            for (i = 0; i < c1size; ++i) {
                String l2;
                String l1 = ((TreebankNode)children1.get(i)).getType();
                if (l1.equals(l2 = ((TreebankNode)children2.get(i)).getType())) continue;
                sameProd = false;
                break;
            }
            if (sameProd) {
                for (i = 0; i < c1size; ++i) {
                    retVal *= 1.0 + this.numCommonSubtrees((TreebankNode)children1.get(i), (TreebankNode)children2.get(i));
                }
                retVal = this.lambda * retVal;
            } else {
                retVal = 0.0;
            }
        }
        return retVal;
    }

    private double ptk(String tree1Str, String tree2Str) {
        SimpleDepTree node1 = null;
        if (!this.depTrees.containsKey(tree1Str)) {
            node1 = SimpleDepTree.fromString(tree1Str);
            this.depTrees.put(tree1Str, node1);
        } else {
            node1 = this.depTrees.get(tree1Str);
        }
        SimpleDepTree node2 = null;
        if (!this.depTrees.containsKey(tree2Str)) {
            node2 = SimpleDepTree.fromString(tree2Str);
            this.depTrees.put(tree2Str, node2);
        } else {
            node2 = this.depTrees.get(tree2Str);
        }
        double norm1 = 0.0;
        double norm2 = 0.0;
        if (this.normalize) {
            double norm;
            if (!this.normalizers.containsKey(tree1Str)) {
                norm = this.ptkSim(node1, node1);
                this.normalizers.put(tree1Str, norm);
            }
            if (!this.normalizers.containsKey(tree2Str)) {
                norm = this.ptkSim(node2, node2);
                this.normalizers.put(tree2Str, norm);
            }
            norm1 = this.normalizers.get(tree1Str);
            norm2 = this.normalizers.get(tree2Str);
            return this.ptkSim(node1, node2) / Math.sqrt(norm1 * norm2);
        }
        return this.ptkSim(node1, node2);
    }

    private double ptkSim(SimpleDepTree t1, SimpleDepTree t2) {
        double sim = 0.0;
        List<SimpleDepTree> t1Nodes = TreeKernel.getDepNodeList(t1);
        List<SimpleDepTree> t2Nodes = TreeKernel.getDepNodeList(t2);
        for (int i = 0; i < t1Nodes.size(); ++i) {
            SimpleDepTree t1Node = t1Nodes.get(i);
            for (int j = 0; j < t2Nodes.size(); ++j) {
                SimpleDepTree t2Node = t2Nodes.get(j);
                double nodeSim = 0.0;
                if (t1Node.cat.equals(t2Node.cat)) {
                    if (t1Node.isLeaf()) {
                        if (t1Node.cat.equals(t2Node.cat)) {
                            nodeSim += this.mu * this.lambdaSquared;
                        }
                    } else {
                        nodeSim = this.ptkDelta(t1Node, t2Node);
                    }
                }
                sim += nodeSim;
            }
        }
        return sim;
    }

    private double ptkDelta(SimpleDepTree node1, SimpleDepTree node2) {
        double delta = 0.0;
        if (!node1.cat.equals(node2.cat)) {
            return 0.0;
        }
        if (this.useCache && this.cache.containsKey(node1) && this.cache.get(node1).containsKey(node2)) {
            return this.cache.get(node1).get(node2);
        }
        int l1 = node1.children.size();
        int l2 = node2.children.size();
        delta = 1.0;
        for (int p = 1; p <= Math.min(Math.min(l1, l2), 5); ++p) {
            double contrP = this.ptkDeltaP(node1.children, node2.children, p);
            delta += contrP;
        }
        double score = this.mu * this.lambdaSquared * delta;
        if (this.useCache) {
            if (!this.cache.containsKey(node1)) {
                this.cache.put(node1, new HashMap());
            }
            this.cache.get(node1).put(node2, score);
        }
        return score;
    }

    private double ptkDeltaP(List<SimpleDepTree> c1, List<SimpleDepTree> c2, int p) {
        return this.ptkDeltaP(c1, c2, c1.size() - 1, c2.size() - 1, p);
    }

    private double ptkDeltaP(List<SimpleDepTree> c1, List<SimpleDepTree> c2, int s1len, int s2len, int p) {
        double delta = this.ptkDelta(c1.get(s1len), c2.get(s2len));
        for (int i = 0; i < s1len; ++i) {
            for (int r = 0; r < s2len; ++r) {
                int exp = s1len - i + s2len - r;
                double lambdaPow = exp < this.lambdaPowers.length - 1 ? this.lambdaPowers[exp] : Math.pow(this.lambda, exp);
                delta += lambdaPow * this.ptkDeltaP(c1, c2, i, r, p - 1);
            }
        }
        return delta;
    }

    private static final List<TreebankNode> getNodeList(TreebankNode tree) {
        ArrayList<TreebankNode> list = new ArrayList<TreebankNode>();
        list.add(tree);
        for (int i = 0; i < list.size(); ++i) {
            list.addAll(list.get(i).getChildren());
        }
        return list;
    }

    private static final List<SimpleDepTree> getDepNodeList(SimpleDepTree tree) {
        ArrayList list = Lists.newArrayList();
        list.add(tree);
        for (int i = 0; i < list.size(); ++i) {
            list.addAll(((SimpleDepTree)list.get((int)i)).children);
        }
        return list;
    }

    public void setUseCache(boolean use) {
        this.useCache = use;
    }

    public boolean getUseCache() {
        return this.useCache;
    }

    public static enum KernelType {
        SUBSET,
        SUBTREE,
        SUBSET_BOW,
        PARTIAL;

    }

    public static enum ForestSumMethod {
        SEQUENTIAL,
        ALL_PAIRS;

    }
}

