package beast.evolution.speciation;

import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.parameter.RealParameter;
import beast.core.util.Log;
import beast.evolution.alignment.Taxon;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.speciation.SpeciesTreePrior;
import beast.evolution.tree.Node;
import beast.evolution.tree.TreeDistribution;
import beast.evolution.tree.TreeInterface;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;

@Description("Calculates probability of gene tree conditioned on a species tree (multi-species coalescent)")
/* loaded from: input_file:beast/evolution/speciation/GeneTreeForSpeciesTreeDistribution.class */
public class GeneTreeForSpeciesTreeDistribution extends TreeDistribution {
    public final Input<TreeInterface> speciesTreeInput = new Input<>("speciesTree", "species tree containing the associated gene tree", Input.Validate.REQUIRED);
    public final Input<Double> ploidyInput = new Input<>("ploidy", "ploidy (copy number) for this gene, typically a whole number or half (default 2 for autosomal_nuclear)", Double.valueOf(2.0d));
    public final Input<SpeciesTreePrior> speciesTreePriorInput = new Input<>("speciesTreePrior", "defines population function and its parameters", Input.Validate.REQUIRED);
    public final Input<TreeTopFinder> treeTopFinderInput = new Input<>("treetop", "calculates height of species tree, required only for linear *beast analysis");
    private PriorityQueue<Double>[] intervalsInput;
    private int[] nrOfLineages;
    protected int[] nrOfLineageToSpeciesMap;
    SpeciesTreePrior.TreePopSizeFunction isConstantPopFunction;
    RealParameter popSizesBottom;
    RealParameter popSizesTop;
    private double ploidy;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GeneTreeForSpeciesTreeDistribution() {
        this.treeInput.setRule(Input.Validate.REQUIRED);
    }

    @Override // beast.core.Distribution, beast.core.BEASTInterface
    public void initAndValidate() {
        this.ploidy = this.ploidyInput.get().doubleValue();
        Node[] nodesAsArray = this.treeInput.get().getNodesAsArray();
        int leafNodeCount = this.treeInput.get().getLeafNodeCount();
        Node[] nodesAsArray2 = this.speciesTreeInput.get().getNodesAsArray();
        int nodeCount = this.speciesTreeInput.get().getNodeCount();
        if (nodeCount > 1 || !nodesAsArray2[0].getID().equals("Beauti2DummyTaxonSet")) {
            this.intervalsInput = new PriorityQueue[nodeCount];
            for (int i = 0; i < nodeCount; i++) {
                this.intervalsInput[i] = new PriorityQueue<>();
            }
            for (int i2 = 0; i2 < leafNodeCount; i2++) {
                if (nodesAsArray[i2].getHeight() != 0.0d) {
                    throw new IllegalArgumentException("Cannot deal with taxon " + nodesAsArray[i2].getID() + ", which has non-zero height + " + nodesAsArray[i2].getHeight());
                }
            }
            this.nrOfLineageToSpeciesMap = new int[leafNodeCount];
            Arrays.fill(this.nrOfLineageToSpeciesMap, -1);
            for (int i3 = 0; i3 < leafNodeCount; i3++) {
                String setID = getSetID(nodesAsArray[i3].getID());
                if (setID == null) {
                    throw new IllegalArgumentException("Cannot find species for lineage " + nodesAsArray[i3].getID());
                }
                int i4 = 0;
                while (true) {
                    if (i4 >= nodeCount) {
                        break;
                    }
                    if (setID.equals(nodesAsArray2[i4].getID())) {
                        this.nrOfLineageToSpeciesMap[i3] = i4;
                        break;
                    }
                    i4++;
                }
                if (this.nrOfLineageToSpeciesMap[i3] < 0) {
                    throw new IllegalArgumentException("Cannot find species with name " + setID + " in species tree");
                }
            }
            this.nrOfLineages = new int[nodeCount];
            SpeciesTreePrior speciesTreePrior = this.speciesTreePriorInput.get();
            this.isConstantPopFunction = speciesTreePrior.popFunctionInput.get();
            this.popSizesBottom = speciesTreePrior.popSizesBottomInput.get();
            this.popSizesTop = speciesTreePrior.popSizesTopInput.get();
            if (!$assertionsDisabled && this.isConstantPopFunction == SpeciesTreePrior.TreePopSizeFunction.linear && this.treeTopFinderInput.get() == null) {
                throw new AssertionError();
            }
        }
    }

