package beast.evolution.likelihood;

import beast.app.BeastMCMC;
import beast.app.beauti.Beauti;
import beast.app.util.Arguments;
import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.FilteredAlignment;
import beast.evolution.sitemodel.SiteModelInterface;
import beast.evolution.substitutionmodel.SubstitutionModel;
import beast.util.XMLParser;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;

@Description("Calculates the likelihood of sequence data on a beast.tree given a site and substitution model using a variant of the 'peeling algorithm'. For details, seeFelsenstein, Joseph (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. J Mol Evol 17 (6): 368-376.")
/* loaded from: input_file:beast/evolution/likelihood/ThreadedTreeLikelihood.class */
public class ThreadedTreeLikelihood extends GenericTreeLikelihood {
    private TreeLikelihood[] treelikelihood;
    private int threadCount;
    private double[] logPByThread;
    private int[] patternPoints;
    public final Input<Boolean> useAmbiguitiesInput = new Input<>("useAmbiguities", "flag to indicate leafs that sites containing ambiguous states should be handled instead of ignored (the default)", false);
    public final Input<Integer> maxNrOfThreadsInput = new Input<>("threads", "maximum number of threads to use, if less than 1 the number of threads in BeastMCMC is used (default -1)", -1);
    public final Input<String> proportionsInput = new Input<>("proportions", "specifies proportions of patterns used per thread as space delimited string. This is useful when using a mixture of BEAGLE devices that run at different speeds, e.g GPU and CPU. The string is duplicated if there are more threads than proportions specified. For example, '1 2' as well as '33 66' with 2 threads specifies that the first thread gets a third of the patterns and the second two thirds. With 3 threads, it is interpreted as '1 2 1' = 25%, 50%, 25% and with 7 threads it is '1 2 1 2 1 2 1' = 10% 20% 10% 20% 10% 20% 10%. If not specified, all threads get the same proportion of patterns.");
    public final Input<Scaling> scalingInput = new Input<>("scaling", "type of scaling to use, one of " + Arrays.toString(Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values());
    private final Input<List<TreeLikelihood>> likelihoodsInput = new Input<>("*", "", new ArrayList());
    private ExecutorService pool = null;
    private final List<Callable<Double>> likelihoodCallers = new ArrayList();

    /* loaded from: input_file:beast/evolution/likelihood/ThreadedTreeLikelihood$Scaling.class */
    enum Scaling {
        none,
        always,
        _default
    }

    /* loaded from: input_file:beast/evolution/likelihood/ThreadedTreeLikelihood$TreeLikelihoodCaller.class */
    class TreeLikelihoodCaller implements Callable<Double> {
        private final TreeLikelihood likelihood;
        private final int threadNr;

        public TreeLikelihoodCaller(TreeLikelihood treeLikelihood, int i) {
            this.likelihood = treeLikelihood;
            this.threadNr = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            try {
                ThreadedTreeLikelihood.this.logPByThread[this.threadNr] = this.likelihood.calculateLogP();
            } catch (Exception e) {
                System.err.println("Something went wrong in thread " + this.threadNr);
                e.printStackTrace();
                System.exit(0);
            }
            return Double.valueOf(ThreadedTreeLikelihood.this.logPByThread[this.threadNr]);
        }
    }

    @Override // beast.core.BEASTInterface
    public List<Input<?>> listInputs() {
        List<Input<?>> listInputs = super.listInputs();
        if (!Beauti.isInBeauti() && System.getProperty("beast.is.junit.testing") == null) {
            listInputs.add(this.likelihoodsInput);
        }
        return listInputs;
    }

    @Override // beast.core.Distribution, beast.core.BEASTInterface
    public void initAndValidate() {
        this.threadCount = BeastMCMC.m_nThreads;
        if (this.maxNrOfThreadsInput.get().intValue() > 0) {
            this.threadCount = Math.min(this.maxNrOfThreadsInput.get().intValue(), BeastMCMC.m_nThreads);
        }
        String property = System.getProperty("beast.instance.count");
        if (property != null && property.length() > 0) {
            this.threadCount = Integer.parseInt(property);
        }
        this.logPByThread = new double[this.threadCount];
        if (this.dataInput.get().getTaxonCount() != this.treeInput.get().getLeafNodeCount()) {
            throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
        }
        this.treelikelihood = new TreeLikelihood[this.threadCount];
        if (this.dataInput.get().isAscertained) {
            Log.warning.println("Note, can only use single thread per alignment because the alignment is ascertained");
            this.threadCount = 1;
        }
        if (this.threadCount <= 1) {
            this.treelikelihood[0] = new TreeLikelihood();
            this.treelikelihood[0].setID(getID() + "0");
            this.treelikelihood[0].initByName(XMLParser.DATA_ELEMENT, this.dataInput.get(), XMLParser.TREE_ELEMENT, this.treeInput.get(), "siteModel", this.siteModelInput.get(), "branchRateModel", this.branchRateModelInput.get(), "useAmbiguities", this.useAmbiguitiesInput.get(), "scaling", this.scalingInput.get() + "");
            this.treelikelihood[0].getOutputs().add(this);
            this.likelihoodsInput.get().add(this.treelikelihood[0]);
            return;
        }
        this.pool = Executors.newFixedThreadPool(this.threadCount);
        calcPatternPoints(this.dataInput.get().getSiteCount());
        for (int i = 0; i < this.threadCount; i++) {
            Alignment alignment = this.dataInput.get();
            String str = (this.patternPoints[i] + 1) + Arguments.ARGUMENT_CHARACTER + this.patternPoints[i + 1];
            if (alignment.isAscertained) {
                str = str + alignment.excludefromInput.get() + Arguments.ARGUMENT_CHARACTER + alignment.excludetoInput.get() + "," + str;
            }
            this.treelikelihood[i] = new TreeLikelihood();
            this.treelikelihood[i].setID(getID() + i);
            this.treelikelihood[i].getOutputs().add(this);
            this.likelihoodsInput.get().add(this.treelikelihood[i]);
            FilteredAlignment filteredAlignment = new FilteredAlignment();
            if (i == 0 && (this.dataInput.get() instanceof FilteredAlignment) && ((FilteredAlignment) this.dataInput.get()).constantSiteWeightsInput.get() != null) {
                filteredAlignment.initByName(XMLParser.DATA_ELEMENT, this.dataInput.get(), "filter", str, "constantSiteWeights", ((FilteredAlignment) this.dataInput.get()).constantSiteWeightsInput.get());
            } else {
                filteredAlignment.initByName(XMLParser.DATA_ELEMENT, this.dataInput.get(), "filter", str);
            }
            this.treelikelihood[i].initByName(XMLParser.DATA_ELEMENT, filteredAlignment, XMLParser.TREE_ELEMENT, this.treeInput.get(), "siteModel", duplicate((BEASTInterface) this.siteModelInput.get(), i), "branchRateModel", duplicate(this.branchRateModelInput.get(), i), "useAmbiguities", this.useAmbiguitiesInput.get(), "scaling", this.scalingInput.get() + "");
            this.likelihoodCallers.add(new TreeLikelihoodCaller(this.treelikelihood[i], i));
        }
    }

    private Object duplicate(BEASTInterface bEASTInterface, int i) {
        if (bEASTInterface == null) {
            return null;
        }
        try {
            BEASTInterface bEASTInterface2 = (BEASTInterface) bEASTInterface.getClass().newInstance();
            bEASTInterface2.setID(bEASTInterface.getID() + "_" + i);
            for (Input<?> input : bEASTInterface.listInputs()) {
                if (input.get() != null) {
                    if (input.get() instanceof List) {
                        for (Object obj : (List) input.get()) {
                            if (obj instanceof BEASTInterface) {
                                bEASTInterface2.setInputValue(input.getName(), obj);
                            }
                        }
                    } else if (input.get() instanceof SubstitutionModel) {
                        bEASTInterface2.setInputValue(input.getName(), (BEASTInterface) duplicate((BEASTInterface) input.get(), i));
                    } else {
                        bEASTInterface2.setInputValue(input.getName(), input.get());
                    }
                }
            }
            bEASTInterface2.initAndValidate();
            return bEASTInterface2;
        } catch (IllegalAccessException | InstantiationException e) {
            e.printStackTrace();
            throw new RuntimeException("Programmer error: every object in the model should have a default constructor that is publicly accessible: " + bEASTInterface.getClass().getName());
        }
    }

    private void calcPatternPoints(int i) {
        this.patternPoints = new int[this.threadCount + 1];
        if (this.proportionsInput.get() == null) {
            int i2 = i / this.threadCount;
            for (int i3 = 0; i3 < this.threadCount - 1; i3++) {
                this.patternPoints[i3 + 1] = i2 * (i3 + 1);
            }
            this.patternPoints[this.threadCount] = i;
            return;
        }
        String[] split = this.proportionsInput.get().split("\\s+");
        double[] dArr = new double[this.threadCount];
        for (int i4 = 0; i4 < this.threadCount; i4++) {
            dArr[i4] = Double.parseDouble(split[i4 % split.length]);
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i5 = 0; i5 < this.threadCount; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d;
        }
        for (int i7 = 1; i7 < this.threadCount; i7++) {
            int i8 = i7;
            dArr[i8] = dArr[i8] + dArr[i7 - 1];
        }
        for (int i9 = 0; i9 < this.threadCount; i9++) {
            this.patternPoints[i9 + 1] = (int) ((dArr[i9] * i) + 0.5d);
        }
    }

    @Override // beast.evolution.likelihood.GenericTreeLikelihood, beast.core.Distribution
    public void sample(State state, Random random) {
        throw new UnsupportedOperationException("Can't sample a fixed alignment!");
    }

    @Override // beast.core.Distribution
    public double calculateLogP() {
        this.logP = calculateLogPByBeagle();
        return this.logP;
    }

    private double calculateLogPByBeagle() {
        try {
            if (this.threadCount > 1) {
                this.pool.invokeAll(this.likelihoodCallers);
                this.logP = 0.0d;
                for (double d : this.logPByThread) {
                    this.logP += d;
                }
            } else {
                this.logP = this.treelikelihood[0].calculateLogP();
            }
        } catch (InterruptedException | RejectedExecutionException e) {
            e.printStackTrace();
            System.exit(0);
        }
        return this.logP;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // beast.core.CalculationNode
    public boolean requiresRecalculation() {
        boolean z = false;
        for (TreeLikelihood treeLikelihood : this.treelikelihood) {
            z |= treeLikelihood.requiresRecalculation();
        }
        return z;
    }

    @Override // beast.core.Distribution, beast.core.CalculationNode
    public void store() {
        super.store();
    }

    @Override // beast.core.Distribution, beast.core.CalculationNode
    public void restore() {
        super.restore();
    }

    @Override // beast.evolution.likelihood.GenericTreeLikelihood, beast.core.Distribution
    public List<String> getArguments() {
        return Collections.singletonList(this.dataInput.get().getID());
    }

    @Override // beast.evolution.likelihood.GenericTreeLikelihood, beast.core.Distribution
    public List<String> getConditions() {
        return ((SiteModelInterface.Base) this.siteModelInput.get()).getConditions();
    }
}
