Skip to content

Punctuation Forced Aligner

speechline.aligners.punctuation_forced_aligner.PunctuationForcedAligner

Force-align predicted phoneme offsets with ground truth text with punctuation.

Parameters:

Name Type Description Default
g2p Callable[[str], List[str]]

Callable grapheme-to-phoneme function.

required
punctuations Optional[List[str]]

List of punctuations to include. Defaults to None.

None
Source code in speechline/aligners/punctuation_forced_aligner.py
class PunctuationForcedAligner:
    """
    Force-align predicted phoneme offsets with ground truth text with punctuation.

    Args:
        g2p (Callable[[str], List[str]]):
            Callable grapheme-to-phoneme function.
        punctuations (Optional[List[str]], optional):
            List of punctuations to include. Defaults to `None`.
    """

    def __init__(
        self, g2p: Callable[[str], List[str]], punctuations: Optional[List[str]] = None
    ):
        self.punctuations = (
            ["?", ",", ".", "!", ";"] if not punctuations else punctuations
        )
        self.g2p = g2p

    def __call__(
        self, offsets: List[Dict[str, Union[str, float]]], text: str
    ) -> List[Dict[str, Union[str, float]]]:
        """
        Performs punctuation-forced alignment on output offsets
        from phoneme-recognition models like wav2vec 2.0.

        ### Example
        ```pycon title="example_punctuation_forced_aligner.py"
        >>> from gruut import sentences
        >>> def g2p(text):
        ...     phonemes = []
        ...     for words in sentences(text):
        ...         for word in words:
        ...             if word.is_major_break or word.is_minor_break:
        ...                 phonemes += word.text
        ...             elif word.phonemes:
        ...                 phonemes += word.phonemes
        ...     return phonemes
        >>> pfa = PunctuationForcedAligner(g2p)
        >>> offsets = [
        ...     {"text": "h", "start_time": 0.0, "end_time": 0.2},
        ...     {"text": "ɚ", "start_time": 0.24, "end_time": 0.28},
        ...     {"text": "i", "start_time": 0.42, "end_time": 0.44},
        ...     {"text": "d", "start_time": 0.5, "end_time": 0.54},
        ...     {"text": "d", "start_time": 0.5, "end_time": 0.54},
        ...     {"text": "ʌ", "start_time": 0.64, "end_time": 0.66},
        ...     {"text": "m", "start_time": 0.7, "end_time": 0.74},
        ...     {"text": "b", "start_time": 0.78, "end_time": 0.82},
        ...     {"text": "ɹ", "start_time": 0.84, "end_time": 0.9},
        ...     {"text": "ɛ", "start_time": 0.92, "end_time": 0.94},
        ...     {"text": "l", "start_time": 1.0, "end_time": 1.04},
        ...     {"text": "ə", "start_time": 1.08, "end_time": 1.12},
        ... ]
        >>> transcript = "Her red, umbrella."
        >>> pfa(offsets, transcript)
        [
            {'text': 'h', 'start_time': 0.0, 'end_time': 0.2},
            {'text': 'ɚ', 'start_time': 0.24, 'end_time': 0.28},
            {'text': 'i', 'start_time': 0.42, 'end_time': 0.44},
            {'text': 'd', 'start_time': 0.5, 'end_time': 0.54},
            {'text': 'd', 'start_time': 0.5, 'end_time': 0.54},
            {'text': ',', 'start_time': 0.54, 'end_time': 0.64},
            {'text': 'ʌ', 'start_time': 0.64, 'end_time': 0.66},
            {'text': 'm', 'start_time': 0.7, 'end_time': 0.74},
            {'text': 'b', 'start_time': 0.78, 'end_time': 0.82},
            {'text': 'ɹ', 'start_time': 0.84, 'end_time': 0.9},
            {'text': 'ɛ', 'start_time': 0.92, 'end_time': 0.94},
            {'text': 'l', 'start_time': 1.0, 'end_time': 1.04},
            {'text': 'ə', 'start_time': 1.08, 'end_time': 1.12},
            {'text': '.', 'start_time': 1.12, 'end_time': 1.12}
        ]
        ```

        Args:
            offsets (List[Dict[str, Union[str, float]]]):
                List of offsets containing information of phonemes
                and their respective start and end times
            text (str):
                ground truth transcript which contains punctuations

        Returns:
            List[Dict[str, Union[str, float]]]:
                List of newly updated offsets which includes punctuations
        """
        updated_offsets = offsets[:]
        predicted_phonemes = [offset["text"] for offset in updated_offsets]
        ground_truth_phonemes = self.g2p(text)

        # segment phonemes based on `self.punctuations`
        segments, cleaned_segments = self.segment_phonemes_punctuations(
            ground_truth_phonemes
        )

        # generate all possible segments from predicted phonemes
        potential_segments = self.generate_partitions(
            predicted_phonemes, n=len(cleaned_segments)
        )

        # if there are multiple possible partitions
        if len(cleaned_segments) > 1:
            # filter for highly probable candidates
            potential_segments = self._filter_candidates_stdev(
                cleaned_segments, potential_segments
            )

        # find most similar predicted segment to actual segments
        max_similarity, aligned_segments = -1, None
        for potential in potential_segments:
            similarity = sum(
                self.similarity(" ".join(hyp), seg)
                for hyp, seg in zip(potential, segments)
            ) / len(cleaned_segments)
            if similarity > max_similarity:
                max_similarity = similarity
                aligned_segments = potential

        # insert punctuations from real segment to predicted segments
        for idx, token in enumerate(segments):
            if token in self.punctuations:
                aligned_segments.insert(idx, [token])

        # add punctuations to offsets
        idx = 0
        for segment in aligned_segments:
            token = segment[0]
            # skip non-punctuation segments
            if token not in self.punctuations:
                idx += len(segment)
                continue

            # start of punctuation is end time of previous token
            start = updated_offsets[idx - 1]["end_time"]

            # end of punctuation is start time of next token
            if idx < len(updated_offsets):
                end = updated_offsets[idx]["start_time"]
            else:
                end = start  # if it's last, end = start

            offset = {"text": token, "start_time": start, "end_time": end}
            updated_offsets.insert(idx, offset)
            idx += 1

        return updated_offsets

    def _filter_candidates_stdev(
        self,
        ground_truth_segments: List[List[str]],
        potential_segments: List[List[List[str]]],
        k: int = 1,
    ) -> List[List[List[str]]]:
        """
        Filters potential segment candidates based on range of
        standard deviation of segment lengths.

        Args:
            ground_truth_segments (List[List[str]]):
                Ground truth segments.
            potential_segments (List[List[List[str]]]):
                List of potential segment candidates to filter.
            k (int, optional):
                Acceptable upper/lower bounds of standard deviation.
                Defaults to `1`.

        Returns:
            List[List[List[str]]]:
                List of filtered segment candidates.
        """
        target_stdev = stdev([len(x.split()) for x in ground_truth_segments])
        stdev_lengths = [
            stdev([len(x) for x in segment]) for segment in potential_segments
        ]

        candidate_idxs = [
            i
            for i, x in enumerate(stdev_lengths)
            if target_stdev - k <= x <= target_stdev + k
        ]
        candidates = [potential_segments[i] for i in candidate_idxs]

        return candidates

    def segment_phonemes_punctuations(
        self, phonemes: List[str]
    ) -> Tuple[List[str], List[str]]:
        """
        Segment/group list of phonemes consecutively, up to a punctuation.

        Args:
            phonemes (List[str]):
                List of phonemes.

        Returns:
            Tuple[List[str], List[str]]:
                Pair of equivalently segmented phonemes.
                Second index returns segments without punctuations.
        """
        phoneme_string = " ".join(phonemes)
        backslash_char = "\\"
        segments = re.split(
            f"({'|'.join(f'{backslash_char}{p}' for p in self.punctuations)})",
            phoneme_string,
        )
        segments = [s.strip() for s in segments if s.strip() != ""]
        cleaned_segments = [s for s in segments if s not in self.punctuations]
        return segments, cleaned_segments

    def similarity(self, a: str, b: str) -> float:
        return SequenceMatcher(None, a, b).ratio()

    def generate_partitions(self, lst: List, n: int) -> List[List[List]]:
        """
        Generate all possible `n` consecutive partitions.
        Source: [StackOverflow](https://stackoverflow.com/a/73356868).

        Args:
            lst (List):
                List to be partitioned.
            n (int):
                Number of partitions to generate.

        Returns:
            List[List[List]]:
                List of all possible list of segments.
        """
        result = []
        for indices in combinations(range(1, len(lst)), n - 1):
            splits = []
            start = 0
            for stop in indices:
                splits.append(lst[start:stop])
                start = stop
            splits.append(lst[start:])
            result.append(splits)
        return result

