LSTM
g2p_id.lstm.LSTM
Phoneme-level LSTM model for sequence-to-sequence phonemization. Trained with Keras, and exported to ONNX. ONNX Runtime engine used during inference.
Source code in g2p_id/lstm.py
class LSTM:
"""Phoneme-level LSTM model for sequence-to-sequence phonemization.
Trained with [Keras](https://keras.io/examples/nlp/lstm_seq2seq/),
and exported to ONNX. ONNX Runtime engine used during inference.
"""
def __init__(self):
encoder_model_path = os.path.join(model_path, "encoder_model.onnx")
decoder_model_path = os.path.join(model_path, "decoder_model.onnx")
g2id_path = os.path.join(model_path, "g2id.json")
p2id_path = os.path.join(model_path, "p2id.json")
config_path = os.path.join(model_path, "config.json")
self.encoder = WrapInferenceSession(
encoder_model_path,
providers=onnxruntime.get_available_providers(),
)
self.decoder = WrapInferenceSession(
decoder_model_path,
providers=onnxruntime.get_available_providers(),
)
with open(g2id_path, encoding="utf-8") as file:
self.g2id = json.load(file)
with open(p2id_path, encoding="utf-8") as file:
self.p2id = json.load(file)
self.id2p = {v: k for k, v in self.p2id.items()}
with open(config_path, encoding="utf-8") as file:
self.config = json.load(file)
def predict(self, text: str) -> str:
"""Performs LSTM inference, predicting phonemes of a given word.
Args:
text (str): Word to convert to phonemes.
Returns:
str: Word in phonemes.
"""
input_seq = np.zeros(
(
1,
self.config["max_encoder_seq_length"],
self.config["num_encoder_tokens"],
),
dtype="float32",
)
for idx, char in enumerate(text):
input_seq[0, idx, self.g2id[char]] = 1.0
input_seq[0, len(text) :, self.g2id[self.config["pad_token"]]] = 1.0
encoder_inputs = {"input_1": input_seq}
states_value = self.encoder.run(None, encoder_inputs)
target_seq = np.zeros((1, 1, self.config["num_decoder_tokens"]), dtype="float32")
target_seq[0, 0, self.p2id[self.config["bos_token"]]] = 1.0
stop_condition = False
decoded_sentence = ""
while not stop_condition:
decoder_inputs = {
"input_2": target_seq,
"input_3": states_value[0],
"input_4": states_value[1],
}
output_tokens, state_memory, state_carry = self.decoder.run(None, decoder_inputs)
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = self.id2p[sampled_token_index]
decoded_sentence += sampled_char
if (
sampled_char == self.config["eos_token"]
or len(decoded_sentence) > self.config["max_decoder_seq_length"]
):
stop_condition = True
target_seq = np.zeros((1, 1, self.config["num_decoder_tokens"]), dtype="float32")
target_seq[0, 0, sampled_token_index] = 1.0
states_value = [state_memory, state_carry]
return decoded_sentence.replace(self.config["eos_token"], "")
predict(self, text)
Performs LSTM inference, predicting phonemes of a given word.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
str |
Word to convert to phonemes. |
required |
Returns:
Type | Description |
---|---|
str |
Word in phonemes. |
Source code in g2p_id/lstm.py
def predict(self, text: str) -> str:
"""Performs LSTM inference, predicting phonemes of a given word.
Args:
text (str): Word to convert to phonemes.
Returns:
str: Word in phonemes.
"""
input_seq = np.zeros(
(
1,
self.config["max_encoder_seq_length"],
self.config["num_encoder_tokens"],
),
dtype="float32",
)
for idx, char in enumerate(text):
input_seq[0, idx, self.g2id[char]] = 1.0
input_seq[0, len(text) :, self.g2id[self.config["pad_token"]]] = 1.0
encoder_inputs = {"input_1": input_seq}
states_value = self.encoder.run(None, encoder_inputs)
target_seq = np.zeros((1, 1, self.config["num_decoder_tokens"]), dtype="float32")
target_seq[0, 0, self.p2id[self.config["bos_token"]]] = 1.0
stop_condition = False
decoded_sentence = ""
while not stop_condition:
decoder_inputs = {
"input_2": target_seq,
"input_3": states_value[0],
"input_4": states_value[1],
}
output_tokens, state_memory, state_carry = self.decoder.run(None, decoder_inputs)
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = self.id2p[sampled_token_index]
decoded_sentence += sampled_char
if (
sampled_char == self.config["eos_token"]
or len(decoded_sentence) > self.config["max_decoder_seq_length"]
):
stop_condition = True
target_seq = np.zeros((1, 1, self.config["num_decoder_tokens"]), dtype="float32")
target_seq[0, 0, sampled_token_index] = 1.0
states_value = [state_memory, state_carry]
return decoded_sentence.replace(self.config["eos_token"], "")