package marytts.dnn.modules;

import marytts.MaryException;
import marytts.config.MaryConfiguration;
import marytts.data.Sequence;
import marytts.data.Utterance;
import marytts.data.item.phonology.Phone;
import marytts.data.item.phonology.Phoneme;
import marytts.dnn.DNNPredictor;
import marytts.dnn.FeatureNormaliser;
import marytts.dnn.normaliser.QuinphoneNormaliser;
import marytts.dnn.normaliser.QuinphoneWithDictNormaliser;
import marytts.exceptions.MaryConfigurationException;
import marytts.io.MaryIOException;
import marytts.modules.MaryModule;
import org.tensorflow.Tensor;

/* loaded from: input_file:marytts/dnn/modules/DNNDurationPrediction.class */
public class DNNDurationPrediction extends MaryModule {
    private DNNPredictor dnn_pred;
    private FeatureNormaliser normaliser;
    private String dict_filename;
    static final /* synthetic */ boolean $assertionsDisabled;

    public DNNDurationPrediction() throws Exception {
        super("duration");
        this.dict_filename = null;
    }

    public String getDictFilename() {
        return this.dict_filename;
    }

    public FeatureNormaliser getNormaliser() {
        return this.normaliser;
    }

    public void setDictFilename(String str) {
        this.dict_filename = str;
    }

    public void setNormaliser(String str) throws MaryIOException {
        try {
            if (str.equals("QuinphoneNormaliser")) {
                this.normaliser = new QuinphoneNormaliser();
            } else {
                if (!str.equals("QuinphoneWithDictNormaliser")) {
                    throw new MaryIOException("Unknown normaliser: " + str);
                }
                if (getDictFilename() == null) {
                    throw new MaryIOException("QuinphoneWithDictNormaliser needs a dict filename");
                }
                this.normaliser = new QuinphoneWithDictNormaliser(getDictFilename());
            }
        } catch (Exception e) {
            throw new MaryIOException("Cannot set normaliser", e);
        }
    }

    public void setPredictorModel(String str) {
        this.dnn_pred = new DNNPredictor(str);
    }

    public void checkStartup() throws MaryConfigurationException {
    }

    public void checkInput(Utterance utterance) throws MaryException {
        if (!utterance.hasSequence("PHONE")) {
            throw new MaryException("Phone sequence is missing", (Throwable) null);
        }
        if (!utterance.hasSequence("FEATURES")) {
            throw new MaryException("Feature sequence is missing", (Throwable) null);
        }
        if (!$assertionsDisabled && utterance.getSequence("PHONE").size() != utterance.getSequence("FEATURES").size()) {
            throw new AssertionError();
        }
    }

    public Utterance process(Utterance utterance, MaryConfiguration maryConfiguration) throws MaryException {
        try {
            maryConfiguration.applyConfiguration(this);
            double d = 15.0d;
            Sequence sequence = utterance.getSequence("PHONE");
            Tensor<Float> predict = this.dnn_pred.predict(this.normaliser.normalise(utterance.getSequence("FEATURES")));
            float[][] fArr = new float[sequence.size()][1];
            predict.copyTo(fArr);
            for (int i = 0; i < sequence.size(); i++) {
                double d2 = fArr[i][0];
                sequence.set(i, new Phone((Phoneme) sequence.get(i), d, d2));
                d += d2;
            }
            sequence.add(0, new Phone("_", 0.0d, 15.0d));
            sequence.add(new Phone("_", d, d + 15.0d));
            return utterance;
        } catch (Exception e) {
            throw new MaryException("Cannot predict duration", e);
        }
    }

    public void setDescription() {
        this.description = "Dummy duration prediction which sets each phone at 1s.";
    }

    static {
        $assertionsDisabled = !DNNDurationPrediction.class.desiredAssertionStatus();
    }
}
