package beast.math.distributions;

import beast.core.Description;
import beast.core.Distribution;
import beast.core.Input;
import beast.core.State;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.util.XMLParser;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;

@Description("Prior over set of taxa, useful for defining monophyletic constraints and distributions over MRCA times or (sets of) tips of trees")
/* loaded from: input_file:beast/math/distributions/MRCAPrior.class */
public class MRCAPrior extends Distribution {
    ParametricDistribution dist;
    Tree tree;
    int[] taxonIndex;
    boolean[] nodesTraversed;
    int nseen;
    public final Input<Tree> treeInput = new Input<>(XMLParser.TREE_ELEMENT, "the tree containing the taxon set", Input.Validate.REQUIRED);
    public final Input<TaxonSet> taxonsetInput = new Input<>("taxonset", "set of taxa for which prior information is available");
    public final Input<Boolean> isMonophyleticInput = new Input<>("monophyletic", "whether the taxon set is monophyletic (forms a clade without other taxa) or nor. Default is false.", false);
    public final Input<ParametricDistribution> distInput = new Input<>("distr", "distribution used to calculate prior over MRCA time, e.g. normal, beta, gamma. If not specified, monophyletic must be true");
    public final Input<Boolean> onlyUseTipsInput = new Input<>("tipsonly", "flag to indicate tip dates are to be used instead of the MRCA node. If set to true, the prior is applied to the height of all tips in the taxonset and the monophyletic flag is ignored. Default is false.", false);
    public final Input<Boolean> useOriginateInput = new Input<>("useOriginate", "Use parent of clade instead of clade. Cannot be used with tipsonly, or on the root.", false);
    int nrOfTaxa = -1;
    Set<String> isInTaxaSet = new LinkedHashSet();
    double MRCATime = -1.0d;
    double storedMRCATime = -1.0d;
    boolean isMonophyletic = false;
    boolean onlyUseTips = false;
    boolean useRoot = false;
    boolean useOriginate = false;
    boolean initialised = false;

    @Override // beast.core.Distribution, beast.core.BEASTInterface
    public void initAndValidate() {
        this.dist = this.distInput.get();
        this.tree = this.treeInput.get();
        ArrayList arrayList = new ArrayList();
        for (String str : this.tree.getTaxaNames()) {
            arrayList.add(str);
        }
        if (this.taxonsetInput.get() != null) {
            this.nrOfTaxa = this.taxonsetInput.get().asStringList().size();
        } else {
            this.nrOfTaxa = arrayList.size();
        }
        this.onlyUseTips = this.onlyUseTipsInput.get().booleanValue();
        this.useOriginate = this.useOriginateInput.get().booleanValue();
        if (this.nrOfTaxa == 1 && !this.useOriginate && !this.onlyUseTips) {
            this.onlyUseTips = true;
        }
        if (!this.onlyUseTips && !this.useOriginate && this.nrOfTaxa < 2) {
            throw new IllegalArgumentException("At least two taxa are required in a taxon set");
        }
        if (!this.onlyUseTips && this.taxonsetInput.get() == null) {
            throw new IllegalArgumentException("Taxonset must be specified OR tipsonly be set to true");
        }
        if (this.useOriginate && this.onlyUseTips) {
            throw new IllegalArgumentException("'useOriginate' and 'tipsOnly' cannot be both true");
        }
        this.useRoot = this.nrOfTaxa == this.tree.getLeafNodeCount();
        if (this.useOriginate && this.useRoot) {
            throw new IllegalArgumentException("Cannot use originate of root. You can set useOriginate to false to fix this");
        }
        this.initialised = false;
    }

