package beast.evolution.tree;

import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Input;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.tree.coalescent.PopulationFunction;
import beast.math.distributions.MRCAPrior;
import beast.math.distributions.ParametricDistribution;
import beast.util.HeapSort;
import beast.util.Randomizer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.commons.math.MathException;

@Description("This class provides the basic engine for coalescent simulation of a given demographic model over a given time period. ")
/* loaded from: input_file:beast/evolution/tree/RandomTree.class */
public class RandomTree extends Tree implements StateNodeInitialiser {
    int nrOfTaxa;
    int lastMonophyletic;
    List<Set<String>> taxonSets;
    List<ParametricDistribution> distributions;
    List<Bound> m_bounds;
    List<String> taxonSetIDs;
    List<Integer>[] children;
    Set<String> taxa;
    int nextNodeNr;
    static final /* synthetic */ boolean $assertionsDisabled;
    public final Input<Alignment> taxaInput = new Input<>("taxa", "set of taxa to initialise tree specified by alignment");
    public final Input<PopulationFunction> populationFunctionInput = new Input<>("populationModel", "population function for generating coalescent???", Input.Validate.REQUIRED);
    public final Input<List<MRCAPrior>> calibrationsInput = new Input<>("constraint", "specifies (monophyletic or height distribution) constraints on internal nodes", new ArrayList());
    public final Input<Double> rootHeightInput = new Input<>("rootHeight", "If specified the tree will be scaled to match the root height, if constraints allow this");
    private final ArrayList<Node> nodeList = new ArrayList<>();
    private int activeNodeCount = 0;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:beast/evolution/tree/RandomTree$Bound.class */
    public class Bound {
        Double upper = Double.valueOf(Double.POSITIVE_INFINITY);
        Double lower = Double.valueOf(Double.NEGATIVE_INFINITY);

        Bound() {
        }

