package beast.evolution.speciation;

import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Function;
import beast.core.Input;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.parameter.RealParameter;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.Taxon;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.alignment.distance.Distance;
import beast.evolution.alignment.distance.JukesCantorDistance;
import beast.evolution.speciation.CalibratedYuleModel;
import beast.evolution.tree.Node;
import beast.evolution.tree.RandomTree;
import beast.evolution.tree.Tree;
import beast.evolution.tree.coalescent.ConstantPopulation;
import beast.math.distributions.MRCAPrior;
import beast.util.ClusterTree;
import beast.util.OutputUtils;
import beast.util.XMLParser;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math.MathException;

@Description("Set a starting point for a *BEAST analysis from gene alignment data.")
/* loaded from: input_file:beast/evolution/speciation/StarBeastStartState.class */
public class StarBeastStartState extends Tree implements StateNodeInitialiser {
    public final Input<Method> initMethod = new Input<>("method", "Initialise either with a totally random state or a point estimate based on alignments data (default point-estimate)", Method.POINT, Method.values());
    public final Input<Tree> speciesTreeInput = new Input<>("speciesTree", "The species tree to initialize");
    public final Input<List<Tree>> genes = new Input<>("gene", "Gene trees to initialize", new ArrayList());
    public final Input<CalibratedYuleModel> calibratedYule = new Input<>("calibratedYule", "The species tree (with calibrations) to initialize", Input.Validate.XOR, this.speciesTreeInput);
    public final Input<RealParameter> popMean = new Input<>("popMean", "Population mean hyper prior to initialse");
    public final Input<RealParameter> birthRate = new Input<>("birthRate", "Tree prior birth rate to initialize");
    public final Input<SpeciesTreePrior> speciesTreePriorInput = new Input<>("speciesTreePrior", "Population size parameters to initialise");
    public final Input<Function> muInput = new Input<>("baseRate", "Main clock rate used to scale trees (default 1).");
    private boolean hasCalibrations;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:beast/evolution/speciation/StarBeastStartState$Method.class */
    enum Method {
        POINT("point-estimate"),
        ALL_RANDOM("random");

        private final String ename;

        Method(String str) {
            this.ename = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.ename;
        }
    }

    @Override // beast.evolution.tree.Tree, beast.core.BEASTInterface
    public void initAndValidate() {
        super.initAndValidate();
        this.hasCalibrations = this.calibratedYule.get() != null;
    }