    protected Node getCommonAncestor(Node node, Node node2) {
        Node node3;
        Node node4;
        Node node5;
        if (!this.nodesTraversed[node.getNr()]) {
            this.nodesTraversed[node.getNr()] = true;
            this.nseen++;
        }
        if (!this.nodesTraversed[node2.getNr()]) {
            this.nodesTraversed[node2.getNr()] = true;
            this.nseen++;
        }
        while (node != node2) {
            double height = node.getHeight();
            double height2 = node2.getHeight();
            if (height < height2) {
                node = node.getParent();
                if (!this.nodesTraversed[node.getNr()]) {
                    this.nodesTraversed[node.getNr()] = true;
                    this.nseen++;
                }
            } else if (height2 < height) {
                node2 = node2.getParent();
                if (!this.nodesTraversed[node2.getNr()]) {
                    this.nodesTraversed[node2.getNr()] = true;
                    this.nseen++;
                }
            } else {
                double length = node.getLength();
                double length2 = node2.getLength();
                if (length > 0.0d) {
                    node4 = node2;
                } else if (length2 > 0.0d) {
                    node4 = node;
                } else {
                    Node node6 = node;
                    while (true) {
                        node3 = node6;
                        if (node3 == null || node3 == node2) {
                            break;
                        }
                        node6 = node3.getParent();
                    }
                    node4 = node3 == node2 ? node : node2;
                }
                if (node4 == node) {
                    Node parent = node4.getParent();
                    node = parent;
                    node5 = parent;
                } else {
                    Node parent2 = node4.getParent();
                    node2 = parent2;
                    node5 = parent2;
                }
                if (!this.nodesTraversed[node5.getNr()]) {
                    this.nodesTraversed[node5.getNr()] = true;
                    this.nseen++;
                }
            }
        }
        return node;
    }

    public Node getCommonAncestor() {
        Node node = this.tree.getNode(this.taxonIndex[0]);
        for (int i = 1; i < this.taxonIndex.length; i++) {
            node = getCommonAncestor(node, this.tree.getNode(this.taxonIndex[i]));
        }
        return node;
    }

    @Override // beast.core.Distribution
    public double calculateLogP() {
        Node commonAncestor;
        if (!this.initialised) {
            initialise();
        }
        this.logP = 0.0d;
        if (this.onlyUseTips) {
            if (this.dist == null) {
                return this.logP;
            }
            for (int i : this.taxonIndex) {
                this.MRCATime = this.tree.getNode(i).getDate();
                this.logP += this.dist.logDensity(this.MRCATime);
            }
            return this.logP;
        }
        if (this.useRoot) {
            if (this.dist != null) {
                this.MRCATime = this.tree.getRoot().getDate();
                this.logP += this.dist.logDensity(this.MRCATime);
            }
            return this.logP;
        }
        if (this.taxonIndex.length == 1) {
            this.isMonophyletic = true;
            commonAncestor = this.tree.getNode(this.taxonIndex[0]);
        } else {
            this.nodesTraversed = new boolean[this.tree.getNodeCount()];
            this.nseen = 0;
            commonAncestor = getCommonAncestor();
            this.isMonophyletic = this.nseen == (2 * this.taxonIndex.length) - 1;
        }
        if (!this.useOriginate) {
            this.MRCATime = commonAncestor.getDate();
        } else if (commonAncestor.isRoot()) {
            this.MRCATime = commonAncestor.getDate();
        } else {
            this.MRCATime = commonAncestor.getParent().getDate();
        }
        if (this.isMonophyleticInput.get().booleanValue() && !this.isMonophyletic) {
            this.logP = Double.NEGATIVE_INFINITY;
            return Double.NEGATIVE_INFINITY;
        }
        if (this.dist != null) {
            this.logP = this.dist.logDensity(this.MRCATime);
        }
        return this.logP;
    }

    protected void initialise() {
        List<String> asStringList = this.taxonsetInput.get() != null ? this.taxonsetInput.get().asStringList() : null;
        ArrayList arrayList = new ArrayList();
        for (String str : this.tree.getTaxaNames()) {
            arrayList.add(str);
        }
        this.taxonIndex = new int[this.nrOfTaxa];
        if (asStringList != null) {
            this.isInTaxaSet.clear();
            int i = 0;
            for (String str2 : asStringList) {
                int indexOf = arrayList.indexOf(str2);
                if (indexOf < 0) {
                    throw new RuntimeException("Cannot find taxon " + str2 + " in data");
                }
                if (this.isInTaxaSet.contains(str2)) {
                    throw new RuntimeException("Taxon " + str2 + " is defined multiple times, while they should be unique");
                }
                this.isInTaxaSet.add(str2);
                int i2 = i;
                i++;
                this.taxonIndex[i2] = indexOf;
            }
        } else {
            for (int i3 = 0; i3 < this.nrOfTaxa; i3++) {
                this.taxonIndex[i3] = i3;
            }
        }
        this.initialised = true;
    }

