package marytts.dnn.normaliser;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import marytts.MaryException;
import marytts.data.Sequence;
import marytts.dnn.FeatureNormaliser;
import marytts.features.FeatureMap;
import org.tensorflow.Tensor;

/* loaded from: input_file:marytts/dnn/normaliser/TSVNormaliser.class */
public class TSVNormaliser implements FeatureNormaliser {
    protected final String SEP = "\t";
    protected HashMap<String, ArrayList<String>> map_id_answers;
    protected Set<String> cts_ids;
    protected HashMap<String, String> map_id_feature;
    protected ArrayList<String> list_ids;

    public TSVNormaliser(String str) throws IOException, MaryException {
        loadInformations(Files.readAllLines(Paths.get(str, new String[0]), StandardCharsets.UTF_8));
    }

    public TSVNormaliser(InputStream inputStream) throws IOException, MaryException {
        loadInformations((List) new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.toList()));
    }

    @Override // marytts.dnn.FeatureNormaliser
    public ArrayList<String> getHeader() {
        return this.list_ids;
    }

    protected void loadInformations(List<String> list) throws MaryException {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            String[] split = it.next().split("\t");
            String str = split[0];
            this.list_ids.add(str);
            this.map_id_feature.put(str, split[1]);
            if (split[2].equals("CONT")) {
                this.cts_ids.add(str);
            } else {
                if (!split[2].equals("DISC")) {
                    throw new MaryException(split[2] + " is an unknown type. It should be (CONT or DISC)");
                }
                ArrayList<String> arrayList = new ArrayList<>();
                for (int i = 3; i < split.length; i++) {
                    arrayList.add(split[i]);
                }
                this.map_id_answers.put(str, arrayList);
            }
        }
    }

    @Override // marytts.dnn.FeatureNormaliser
    public Tensor<Float> normalise(Sequence<FeatureMap> sequence) throws MaryException {
        try {
            int i = 0;
            Iterator<String> it = this.list_ids.iterator();
            while (it.hasNext()) {
                String next = it.next();
                i = this.cts_ids.contains(next) ? i + 1 : i + this.map_id_answers.get(next).size();
            }
            float[][] fArr = new float[sequence.size()][i];
            for (int i2 = 0; i2 < sequence.size(); i2++) {
                FeatureMap featureMap = (FeatureMap) sequence.get(i2);
                int i3 = 0;
                Iterator<String> it2 = this.list_ids.iterator();
                while (it2.hasNext()) {
                    String next2 = it2.next();
                    if (this.cts_ids.contains(next2)) {
                        fArr[i2][i3] = ((Number) featureMap.get(this.map_id_feature.get(next2)).getValue()).floatValue();
                        i3++;
                    } else {
                        int indexOf = this.map_id_answers.get(next2).indexOf(featureMap.get(this.map_id_feature.get(next2)).getStringValue());
                        if (indexOf >= 0) {
                            fArr[i2][i3 + indexOf] = 1.0f;
                        }
                        i3 += this.map_id_answers.get(next2).size();
                    }
                }
            }
            return Tensor.create(fArr, Float.class);
        } catch (Exception e) {
            throw new MaryException("Problem with encoding", e);
        }
    }
}
