Audio Multilabel Classifier
speechline.modules.audio_multilabel_classifier.AudioMultiLabelClassifier (AudioModule)
Generic AudioClassifier Module. Performs padded audio classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_checkpoint |
str |
HuggingFace Hub model checkpoint. |
required |
Source code in speechline/modules/audio_multilabel_classifier.py
class AudioMultiLabelClassifier(AudioModule):
"""
Generic AudioClassifier Module. Performs padded audio classification.
Args:
model_checkpoint (str):
HuggingFace Hub model checkpoint.
"""
def __init__(self, model_checkpoint: str, **kwargs) -> None:
classifier = pipeline(
"audio-classification",
model=model_checkpoint,
feature_extractor=model_checkpoint,
device=0 if torch.cuda.is_available() else -1,
pipeline_class=AudioMultiLabelClassificationPipeline,
**kwargs,
)
super().__init__(pipeline=classifier)
def inference(
self, dataset: Dataset, threshold: float = 0.5
) -> List[Dict[str, Union[str, float]]]:
"""
Inference function for audio classification.
Args:
dataset (Dataset):
Dataset to be inferred.
threshold (float):
Threshold probability for predicted labels.
Anything above this threshold will be considered as a valid prediction.
Returns:
List[Dict[str, Union[str, float]]]:
List of predictions in the format of dictionaries,
consisting of the predicted label and probability.
"""
def _get_audio_array(
dataset: Dataset,
) -> np.ndarray:
for item in dataset:
yield item["audio"]["array"]
results = []
for out in tqdm(
self.pipeline(_get_audio_array(dataset), top_k=1),
total=len(dataset),
desc="Classifying Audios",
):
ids = np.where(out >= threshold)[0].tolist()
if len(ids) > 0:
prediction = [
{
"label": self.pipeline.model.config.id2label[id],
"score": out[id],
}
for id in ids
]
results.append(prediction)
return results
inference(self, dataset, threshold=0.5)
Inference function for audio classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset |
Dataset to be inferred. |
required |
threshold |
float |
Threshold probability for predicted labels. Anything above this threshold will be considered as a valid prediction. |
0.5 |
Returns:
Type | Description |
---|---|
List[Dict[str, Union[str, float]]] |
List of predictions in the format of dictionaries, consisting of the predicted label and probability. |
Source code in speechline/modules/audio_multilabel_classifier.py
def inference(
self, dataset: Dataset, threshold: float = 0.5
) -> List[Dict[str, Union[str, float]]]:
"""
Inference function for audio classification.
Args:
dataset (Dataset):
Dataset to be inferred.
threshold (float):
Threshold probability for predicted labels.
Anything above this threshold will be considered as a valid prediction.
Returns:
List[Dict[str, Union[str, float]]]:
List of predictions in the format of dictionaries,
consisting of the predicted label and probability.
"""
def _get_audio_array(
dataset: Dataset,
) -> np.ndarray:
for item in dataset:
yield item["audio"]["array"]
results = []
for out in tqdm(
self.pipeline(_get_audio_array(dataset), top_k=1),
total=len(dataset),
desc="Classifying Audios",
):
ids = np.where(out >= threshold)[0].tolist()
if len(ids) > 0:
prediction = [
{
"label": self.pipeline.model.config.id2label[id],
"score": out[id],
}
for id in ids
]
results.append(prediction)
return results