        public String toString() {
            return "[" + this.lower + "," + this.upper + "]";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:beast/evolution/tree/RandomTree$ConstraintViolatedException.class */
    public class ConstraintViolatedException extends Exception {
        private static final long serialVersionUID = 1;

        protected ConstraintViolatedException() {
        }
    }

    @Override // beast.evolution.tree.Tree, beast.core.BEASTInterface
    public void initAndValidate() {
        this.taxa = new LinkedHashSet();
        if (this.taxaInput.get() != null) {
            this.taxa.addAll(this.taxaInput.get().getTaxaNames());
        } else {
            this.taxa.addAll(this.m_taxonset.get().asStringList());
        }
        this.nrOfTaxa = this.taxa.size();
        initStateNodes();
        super.initAndValidate();
    }

    private void swap(List list, int i, int i2) {
        Object obj = list.get(i);
        list.set(i, list.get(i2));
        list.set(i2, obj);
    }

    @Override // beast.core.StateNodeInitialiser
    public void initStateNodes() {
        this.taxonSets = new ArrayList();
        this.m_bounds = new ArrayList();
        this.distributions = new ArrayList();
        this.taxonSetIDs = new ArrayList();
        this.lastMonophyletic = 0;
        if (this.taxaInput.get() != null) {
            this.taxa.addAll(this.taxaInput.get().getTaxaNames());
        } else {
            this.taxa.addAll(this.m_taxonset.get().asStringList());
        }
        ArrayList<MRCAPrior> arrayList = new ArrayList();
        arrayList.addAll(this.calibrationsInput.get());
        for (BEASTInterface bEASTInterface : getOutputs()) {
            if ((bEASTInterface instanceof MRCAPrior) && !arrayList.contains(bEASTInterface)) {
                arrayList.add((MRCAPrior) bEASTInterface);
            }
        }
        if (this.m_initial.get() != null) {
            for (BEASTInterface bEASTInterface2 : this.m_initial.get().getOutputs()) {
                if ((bEASTInterface2 instanceof MRCAPrior) && !arrayList.contains(bEASTInterface2)) {
                    arrayList.add((MRCAPrior) bEASTInterface2);
                }
            }
        }
        for (MRCAPrior mRCAPrior : arrayList) {
            TaxonSet taxonSet = mRCAPrior.taxonsetInput.get();
            if (taxonSet != null && !mRCAPrior.onlyUseTipsInput.get().booleanValue()) {
                LinkedHashSet linkedHashSet = new LinkedHashSet();
                if (taxonSet.asStringList() == null) {
                    taxonSet.initAndValidate();
                }
                for (String str : taxonSet.asStringList()) {
                    if (!this.taxa.contains(str)) {
                        throw new IllegalArgumentException("Taxon <" + str + "> could not be found in list of taxa. Choose one of " + this.taxa);
                    }
                    linkedHashSet.add(str);
                }
                ParametricDistribution parametricDistribution = mRCAPrior.distInput.get();
                Bound bound = new Bound();
                if (parametricDistribution != null) {
                    ArrayList arrayList2 = new ArrayList();
                    parametricDistribution.getPredecessors(arrayList2);
                    for (int size = arrayList2.size() - 1; size >= 0; size--) {
                        arrayList2.get(size).initAndValidate();
                    }
                    try {
                        bound.lower = Double.valueOf(parametricDistribution.inverseCumulativeProbability(0.0d) + parametricDistribution.offsetInput.get().doubleValue());
                        bound.upper = Double.valueOf(parametricDistribution.inverseCumulativeProbability(1.0d) + parametricDistribution.offsetInput.get().doubleValue());
                    } catch (MathException e) {
                        Log.warning.println("At RandomTree::initStateNodes, bound on MRCAPrior could not be set " + e.getMessage());
                    }
                }
                if (mRCAPrior.isMonophyleticInput.get().booleanValue()) {
                    this.taxonSets.add(this.lastMonophyletic, linkedHashSet);
                    this.distributions.add(this.lastMonophyletic, parametricDistribution);
                    this.m_bounds.add(this.lastMonophyletic, bound);
                    this.taxonSetIDs.add(mRCAPrior.getID());
                    this.lastMonophyletic++;
                } else if (!Double.isInfinite(bound.lower.doubleValue()) || !Double.isInfinite(bound.upper.doubleValue())) {
                    this.taxonSets.add(linkedHashSet);
                    this.distributions.add(parametricDistribution);
                    this.m_bounds.add(bound);
                    this.taxonSetIDs.add(mRCAPrior.getID());
                }
            }
        }
        this.lastMonophyletic = this.taxonSets.size();
        for (int i = 0; i < this.lastMonophyletic; i++) {
            for (int i2 = i + 1; i2 < this.lastMonophyletic; i2++) {
                LinkedHashSet linkedHashSet2 = new LinkedHashSet(this.taxonSets.get(i));
                linkedHashSet2.retainAll(this.taxonSets.get(i2));
                if (linkedHashSet2.size() > 0) {
                    boolean containsAll = this.taxonSets.get(i).containsAll(this.taxonSets.get(i2));
                    boolean containsAll2 = this.taxonSets.get(i2).containsAll(this.taxonSets.get(i));
                    if (!containsAll && !containsAll2) {
                        throw new IllegalArgumentException("333: Don't know how to generate a Random Tree for taxon sets that intersect, but are not inclusive. Taxonset " + this.taxonSetIDs.get(i) + " and " + this.taxonSetIDs.get(i2));
                    }
                    if (containsAll) {
                        swap(this.taxonSets, i, i2);
                        swap(this.distributions, i, i2);
                        swap(this.m_bounds, i, i2);
                        swap(this.taxonSetIDs, i, i2);
                    }
                }
            }
        }
        int[] iArr = new int[this.lastMonophyletic];
        this.children = new List[this.lastMonophyletic + 1];
        for (int i3 = 0; i3 < this.lastMonophyletic + 1; i3++) {
            this.children[i3] = new ArrayList();
        }
        for (int i4 = 0; i4 < this.lastMonophyletic; i4++) {
            int i5 = i4 + 1;
            while (i5 < this.lastMonophyletic && !this.taxonSets.get(i5).containsAll(this.taxonSets.get(i4))) {
                i5++;
            }
            iArr[i4] = i5;
            this.children[i5].add(Integer.valueOf(i4));
        }
        for (int i6 = this.lastMonophyletic - 1; i6 >= 0; i6--) {
            if (iArr[i6] < this.lastMonophyletic && this.m_bounds.get(i6).upper.doubleValue() > this.m_bounds.get(iArr[i6]).upper.doubleValue()) {
                this.m_bounds.get(i6).upper = Double.valueOf(this.m_bounds.get(iArr[i6]).upper.doubleValue() - 1.0E-100d);
            }
        }
        simulateTree(this.taxa, this.populationFunctionInput.get());
        if (this.rootHeightInput.get() != null) {
            scaleToFit(this.rootHeightInput.get().doubleValue() / this.root.getHeight(), this.root);
        }
        this.nodeCount = (2 * this.taxa.size()) - 1;
        this.internalNodeCount = this.taxa.size() - 1;
        this.leafNodeCount = this.taxa.size();
        HashMap hashMap = null;
        if (this.m_initial.get() == null) {
            hashMap = new HashMap();
            String[] taxaNames = getTaxaNames();
            for (int i7 = 0; i7 < taxaNames.length; i7++) {
                hashMap.put(taxaNames[i7], Integer.valueOf(i7));
            }
        } else if (this.leafNodeCount == this.m_initial.get().getLeafNodeCount()) {
            hashMap = new HashMap();
            for (Node node : this.m_initial.get().getExternalNodes()) {
                hashMap.put(node.getID(), Integer.valueOf(node.getNr()));
            }
        }
        setNodesNrs(this.root, 0, new int[1], hashMap);
        initArrays();
        if (this.m_initial.get() != null) {
            this.m_initial.get().assignFromWithoutID(this);
        }
        for (int i8 = 0; i8 < this.lastMonophyletic; i8++) {
            MRCAPrior mRCAPrior2 = (MRCAPrior) arrayList.get(i8);
            if (mRCAPrior2.isMonophyleticInput.get().booleanValue()) {
                TaxonSet taxonSet2 = mRCAPrior2.taxonsetInput.get();
                if (taxonSet2 == null) {
                    throw new IllegalArgumentException("Something is wrong with constraint " + mRCAPrior2.getID() + " -- a taxonset must be specified if a monophyletic constraint is enforced.");
                }
                LinkedHashSet linkedHashSet3 = new LinkedHashSet();
                linkedHashSet3.addAll(taxonSet2.asStringList());
                traverse(this.root, linkedHashSet3, taxonSet2.getTaxonCount(), new int[1]);
            }
        }
    }

    private int setNodesNrs(Node node, int i, int[] iArr, Map<String, Integer> map) {
        if (!node.isLeaf()) {
            Iterator<Node> it = node.getChildren().iterator();
            while (it.hasNext()) {
                i = setNodesNrs(it.next(), i, iArr, map);
            }
            node.setNr(this.nrOfTaxa + i);
            i++;
        } else if (map != null) {
            node.setNr(map.get(node.getID()).intValue());
        } else {
            node.setNr(iArr[0]);
            iArr[0] = iArr[0] + 1;
        }
        return i;
    }

    private void scaleToFit(double d, Node node) {
        if (node.isLeaf()) {
            return;
        }
        double height = node.getHeight();
        node.height *= d;
        Integer distrConstraint = getDistrConstraint(node);
        if (distrConstraint != null && (node.height < this.m_bounds.get(distrConstraint.intValue()).lower.doubleValue() || node.height > this.m_bounds.get(distrConstraint.intValue()).upper.doubleValue())) {
            node.height = height;
            return;
        }
        scaleToFit(d, node.getLeft());
        scaleToFit(d, node.getRight());
        if (node.height < Math.max(node.getLeft().getHeight(), node.getRight().getHeight())) {
            node.height = 1.0000001d * Math.max(node.getLeft().getHeight(), node.getRight().getHeight());
        }
    }

    @Override // beast.core.StateNodeInitialiser
    public void getInitialisedStateNodes(List<StateNode> list) {
        list.add(this.m_initial.get());
    }

    public void simulateTree(Set<String> set, PopulationFunction populationFunction) {
        if (set.size() == 0) {
            return;
        }
        String str = "Failed to generate a random tree (probably a bug).";
        for (int i = 0; i < 1000; i++) {
            try {
                this.nextNodeNr = this.nrOfTaxa;
                LinkedHashSet linkedHashSet = new LinkedHashSet();
                int i2 = 0;
                for (String str2 : set) {
                    Node newNode = newNode();
                    newNode.setNr(i2);
                    newNode.setID(str2);
                    newNode.setHeight(0.0d);
                    linkedHashSet.add(newNode);
                    i2++;
                }
                if (this.m_initial.get() != null) {
                    processCandidateTraits(linkedHashSet, this.m_initial.get().m_traitList.get());
                } else {
                    processCandidateTraits(linkedHashSet, this.m_traitList.get());
                }
                TreeMap treeMap = new TreeMap();
                for (Node node : linkedHashSet) {
                    treeMap.put(node.getID(), node);
                }
                this.root = simulateCoalescent(this.lastMonophyletic, treeMap, linkedHashSet, populationFunction);
                return;
            } catch (ConstraintViolatedException e) {
                str = (("\nWARNING: Generating a random tree did not succeed. The most common reasons are:\n1. there are conflicting monophyletic constraints, for example if both (A,B) \nand (B,C) must be monophyletic no tree will be able to meet these constraints at the same \ntime. To fix this, carefully check all clade sets, especially the ones that are expected to \nbe nested clades.\n") + "2. clade heights are constrained by an upper and lower bound, but the population size \nis too large, so it is very unlikely a generated treed does not violate these constraints. To \nfix this you can try to reduce the population size of the population model.\n") + "Expect BEAST to crash if this is not fixed.\n";
                Log.err.println(str);
            }
        }
        throw new RuntimeException(str);
    }

    private void processCandidateTraits(Set<Node> set, List<TraitSet> list) {
        for (TraitSet traitSet : list) {
            for (Node node : set) {
                node.setMetaData(traitSet.getTraitName(), Double.valueOf(traitSet.getValue(node.getID())));
            }
        }
    }

    private Node simulateCoalescent(int i, Map<String, Node> map, Set<Node> set, PopulationFunction populationFunction) throws ConstraintViolatedException {
        ArrayList arrayList = new ArrayList();
        TreeSet treeSet = new TreeSet();
        Iterator<Integer> it = this.children[i].iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            Set<String> set2 = this.taxonSets.get(intValue);
            Iterator<String> it2 = set2.iterator();
            while (it2.hasNext()) {
                linkedHashSet.add(map.get(it2.next()));
            }
            arrayList.add(simulateCoalescent(intValue, map, linkedHashSet, populationFunction));
            treeSet.addAll(set2);
        }
        for (Node node : set) {
            if (!treeSet.contains(node.getID())) {
                arrayList.add(node);
            }
        }
        return simulateCoalescentWithMax(arrayList, populationFunction, i < this.m_bounds.size() ? this.m_bounds.get(i).upper.doubleValue() : Double.POSITIVE_INFINITY);
    }