__call__(self, offsets, text) special

Performs punctuation-forced alignment on output offsets from phoneme-recognition models like wav2vec 2.0.

Example
example_punctuation_forced_aligner.py
>>> from gruut import sentences
>>> def g2p(text):
...     phonemes = []
...     for words in sentences(text):
...         for word in words:
...             if word.is_major_break or word.is_minor_break:
...                 phonemes += word.text
...             elif word.phonemes:
...                 phonemes += word.phonemes
...     return phonemes
>>> pfa = PunctuationForcedAligner(g2p)
>>> offsets = [
...     {"text": "h", "start_time": 0.0, "end_time": 0.2},
...     {"text": "ɚ", "start_time": 0.24, "end_time": 0.28},
...     {"text": "i", "start_time": 0.42, "end_time": 0.44},
...     {"text": "d", "start_time": 0.5, "end_time": 0.54},
...     {"text": "d", "start_time": 0.5, "end_time": 0.54},
...     {"text": "ʌ", "start_time": 0.64, "end_time": 0.66},
...     {"text": "m", "start_time": 0.7, "end_time": 0.74},
...     {"text": "b", "start_time": 0.78, "end_time": 0.82},
...     {"text": "ɹ", "start_time": 0.84, "end_time": 0.9},
...     {"text": "ɛ", "start_time": 0.92, "end_time": 0.94},
...     {"text": "l", "start_time": 1.0, "end_time": 1.04},
...     {"text": "ə", "start_time": 1.08, "end_time": 1.12},
... ]
>>> transcript = "Her red, umbrella."
>>> pfa(offsets, transcript)
[
    {'text': 'h', 'start_time': 0.0, 'end_time': 0.2},
    {'text': 'ɚ', 'start_time': 0.24, 'end_time': 0.28},
    {'text': 'i', 'start_time': 0.42, 'end_time': 0.44},
    {'text': 'd', 'start_time': 0.5, 'end_time': 0.54},
    {'text': 'd', 'start_time': 0.5, 'end_time': 0.54},
    {'text': ',', 'start_time': 0.54, 'end_time': 0.64},
    {'text': 'ʌ', 'start_time': 0.64, 'end_time': 0.66},
    {'text': 'm', 'start_time': 0.7, 'end_time': 0.74},
    {'text': 'b', 'start_time': 0.78, 'end_time': 0.82},
    {'text': 'ɹ', 'start_time': 0.84, 'end_time': 0.9},
    {'text': 'ɛ', 'start_time': 0.92, 'end_time': 0.94},
    {'text': 'l', 'start_time': 1.0, 'end_time': 1.04},
    {'text': 'ə', 'start_time': 1.08, 'end_time': 1.12},
    {'text': '.', 'start_time': 1.12, 'end_time': 1.12}
]

