Skip to content

BERT

g2p_id.bert.BERT

Phoneme-level BERT model for predicting the correct phoneme for the letter e. Trained with Keras, and exported to ONNX. ONNX Runtime engine used during inference.

Source code in g2p_id/bert.py
class BERT:
    """Phoneme-level BERT model for predicting the correct phoneme for the letter `e`.
    Trained with [Keras](https://keras.io/examples/nlp/masked_language_modeling/),
    and exported to ONNX. ONNX Runtime engine used during inference.
    """

    def __init__(self):
        bert_model_path = os.path.join(model_path, "bert_mlm.onnx")
        token2id = os.path.join(model_path, "token2id.json")
        config_path = os.path.join(model_path, "config.json")
        self.model = ort.InferenceSession(
            bert_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
        )
        with open(config_path, encoding="utf-8") as file:
            self.config = json.load(file)
        with open(token2id, encoding="utf-8") as file:
            self.token2id = json.load(file)
        self.id2token = {v: k for k, v in self.token2id.items()}

    def predict(self, text: str) -> str:
        """Performs BERT inference, predicting the correct phoneme for the letter `e`.

        Args:
            text (str): Word to predict from.

        Returns:
            str: Word after prediction.
        """
        # `x` is currently OOV, we replace with
        text = text.replace("x", "ks")
        # mask `e`'s
        text = " ".join([c if c != "e" else "[mask]" for c in text])

        # tokenize and pad to max length
        tokens = [self.token2id[c] for c in text.split()]
        padding = [
            self.token2id[self.config["pad_token"]]
            for _ in range(self.config["max_seq_length"] - len(tokens))
        ]
        tokens = tokens + padding

        input_ids = np.array([tokens], dtype="int64")
        inputs = {"input_1": input_ids}
        prediction = self.model.run(None, inputs)

        # find masked idx token
        mask_token_id = self.token2id[self.config["mask_token"]]
        masked_index = np.where(input_ids == mask_token_id)[1]

        # get prediction at masked indices
        mask_prediction = prediction[0][0][masked_index]
        predicted_ids = np.argmax(mask_prediction, axis=1)

        # replace mask with predicted token
        for i, idx in enumerate(masked_index):
            tokens[idx] = predicted_ids[i]

        return "".join([self.id2token[t] for t in tokens if t != 0])

predict(self, text)

Performs BERT inference, predicting the correct phoneme for the letter e.

Parameters:

Name Type Description Default
text str

Word to predict from.

required

Returns:

Type Description
str

Word after prediction.

Source code in g2p_id/bert.py
def predict(self, text: str) -> str:
    """Performs BERT inference, predicting the correct phoneme for the letter `e`.

    Args:
        text (str): Word to predict from.

    Returns:
        str: Word after prediction.
    """
    # `x` is currently OOV, we replace with
    text = text.replace("x", "ks")
    # mask `e`'s
    text = " ".join([c if c != "e" else "[mask]" for c in text])

    # tokenize and pad to max length
    tokens = [self.token2id[c] for c in text.split()]
    padding = [
        self.token2id[self.config["pad_token"]]
        for _ in range(self.config["max_seq_length"] - len(tokens))
    ]
    tokens = tokens + padding

    input_ids = np.array([tokens], dtype="int64")
    inputs = {"input_1": input_ids}
    prediction = self.model.run(None, inputs)

    # find masked idx token
    mask_token_id = self.token2id[self.config["mask_token"]]
    masked_index = np.where(input_ids == mask_token_id)[1]

    # get prediction at masked indices
    mask_prediction = prediction[0][0][masked_index]
    predicted_ids = np.argmax(mask_prediction, axis=1)

    # replace mask with predicted token
    for i, idx in enumerate(masked_index):
        tokens[idx] = predicted_ids[i]

    return "".join([self.id2token[t] for t in tokens if t != 0])

Usage

texts = ["mengembangkannya", "merdeka", "pecel", "lele"]
bert = BERT()
for text in texts:
    print(bert.predict(text))
>> məngəmbangkannya
>> mərdeka
>> pəcel
>> lele