package dp;

import java.io.File;
import java.io.FileInputStream;
import java.util.Iterator;
import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionFactory;
import org.biojava.bio.dist.UniformDistribution;
import org.biojava.bio.dp.BaumWelchTrainer;
import org.biojava.bio.dp.DP;
import org.biojava.bio.dp.DPFactory;
import org.biojava.bio.dp.EmissionState;
import org.biojava.bio.dp.MagicalState;
import org.biojava.bio.dp.MarkovModel;
import org.biojava.bio.dp.ProfileHMM;
import org.biojava.bio.dp.ScoreType;
import org.biojava.bio.dp.SimpleModelTrainer;
import org.biojava.bio.dp.State;
import org.biojava.bio.dp.StatePath;
import org.biojava.bio.dp.StoppingCriteria;
import org.biojava.bio.dp.TrainingAlgorithm;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.SequenceIterator;
import org.biojava.bio.seq.db.HashSequenceDB;
import org.biojava.bio.seq.db.IDMaker;
import org.biojava.bio.seq.db.SequenceDB;
import org.biojava.bio.seq.io.FastaDescriptionLineParser;
import org.biojava.bio.seq.io.FastaFormat;
import org.biojava.bio.seq.io.SimpleSequenceBuilder;
import org.biojava.bio.seq.io.StreamReader;
import org.biojava.bio.seq.io.SymbolTokenization;
import org.biojava.bio.symbol.Alphabet;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.Symbol;
import org.biojava.bio.symbol.SymbolList;

/* loaded from: input_file:biojava-live_1.6/demos-live.jar:dp/SearchProfile.class */
public class SearchProfile {
    public static Distribution nullModel;

    public static void main(String[] strArr) {
        try {
            File file = new File(strArr[0]);
            FiniteAlphabet alphabet = ProteinTools.getAlphabet();
            nullModel = new UniformDistribution(alphabet);
            System.out.println("Loading sequences");
            SequenceDB readSequenceDB = readSequenceDB(file, alphabet);
            System.out.println("Creating profile HMM");
            ProfileHMM createProfile = createProfile(readSequenceDB, alphabet);
            System.out.println("make dp object");
            DP createDP = DPFactory.DEFAULT.createDP(createProfile);
            dumpDP(createDP);
            Sequence[] sequenceArr = {readSequenceDB.sequenceIterator().nextSequence()};
            System.out.println("Viterbi: " + createDP.viterbi(sequenceArr, ScoreType.PROBABILITY).getScore());
            System.out.println("Forward: " + createDP.forward(sequenceArr, ScoreType.PROBABILITY));
            System.out.println("Backward: " + createDP.backward(sequenceArr, ScoreType.PROBABILITY));
            System.out.println("Training whole profile");
            new BaumWelchTrainer(createDP).train(readSequenceDB, 5.0d, new StoppingCriteria() { // from class: dp.SearchProfile.1
                @Override // org.biojava.bio.dp.StoppingCriteria
                public boolean isTrainingComplete(TrainingAlgorithm trainingAlgorithm) {
                    System.out.println("Cycle " + trainingAlgorithm.getCycle() + " completed");
                    System.out.println("Score: " + trainingAlgorithm.getCurrentScore());
                    return trainingAlgorithm.getCycle() >= 5;
                }
            });
            System.out.println("Alignining sequences to the model");
            SequenceIterator sequenceIterator = readSequenceDB.sequenceIterator();
            while (sequenceIterator.hasNext()) {
                Sequence nextSequence = sequenceIterator.nextSequence();
                SymbolList[] symbolListArr = {nextSequence};
                StatePath viterbi = createDP.viterbi(symbolListArr, ScoreType.PROBABILITY);
                System.out.println(nextSequence.getName() + " viterbi: " + viterbi.getScore() + ", forwards: " + createDP.forward(symbolListArr, ScoreType.PROBABILITY) + ", backwards: " + createDP.backward(symbolListArr, ScoreType.PROBABILITY));
                SymbolTokenization tokenization = ProteinTools.getAlphabet().getTokenization("token");
                for (int i = 0; i <= viterbi.length() / 60; i++) {
                    for (int i2 = i * 60; i2 < Math.min((i + 1) * 60, viterbi.length()); i2++) {
                        System.out.print(tokenization.tokenizeSymbol(viterbi.symbolAt(StatePath.SEQUENCE, i2 + 1)));
                    }
                    System.out.print("\n");
                    for (int i3 = i * 60; i3 < Math.min((i + 1) * 60, viterbi.length()); i3++) {
                        System.out.print(viterbi.symbolAt(StatePath.STATES, i3 + 1).getName().charAt(0));
                    }
                    System.out.print("\n");
                    System.out.print("\n");
                }
            }
        } catch (Throwable th) {
            th.printStackTrace();
        }
    }