    String getSetID(String str) {
        for (Taxon taxon : this.speciesTreePriorInput.get().taxonSetInput.get().taxonsetInput.get()) {
            Iterator<Taxon> it = ((TaxonSet) taxon).taxonsetInput.get().iterator();
            while (it.hasNext()) {
                if (it.next().getID().equals(str)) {
                    return taxon.getID();
                }
            }
        }
        return null;
    }

    @Override // beast.core.Distribution
    public double calculateLogP() {
        this.logP = 0.0d;
        for (PriorityQueue<Double> priorityQueue : this.intervalsInput) {
            priorityQueue.clear();
        }
        Arrays.fill(this.nrOfLineages, 0);
        TreeInterface treeInterface = this.speciesTreeInput.get();
        traverseLineageTree(treeInterface.getNodesAsArray(), this.treeInput.get().getRoot());
        if (this.logP == 0.0d) {
            traverseSpeciesTree(treeInterface.getRoot());
        }
        return this.logP;
    }

    private void traverseSpeciesTree(Node node) {
        if (!node.isLeaf()) {
            traverseSpeciesTree(node.getLeft());
            traverseSpeciesTree(node.getRight());
        }
        int nr = node.getNr();
        int size = this.intervalsInput[nr].size();
        double[] dArr = new double[size + 2];
        dArr[0] = node.getHeight();
        for (int i = 1; i <= size; i++) {
            dArr[i] = this.intervalsInput[nr].poll().doubleValue();
        }
        if (!node.isRoot()) {
            dArr[size + 1] = node.getParent().getHeight();
        } else if (this.isConstantPopFunction == SpeciesTreePrior.TreePopSizeFunction.linear) {
            dArr[size + 1] = this.treeTopFinderInput.get().getHighestTreeHeight();
        } else {
            dArr[size + 1] = Math.max(node.getHeight(), this.treeInput.get().getRoot().getHeight());
        }
        for (int i2 = 0; i2 <= size; i2++) {
            if (dArr[i2] > dArr[i2 + 1]) {
                Log.warning.println("invalid times");
                calculateLogP();
            }
        }
        int i3 = this.nrOfLineages[nr];
        switch (this.isConstantPopFunction) {
            case constant:
                calcConstantPopSizeContribution(i3, this.popSizesBottom.getValue(nr).doubleValue(), dArr, size);
                return;
            case linear:
                this.logP += calcLinearPopSizeContributionJH(i3, nr, dArr, size, node);
                return;
            case linear_with_constant_root:
                if (node.isRoot()) {
                    calcConstantPopSizeContribution(i3, getTopPopSize(node.getLeft().getNr()) + getTopPopSize(node.getRight().getNr()), dArr, size);
                    return;
                } else {
                    this.logP += calcLinearPopSizeContribution(i3, nr, dArr, size, node);
                    return;
                }
            default:
                return;
        }
    }

    private void calcConstantPopSizeContribution(int i, double d, double[] dArr, int i2) {
        double d2 = d * this.ploidy;
        this.logP += (-i2) * Math.log(d2);
        for (int i3 = 0; i3 <= i2; i3++) {
            this.logP += ((-(((i - i3) * ((i - i3) - 1.0d)) / 2.0d)) * (dArr[i3 + 1] - dArr[i3])) / d2;
        }
    }

    private double calcLinearPopSizeContribution(int i, int i2, double[] dArr, int i3, Node node) {
        double d;
        double log;
        double d2;
        double d3 = 0.0d;
        double doubleValue = node.isLeaf() ? this.popSizesBottom.getValue(i2).doubleValue() * this.ploidy : (getTopPopSize(node.getLeft().getNr()) + getTopPopSize(node.getRight().getNr())) * this.ploidy;
        double topPopSize = getTopPopSize(i2) * this.ploidy;
        double d4 = (topPopSize - doubleValue) / (dArr[i3 + 1] - dArr[0]);
        double d5 = doubleValue;
        for (int i4 = 0; i4 < i3; i4++) {
            d3 += -Math.log((d4 * (dArr[i4 + 1] - dArr[0])) + d5);
        }
        for (int i5 = 0; i5 <= i3; i5++) {
            if (Math.abs(topPopSize - doubleValue) < 1.0E-10d) {
                double d6 = (d4 * (dArr[i5 + 1] - dArr[0])) + d5;
                d = d3;
                log = (-(((i - i5) * ((i - i5) - 1.0d)) / 2.0d)) * (dArr[i5 + 1] - dArr[i5]);
                d2 = d6;
            } else {
                double d7 = ((d4 * (dArr[i5 + 1] - dArr[0])) + d5) / ((d4 * (dArr[i5] - dArr[0])) + d5);
                d = d3;
                log = (-(((i - i5) * ((i - i5) - 1.0d)) / 2.0d)) * Math.log(d7);
                d2 = d4;
            }
            d3 = d + (log / d2);
        }
        return d3;
    }

