BERT
g2p_id.bert.BERT
Phoneme-level BERT model for predicting the correct phoneme for the letter e
.
Trained with Keras,
and exported to ONNX. ONNX Runtime engine used during inference.
Source code in g2p_id/bert.py
class BERT:
"""Phoneme-level BERT model for predicting the correct phoneme for the letter `e`.
Trained with [Keras](https://keras.io/examples/nlp/masked_language_modeling/),
and exported to ONNX. ONNX Runtime engine used during inference.
"""
def __init__(self):
bert_model_path = os.path.join(model_path, "bert_mlm.onnx")
token2id = os.path.join(model_path, "token2id.json")
config_path = os.path.join(model_path, "config.json")
self.model = WrapInferenceSession(bert_model_path, providers=onnxruntime.get_available_providers())
with open(config_path, encoding="utf-8") as file:
self.config = json.load(file)
with open(token2id, encoding="utf-8") as file:
self.token2id = json.load(file)
self.id2token = {v: k for k, v in self.token2id.items()}
def predict(self, text: str) -> str:
"""Performs BERT inference, predicting the correct phoneme for the letter `e`.
Args:
text (str): Word to predict from.
Returns:
str: Word after prediction.
"""
# `x` is currently OOV, we replace with
text = text.replace("x", "ks")
# mask `e`'s
text = " ".join([c if c != "e" else "[mask]" for c in text])
# tokenize and pad to max length
tokens = [self.token2id[c] for c in text.split()]
padding = [self.token2id[self.config["pad_token"]] for _ in range(self.config["max_seq_length"] - len(tokens))]
tokens = tokens + padding
input_ids = np.array([tokens], dtype="int64")
inputs = {"input_1": input_ids}
prediction = self.model.run(None, inputs)
# find masked idx token
mask_token_id = self.token2id[self.config["mask_token"]]
masked_index = np.where(input_ids == mask_token_id)[1]
# get prediction at masked indices
mask_prediction = prediction[0][0][masked_index]
predicted_ids = np.argmax(mask_prediction, axis=1)
# replace mask with predicted token
for i, idx in enumerate(masked_index):
tokens[idx] = predicted_ids[i]
return "".join([self.id2token[t] for t in tokens if t != 0])
predict(self, text)
Performs BERT inference, predicting the correct phoneme for the letter e
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
str |
Word to predict from. |
required |
Returns:
Type | Description |
---|---|
str |
Word after prediction. |
Source code in g2p_id/bert.py
def predict(self, text: str) -> str:
"""Performs BERT inference, predicting the correct phoneme for the letter `e`.
Args:
text (str): Word to predict from.
Returns:
str: Word after prediction.
"""
# `x` is currently OOV, we replace with
text = text.replace("x", "ks")
# mask `e`'s
text = " ".join([c if c != "e" else "[mask]" for c in text])
# tokenize and pad to max length
tokens = [self.token2id[c] for c in text.split()]
padding = [self.token2id[self.config["pad_token"]] for _ in range(self.config["max_seq_length"] - len(tokens))]
tokens = tokens + padding
input_ids = np.array([tokens], dtype="int64")
inputs = {"input_1": input_ids}
prediction = self.model.run(None, inputs)
# find masked idx token
mask_token_id = self.token2id[self.config["mask_token"]]
masked_index = np.where(input_ids == mask_token_id)[1]
# get prediction at masked indices
mask_prediction = prediction[0][0][masked_index]
predicted_ids = np.argmax(mask_prediction, axis=1)
# replace mask with predicted token
for i, idx in enumerate(masked_index):
tokens[idx] = predicted_ids[i]
return "".join([self.id2token[t] for t in tokens if t != 0])