    private static ProfileHMM createProfile(SequenceDB sequenceDB, Alphabet alphabet) throws Exception {
        double d = 0.0d;
        SequenceIterator sequenceIterator = sequenceDB.sequenceIterator();
        while (sequenceIterator.hasNext()) {
            d += Math.log(sequenceIterator.nextSequence().length());
        }
        int exp = (int) Math.exp(d / sequenceDB.ids().size());
        System.out.println("Estimating alignment as having length " + exp);
        ProfileHMM profileHMM = new ProfileHMM(alphabet, exp, DistributionFactory.DEFAULT, DistributionFactory.DEFAULT);
        randomize(profileHMM);
        return profileHMM;
    }

    public static SequenceDB readSequenceDB(File file, Alphabet alphabet) throws Exception {
        HashSequenceDB hashSequenceDB = new HashSequenceDB(IDMaker.byName);
        FastaDescriptionLineParser.Factory factory = new FastaDescriptionLineParser.Factory(SimpleSequenceBuilder.FACTORY);
        StreamReader streamReader = new StreamReader(new FileInputStream(file), new FastaFormat(), alphabet.getTokenization("token"), factory);
        while (streamReader.hasNext()) {
            hashSequenceDB.addSequence(streamReader.nextSequence());
        }
        return hashSequenceDB;
    }

    private static void randomize(MarkovModel markovModel) throws Exception {
        SimpleModelTrainer simpleModelTrainer = new SimpleModelTrainer();
        simpleModelTrainer.registerModel(markovModel);
        simpleModelTrainer.setNullModelWeight(5.0d);
        for (State state : markovModel.stateAlphabet()) {
            if ((state instanceof EmissionState) && !(state instanceof MagicalState)) {
                EmissionState emissionState = (EmissionState) state;
                Iterator it = ((FiniteAlphabet) emissionState.getDistribution().getAlphabet()).iterator();
                while (it.hasNext()) {
                    simpleModelTrainer.addCount(emissionState.getDistribution(), (Symbol) it.next(), Math.random());
                }
            }
            Distribution weights = markovModel.getWeights(state);
            Iterator it2 = markovModel.transitionsFrom(state).iterator();
            while (it2.hasNext()) {
                simpleModelTrainer.addCount(weights, (State) it2.next(), Math.random());
            }
        }
        simpleModelTrainer.train();
        simpleModelTrainer.clearCounts();
    }

    private static void dumpDP(DP dp2) {
        State[] states = dp2.getStates();
        System.out.print("states: ");
        for (State state : states) {
            System.out.print(" " + state.getName());
        }
        System.out.println("\n");
        int[][] forwardTransitions = dp2.getForwardTransitions();
        double[][] forwardTransitionScores = dp2.getForwardTransitionScores(ScoreType.ODDS);
        for (int i = 0; i < states.length; i++) {
            System.out.print("Transitions from " + i + ": ");
            for (int i2 = 0; i2 < forwardTransitions[i].length; i2++) {
                System.out.print(" " + forwardTransitions[i][i2] + "(" + forwardTransitionScores[i][i2] + ")");
            }
        }
    }

    static void Print(SymbolList symbolList) {
        Iterator it = symbolList.iterator();
        while (it.hasNext()) {
            System.out.print(((Symbol) it.next()).getName() + " ");
        }
        System.out.println();
    }
}
