/*
 * Decompiled with CFR 0.152.
 */
package beast.evolution.likelihood;

import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.branchratemodel.BranchRateModel;
import beast.evolution.branchratemodel.StrictClockModel;
import beast.evolution.likelihood.BeagleTreeLikelihood;
import beast.evolution.likelihood.BeerLikelihoodCore;
import beast.evolution.likelihood.BeerLikelihoodCore4;
import beast.evolution.likelihood.GenericTreeLikelihood;
import beast.evolution.likelihood.LikelihoodCore;
import beast.evolution.sitemodel.SiteModelInterface;
import beast.evolution.substitutionmodel.SubstitutionModel;
import beast.evolution.tree.Node;
import beast.evolution.tree.TreeInterface;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

@Description(value="Calculates the probability of sequence data on a beast.tree given a site and substitution model using a variant of the 'peeling algorithm'. For details, seeFelsenstein, Joseph (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. J Mol Evol 17 (6): 368-376.")
public class TreeLikelihood
extends GenericTreeLikelihood {
    public final Input<Boolean> m_useAmbiguities = new Input<Boolean>("useAmbiguities", "flag to indicate that sites containing ambiguous states should be handled instead of ignored (the default)", false);
    public final Input<Boolean> m_useTipLikelihoods = new Input<Boolean>("useTipLikelihoods", "flag to indicate that partial likelihoods are provided at the tips", false);
    public final Input<String> implementationInput = new Input<String>("implementation", "name of class that implements this treelikelihood potentially more efficiently. This class will be tried first, with the TreeLikelihood as fallback implementation. When multi-threading, multiple objects can be created.", "beast.evolution.likelihood.BeagleTreeLikelihood");
    public final Input<Scaling> scaling = new Input<Scaling>("scaling", "type of scaling to use, one of " + Arrays.toString((Object[])Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values());
    protected LikelihoodCore likelihoodCore;
    BeagleTreeLikelihood beagle;
    SubstitutionModel substitutionModel;
    protected SiteModelInterface.Base m_siteModel;
    protected BranchRateModel.Base branchRateModel;
    protected int hasDirt;
    protected double[] m_branchLengths;
    protected double[] storedBranchLengths;
    protected double[] patternLogLikelihoods;
    protected double[] m_fRootPartials;
    double[] probabilities;
    int matrixSize;
    boolean useAscertainedSitePatterns = false;
    double proportionInvariant = 0.0;
    List<Integer> constantPattern = null;
    double m_fScale = 1.01;
    int m_nScale = 0;
    int X = 100;

    @Override
    public void initAndValidate() {
        if (((Alignment)this.dataInput.get()).getTaxonCount() != ((TreeInterface)this.treeInput.get()).getLeafNodeCount()) {
            throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
        }
        this.beagle = null;
        this.beagle = new BeagleTreeLikelihood();
        try {
            this.beagle.initByName("data", this.dataInput.get(), "tree", this.treeInput.get(), "siteModel", this.siteModelInput.get(), "branchRateModel", this.branchRateModelInput.get(), "useAmbiguities", this.m_useAmbiguities.get(), "useTipLikelihoods", this.m_useTipLikelihoods.get(), "scaling", this.scaling.get().toString());
            if (this.beagle.beagle != null) {
                return;
            }
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.beagle = null;
        int n = ((TreeInterface)this.treeInput.get()).getNodeCount();
        if (!(this.siteModelInput.get() instanceof SiteModelInterface.Base)) {
            throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
        }
        this.m_siteModel = (SiteModelInterface.Base)this.siteModelInput.get();
        this.m_siteModel.setDataType(((Alignment)this.dataInput.get()).getDataType());
        this.substitutionModel = this.m_siteModel.substModelInput.get();
        this.branchRateModel = this.branchRateModelInput.get() != null ? (BranchRateModel.Base)this.branchRateModelInput.get() : new StrictClockModel();
        this.m_branchLengths = new double[n];
        this.storedBranchLengths = new double[n];
        int n2 = ((Alignment)this.dataInput.get()).getMaxStateCount();
        int n3 = ((Alignment)this.dataInput.get()).getPatternCount();
        this.likelihoodCore = n2 == 4 ? new BeerLikelihoodCore4() : new BeerLikelihoodCore(n2);
        String string = this.getClass().getSimpleName();
        Alignment alignment = (Alignment)this.dataInput.get();
        Log.info.println(string + "(" + this.getID() + ") uses " + this.likelihoodCore.getClass().getSimpleName());
        Log.info.println("  " + alignment.toString(true));
        this.proportionInvariant = this.m_siteModel.getProportionInvariant();
        this.m_siteModel.setPropInvariantIsCategory(false);
        if (this.proportionInvariant > 0.0) {
            this.calcConstantPatternIndices(n3, n2);
        }
        this.initCore();
        this.patternLogLikelihoods = new double[n3];
        this.m_fRootPartials = new double[n3 * n2];
        this.matrixSize = (n2 + 1) * (n2 + 1);
        this.probabilities = new double[(n2 + 1) * (n2 + 1)];
        Arrays.fill(this.probabilities, 1.0);
        if (((Alignment)this.dataInput.get()).isAscertained) {
            this.useAscertainedSitePatterns = true;
        }
    }

    void calcConstantPatternIndices(int n, int n2) {
        this.constantPattern = new ArrayList<Integer>();
        for (int i = 0; i < n; ++i) {
            int[] nArray = ((Alignment)this.dataInput.get()).getPattern(i);
            boolean[] blArray = new boolean[n2];
            Arrays.fill(blArray, true);
            for (int n3 : nArray) {
                boolean[] blArray2 = ((Alignment)this.dataInput.get()).getStateSet(n3);
                if (!this.m_useAmbiguities.get().booleanValue() && ((Alignment)this.dataInput.get()).getDataType().isAmbiguousState(n3)) continue;
                for (int j = 0; j < n2; ++j) {
                    int n4 = j;
                    blArray[n4] = blArray[n4] & blArray2[j];
                }
            }
            for (int j = 0; j < n2; ++j) {
                if (!blArray[j]) continue;
                this.constantPattern.add(i * n2 + j);
            }
        }
    }

    protected void initCore() {
        int n = ((TreeInterface)this.treeInput.get()).getNodeCount();
        this.likelihoodCore.initialize(n, ((Alignment)this.dataInput.get()).getPatternCount(), this.m_siteModel.getCategoryCount(), true, this.m_useAmbiguities.get());
        int n2 = n / 2 + 1;
        int n3 = n / 2;
        if (this.m_useAmbiguities.get().booleanValue() || this.m_useTipLikelihoods.get().booleanValue()) {
            this.setPartials(((TreeInterface)this.treeInput.get()).getRoot(), ((Alignment)this.dataInput.get()).getPatternCount());
        } else {
            this.setStates(((TreeInterface)this.treeInput.get()).getRoot(), ((Alignment)this.dataInput.get()).getPatternCount());
        }
        this.hasDirt = 2;
        for (int i = 0; i < n3; ++i) {
            this.likelihoodCore.createNodePartials(n2 + i);
        }
    }

    @Override
    public void sample(State state, Random random) {
        throw new UnsupportedOperationException("Can't sample a fixed alignment!");
    }

    protected void setStates(Node node, int n) {
        if (node.isLeaf()) {
            Alignment alignment = (Alignment)this.dataInput.get();
            int[] nArray = new int[n];
            int n2 = this.getTaxonIndex(node.getID(), alignment);
            for (int i = 0; i < n; ++i) {
                int n3 = alignment.getPattern(n2, i);
                int[] nArray2 = alignment.getDataType().getStatesForCode(n3);
                nArray[i] = nArray2.length == 1 ? nArray2[0] : n3;
            }
            this.likelihoodCore.setNodeStates(node.getNr(), nArray);
        } else {
            this.setStates(node.getLeft(), n);
            this.setStates(node.getRight(), n);
        }
    }

    private int getTaxonIndex(String string, Alignment alignment) {
        int n = alignment.getTaxonIndex(string);
        if (n == -1) {
            if (string.startsWith("'") || string.startsWith("\"")) {
                n = alignment.getTaxonIndex(string.substring(1, string.length() - 1));
            }
            if (n == -1) {
                throw new RuntimeException("Could not find sequence " + string + " in the alignment");
            }
        }
        return n;
    }

    protected void setPartials(Node node, int n) {
        if (node.isLeaf()) {
            Alignment alignment = (Alignment)this.dataInput.get();
            int n2 = alignment.getDataType().getStateCount();
            double[] dArray = new double[n * n2];
            int n3 = 0;
            int n4 = this.getTaxonIndex(node.getID(), alignment);
            for (int i = 0; i < n; ++i) {
                int n5;
                double[] dArray2 = alignment.getTipLikelihoods(n4, i);
                if (dArray2 != null) {
                    for (n5 = 0; n5 < n2; ++n5) {
                        dArray[n3++] = dArray2[n5];
                    }
                    continue;
                }
                n5 = alignment.getPattern(n4, i);
                boolean[] blArray = alignment.getStateSet(n5);
                for (int j = 0; j < n2; ++j) {
                    dArray[n3++] = blArray[j] ? 1.0 : 0.0;
                }
            }
            this.likelihoodCore.setNodePartials(node.getNr(), dArray);
        } else {
            this.setPartials(node.getLeft(), n);
            this.setPartials(node.getRight(), n);
        }
    }

    @Override
    public double calculateLogP() {
        if (this.beagle != null) {
            this.logP = this.beagle.calculateLogP();
            return this.logP;
        }
        TreeInterface treeInterface = (TreeInterface)this.treeInput.get();
        try {
            if (this.traverse(treeInterface.getRoot()) != 0) {
                this.calcLogP();
            }
        }
        catch (ArithmeticException arithmeticException) {
            return Double.NEGATIVE_INFINITY;
        }
        ++this.m_nScale;
        if (!(this.logP > 0.0 || this.likelihoodCore.getUseScaling() && this.m_nScale > this.X || this.logP != Double.NEGATIVE_INFINITY || !(this.m_fScale < 10.0) || this.scaling.get().equals((Object)Scaling.none))) {
            this.m_nScale = 0;
            this.m_fScale *= 1.01;
            Log.warning.println("Turning on scaling to prevent numeric instability " + this.m_fScale);
            this.likelihoodCore.setUseScaling(this.m_fScale);
            this.likelihoodCore.unstore();
            this.hasDirt = 2;
            this.traverse(treeInterface.getRoot());
            this.calcLogP();
            return this.logP;
        }
        return this.logP;
    }

    void calcLogP() {
        this.logP = 0.0;
        if (this.useAscertainedSitePatterns) {
            double d = ((Alignment)this.dataInput.get()).getAscertainmentCorrection(this.patternLogLikelihoods);
            for (int i = 0; i < ((Alignment)this.dataInput.get()).getPatternCount(); ++i) {
                this.logP += (this.patternLogLikelihoods[i] - d) * (double)((Alignment)this.dataInput.get()).getPatternWeight(i);
            }
        } else {
            for (int i = 0; i < ((Alignment)this.dataInput.get()).getPatternCount(); ++i) {
                this.logP += this.patternLogLikelihoods[i] * (double)((Alignment)this.dataInput.get()).getPatternWeight(i);
            }
        }
    }

    int traverse(Node node) {
        int n;
        Node node2;
        int n2 = node.isDirty() | this.hasDirt;
        int n3 = node.getNr();
        double d = this.branchRateModel.getRateForBranch(node);
        double d2 = node.getLength() * d;
        if (!(node.isRoot() || n2 == 0 && d2 == this.m_branchLengths[n3])) {
            this.m_branchLengths[n3] = d2;
            node2 = node.getParent();
            this.likelihoodCore.setNodeMatrixForUpdate(n3);
            for (n = 0; n < this.m_siteModel.getCategoryCount(); ++n) {
                double d3 = this.m_siteModel.getRateForCategory(n, node) * d;
                this.substitutionModel.getTransitionProbabilities(node, node2.getHeight(), node.getHeight(), d3, this.probabilities);
                this.likelihoodCore.setNodeMatrix(n3, n, this.probabilities);
            }
            n2 |= 1;
        }
        if (!node.isLeaf()) {
            node2 = node.getLeft();
            n = this.traverse(node2);
            Node node3 = node.getRight();
            int n4 = this.traverse(node3);
            if (n != 0 || n4 != 0) {
                int n5 = node2.getNr();
                int n6 = node3.getNr();
                this.likelihoodCore.setNodePartialsForUpdate(n3);
                if ((n2 |= n | n4) >= 2) {
                    this.likelihoodCore.setNodeStatesForUpdate(n3);
                }
                if (!this.m_siteModel.integrateAcrossCategories()) {
                    throw new RuntimeException("Error TreeLikelihood 201: Site categories not supported");
                }
                this.likelihoodCore.calculatePartials(n5, n6, n3);
                if (node.isRoot()) {
                    double[] dArray = this.substitutionModel.getFrequencies();
                    double[] dArray2 = this.m_siteModel.getCategoryProportions(node);
                    this.likelihoodCore.integratePartials(node.getNr(), dArray2, this.m_fRootPartials);
                    if (this.constantPattern != null) {
                        this.proportionInvariant = this.m_siteModel.getProportionInvariant();
                        Iterator<Integer> iterator = this.constantPattern.iterator();
                        while (iterator.hasNext()) {
                            int n7;
                            int n8 = n7 = iterator.next().intValue();
                            this.m_fRootPartials[n8] = this.m_fRootPartials[n8] + this.proportionInvariant;
                        }
                    }
                    this.likelihoodCore.calculateLogLikelihoods(this.m_fRootPartials, dArray, this.patternLogLikelihoods);
                }
            }
        }
        return n2;
    }

    public double[] getPatternLogLikelihoods() {
        if (this.beagle != null) {
            return this.beagle.getPatternLogLikelihoods();
        }
        return (double[])this.patternLogLikelihoods.clone();
    }

    @Override
    protected boolean requiresRecalculation() {
        if (this.beagle != null) {
            return this.beagle.requiresRecalculation();
        }
        this.hasDirt = 0;
        if (((Alignment)this.dataInput.get()).isDirtyCalculation()) {
            this.hasDirt = 2;
            return true;
        }
        if (this.m_siteModel.isDirtyCalculation()) {
            this.hasDirt = 1;
            return true;
        }
        if (this.branchRateModel != null && this.branchRateModel.isDirtyCalculation()) {
            return true;
        }
        return ((TreeInterface)this.treeInput.get()).somethingIsDirty();
    }

    @Override
    public void store() {
        if (this.beagle != null) {
            this.beagle.store();
            super.store();
            return;
        }
        if (this.likelihoodCore != null) {
            this.likelihoodCore.store();
        }
        super.store();
        System.arraycopy(this.m_branchLengths, 0, this.storedBranchLengths, 0, this.m_branchLengths.length);
    }

    @Override
    public void restore() {
        if (this.beagle != null) {
            this.beagle.restore();
            super.restore();
            return;
        }
        if (this.likelihoodCore != null) {
            this.likelihoodCore.restore();
        }
        super.restore();
        double[] dArray = this.m_branchLengths;
        this.m_branchLengths = this.storedBranchLengths;
        this.storedBranchLengths = dArray;
    }

    @Override
    public List<String> getArguments() {
        return Collections.singletonList(((Alignment)this.dataInput.get()).getID());
    }

    @Override
    public List<String> getConditions() {
        return this.m_siteModel.getConditions();
    }

    public static enum Scaling {
        none,
        always,
        _default;

    }
}

