Skip to content

Audio Spectogram Transformer Classifier

speechline.classifiers.ast.ASTClassifier (AudioMultiLabelClassifier)

Audio classifier with feature extractor.

Parameters:

Name Type Description Default
model_checkpoint str

HuggingFace model hub checkpoint.

required
Source code in speechline/classifiers/ast.py
class ASTClassifier(AudioMultiLabelClassifier):
    """
    Audio classifier with feature extractor.

    Args:
        model_checkpoint (str):
            HuggingFace model hub checkpoint.
    """

    def __init__(self, model_checkpoint: str) -> None:
        super().__init__(model_checkpoint)

    def predict(
        self, dataset: Dataset, threshold: float = 0.5
    ) -> List[Dict[str, Union[str, float]]]:
        """
        Performs audio classification (inference) on `dataset`.
        Preprocesses datasets, performs inference, then returns predictions.

        Args:
            dataset (Dataset):
                Dataset to be inferred.
            threshold (float):
                Threshold probability for predicted labels.
                Anything above this threshold will be considered as a valid prediction.

        Returns:
            List[Dict[str, Union[str, float]]]:
                List of predictions in the format of dictionaries,
                consisting of the predicted label and probability.
        """
        return self.inference(dataset, threshold)

predict(self, dataset, threshold=0.5)

Performs audio classification (inference) on dataset. Preprocesses datasets, performs inference, then returns predictions.

Parameters:

Name Type Description Default
dataset Dataset

Dataset to be inferred.

required
threshold float

Threshold probability for predicted labels. Anything above this threshold will be considered as a valid prediction.

0.5

Returns:

Type Description
List[Dict[str, Union[str, float]]]

List of predictions in the format of dictionaries, consisting of the predicted label and probability.

Source code in speechline/classifiers/ast.py
def predict(
    self, dataset: Dataset, threshold: float = 0.5
) -> List[Dict[str, Union[str, float]]]:
    """
    Performs audio classification (inference) on `dataset`.
    Preprocesses datasets, performs inference, then returns predictions.

    Args:
        dataset (Dataset):
            Dataset to be inferred.
        threshold (float):
            Threshold probability for predicted labels.
            Anything above this threshold will be considered as a valid prediction.

    Returns:
        List[Dict[str, Union[str, float]]]:
            List of predictions in the format of dictionaries,
            consisting of the predicted label and probability.
    """
    return self.inference(dataset, threshold)