Wav2Vec2 Classifier
speechline.classifiers.wav2vec2.Wav2Vec2Classifier (AudioClassifier)
Audio classifier with feature extractor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_checkpoint |
str |
HuggingFace model hub checkpoint. |
required |
max_duration_s |
float |
Maximum audio duration in seconds. |
required |
Source code in speechline/classifiers/wav2vec2.py
class Wav2Vec2Classifier(AudioClassifier):
"""
Audio classifier with feature extractor.
Args:
model_checkpoint (str):
HuggingFace model hub checkpoint.
max_duration_s (float):
Maximum audio duration in seconds.
"""
def __init__(self, model_checkpoint: str, max_duration_s: float) -> None:
super().__init__(model_checkpoint, max_duration_s=max_duration_s)
def predict(self, dataset: Dataset) -> List[str]:
"""
Performs audio classification (inference) on `dataset`.
Preprocesses datasets, performs inference, then returns predictions.
Args:
dataset (Dataset):
Dataset to be inferred.
Returns:
List[str]:
List of predictions (as strings of labels).
"""
return self.inference(dataset)
predict(self, dataset)
Performs audio classification (inference) on dataset
.
Preprocesses datasets, performs inference, then returns predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset |
Dataset to be inferred. |
required |
Returns:
Type | Description |
---|---|
List[str] |
List of predictions (as strings of labels). |
Source code in speechline/classifiers/wav2vec2.py
def predict(self, dataset: Dataset) -> List[str]:
"""
Performs audio classification (inference) on `dataset`.
Preprocesses datasets, performs inference, then returns predictions.
Args:
dataset (Dataset):
Dataset to be inferred.
Returns:
List[str]:
List of predictions (as strings of labels).
"""
return self.inference(dataset)