package marytts.dnn.normaliser;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import marytts.MaryException;
import marytts.data.Sequence;
import marytts.dnn.FeatureNormaliser;
import marytts.features.FeatureMap;
import org.tensorflow.Tensor;

/* loaded from: input_file:marytts/dnn/normaliser/HTKQuestionNormaliser.class */
public class HTKQuestionNormaliser implements FeatureNormaliser {
    protected final String FEAT_QS_SEP = "==";
    protected final String POS_SEP = "#";
    protected HashMap<String, ArrayList<String>> qs_map;
    protected Set<String> cqs_set;
    protected ArrayList<String> list_questions;

    public HTKQuestionNormaliser(String str) throws IOException, MaryException {
        parseQuestionLines(Files.readAllLines(Paths.get(str, new String[0]), StandardCharsets.UTF_8));
    }

    public HTKQuestionNormaliser(InputStream inputStream) throws IOException, MaryException {
        parseQuestionLines((List) new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.toList()));
    }

    @Override // marytts.dnn.FeatureNormaliser
    public ArrayList<String> getHeader() {
        ArrayList<String> arrayList = new ArrayList<>();
        Iterator<String> it = this.list_questions.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (this.cqs_set.contains(next)) {
                arrayList.add(next);
            } else if (this.qs_map.containsKey(next)) {
                Iterator<String> it2 = this.qs_map.get(next).iterator();
                while (it2.hasNext()) {
                    arrayList.add(next + "#" + it2.next());
                }
            }
        }
        return arrayList;
    }

    protected void parseQuestionLines(List<String> list) throws MaryException {
        for (String str : list) {
            if (str.startsWith("QS")) {
                parseQS(str);
            } else if (str.startsWith("CQS")) {
                parseCQS(str);
            } else if (!str.startsWith("#") && !str.isEmpty()) {
                throw new MaryException("the following line is not valid: " + str);
            }
        }
    }

    protected void parseCQS(String str) throws MaryException {
        Matcher matcher = Pattern.compile("QS[ \t]*\"([^\"]*)\"[ \t]*.*").matcher(str);
        if (matcher.find()) {
            String group = matcher.group(0);
            this.list_questions.add(group);
            this.cqs_set.add(group);
        }
    }

    protected void parseQS(String str) throws MaryException {
        Matcher matcher = Pattern.compile("QS[ \t]*\"([^\"]*)\"[ \t]*\\{(.*)\\}").matcher(str);
        if (!matcher.find()) {
            throw new MaryException("the following is not a valid question file line " + str);
        }
        String[] split = matcher.group(1).replaceAll("/[A-Z]:", "/:").split(",");
        ArrayList<String> arrayList = new ArrayList<>();
        Pattern compile = Pattern.compile("[^a-zA-Z0-9]*([^a-zA-Z0-9]*)[^a-zA-Z0-9].*");
        for (String str2 : split) {
            Matcher matcher2 = compile.matcher(str2);
            if (!matcher.find()) {
                throw new MaryException("\"" + str2 + "\" is not a valid answer");
            }
            arrayList.add(matcher2.group(0));
        }
        String group = matcher.group(0);
        this.list_questions.add(group);
        this.qs_map.put(group, arrayList);
    }

    @Override // marytts.dnn.FeatureNormaliser
    public Tensor<Float> normalise(Sequence<FeatureMap> sequence) throws MaryException {
        try {
            int i = 0;
            Iterator<String> it = this.list_questions.iterator();
            while (it.hasNext()) {
                String next = it.next();
                if (this.cqs_set.contains(next)) {
                    i++;
                } else {
                    if (!this.qs_map.containsKey(next)) {
                        throw new MaryException(next + " is not available in qs map an cqs set");
                    }
                    i += this.qs_map.get(next).size();
                }
            }
            float[][] fArr = new float[sequence.size()][i];
            for (int i2 = 0; i2 < sequence.size(); i2++) {
                FeatureMap featureMap = (FeatureMap) sequence.get(i2);
                int i3 = 0;
                Iterator<String> it2 = this.list_questions.iterator();
                while (it2.hasNext()) {
                    String next2 = it2.next();
                    if (this.cqs_set.contains(next2)) {
                        fArr[i2][i3] = ((Number) featureMap.get(next2).getValue()).floatValue();
                        i3++;
                    } else if (this.qs_map.containsKey(next2)) {
                        int indexOf = this.qs_map.get(next2).indexOf(featureMap.get(next2).getStringValue());
                        if (indexOf >= 0) {
                            fArr[i2][i3 + indexOf] = 1.0f;
                        }
                        i3 += this.qs_map.get(next2).size();
                    }
                }
            }
            return Tensor.create(fArr, Float.class);
        } catch (Exception e) {
            throw new MaryException("Problem with encoding", e);
        }
    }
}
