Audio Spectogram Transformer Classifier
speechline.classifiers.ast.ASTClassifier (AudioMultiLabelClassifier)
Audio classifier with feature extractor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_checkpoint |
str |
HuggingFace model hub checkpoint. |
required |
Source code in speechline/classifiers/ast.py
class ASTClassifier(AudioMultiLabelClassifier):
"""
Audio classifier with feature extractor.
Args:
model_checkpoint (str):
HuggingFace model hub checkpoint.
"""
def __init__(self, model_checkpoint: str) -> None:
super().__init__(model_checkpoint)
def predict(
self, dataset: Dataset, threshold: float = 0.5
) -> List[Dict[str, Union[str, float]]]:
"""
Performs audio classification (inference) on `dataset`.
Preprocesses datasets, performs inference, then returns predictions.
Args:
dataset (Dataset):
Dataset to be inferred.
threshold (float):
Threshold probability for predicted labels.
Anything above this threshold will be considered as a valid prediction.
Returns:
List[Dict[str, Union[str, float]]]:
List of predictions in the format of dictionaries,
consisting of the predicted label and probability.
"""
return self.inference(dataset, threshold)
predict(self, dataset, threshold=0.5)
Performs audio classification (inference) on dataset
.
Preprocesses datasets, performs inference, then returns predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset |
Dataset to be inferred. |
required |
threshold |
float |
Threshold probability for predicted labels. Anything above this threshold will be considered as a valid prediction. |
0.5 |
Returns:
Type | Description |
---|---|
List[Dict[str, Union[str, float]]] |
List of predictions in the format of dictionaries, consisting of the predicted label and probability. |
Source code in speechline/classifiers/ast.py
def predict(
self, dataset: Dataset, threshold: float = 0.5
) -> List[Dict[str, Union[str, float]]]:
"""
Performs audio classification (inference) on `dataset`.
Preprocesses datasets, performs inference, then returns predictions.
Args:
dataset (Dataset):
Dataset to be inferred.
threshold (float):
Threshold probability for predicted labels.
Anything above this threshold will be considered as a valid prediction.
Returns:
List[Dict[str, Union[str, float]]]:
List of predictions in the format of dictionaries,
consisting of the predicted label and probability.
"""
return self.inference(dataset, threshold)