    private double calcLinearPopSizeContributionJH(int i, int i2, double[] dArr, int i3, Node node) {
        double d = 0.0d;
        double doubleValue = (node.isLeaf() ? this.popSizesBottom.getValue(i2).doubleValue() : getTopPopSize(node.getLeft().getNr()) + getTopPopSize(node.getRight().getNr())) * this.ploidy;
        double topPopSize = (getTopPopSize(i2) * this.ploidy) - doubleValue;
        double d2 = dArr[0];
        double d3 = topPopSize / (dArr[i3 + 1] - d2);
        if (Math.abs(topPopSize) < 1.0E-10d) {
            for (int i4 = 0; i4 <= i3; i4++) {
                double d4 = dArr[i4 + 1];
                double d5 = (d3 * (d4 - d2)) + doubleValue;
                if (i4 < i3) {
                    d += -Math.log(d5);
                }
                int i5 = i - i4;
                d -= (((i5 * (i5 - 1.0d)) / 2.0d) * (d4 - dArr[i4])) / d5;
            }
        } else {
            double d6 = doubleValue - (d3 * d2);
            for (int i6 = 0; i6 <= i3; i6++) {
                double d7 = (d3 * dArr[i6 + 1]) + d6;
                if (i6 < i3) {
                    d += -Math.log(d7);
                }
                double d8 = d7 / ((d3 * dArr[i6]) + d6);
                int i7 = i - i6;
                d += ((-((i7 * (i7 - 1.0d)) / 2.0d)) * Math.log(d8)) / d3;
            }
        }
        return d;
    }

    private int traverseLineageTree(Node[] nodeArr, Node node) {
        if (node.isLeaf()) {
            int i = this.nrOfLineageToSpeciesMap[node.getNr()];
            int[] iArr = this.nrOfLineages;
            iArr[i] = iArr[i] + 1;
            return i;
        }
        int traverseLineageTree = traverseLineageTree(nodeArr, node.getLeft());
        int traverseLineageTree2 = traverseLineageTree(nodeArr, node.getRight());
        double height = node.getHeight();
        while (!nodeArr[traverseLineageTree].isRoot() && height > nodeArr[traverseLineageTree].getParent().getHeight()) {
            traverseLineageTree = nodeArr[traverseLineageTree].getParent().getNr();
            int[] iArr2 = this.nrOfLineages;
            iArr2[traverseLineageTree] = iArr2[traverseLineageTree] + 1;
        }
        while (!nodeArr[traverseLineageTree2].isRoot() && height > nodeArr[traverseLineageTree2].getParent().getHeight()) {
            traverseLineageTree2 = nodeArr[traverseLineageTree2].getParent().getNr();
            int[] iArr3 = this.nrOfLineages;
            iArr3[traverseLineageTree2] = iArr3[traverseLineageTree2] + 1;
        }
        if (traverseLineageTree != traverseLineageTree2) {
            this.logP = Double.NEGATIVE_INFINITY;
        }
        this.intervalsInput[traverseLineageTree2].add(Double.valueOf(height));
        return traverseLineageTree2;
    }

    private double getTopPopSize(int i) {
        return i < this.popSizesTop.getDimension() ? this.popSizesTop.getArrayValue(i) : this.popSizesTop.getArrayValue(this.speciesTreeInput.get().getRoot().getNr());
    }

    @Override // beast.evolution.tree.TreeDistribution, beast.core.CalculationNode
    public boolean requiresRecalculation() {
        return true;
    }

    @Override // beast.evolution.tree.TreeDistribution, beast.core.Distribution
    public List<String> getArguments() {
        return null;
    }

    @Override // beast.evolution.tree.TreeDistribution, beast.core.Distribution
    public List<String> getConditions() {
        return null;
    }

    @Override // beast.evolution.tree.TreeDistribution, beast.core.Distribution
    public void sample(State state, Random random) {
    }

    static {
        $assertionsDisabled = !GeneTreeForSpeciesTreeDistribution.class.desiredAssertionStatus();
    }
}