    public Node simulateCoalescentWithMax(List<Node> list, PopulationFunction populationFunction, double d) throws ConstraintViolatedException {
        if (list.size() == 0) {
            throw new IllegalArgumentException("empty nodes set");
        }
        for (int i = 0; i < 1000; i++) {
            List<Node> simulateCoalescent = simulateCoalescent(list, populationFunction, 0.0d, d);
            if (simulateCoalescent.size() == 1) {
                return simulateCoalescent.get(0);
            }
        }
        if (!Double.isFinite(d)) {
            throw new RuntimeException("failed to merge trees after 1000 tries!");
        }
        double d2 = -1.0d;
        Iterator<Node> it = this.nodeList.iterator();
        while (it.hasNext()) {
            d2 = Math.max(d2, it.next().getHeight());
        }
        if (!$assertionsDisabled && d2 >= d) {
            throw new AssertionError();
        }
        double size = (d - d2) / (this.nodeList.size() + 1);
        while (this.nodeList.size() > 1) {
            int size2 = this.nodeList.size() - 1;
            Node remove = this.nodeList.remove(size2);
            Node node = this.nodeList.get(size2 - 1);
            Node newNode = newNode();
            int i2 = this.nextNodeNr;
            this.nextNodeNr = i2 + 1;
            newNode.setNr(i2);
            newNode.setHeight(d2 + size);
            newNode.setLeft(remove);
            remove.setParent(newNode);
            newNode.setRight(node);
            node.setParent(newNode);
            this.nodeList.set(size2 - 1, newNode);
        }
        if ($assertionsDisabled || this.nodeList.size() == 1) {
            return this.nodeList.get(0);
        }
        throw new AssertionError();
    }