    int calcMRCAtime(Node node, int[] iArr) {
        if (node.isLeaf()) {
            iArr[0] = iArr[0] + 1;
            return this.isInTaxaSet.contains(node.getID()) ? 1 : 0;
        }
        int calcMRCAtime = calcMRCAtime(node.getLeft(), iArr);
        int i = iArr[0];
        iArr[0] = 0;
        if (node.getRight() != null) {
            calcMRCAtime += calcMRCAtime(node.getRight(), iArr);
            iArr[0] = i + iArr[0];
            if (calcMRCAtime == this.nrOfTaxa) {
                if (this.nrOfTaxa == 1 && this.useOriginate) {
                    this.MRCATime = node.getDate();
                    this.isMonophyletic = true;
                    return calcMRCAtime + 1;
                }
                if (this.useOriginate) {
                    Node parent = node.getParent();
                    if (parent != null) {
                        this.MRCATime = parent.getDate();
                    } else {
                        this.MRCATime = node.getDate();
                    }
                } else {
                    this.MRCATime = node.getDate();
                }
                this.isMonophyletic = iArr[0] == this.nrOfTaxa;
                return calcMRCAtime + 1;
            }
        }
        return calcMRCAtime;
    }

    @Override // beast.core.Distribution, beast.core.CalculationNode
    public void store() {
        this.storedMRCATime = this.MRCATime;
        super.store();
    }

    @Override // beast.core.Distribution, beast.core.CalculationNode
    public void restore() {
        this.MRCATime = this.storedMRCATime;
        super.restore();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // beast.core.CalculationNode
    public boolean requiresRecalculation() {
        return super.requiresRecalculation();
    }

    @Override // beast.core.Distribution, beast.core.Loggable
    public void init(PrintStream printStream) {
        if (!this.initialised) {
            initialise();
        }
        if (!this.onlyUseTips) {
            if (!this.isMonophyleticInput.get().booleanValue()) {
                printStream.print("monophyletic(" + this.taxonsetInput.get().getID() + ")\t");
            }
            if (this.dist != null) {
                printStream.print("logP(mrca(" + this.taxonsetInput.get().getID() + "))\t");
            }
            printStream.print("mrcatime(" + this.taxonsetInput.get().getID() + (this.useOriginate ? ".originate" : "") + ")\t");
            return;
        }
        if (this.dist != null) {
            printStream.print("logP(mrca(" + getID() + "))\t");
        }
        for (int i : this.taxonIndex) {
            printStream.print("height(" + this.tree.getTaxaNames()[i] + ")\t");
        }
    }

    @Override // beast.core.Distribution, beast.core.Loggable
    public void log(int i, PrintStream printStream) {
        if (!this.onlyUseTips) {
            if (!this.isMonophyleticInput.get().booleanValue()) {
                printStream.print((this.isMonophyletic ? 1 : 0) + "\t");
            }
            if (this.dist != null) {
                printStream.print(getCurrentLogP() + "\t");
            } else {
                calcMRCAtime(this.tree.getRoot(), new int[1]);
            }
            printStream.print(this.MRCATime + "\t");
            return;
        }
        if (this.dist != null) {
            printStream.print(getCurrentLogP() + "\t");
        }
        for (int i2 : this.taxonIndex) {
            printStream.print(this.tree.getNode(i2).getDate() + "\t");
        }
    }

    @Override // beast.core.Distribution, beast.core.Loggable
    public void close(PrintStream printStream) {
    }

    @Override // beast.core.Distribution, beast.core.Function
    public int getDimension() {
        return 2;
    }

    @Override // beast.core.Distribution, beast.core.Function
    public double getArrayValue() {
        if (Double.isNaN(this.logP)) {
            try {
                calculateLogP();
            } catch (Exception e) {
                this.logP = Double.NaN;
            }
        }
        return this.logP;
    }

    @Override // beast.core.Distribution, beast.core.Function
    public double getArrayValue(int i) {
        if (Double.isNaN(this.logP)) {
            try {
                calculateLogP();
            } catch (Exception e) {
                this.logP = Double.NaN;
            }
        }
        switch (i) {
            case 0:
                return this.logP;
            case 1:
                return this.MRCATime;
            default:
                return 0.0d;
        }
    }

    @Override // beast.core.Distribution
    public void sample(State state, Random random) {
    }

    @Override // beast.core.Distribution
    public List<String> getArguments() {
        return null;
    }

    @Override // beast.core.Distribution
    public List<String> getConditions() {
        return null;
    }
}