    @Override // beast.core.StateNodeInitialiser
    public void initStateNodes() {
        Set<BEASTInterface> outputs = this.speciesTreeInput.get().getOutputs();
        ArrayList arrayList = new ArrayList();
        for (BEASTInterface bEASTInterface : outputs) {
            if (bEASTInterface instanceof MRCAPrior) {
                arrayList.add((MRCAPrior) bEASTInterface);
            }
        }
        if (this.hasCalibrations) {
            if (arrayList.size() > 0) {
                throw new IllegalArgumentException("Not implemented: mix of calibrated yule and MRCA priors: place all priors in the calibrated Yule");
            }
            try {
                initWithCalibrations();
                return;
            } catch (MathException e) {
                throw new IllegalArgumentException(e);
            }
        }
        if (arrayList.size() > 0) {
            initWithMRCACalibrations(arrayList);
            return;
        }
        switch (this.initMethod.get()) {
            case POINT:
                fullInit();
                return;
            case ALL_RANDOM:
                randomInit();
                return;
            default:
                return;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double[] firstMeetings(Tree tree, Map<String, Integer> map, int i) {
        Node[] listNodesPostOrder = tree.listNodesPostOrder(null, null);
        Set[] setArr = new Set[listNodesPostOrder.length];
        for (int i2 = 0; i2 < setArr.length; i2++) {
            setArr[i2] = new HashSet();
        }
        double[] dArr = new double[(i * (i - 1)) / 2];
        Arrays.fill(dArr, Double.MAX_VALUE);
        for (Node node : listNodesPostOrder) {
            if (node.isLeaf()) {
                setArr[node.getNr()].add(map.get(node.getID()));
            } else {
                if (!$assertionsDisabled && node.getChildCount() != 2) {
                    throw new AssertionError();
                }
                Set[] setArr2 = {setArr[node.getChild(0).getNr()], setArr[node.getChild(1).getNr()]};
                HashSet hashSet = new HashSet(setArr2[0]);
                hashSet.retainAll(setArr2[1]);
                setArr2[0].removeAll(hashSet);
                setArr2[1].removeAll(hashSet);
                for (Integer num : setArr2[0]) {
                    Iterator it = setArr2[1].iterator();
                    while (it.hasNext()) {
                        int dMindex = getDMindex(i, num.intValue(), ((Integer) it.next()).intValue());
                        dArr[dMindex] = Math.min(dArr[dMindex], node.getHeight());
                    }
                }
                hashSet.addAll(setArr2[0]);
                hashSet.addAll(setArr2[1]);
                setArr[node.getNr()] = hashSet;
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int getDMindex(int i, int i2, int i3) {
        int min = Math.min(i2, i3);
        return ((min * (((2 * i) - 1) - min)) / 2) + (Math.abs(i2 - i3) - 1);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void fullInit() {
        double d;
        double d2;
        Function function = this.muInput.get();
        final double arrayValue = function != null ? function.getArrayValue() : 1.0d;
        Tree tree = this.speciesTreeInput.get();
        TaxonSet taxonSet = tree.m_taxonset.get();
        List<String> asStringList = taxonSet.asStringList();
        final int size = asStringList.size();
        List<Tree> list = this.genes.get();
        double d3 = 0.0d;
        for (Tree tree2 : list) {
            new ClusterTree().initByName("initial", tree2, "clusterType", "upgma", "taxa", tree2.m_taxonset.get().alignmentInput.get());
            tree2.scale(1.0d / arrayValue);
            d3 = Math.max(d3, r0.getSiteCount());
        }
        Map<String, Integer> hashMap = new HashMap<>();
        List<Taxon> list2 = taxonSet.taxonsetInput.get();
        for (int i = 0; i < asStringList.size(); i++) {
            Iterator<Taxon> it = ((TaxonSet) list2.get(i)).taxonsetInput.get().iterator();
            while (it.hasNext()) {
                hashMap.put(it.next().getID(), Integer.valueOf(i));
            }
        }
        final double[] dArr = new double[(size * (size - 1)) / 2];
        double[] dArr2 = new double[list.size()];
        for (int i2 = 0; i2 < list.size(); i2++) {
            Tree tree3 = list.get(i2);
            double[] firstMeetings = firstMeetings(tree3, hashMap, size);
            dArr2[i2] = firstMeetings;
            int i3 = 0;
            while (i3 < firstMeetings.length) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + firstMeetings[i3];
                if (firstMeetings[i3] == Double.MAX_VALUE) {
                    String id = i3 < size - 1 ? tree.getExternalNodes().get(i3 + 1).getID() : "unknown taxon";
                    if (i3 == 0) {
                        boolean z = true;
                        for (int i5 = 1; z && i5 < size - 1; i5++) {
                            z = firstMeetings[i5] == Double.MAX_VALUE;
                        }
                        if (z) {
                            id = tree.getExternalNodes().get(0).getID();
                        }
                    }
                    throw new RuntimeException("Gene tree " + tree3.getID() + " has no lineages for species taxon " + id + OutputUtils.SPACE);
                }
                i3++;
            }
        }
        for (int i6 = 0; i6 < dArr.length; i6++) {
            double size2 = dArr[i6] / list.size();
            if (size2 == 0.0d) {
                d = 0.5d / d3;
                d2 = 1.0d / arrayValue;
            } else {
                d = size2;
                d2 = 2.0d;
            }
            dArr[i6] = d * d2;
        }
        new ClusterTree().initByName("initial", tree, "taxonset", taxonSet, "clusterType", "upgma", "distance", new Distance() { // from class: beast.evolution.speciation.StarBeastStartState.1
            @Override // beast.evolution.alignment.distance.Distance
            public double pairwiseDistance(int i7, int i8) {
                return dArr[StarBeastStartState.this.getDMindex(size, i7, i8)];
            }
        });
        Map<String, Integer> hashMap2 = new HashMap<>();
        for (int i7 = 0; i7 < asStringList.size(); i7++) {
            hashMap2.put(asStringList.get(i7), Integer.valueOf(i7));
        }
        final double[] firstMeetings2 = firstMeetings(tree, hashMap2, size);
        for (int i8 = 0; i8 < list.size(); i8++) {
            Object[] objArr = dArr2[i8];
            boolean z2 = true;
            int i9 = 0;
            while (true) {
                if (i9 >= firstMeetings2.length) {
                    break;
                }
                if (objArr[i9] <= firstMeetings2[i9]) {
                    z2 = false;
                    break;
                }
                i9++;
            }
            if (!z2) {
                Tree tree4 = list.get(i8);
                TaxonSet taxonSet2 = tree4.m_taxonset.get();
                Alignment alignment = taxonSet2.alignmentInput.get();
                List<String> taxaNames = alignment.getTaxaNames();
                int size3 = taxaNames.size();
                final HashMap hashMap3 = new HashMap();
                for (int i10 = 0; i10 < size3; i10++) {
                    hashMap3.put(Integer.valueOf(i10), hashMap.get(taxaNames.get(i10)));
                }
                final JukesCantorDistance jukesCantorDistance = new JukesCantorDistance();
                jukesCantorDistance.setPatterns(alignment);
                new ClusterTree().initByName("initial", tree4, "taxonset", taxonSet2, "clusterType", "upgma", "distance", new Distance() { // from class: beast.evolution.speciation.StarBeastStartState.2
                    @Override // beast.evolution.alignment.distance.Distance
                    public double pairwiseDistance(int i11, int i12) {
                        int intValue = ((Integer) hashMap3.get(Integer.valueOf(i11))).intValue();
                        int intValue2 = ((Integer) hashMap3.get(Integer.valueOf(i12))).intValue();
                        double pairwiseDistance = jukesCantorDistance.pairwiseDistance(i11, i12) / arrayValue;
                        if (intValue != intValue2) {
                            double d4 = 2.0d * firstMeetings2[StarBeastStartState.this.getDMindex(size, intValue, intValue2)];
                            if (pairwiseDistance <= d4) {
                                pairwiseDistance = d4 * 1.001d;
                            }
                        }
                        return pairwiseDistance;
                    }
                });
            }
        }
        RealParameter realParameter = this.birthRate.get();
        if (realParameter != null) {
            double height = tree.getRoot().getHeight();
            double d4 = 0.0d;
            for (int i11 = 2; i11 < size + 1; i11++) {
                d4 += 1.0d / i11;
            }
            realParameter.setValue(Double.valueOf((1.0d / height) * d4));
        }
        double d5 = 0.0d;
        for (Node node : tree.getNodesAsArray()) {
            if (!node.isRoot()) {
                d5 += node.getLength();
            }
        }
        double length = d5 / (2 * (r0.length - 1));
        RealParameter realParameter2 = this.popMean.get();
        if (realParameter2 != null) {
            realParameter2.setValue(Double.valueOf(length));
        }
        SpeciesTreePrior speciesTreePrior = this.speciesTreePriorInput.get();
        if (speciesTreePrior != null) {
            RealParameter realParameter3 = speciesTreePrior.popSizesBottomInput.get();
            if (realParameter3 != null) {
                for (int i12 = 0; i12 < realParameter3.getDimension(); i12++) {
                    realParameter3.setValue(i12, Double.valueOf(2.0d * length));
                }
            }
            RealParameter realParameter4 = speciesTreePrior.popSizesTopInput.get();
            if (realParameter4 != null) {
                for (int i13 = 0; i13 < realParameter4.getDimension(); i13++) {
                    realParameter4.setValue(i13, Double.valueOf(length));
                }
            }
        }
    }

    private void randomInitGeneTrees(double d) {
        Iterator<Tree> it = this.genes.get().iterator();
        while (it.hasNext()) {
            it.next().makeCaterpillar(d, d / r0.getInternalNodeCount(), true);
        }
    }

    private void randomInit() {
        RealParameter realParameter = this.birthRate.get();
        double arrayValue = realParameter != null ? realParameter.getArrayValue() : 1.0d;
        Tree tree = this.speciesTreeInput.get();
        double d = 0.0d;
        for (int i = 2; i <= tree.m_taxonset.get().asStringList().size(); i++) {
            d += 1.0d / i;
        }
        double d2 = (1.0d / arrayValue) * d;
        tree.scale(d2 / tree.getRoot().getHeight());
        randomInitGeneTrees(d2);
    }

    private void initWithCalibrations() throws MathException {
        CalibratedYuleModel calibratedYuleModel = this.calibratedYule.get();
        Tree tree = (Tree) calibratedYuleModel.treeInput.get();
        List<CalibrationPoint> list = calibratedYuleModel.calibrationsInput.get();
        CalibratedYuleModel calibratedYuleModel2 = new CalibratedYuleModel();
        calibratedYuleModel2.getOutputs().addAll(calibratedYuleModel.getOutputs());
        Iterator<CalibrationPoint> it = list.iterator();
        while (it.hasNext()) {
            calibratedYuleModel2.setInputValue("calibrations", it.next());
        }
        calibratedYuleModel2.setInputValue(XMLParser.TREE_ELEMENT, tree);
        calibratedYuleModel2.setInputValue("type", CalibratedYuleModel.Type.NONE);
        calibratedYuleModel2.initAndValidate();
        Tree compatibleInitialTree = calibratedYuleModel2.compatibleInitialTree();
        if (!$assertionsDisabled && tree.getLeafNodeCount() != compatibleInitialTree.getLeafNodeCount()) {
            throw new AssertionError();
        }
        tree.assignFromWithoutID(compatibleInitialTree);
        randomInitGeneTrees(tree.getRoot().getHeight());
        calibratedYuleModel.initAndValidate();
    }

    private void initWithMRCACalibrations(List<MRCAPrior> list) {
        Tree tree = this.speciesTreeInput.get();
        StateNode randomTree = new RandomTree();
        randomTree.setInputValue("taxonset", tree.getTaxonset());
        Iterator<MRCAPrior> it = list.iterator();
        while (it.hasNext()) {
            randomTree.setInputValue("constraint", it.next());
        }
        ConstantPopulation constantPopulation = new ConstantPopulation();
        constantPopulation.setInputValue("popSize", new RealParameter("1.0"));
        randomTree.setInputValue("populationModel", constantPopulation);
        randomTree.initAndValidate();
        tree.assignFromWithoutID(randomTree);
        randomInitGeneTrees(tree.getRoot().getHeight());
    }

    @Override // beast.core.StateNodeInitialiser
    public void getInitialisedStateNodes(List<StateNode> list) {
        if (this.hasCalibrations) {
            list.add((Tree) this.calibratedYule.get().treeInput.get());
        } else {
            list.add(this.speciesTreeInput.get());
        }
        Iterator<Tree> it = this.genes.get().iterator();
        while (it.hasNext()) {
            list.add((Tree) it.next());
        }
        StateNode stateNode = (RealParameter) this.popMean.get();
        if (stateNode != null) {
            list.add(stateNode);
        }
        StateNode stateNode2 = (RealParameter) this.birthRate.get();
        if (stateNode2 != null) {
            list.add(stateNode2);
        }
        SpeciesTreePrior speciesTreePrior = this.speciesTreePriorInput.get();
        if (speciesTreePrior != null) {
            StateNode stateNode3 = (RealParameter) speciesTreePrior.popSizesBottomInput.get();
            if (stateNode3 != null) {
                list.add(stateNode3);
            }
            StateNode stateNode4 = (RealParameter) speciesTreePrior.popSizesTopInput.get();
            if (stateNode4 != null) {
                list.add(stateNode4);
            }
        }
    }

    static {
        $assertionsDisabled = !StarBeastStartState.class.desiredAssertionStatus();
    }
}
