package marytts.dnn;

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

/* loaded from: input_file:marytts/dnn/DNNPredictor.class */
public class DNNPredictor {
    protected SavedModelBundle model;
    protected String input_layer_name;
    protected String output_layer_name;
    protected String tag;
    protected String model_path;

    public DNNPredictor(String str) {
        setInputLayerName("input");
        setOutputLayerName("output");
        setTag("serve");
        this.model_path = str;
    }

    public Tensor<Float> predict(Tensor<Float> tensor) throws Exception {
        this.model = SavedModelBundle.load(this.model_path, new String[]{getTag()});
        Tensor<Float> tensor2 = (Tensor) this.model.session().runner().feed(this.input_layer_name, tensor).fetch(this.output_layer_name).run().get(0);
        this.model.close();
        return tensor2;
    }

    public String getTag() {
        return this.tag;
    }

    public String getOutputLayerName() {
        return this.output_layer_name;
    }

    public String getInputLayerName() {
        return this.input_layer_name;
    }

    public void setTag(String str) {
        this.tag = str;
    }

    public void setOutputLayerName(String str) {
        this.output_layer_name = str;
    }

    public void setInputLayerName(String str) {
        this.input_layer_name = str;
    }
}