Parameters:

Name Type Description Default
offsets List[Dict[str, Union[str, float]]]

List of offsets containing information of phonemes and their respective start and end times

required
text str

ground truth transcript which contains punctuations

required

Returns:

Type Description
List[Dict[str, Union[str, float]]]

List of newly updated offsets which includes punctuations

Source code in speechline/aligners/punctuation_forced_aligner.py
def __call__(
    self, offsets: List[Dict[str, Union[str, float]]], text: str
) -> List[Dict[str, Union[str, float]]]:
    """
    Performs punctuation-forced alignment on output offsets
    from phoneme-recognition models like wav2vec 2.0.

    ### Example
    ```pycon title="example_punctuation_forced_aligner.py"
    >>> from gruut import sentences
    >>> def g2p(text):
    ...     phonemes = []
    ...     for words in sentences(text):
    ...         for word in words:
    ...             if word.is_major_break or word.is_minor_break:
    ...                 phonemes += word.text
    ...             elif word.phonemes:
    ...                 phonemes += word.phonemes
    ...     return phonemes
    >>> pfa = PunctuationForcedAligner(g2p)
    >>> offsets = [
    ...     {"text": "h", "start_time": 0.0, "end_time": 0.2},
    ...     {"text": "ɚ", "start_time": 0.24, "end_time": 0.28},
    ...     {"text": "i", "start_time": 0.42, "end_time": 0.44},
    ...     {"text": "d", "start_time": 0.5, "end_time": 0.54},
    ...     {"text": "d", "start_time": 0.5, "end_time": 0.54},
    ...     {"text": "ʌ", "start_time": 0.64, "end_time": 0.66},
    ...     {"text": "m", "start_time": 0.7, "end_time": 0.74},
    ...     {"text": "b", "start_time": 0.78, "end_time": 0.82},
    ...     {"text": "ɹ", "start_time": 0.84, "end_time": 0.9},
    ...     {"text": "ɛ", "start_time": 0.92, "end_time": 0.94},
    ...     {"text": "l", "start_time": 1.0, "end_time": 1.04},
    ...     {"text": "ə", "start_time": 1.08, "end_time": 1.12},
    ... ]
    >>> transcript = "Her red, umbrella."
    >>> pfa(offsets, transcript)
    [
        {'text': 'h', 'start_time': 0.0, 'end_time': 0.2},
        {'text': 'ɚ', 'start_time': 0.24, 'end_time': 0.28},
        {'text': 'i', 'start_time': 0.42, 'end_time': 0.44},
        {'text': 'd', 'start_time': 0.5, 'end_time': 0.54},
        {'text': 'd', 'start_time': 0.5, 'end_time': 0.54},
        {'text': ',', 'start_time': 0.54, 'end_time': 0.64},
        {'text': 'ʌ', 'start_time': 0.64, 'end_time': 0.66},
        {'text': 'm', 'start_time': 0.7, 'end_time': 0.74},
        {'text': 'b', 'start_time': 0.78, 'end_time': 0.82},
        {'text': 'ɹ', 'start_time': 0.84, 'end_time': 0.9},
        {'text': 'ɛ', 'start_time': 0.92, 'end_time': 0.94},
        {'text': 'l', 'start_time': 1.0, 'end_time': 1.04},
        {'text': 'ə', 'start_time': 1.08, 'end_time': 1.12},
        {'text': '.', 'start_time': 1.12, 'end_time': 1.12}
    ]
    ```

    Args:
        offsets (List[Dict[str, Union[str, float]]]):
            List of offsets containing information of phonemes
            and their respective start and end times
        text (str):
            ground truth transcript which contains punctuations

    Returns:
        List[Dict[str, Union[str, float]]]:
            List of newly updated offsets which includes punctuations
    """
    updated_offsets = offsets[:]
    predicted_phonemes = [offset["text"] for offset in updated_offsets]
    ground_truth_phonemes = self.g2p(text)

    # segment phonemes based on `self.punctuations`
    segments, cleaned_segments = self.segment_phonemes_punctuations(
        ground_truth_phonemes
    )

    # generate all possible segments from predicted phonemes
    potential_segments = self.generate_partitions(
        predicted_phonemes, n=len(cleaned_segments)
    )

    # if there are multiple possible partitions
    if len(cleaned_segments) > 1:
        # filter for highly probable candidates
        potential_segments = self._filter_candidates_stdev(
            cleaned_segments, potential_segments
        )

    # find most similar predicted segment to actual segments
    max_similarity, aligned_segments = -1, None
    for potential in potential_segments:
        similarity = sum(
            self.similarity(" ".join(hyp), seg)
            for hyp, seg in zip(potential, segments)
        ) / len(cleaned_segments)
        if similarity > max_similarity:
            max_similarity = similarity
            aligned_segments = potential

    # insert punctuations from real segment to predicted segments
    for idx, token in enumerate(segments):
        if token in self.punctuations:
            aligned_segments.insert(idx, [token])

    # add punctuations to offsets
    idx = 0
    for segment in aligned_segments:
        token = segment[0]
        # skip non-punctuation segments
        if token not in self.punctuations:
            idx += len(segment)
            continue

        # start of punctuation is end time of previous token
        start = updated_offsets[idx - 1]["end_time"]

        # end of punctuation is start time of next token
        if idx < len(updated_offsets):
            end = updated_offsets[idx]["start_time"]
        else:
            end = start  # if it's last, end = start

        offset = {"text": token, "start_time": start, "end_time": end}
        updated_offsets.insert(idx, offset)
        idx += 1

    return updated_offsets

