未验证 提交 8d17108b 编写于 作者: K KP 提交者: GitHub

Refactor code in paddleaudio/models (#5300)

* Refactor code in paddleaudio/models

* Upgrade __getitem__ to support on the fly feature extraction
上级 8a0045aa
......@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .dcase import TAUUrbanAcousticScenes_2020_Mobile_DevelopmentSet
from .dcase import UrbanAcousticScenes
from .esc50 import ESC50
from .gtzan import GTZAN
from .ravdess import RAVDESS
from .urban_sound import UrbanSound8K
__all__ = [
'ESC50',
'UrbanSound8K',
'GTZAN',
'TAUUrbanAcousticScenes_2020_Mobile_DevelopmentSet',
'UrbanAcousticScenes',
'RAVDESS',
]
......@@ -35,12 +35,19 @@ class AudioClassificationDataset(paddle.io.Dataset):
'log_spect': log_spect,
}
def __init__(self, files: List[str], labels: List[int], sample_rate: int, feat_type: str = 'raw', **kwargs):
def __init__(self,
files: List[str],
labels: List[int],
sample_rate: int,
duration: float,
feat_type: str = 'raw',
**kwargs):
"""
Ags:
files (:obj:`List[str]`): A list of absolute path of audio files.
labels (:obj:`List[int]`): Labels of audio files.
sample_rate (:obj:`int`): Sample rate of audio files.
duration (:obj:`float`): Duration of audio files.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
......@@ -49,32 +56,39 @@ class AudioClassificationDataset(paddle.io.Dataset):
if feat_type not in self._feat_func.keys():
raise RuntimeError(\
f"Unknown feat_type: {feat_type}, it must be one in {list(self._feat_func.keys())}")
self.feat_type = feat_type
self.files = files
self.labels = labels
self.records = self._convert_to_records(sample_rate, **kwargs)
self.sample_rate = sample_rate
self.duration = duration
self.feat_type = feat_type
self.feat_config = kwargs # Pass keyword arguments to customize feature config
def _get_data(self, input_file: str):
raise NotImplementedError
def _convert_to_records(self, sample_rate: int, **kwargs) -> List[dict]:
records = []
feat_func = self._feat_func[self.feat_type]
def _convert_to_record(self, idx):
file, label = self.files[idx], self.labels[idx]
waveform, _ = librosa.load(file, sr=self.sample_rate)
normal_length = self.sample_rate * self.duration
if len(waveform) > normal_length:
waveform = waveform[:normal_length]
else:
waveform = np.pad(waveform, (0, normal_length - len(waveform)))
logger.info('Start extracting features from audio files.')
for file, label in tqdm(zip(self.files, self.labels), total=len(self.files)):
record = {}
waveform, _ = librosa.load(file, sr=sample_rate)
record['feat'] = feat_func(waveform, **kwargs) if feat_func else waveform
record['label'] = label
records.append(record)
feat_func = self._feat_func[self.feat_type]
return records
record = {}
record['feat'] = feat_func(waveform, sample_rate=self.sample_rate, **
self.feat_config) if feat_func else waveform
record['label'] = label
return record
def __getitem__(self, idx):
record = self.records[idx]
return np.array(record['feat']), np.array(record['label'], dtype=np.int64)
record = self._convert_to_record(idx)
return np.array(record['feat']).transpose(), np.array(record['label'], dtype=np.int64)
def __len__(self):
return len(self.records)
return len(self.files)
......@@ -18,17 +18,23 @@ from typing import List, Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from ..utils.log import logger
from .dataset import AudioClassificationDataset
__all__ = ['TAUUrbanAcousticScenes_2020_Mobile_DevelopmentSet']
__all__ = ['UrbanAcousticScenes']
class TAUUrbanAcousticScenes_2020_Mobile_DevelopmentSet(AudioClassificationDataset):
class UrbanAcousticScenes(AudioClassificationDataset):
"""
TAU Urban Acoustic Scenes 2020 Mobile Development dataset
This dataset is used in DCASE2020 - Task 1, Acoustic scene classification / Subtask A / Development
TAU Urban Acoustic Scenes 2020 Mobile Development dataset contains recordings from
12 European cities in 10 different acoustic scenes using 4 different devices.
Additionally, synthetic data for 11 mobile devices was created based on the original
recordings. Of the 12 cities, two are present only in the evaluation set.
Reference:
A multi-device dataset for urban acoustic scene classification
https://arxiv.org/abs/1807.09840
"""
source_url = 'https://zenodo.org/record/3819968/files/'
base_name = 'TAU-urban-acoustic-scenes-2020-mobile-development'
archieves = [
......@@ -125,12 +131,12 @@ class TAUUrbanAcousticScenes_2020_Mobile_DevelopmentSet(AudioClassificationDatas
It identifies the feature type that user wants to extrace of an audio file.
"""
files, labels = self._get_data(mode)
super(TAUUrbanAcousticScenes_2020_Mobile_DevelopmentSet, \
self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
feat_type=feat_type,
**kwargs)
super(UrbanAcousticScenes, self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
duration=self.duration,
feat_type=feat_type,
**kwargs)
def _get_meta_info(self, subset: str = None, skip_header: bool = True) -> List[collections.namedtuple]:
if subset is None:
......
......@@ -18,7 +18,6 @@ from typing import List, Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from ..utils.log import logger
from .dataset import AudioClassificationDataset
__all__ = ['ESC50']
......@@ -26,13 +25,20 @@ __all__ = ['ESC50']
class ESC50(AudioClassificationDataset):
"""
Environment Sound Classification Dataset
The ESC-50 dataset is a labeled collection of 2000 environmental audio recordings
suitable for benchmarking methods of environmental sound classification. The dataset
consists of 5-second-long recordings organized into 50 semantical classes (with
40 examples per class)
Reference:
ESC: Dataset for Environmental Sound Classification
http://dx.doi.org/10.1145/2733373.2806390
"""
archieves = [
{
'url': 'https://github.com/karoldvl/ESC-50/archive/master.zip',
'md5': '70aba3bada37d2674b8f6cd5afd5f065',
'md5': '1fdc5dd87626d5eb91be20ed53c9aed9',
},
]
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
......@@ -56,6 +62,7 @@ class ESC50(AudioClassificationDataset):
super(ESC50, self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
duration=self.duration,
feat_type=feat_type,
**kwargs)
......
......@@ -19,7 +19,6 @@ from typing import List, Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from ..utils.log import logger
from .dataset import AudioClassificationDataset
__all__ = ['GTZAN']
......@@ -27,7 +26,13 @@ __all__ = ['GTZAN']
class GTZAN(AudioClassificationDataset):
"""
GTZAN Dataset
The GTZAN dataset consists of 1000 audio tracks each 30 seconds long. It contains 10 genres,
each represented by 100 tracks. The dataset is the most-used public dataset for evaluation
in machine listening research for music genre recognition (MGR).
Reference:
Musical genre classification of audio signals
https://ieeexplore.ieee.org/document/1021072/
"""
archieves = [
......@@ -40,8 +45,8 @@ class GTZAN(AudioClassificationDataset):
meta = os.path.join('genres', 'input.mf')
meta_info = collections.namedtuple('META_INFO', ('file_path', 'label'))
audio_path = 'genres'
sample_rate = 22050 # 44.1 khz
duration = 30 # 5s
sample_rate = 22050
duration = 30
def __init__(self, mode='train', seed=0, n_folds=5, split=1, feat_type='raw', **kwargs):
"""
......@@ -62,6 +67,7 @@ class GTZAN(AudioClassificationDataset):
super(GTZAN, self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
duration=self.duration,
feat_type=feat_type,
**kwargs)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import random
from typing import List, Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['RAVDESS']
class RAVDESS(AudioClassificationDataset):
"""
The RAVDESS contains 24 professional actors (12 female, 12 male), vocalizing two
lexically-matched statements in a neutral North American accent. Speech emotions
includes calm, happy, sad, angry, fearful, surprise, and disgust expressions.
Each expression is produced at two levels of emotional intensity (normal, strong),
with an additional neutral expression.
Reference:
The Ryerson Audio-Visual Database of Emotional Speech and Song (RAVDESS):
A dynamic, multimodal set of facial and vocal expressions in North American English
https://doi.org/10.1371/journal.pone.0196391
"""
archieves = [
{
'url': 'https://zenodo.org/record/1188976/files/Audio_Song_Actors_01-24.zip',
'md5': '5411230427d67a21e18aa4d466e6d1b9',
},
{
'url': 'https://zenodo.org/record/1188976/files/Audio_Speech_Actors_01-24.zip',
'md5': 'bc696df654c87fed845eb13823edef8a',
},
]
label_list = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']
meta_info = collections.namedtuple(
'META_INFO', ('modality', 'vocal_channel', 'emotion', 'emotion_intensity', 'statement', 'repitition', 'actor'))
speech_path = os.path.join(DATA_HOME, 'Audio_Speech_Actors_01-24')
song_path = os.path.join(DATA_HOME, 'Audio_Song_Actors_01-24')
sample_rate = 44100 # 44.1 khz
duration = 5 # 5s
def __init__(self, mode='train', seed=0, n_folds=5, split=1, feat_type='raw', **kwargs):
"""
Ags:
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train or dev).
seed (:obj:`int`, `optional`, defaults to 0):
Set the random seed to shuffle samples.
n_folds (:obj:`int`, `optional`, defaults to 5):
Split the dataset into n folds. 1 fold for dev dataset and n-1 for train dataset.
split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
files, labels = self._get_data(mode, seed, n_folds, split)
super(RAVDESS, self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
duration=self.duration,
feat_type=feat_type,
**kwargs)
def _get_meta_info(self, files) -> List[collections.namedtuple]:
ret = []
for file in files:
basename_without_extend = os.path.basename(file)[:-4]
ret.append(self.meta_info(*basename_without_extend.split('-')))
return ret
def _get_data(self, mode, seed, n_folds, split) -> Tuple[List[str], List[int]]:
if not os.path.isdir(self.speech_path) and not os.path.isdir(self.song_path):
download_and_decompress(self.archieves, DATA_HOME)
wav_files = []
for root, _, files in os.walk(self.speech_path):
for file in files:
if file.endswith('.wav'):
wav_files.append(os.path.join(root, file))
for root, _, files in os.walk(self.song_path):
for file in files:
if file.endswith('.wav'):
wav_files.append(os.path.join(root, file))
random.seed(seed) # shuffle samples to split data
random.shuffle(wav_files) # make sure using the same seed to create train and dev dataset
meta_info = self._get_meta_info(wav_files)
files = []
labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info):
_, _, emotion, _, _, _, _ = sample
target = int(emotion) - 1
fold = idx // n_samples_per_fold + 1
if mode == 'train' and int(fold) != split:
files.append(wav_files[idx])
labels.append(target)
if mode != 'train' and int(fold) == split:
files.append(wav_files[idx])
labels.append(target)
return files, labels
if __name__ == "__main__":
train_ds = RAVDESS(mode='train', feat_type='mel_spect')
dev_ds = RAVDESS(mode='dev', feat_type='mel_spect')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import random
from typing import List, Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from ..utils.log import logger
from .dataset import AudioClassificationDataset
__all__ = ['TESS']
class TESS(AudioClassificationDataset):
"""
TESS Dataset
"""
archieves = []
label_list = [
'angry',
'disgust',
'fear',
'happy',
'neutral',
'ps', # pleasant surprise
'sad',
]
meta_info = collections.namedtuple('META_INFO', ('speaker', 'word', 'emotion'))
audio_path = os.path.join(DATA_HOME, 'TESS Toronto emotional speech set data')
sample_rate = 24414
duration = 2
def __init__(self, mode='train', seed=0, n_folds=5, split=1, feat_type='raw', **kwargs):
"""
Ags:
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train or dev).
seed (:obj:`int`, `optional`, defaults to 0):
Set the random seed to shuffle samples.
n_folds (:obj:`int`, `optional`, defaults to 5):
Split the dataset into n folds. 1 fold for dev dataset and n-1 for train dataset.
split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
files, labels = self._get_data(mode, seed, n_folds, split)
super(TESS, self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
duration=self.duration,
feat_type=feat_type,
**kwargs)
def _get_meta_info(self, files) -> List[collections.namedtuple]:
ret = []
for file in files:
basename_without_extend = os.path.basename(file)[:-4]
ret.append(self.meta_info(*basename_without_extend.split('_')))
return ret
def _get_data(self, mode, seed, n_folds, split) -> Tuple[List[str], List[int]]:
if not os.path.isdir(self.audio_path):
download_and_decompress(self.archieves, DATA_HOME)
wav_files = []
for root, _, files in os.walk(self.audio_path):
for file in files:
if file.endswith('.wav'):
wav_files.append(os.path.join(root, file))
random.seed(seed) # shuffle samples to split data
random.shuffle(wav_files) # make sure using the same seed to create train and dev dataset
meta_info = self._get_meta_info(wav_files)
files = []
labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info):
_, _, emotion = sample
target = self.label_list.index(emotion)
fold = idx // n_samples_per_fold + 1
if mode == 'train' and int(fold) != split:
files.append(wav_files[idx])
labels.append(target)
if mode != 'train' and int(fold) == split:
files.append(wav_files[idx])
labels.append(target)
return files, labels
......@@ -18,7 +18,6 @@ from typing import List, Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from ..utils.log import logger
from .dataset import AudioClassificationDataset
__all__ = ['UrbanSound8K']
......@@ -26,7 +25,14 @@ __all__ = ['UrbanSound8K']
class UrbanSound8K(AudioClassificationDataset):
"""
UrbanSound8K Dataset
UrbanSound8K dataset contains 8732 labeled sound excerpts (<=4s) of urban
sounds from 10 classes: air_conditioner, car_horn, children_playing, dog_bark,
drilling, enginge_idling, gun_shot, jackhammer, siren, and street_music. The
classes are drawn from the urban sound taxonomy.
Reference:
A Dataset and Taxonomy for Urban Sound Research
https://dl.acm.org/doi/10.1145/2647868.2655045
"""
archieves = [
......@@ -47,6 +53,7 @@ class UrbanSound8K(AudioClassificationDataset):
super(UrbanSound8K, self).__init__(files=files,
labels=labels,
sample_rate=self.sample_rate,
duration=self.duration,
feat_type=feat_type,
**kwargs)
"""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cnn6 import CNN6
from .cnn10 import CNN10
from .cnn14 import CNN14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.log import logger
from .conv import ConvBlock
class CNN10(nn.Layer):
"""
The CNN10(14-layer CNNs) mainly consist of 4 convolutional blocks while each convolutional
block consists of 2 convolutional layers with a kernel size of 3 × 3.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 512
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN10, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
print(f'Loaded CNN10 pretrained parameters from: {checkpoint}')
else:
print('No valid checkpoints for CNN10. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.log import logger
from .conv import ConvBlock
class CNN14(nn.Layer):
"""
The CNN14(14-layer CNNs) mainly consist of 6 convolutional blocks while each convolutional
block consists of 2 convolutional layers with a kernel size of 3 × 3.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 2048
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN14, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.fc1 = nn.Linear(2048, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
logger.info(f'Loaded CNN14 pretrained parameters from: {checkpoint}')
else:
logger.error('No valid checkpoints for CNN14. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...utils.log import logger
from .conv import ConvBlock5x5
class CNN6(nn.Layer):
"""
The CNN14(14-layer CNNs) mainly consist of 4 convolutional blocks while each convolutional
block consists of 1 convolutional layers with a kernel size of 5 × 5.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 512
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN6, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
print(f'Loaded CNN6 pretrained parameters from: {checkpoint}')
else:
print('No valid checkpoints for CNN6. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class ConvBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.conv2 = nn.Conv2D(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
self.bn2 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class ConvBlock5x5(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .PANNs import CNN14, CNN10, CNN6
__all__ = [
'CNN14',
'CNN10',
'CNN6',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ..utils.download import load_state_dict_from_url
from ..utils.env import MODEL_HOME
from ..utils.log import logger
__all__ = ['CNN14', 'CNN10', 'CNN6', 'cnn14', 'cnn10', 'cnn6']
pretrained_model_urls = {
# TODO: replace test urls
'cnn14': 'https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams',
'cnn10': 'https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams',
'cnn6': 'https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams',
}
class ConvBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.conv2 = nn.Conv2D(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
self.bn2 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class ConvBlock5x5(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class CNN14(nn.Layer):
"""
The CNN14(14-layer CNNs) mainly consist of 6 convolutional blocks while each convolutional
block consists of 2 convolutional layers with a kernel size of 3 × 3.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 2048
def __init__(self, extract_embedding: bool = True):
super(CNN14, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.fc1 = nn.Linear(2048, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
class CNN10(nn.Layer):
"""
The CNN10(14-layer CNNs) mainly consist of 4 convolutional blocks while each convolutional
block consists of 2 convolutional layers with a kernel size of 3 × 3.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 512
def __init__(self, extract_embedding: bool = True):
super(CNN10, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
class CNN6(nn.Layer):
"""
The CNN14(14-layer CNNs) mainly consist of 4 convolutional blocks while each convolutional
block consists of 1 convolutional layers with a kernel size of 5 × 5.
Reference:
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
https://arxiv.org/pdf/1912.10211.pdf
"""
emb_size = 512
def __init__(self, extract_embedding: bool = True):
super(CNN6, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
def cnn14(pretrained: bool = False, extract_embedding: bool = True) -> CNN14:
model = CNN14(extract_embedding=extract_embedding)
if pretrained:
state_dict = load_state_dict_from_url(url=pretrained_model_urls['cnn14'],
path=os.path.join(MODEL_HOME, 'panns'))
model.set_state_dict(state_dict)
return model
def cnn10(pretrained: bool = False, extract_embedding: bool = True) -> CNN10:
model = CNN10(extract_embedding=extract_embedding)
if pretrained:
state_dict = load_state_dict_from_url(url=pretrained_model_urls['cnn10'],
path=os.path.join(MODEL_HOME, 'panns'))
model.set_state_dict(state_dict)
return model
def cnn6(pretrained: bool = False, extract_embedding: bool = True) -> CNN6:
model = CNN6(extract_embedding=extract_embedding)
if pretrained:
state_dict = load_state_dict_from_url(url=pretrained_model_urls['cnn6'], path=os.path.join(MODEL_HOME, 'panns'))
model.set_state_dict(state_dict)
return model
......@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, List
from paddle.framework import load as load_state_dict
from paddle.utils import download
from .log import logger
......@@ -25,9 +27,22 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str):
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
logger.info(f'Downloading from: {archive["url"]}')
download.get_path_from_url(archive['url'], path, archive['md5'])
def load_state_dict_from_url(url: str, path: str, md5: str = None):
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load_state_dict(os.path.join(path, os.path.basename(url)))
colorama
colorlog
easydict
filelock
gitpython
numpy
packaging
Pillow
pyyaml
pyzmq
rarfile
tqdm
scipy
librosa
tqdm
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册