package beast.evolution.likelihood;

import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.branchratemodel.BranchRateModel;
import beast.evolution.branchratemodel.StrictClockModel;
import beast.evolution.sitemodel.SiteModelInterface;
import beast.evolution.substitutionmodel.SubstitutionModel;
import beast.evolution.tree.Node;
import beast.evolution.tree.TreeInterface;
import beast.util.XMLParser;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

@Description("Calculates the probability of sequence data on a beast.tree given a site and substitution model using a variant of the 'peeling algorithm'. For details, seeFelsenstein, Joseph (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. J Mol Evol 17 (6): 368-376.")
/* loaded from: input_file:beast/evolution/likelihood/TreeLikelihood.class */
public class TreeLikelihood extends GenericTreeLikelihood {
    protected LikelihoodCore likelihoodCore;
    BeagleTreeLikelihood beagle;
    SubstitutionModel substitutionModel;
    protected SiteModelInterface.Base m_siteModel;
    protected BranchRateModel.Base branchRateModel;
    protected int hasDirt;
    protected double[] m_branchLengths;
    protected double[] storedBranchLengths;
    protected double[] patternLogLikelihoods;
    protected double[] m_fRootPartials;
    double[] probabilities;
    int matrixSize;
    public final Input<Boolean> m_useAmbiguities = new Input<>("useAmbiguities", "flag to indicate that sites containing ambiguous states should be handled instead of ignored (the default)", false);
    public final Input<Boolean> m_useTipLikelihoods = new Input<>("useTipLikelihoods", "flag to indicate that partial likelihoods are provided at the tips", false);
    public final Input<Scaling> scaling = new Input<>("scaling", "type of scaling to use, one of " + Arrays.toString(Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values());
    boolean useAscertainedSitePatterns = false;
    double proportionInvariant = 0.0d;
    List<Integer> constantPattern = null;
    double m_fScale = 1.01d;
    int m_nScale = 0;
    int X = 100;

    /* loaded from: input_file:beast/evolution/likelihood/TreeLikelihood$Scaling.class */
    public enum Scaling {
        none,
        always,
        _default
    }

    @Override // beast.core.Distribution, beast.core.BEASTInterface
    public void initAndValidate() {
        if (this.dataInput.get().getTaxonCount() != this.treeInput.get().getLeafNodeCount()) {
            throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
        }
        this.beagle = null;
        this.beagle = new BeagleTreeLikelihood();
        try {
            this.beagle.initByName(XMLParser.DATA_ELEMENT, this.dataInput.get(), XMLParser.TREE_ELEMENT, this.treeInput.get(), "siteModel", this.siteModelInput.get(), "branchRateModel", this.branchRateModelInput.get(), "useAmbiguities", this.m_useAmbiguities.get(), "useTipLikelihoods", this.m_useTipLikelihoods.get(), "scaling", this.scaling.get().toString());
            if (this.beagle.beagle != null) {
                return;
            }
        } catch (Exception e) {
        }
        this.beagle = null;
        int nodeCount = this.treeInput.get().getNodeCount();
        if (!(this.siteModelInput.get() instanceof SiteModelInterface.Base)) {
            throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
        }
        this.m_siteModel = (SiteModelInterface.Base) this.siteModelInput.get();
        this.m_siteModel.setDataType(this.dataInput.get().getDataType());
        this.substitutionModel = this.m_siteModel.substModelInput.get();
        if (this.branchRateModelInput.get() != null) {
            this.branchRateModel = this.branchRateModelInput.get();
        } else {
            this.branchRateModel = new StrictClockModel();
        }
        this.m_branchLengths = new double[nodeCount];
        this.storedBranchLengths = new double[nodeCount];
        int maxStateCount = this.dataInput.get().getMaxStateCount();
        int patternCount = this.dataInput.get().getPatternCount();
        if (maxStateCount == 4) {
            this.likelihoodCore = new BeerLikelihoodCore4();
        } else {
            this.likelihoodCore = new BeerLikelihoodCore(maxStateCount);
        }
        String simpleName = getClass().getSimpleName();
        Alignment alignment = this.dataInput.get();
        Log.info.println(simpleName + "(" + getID() + ") uses " + this.likelihoodCore.getClass().getSimpleName());
        Log.info.println("  " + alignment.toString(true));
        this.proportionInvariant = this.m_siteModel.getProportionInvariant();
        this.m_siteModel.setPropInvariantIsCategory(false);
        if (this.proportionInvariant > 0.0d) {
            calcConstantPatternIndices(patternCount, maxStateCount);
        }
        initCore();
        this.patternLogLikelihoods = new double[patternCount];
        this.m_fRootPartials = new double[patternCount * maxStateCount];
        this.matrixSize = (maxStateCount + 1) * (maxStateCount + 1);
        this.probabilities = new double[(maxStateCount + 1) * (maxStateCount + 1)];
        Arrays.fill(this.probabilities, 1.0d);
        if (this.dataInput.get().isAscertained) {
            this.useAscertainedSitePatterns = true;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void calcConstantPatternIndices(int i, int i2) {
        this.constantPattern = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            int[] pattern = this.dataInput.get().getPattern(i3);
            boolean[] zArr = new boolean[i2];
            Arrays.fill(zArr, true);
            for (int i4 : pattern) {
                boolean[] stateSet = this.dataInput.get().getStateSet(i4);
                if (this.m_useAmbiguities.get().booleanValue() || !this.dataInput.get().getDataType().isAmbiguousState(i4)) {
                    for (int i5 = 0; i5 < i2; i5++) {
                        int i6 = i5;
                        zArr[i6] = zArr[i6] & stateSet[i5];
                    }
                }
            }
            for (int i7 = 0; i7 < i2; i7++) {
                if (zArr[i7]) {
                    this.constantPattern.add(Integer.valueOf((i3 * i2) + i7));
                }
            }
        }
    }

    protected void initCore() {
        int nodeCount = this.treeInput.get().getNodeCount();
        this.likelihoodCore.initialize(nodeCount, this.dataInput.get().getPatternCount(), this.m_siteModel.getCategoryCount(), true, this.m_useAmbiguities.get().booleanValue());
        int i = (nodeCount / 2) + 1;
        int i2 = nodeCount / 2;
        if (this.m_useAmbiguities.get().booleanValue() || this.m_useTipLikelihoods.get().booleanValue()) {
            setPartials(this.treeInput.get().getRoot(), this.dataInput.get().getPatternCount());
        } else {
            setStates(this.treeInput.get().getRoot(), this.dataInput.get().getPatternCount());
        }
        this.hasDirt = 2;
        for (int i3 = 0; i3 < i2; i3++) {
            this.likelihoodCore.createNodePartials(i + i3);
        }
    }

    @Override // beast.evolution.likelihood.GenericTreeLikelihood, beast.core.Distribution
    public void sample(State state, Random random) {
        throw new UnsupportedOperationException("Can't sample a fixed alignment!");
    }

    protected void setStates(Node node, int i) {
        if (!node.isLeaf()) {
            setStates(node.getLeft(), i);
            setStates(node.getRight(), i);
            return;
        }
        Alignment alignment = this.dataInput.get();
        int[] iArr = new int[i];
        int taxonIndex = getTaxonIndex(node.getID(), alignment);
        for (int i2 = 0; i2 < i; i2++) {
            int pattern = alignment.getPattern(taxonIndex, i2);
            int[] statesForCode = alignment.getDataType().getStatesForCode(pattern);
            if (statesForCode.length == 1) {
                iArr[i2] = statesForCode[0];
            } else {
                iArr[i2] = pattern;
            }
        }
        this.likelihoodCore.setNodeStates(node.getNr(), iArr);
    }

    private int getTaxonIndex(String str, Alignment alignment) {
        int taxonIndex = alignment.getTaxonIndex(str);
        if (taxonIndex == -1) {
            if (str.startsWith("'") || str.startsWith("\"")) {
                taxonIndex = alignment.getTaxonIndex(str.substring(1, str.length() - 1));
            }
            if (taxonIndex == -1) {
                throw new RuntimeException("Could not find sequence " + str + " in the alignment");
            }
        }
        return taxonIndex;
    }

    protected void setPartials(Node node, int i) {
        if (!node.isLeaf()) {
            setPartials(node.getLeft(), i);
            setPartials(node.getRight(), i);
            return;
        }
        Alignment alignment = this.dataInput.get();
        int stateCount = alignment.getDataType().getStateCount();
        double[] dArr = new double[i * stateCount];
        int i2 = 0;
        int taxonIndex = getTaxonIndex(node.getID(), alignment);
        for (int i3 = 0; i3 < i; i3++) {
            double[] tipLikelihoods = alignment.getTipLikelihoods(taxonIndex, i3);
            if (tipLikelihoods != null) {
                for (int i4 = 0; i4 < stateCount; i4++) {
                    int i5 = i2;
                    i2++;
                    dArr[i5] = tipLikelihoods[i4];
                }
            } else {
                boolean[] stateSet = alignment.getStateSet(alignment.getPattern(taxonIndex, i3));
                for (int i6 = 0; i6 < stateCount; i6++) {
                    int i7 = i2;
                    i2++;
                    dArr[i7] = stateSet[i6] ? 1.0d : 0.0d;
                }
            }
        }
        this.likelihoodCore.setNodePartials(node.getNr(), dArr);
    }

    @Override // beast.core.Distribution
    public double calculateLogP() {
        if (this.beagle != null) {
            this.logP = this.beagle.calculateLogP();
            return this.logP;
        }
        TreeInterface treeInterface = this.treeInput.get();
        try {
            if (traverse(treeInterface.getRoot()) != 0) {
                calcLogP();
            }
            this.m_nScale++;
            if (this.logP > 0.0d || ((this.likelihoodCore.getUseScaling() && this.m_nScale > this.X) || this.logP != Double.NEGATIVE_INFINITY || this.m_fScale >= 10.0d || this.scaling.get().equals(Scaling.none))) {
                return this.logP;
            }
            this.m_nScale = 0;
            this.m_fScale *= 1.01d;
            Log.warning.println("Turning on scaling to prevent numeric instability " + this.m_fScale);
            this.likelihoodCore.setUseScaling(this.m_fScale);
            this.likelihoodCore.unstore();
            this.hasDirt = 2;
            traverse(treeInterface.getRoot());
            calcLogP();
            return this.logP;
        } catch (ArithmeticException e) {
            return Double.NEGATIVE_INFINITY;
        }
    }

    void calcLogP() {
        this.logP = 0.0d;
        if (!this.useAscertainedSitePatterns) {
            for (int i = 0; i < this.dataInput.get().getPatternCount(); i++) {
                this.logP += this.patternLogLikelihoods[i] * this.dataInput.get().getPatternWeight(i);
            }
            return;
        }
        double ascertainmentCorrection = this.dataInput.get().getAscertainmentCorrection(this.patternLogLikelihoods);
        for (int i2 = 0; i2 < this.dataInput.get().getPatternCount(); i2++) {
            this.logP += (this.patternLogLikelihoods[i2] - ascertainmentCorrection) * this.dataInput.get().getPatternWeight(i2);
        }
    }

    int traverse(Node node) {
        int isDirty = node.isDirty() | this.hasDirt;
        int nr = node.getNr();
        double rateForBranch = this.branchRateModel.getRateForBranch(node);
        double length = node.getLength() * rateForBranch;
        if (!node.isRoot() && (isDirty != 0 || length != this.m_branchLengths[nr])) {
            this.m_branchLengths[nr] = length;
            Node parent = node.getParent();
            this.likelihoodCore.setNodeMatrixForUpdate(nr);
            for (int i = 0; i < this.m_siteModel.getCategoryCount(); i++) {
                this.substitutionModel.getTransitionProbabilities(node, parent.getHeight(), node.getHeight(), this.m_siteModel.getRateForCategory(i, node) * rateForBranch, this.probabilities);
                this.likelihoodCore.setNodeMatrix(nr, i, this.probabilities);
            }
            isDirty |= 1;
        }
        if (!node.isLeaf()) {
            Node left = node.getLeft();
            int traverse = traverse(left);
            Node right = node.getRight();
            int traverse2 = traverse(right);
            if (traverse != 0 || traverse2 != 0) {
                int nr2 = left.getNr();
                int nr3 = right.getNr();
                this.likelihoodCore.setNodePartialsForUpdate(nr);
                isDirty |= traverse | traverse2;
                if (isDirty >= 2) {
                    this.likelihoodCore.setNodeStatesForUpdate(nr);
                }
                if (!this.m_siteModel.integrateAcrossCategories()) {
                    throw new RuntimeException("Error TreeLikelihood 201: Site categories not supported");
                }
                this.likelihoodCore.calculatePartials(nr2, nr3, nr);
                if (node.isRoot()) {
                    double[] frequencies = this.substitutionModel.getFrequencies();
                    this.likelihoodCore.integratePartials(node.getNr(), this.m_siteModel.getCategoryProportions(node), this.m_fRootPartials);
                    if (this.constantPattern != null) {
                        this.proportionInvariant = this.m_siteModel.getProportionInvariant();
                        Iterator<Integer> it = this.constantPattern.iterator();
                        while (it.hasNext()) {
                            int intValue = it.next().intValue();
                            double[] dArr = this.m_fRootPartials;
                            dArr[intValue] = dArr[intValue] + this.proportionInvariant;
                        }
                    }
                    this.likelihoodCore.calculateLogLikelihoods(this.m_fRootPartials, frequencies, this.patternLogLikelihoods);
                }
            }
        }
        return isDirty;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // beast.core.CalculationNode
    public boolean requiresRecalculation() {
        if (this.beagle != null) {
            return this.beagle.requiresRecalculation();
        }
        this.hasDirt = 0;
        if (this.dataInput.get().isDirtyCalculation()) {
            this.hasDirt = 2;
            return true;
        }
        if (this.m_siteModel.isDirtyCalculation()) {
            this.hasDirt = 1;
            return true;
        }
        if (this.branchRateModel == null || !this.branchRateModel.isDirtyCalculation()) {
            return this.treeInput.get().somethingIsDirty();
        }
        return true;
    }

    @Override // beast.core.Distribution, beast.core.CalculationNode
    public void store() {
        if (this.beagle != null) {
            this.beagle.store();
            super.store();
        } else {
            if (this.likelihoodCore != null) {
                this.likelihoodCore.store();
            }
            super.store();
            System.arraycopy(this.m_branchLengths, 0, this.storedBranchLengths, 0, this.m_branchLengths.length);
        }
    }

    @Override // beast.core.Distribution, beast.core.CalculationNode
    public void restore() {
        if (this.beagle != null) {
            this.beagle.restore();
            super.restore();
            return;
        }
        if (this.likelihoodCore != null) {
            this.likelihoodCore.restore();
        }
        super.restore();
        double[] dArr = this.m_branchLengths;
        this.m_branchLengths = this.storedBranchLengths;
        this.storedBranchLengths = dArr;
    }

    @Override // beast.evolution.likelihood.GenericTreeLikelihood, beast.core.Distribution
    public List<String> getArguments() {
        return Collections.singletonList(this.dataInput.get().getID());
    }

    @Override // beast.evolution.likelihood.GenericTreeLikelihood, beast.core.Distribution
    public List<String> getConditions() {
        return this.m_siteModel.getConditions();
    }
}