    public List<Node> simulateCoalescent(List<Node> list, PopulationFunction populationFunction, double d, double d2) throws ConstraintViolatedException {
        if (list.size() == 1) {
            return list;
        }
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = list.get(i).getHeight();
        }
        int[] iArr = new int[list.size()];
        HeapSort.sort(dArr, iArr);
        this.nodeList.clear();
        this.activeNodeCount = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            this.nodeList.add(list.get(iArr[i2]));
        }
        setCurrentHeight(d);
        while (getActiveNodeCount() < 2) {
            d = getMinimumInactiveHeight();
            setCurrentHeight(d);
        }
        double simulatedInterval = d + PopulationFunction.Utils.getSimulatedInterval(populationFunction, getActiveNodeCount(), d);
        while (simulatedInterval < d2 && this.nodeList.size() > 1) {
            if (simulatedInterval >= getMinimumInactiveHeight()) {
                d = getMinimumInactiveHeight();
                setCurrentHeight(d);
            } else {
                d = coalesceTwoActiveNodes(d, simulatedInterval);
            }
            if (this.nodeList.size() > 1) {
                while (getActiveNodeCount() < 2) {
                    d = getMinimumInactiveHeight();
                    setCurrentHeight(d);
                }
                simulatedInterval = d + PopulationFunction.Utils.getSimulatedInterval(populationFunction, getActiveNodeCount(), d);
            }
        }
        return this.nodeList;
    }

    private double getMinimumInactiveHeight() {
        if (this.activeNodeCount < this.nodeList.size()) {
            return this.nodeList.get(this.activeNodeCount).getHeight();
        }
        return Double.POSITIVE_INFINITY;
    }

    private void setCurrentHeight(double d) {
        while (getMinimumInactiveHeight() <= d) {
            this.activeNodeCount++;
        }
    }

    private int getActiveNodeCount() {
        return this.activeNodeCount;
    }

    private double coalesceTwoActiveNodes(double d, double d2) throws ConstraintViolatedException {
        int i;
        int nextInt = Randomizer.nextInt(this.activeNodeCount);
        int i2 = nextInt;
        while (true) {
            i = i2;
            if (i != nextInt) {
                break;
            }
            i2 = Randomizer.nextInt(this.activeNodeCount);
        }
        Node node = this.nodeList.get(nextInt);
        Node node2 = this.nodeList.get(i);
        Node newNode = newNode();
        int i3 = this.nextNodeNr;
        this.nextNodeNr = i3 + 1;
        newNode.setNr(i3);
        newNode.setHeight(d2);
        newNode.setLeft(node);
        node.setParent(newNode);
        newNode.setRight(node2);
        node2.setParent(newNode);
        this.nodeList.remove(node);
        this.nodeList.remove(node2);
        this.activeNodeCount -= 2;
        this.nodeList.add(this.activeNodeCount, newNode);
        this.activeNodeCount++;
        Integer distrConstraint = getDistrConstraint(newNode);
        if (distrConstraint != null) {
            double max = Math.max(this.m_bounds.get(distrConstraint.intValue()).lower.doubleValue(), d);
            double doubleValue = this.m_bounds.get(distrConstraint.intValue()).upper.doubleValue();
            if (doubleValue < max) {
                throw new ConstraintViolatedException();
            }
            if (d2 < max || d2 > doubleValue) {
                d2 = doubleValue == Double.POSITIVE_INFINITY ? max + 0.1d : max + (Randomizer.nextDouble() * (doubleValue - max));
                newNode.setHeight(d2);
            }
        }
        if (getMinimumInactiveHeight() < d2) {
            throw new RuntimeException("This should never happen! Somehow the current active node is older than the next inactive node!\nOne possible solution you can try is to increase the population size of the population model.");
        }
        return d2;
    }

    private Integer getDistrConstraint(Node node) {
        for (int i = 0; i < this.distributions.size(); i++) {
            if (this.distributions.get(i) != null) {
                Set<String> set = this.taxonSets.get(i);
                if (traverse(node, set, set.size(), new int[1]) == this.nrOfTaxa + 127) {
                    return Integer.valueOf(i);
                }
            }
        }
        return null;
    }

    int traverse(Node node, Set<String> set, int i, int[] iArr) {
        if (node.isLeaf()) {
            iArr[0] = iArr[0] + 1;
            return set.contains(node.getID()) ? 1 : 0;
        }
        int traverse = traverse(node.getLeft(), set, i, iArr);
        int i2 = iArr[0];
        iArr[0] = 0;
        if (node.getRight() != null) {
            traverse += traverse(node.getRight(), set, i, iArr);
            iArr[0] = i2 + iArr[0];
        }
        if (traverse == this.nrOfTaxa + 127) {
            traverse++;
        }
        return traverse == i ? this.nrOfTaxa + 127 : traverse;
    }

    @Override // beast.evolution.tree.Tree
    public String[] getTaxaNames() {
        if (this.m_sTaxaNames == null) {
            List<String> taxaNames = this.taxaInput.get() != null ? this.taxaInput.get().getTaxaNames() : this.m_taxonset.get().asStringList();
            this.m_sTaxaNames = (String[]) taxaNames.toArray(new String[taxaNames.size()]);
        }
        return this.m_sTaxaNames;
    }

    static {
        $assertionsDisabled = !RandomTree.class.desiredAssertionStatus();
    }
}
