Audio Classifier
speechline.modules.audio_classifier.AudioClassifier (AudioModule)
Generic AudioClassifier Module. Performs padded audio classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_checkpoint |
str |
HuggingFace Hub model checkpoint. |
required |
Source code in speechline/modules/audio_classifier.py
class AudioClassifier(AudioModule):
"""
Generic AudioClassifier Module. Performs padded audio classification.
Args:
model_checkpoint (str):
HuggingFace Hub model checkpoint.
"""
def __init__(self, model_checkpoint: str, **kwargs) -> None:
classifier = pipeline(
"audio-classification",
model=model_checkpoint,
device=0 if torch.cuda.is_available() else -1,
pipeline_class=AudioClassificationWithPaddingPipeline,
**kwargs,
)
super().__init__(pipeline=classifier)
def inference(self, dataset: Dataset) -> List[str]:
"""
Inference function for audio classification.
Args:
dataset (Dataset):
Dataset to be inferred.
Returns:
List[str]:
List of predicted labels.
"""
def _get_audio_array(
dataset: Dataset,
) -> np.ndarray:
for item in dataset:
yield item["audio"]["array"]
results = []
for out in tqdm(
self.pipeline(_get_audio_array(dataset), top_k=1),
total=len(dataset),
desc="Classifying Audios",
):
prediction = out[0]["label"]
results.append(prediction)
return results
inference(self, dataset)
Inference function for audio classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset |
Dataset to be inferred. |
required |
Returns:
Type | Description |
---|---|
List[str] |
List of predicted labels. |
Source code in speechline/modules/audio_classifier.py
def inference(self, dataset: Dataset) -> List[str]:
"""
Inference function for audio classification.
Args:
dataset (Dataset):
Dataset to be inferred.
Returns:
List[str]:
List of predicted labels.
"""
def _get_audio_array(
dataset: Dataset,
) -> np.ndarray:
for item in dataset:
yield item["audio"]["array"]
results = []
for out in tqdm(
self.pipeline(_get_audio_array(dataset), top_k=1),
total=len(dataset),
desc="Classifying Audios",
):
prediction = out[0]["label"]
results.append(prediction)
return results