Skip to content

Audio Classifier

speechline.modules.audio_classifier.AudioClassifier (AudioModule)

Generic AudioClassifier Module. Performs padded audio classification.

Parameters:

Name Type Description Default
model_checkpoint str

HuggingFace Hub model checkpoint.

required
Source code in speechline/modules/audio_classifier.py
class AudioClassifier(AudioModule):
    """
    Generic AudioClassifier Module. Performs padded audio classification.

    Args:
        model_checkpoint (str):
            HuggingFace Hub model checkpoint.
    """

    def __init__(self, model_checkpoint: str, **kwargs) -> None:
        classifier = pipeline(
            "audio-classification",
            model=model_checkpoint,
            device=0 if torch.cuda.is_available() else -1,
            pipeline_class=AudioClassificationWithPaddingPipeline,
            **kwargs,
        )
        super().__init__(pipeline=classifier)

    def inference(self, dataset: Dataset) -> List[str]:
        """
        Inference function for audio classification.

        Args:
            dataset (Dataset):
                Dataset to be inferred.

        Returns:
            List[str]:
                List of predicted labels.
        """

        def _get_audio_array(
            dataset: Dataset,
        ) -> np.ndarray:
            for item in dataset:
                yield item["audio"]["array"]

        results = []

        for out in tqdm(
            self.pipeline(_get_audio_array(dataset), top_k=1),
            total=len(dataset),
            desc="Classifying Audios",
        ):
            prediction = out[0]["label"]
            results.append(prediction)

        return results

inference(self, dataset)

Inference function for audio classification.

Parameters:

Name Type Description Default
dataset Dataset

Dataset to be inferred.

required

Returns:

Type Description
List[str]

List of predicted labels.

Source code in speechline/modules/audio_classifier.py
def inference(self, dataset: Dataset) -> List[str]:
    """
    Inference function for audio classification.

    Args:
        dataset (Dataset):
            Dataset to be inferred.

    Returns:
        List[str]:
            List of predicted labels.
    """

    def _get_audio_array(
        dataset: Dataset,
    ) -> np.ndarray:
        for item in dataset:
            yield item["audio"]["array"]

    results = []

    for out in tqdm(
        self.pipeline(_get_audio_array(dataset), top_k=1),
        total=len(dataset),
        desc="Classifying Audios",
    ):
        prediction = out[0]["label"]
        results.append(prediction)

    return results