generate_partitions(self, lst, n)

Generate all possible n consecutive partitions. Source: StackOverflow.

Parameters:

Name Type Description Default
lst List

List to be partitioned.

required
n int

Number of partitions to generate.

required

Returns:

Type Description
List[List[List]]

List of all possible list of segments.

Source code in speechline/aligners/punctuation_forced_aligner.py
def generate_partitions(self, lst: List, n: int) -> List[List[List]]:
    """
    Generate all possible `n` consecutive partitions.
    Source: [StackOverflow](https://stackoverflow.com/a/73356868).

    Args:
        lst (List):
            List to be partitioned.
        n (int):
            Number of partitions to generate.

    Returns:
        List[List[List]]:
            List of all possible list of segments.
    """
    result = []
    for indices in combinations(range(1, len(lst)), n - 1):
        splits = []
        start = 0
        for stop in indices:
            splits.append(lst[start:stop])
            start = stop
        splits.append(lst[start:])
        result.append(splits)
    return result

segment_phonemes_punctuations(self, phonemes)

Segment/group list of phonemes consecutively, up to a punctuation.

Parameters:

Name Type Description Default
phonemes List[str]

List of phonemes.

required

Returns:

Type Description
Tuple[List[str], List[str]]

Pair of equivalently segmented phonemes. Second index returns segments without punctuations.

Source code in speechline/aligners/punctuation_forced_aligner.py
def segment_phonemes_punctuations(
    self, phonemes: List[str]
) -> Tuple[List[str], List[str]]:
    """
    Segment/group list of phonemes consecutively, up to a punctuation.

    Args:
        phonemes (List[str]):
            List of phonemes.

    Returns:
        Tuple[List[str], List[str]]:
            Pair of equivalently segmented phonemes.
            Second index returns segments without punctuations.
    """
    phoneme_string = " ".join(phonemes)
    backslash_char = "\\"
    segments = re.split(
        f"({'|'.join(f'{backslash_char}{p}' for p in self.punctuations)})",
        phoneme_string,
    )
    segments = [s.strip() for s in segments if s.strip() != ""]
    cleaned_segments = [s for s in segments if s not in self.punctuations]
    return segments, cleaned_segments