未验证 提交 23493590 编写于 作者: K KP 提交者: GitHub

Add audio classification module and ESC50 dataset.

上级 c849198a
# PaddleHub 声音分类
本示例展示如何使用PaddleHub Fine-tune API以及CNN14等预训练模型完成声音分类和Tagging的任务。
CNN14等预训练模型的详情,请参考论文[PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/pdf/1912.10211.pdf)和代码[audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn)
## 如何开始Fine-tune
我们以环境声音分类公开数据集[ESC50](https://github.com/karolpiczak/ESC-50)为示例数据集,可以运行下面的命令,在训练集(train.npz)上进行模型训练,并在开发集(dev.npz)验证。通过如下命令,即可启动训练。
```python
# 设置使用的GPU卡号
export CUDA_VISIBLE_DEVICES=0
python train.py
```
## 代码步骤
使用PaddleHub Fine-tune API进行Fine-tune可以分为4个步骤。
### Step1: 选择模型
```python
import paddle
import paddlehub as hub
from paddlehub.datasets import ESC50
model = hub.Module(name='panns_cnn14', version='1.0.0', task='sound-cls', num_class=ESC50.num_class)
```
其中,参数:
- `name`: 模型名称,可以选择`panns_cnn14``panns_cnn10``panns_cnn6`,具体的模型参数信息可见下表。
- `version`: module版本号
- `task`:模型的执行任务。`sound-cls`表示声音分类任务;`None`表示Audio Tagging任务。
- `num_classes`:表示当前声音分类任务的类别数,根据具体使用的数据集确定。
目前可选用的预训练模型:
模型名 | PaddleHub Module
-----------| :------:
CNN14 | `hub.Module(name='panns_cnn14')`
CNN10 | `hub.Module(name='panns_cnn10')`
CNN6 | `hub.Module(name='panns_cnn6')`
### Step2: 加载数据集
```python
train_dataset = ESC50(mode='train')
dev_dataset = ESC50(mode='dev')
```
### Step3: 选择优化策略和运行配置
```python
optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir='./', use_gpu=True)
```
#### 优化策略
Paddle2.0提供了多种优化器选择,如`SGD`, `AdamW`, `Adamax`等,详细参见[策略](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html)
其中`AdamW`:
- `learning_rate`: 全局学习率。默认为1e-3;
- `parameters`: 待优化模型参数。
其余可配置参数请参考[AdamW](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/adamw/AdamW_cn.html#cn-api-paddle-optimizer-adamw)
#### 运行配置
`Trainer` 主要控制Fine-tune的训练,包含以下可控制的参数:
- `model`: 被优化模型;
- `optimizer`: 优化器选择;
- `use_vdl`: 是否使用vdl可视化训练过程;
- `checkpoint_dir`: 保存模型参数的地址;
- `compare_metrics`: 保存最优模型的衡量指标;
### Step4: 执行训练和模型评估
```python
trainer.train(
train_dataset,
epochs=50,
batch_size=16,
eval_dataset=dev_dataset,
save_interval=10,
)
trainer.evaluate(dev_dataset, batch_size=16)
```
`trainer.train`执行模型的训练,其参数可以控制具体的训练过程,主要的参数包含:
- `train_dataset`: 训练时所用的数据集;
- `epochs`: 训练轮数;
- `batch_size`: 训练时每一步用到的样本数目,如果使用GPU,请根据实际情况调整batch_size;
- `num_workers`: works的数量,默认为0;
- `eval_dataset`: 验证集;
- `log_interval`: 打印日志的间隔, 单位为执行批训练的次数。
- `save_interval`: 保存模型的间隔频次,单位为执行训练的轮数。
`trainer.evaluate`执行模型的评估,主要的参数包含:
- `eval_dataset`: 模型评估时所用的数据集;
- `batch_size`: 模型评估时每一步用到的样本数目,如果使用GPU,请根据实际情况调整batch_size
## 模型预测
当完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在`${CHECKPOINT_DIR}/best_model`目录下,其中`${CHECKPOINT_DIR}`目录为Fine-tune时所选择的保存checkpoint的目录。
以下代码将本地的音频文件`./cat.wav`作为预测数据,使用训好的模型对它进行分类,输出结果。
```python
import os
import librosa
import paddlehub as hub
from paddlehub.datasets import ESC50
wav = './cat.wav' # 存储在本地的需要预测的wav文件
sr = 44100 # 音频文件的采样率
checkpoint = './best_model/model.pdparams' # 模型checkpoint
label_map = {idx: label for idx, label in enumerate(ESC50.label_list)}
model = hub.Module(name='panns_cnn14',
version='1.0.0',
task='sound-cls',
num_class=ESC50.num_class,
label_map=label_map,
load_checkpoint=checkpoint)
data = [librosa.load(wav, sr=sr)[0]]
result = model.predict(data, sample_rate=sr, batch_size=1, feat_type='mel', use_gpu=True)
print(result[0]) # result[0]包含音频文件属于各类别的概率值
```
## Audio Tagging
当前使用的模型是基于[Audioset数据集](https://research.google.com/audioset/)的预训练模型,除了以上的针对特定声音分类数据集的finetune任务,模型还支持基于Audioset 527个标签的Tagging功能。
以下代码将本地的音频文件`./cat.wav`作为预测数据,使用预训练模型对它进行打分,输出top 10的标签和对应的得分。
```python
import os
import librosa
import numpy as np
import paddlehub as hub
from paddlehub.env import MODULE_HOME
wav = './cat.wav' # 存储在本地的需要预测的wav文件
sr = 44100 # 音频文件的采样率
topk = 10 # 展示音频得分前10的标签和分数
model = hub.Module(name='panns_cnn14', version='1.0.0', task=None)
# 读取audioset数据集的label文件
label_file = os.path.join(MODULE_HOME, 'panns_cnn14', 'audioset_labels.txt')
label_map = {}
with open(label_file, 'r') as f:
for i, l in enumerate(f.readlines()):
label_map[i] = l.strip()
data = [librosa.load(wav, sr=sr)[0]]
result = model.predict(data, sample_rate=sr, batch_size=1, feat_type='mel', use_gpu=True)
# 打印topk的类别和对应得分
for label, score in list(result[0].items())[:topk]:
msg += f'{label}: {score}\n'
print(msg)
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import librosa
import numpy as np
import paddlehub as hub
from paddlehub.env import MODULE_HOME
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--wav", type=str, required=True, help="Audio file to infer.")
parser.add_argument("--sr", type=int, default=32000, help="Sample rate of inference audio.")
parser.add_argument("--model_type", type=str, default='panns_cnn14', help="Select model to to inference.")
parser.add_argument("--topk", type=int, default=10, help="Show top k results of audioset labels.")
args = parser.parse_args()
if __name__ == '__main__':
label_file = os.path.join(MODULE_HOME, args.model_type, 'audioset_labels.txt')
label_map = {}
with open(label_file, 'r') as f:
for i, l in enumerate(f.readlines()):
label_map[i] = l.strip()
model = hub.Module(name=args.model_type, version='1.0.0', task=None, label_map=label_map)
data = [librosa.load(args.wav, sr=args.sr)[0]] # (t, num_mel_bins)
result = model.predict(data, sample_rate=args.sr, batch_size=1, feat_type='mel', use_gpu=True)
msg = f'[{args.wav}]\n'
for label, score in list(result[0].items())[:args.topk]:
msg += f'{label}: {score}\n'
print(msg)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import librosa
import paddlehub as hub
from paddlehub.datasets import ESC50
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--wav", type=str, required=True, help="Audio file to infer.")
parser.add_argument("--sr", type=int, default=44100, help="Sample rate of inference audio.")
parser.add_argument("--model_type", type=str, default='panns_cnn14', help="Select model to to inference.")
parser.add_argument("--topk", type=int, default=1, help="Show top k results of prediction labels.")
parser.add_argument("--checkpoint",
type=str,
default='./checkpoint/best_model/model.pdparams',
help="Checkpoint of model.")
args = parser.parse_args()
if __name__ == '__main__':
label_map = {idx: label for idx, label in enumerate(ESC50.label_list)}
model = hub.Module(name=args.model_type,
version='1.0.0',
task='sound-cls',
num_class=ESC50.num_class,
label_map=label_map,
load_checkpoint=args.checkpoint)
data = [librosa.load(args.wav, sr=args.sr)[0]]
result = model.predict(data, sample_rate=args.sr, batch_size=1, feat_type='mel', use_gpu=True)
msg = f'[{args.wav}]\n'
for label, score in list(result[0].items())[:args.topk]:
msg += f'{label}: {score}\n'
print(msg)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import paddle
import paddlehub as hub
from paddlehub.datasets import ESC50
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=50, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu",
type=ast.literal_eval,
default=True,
help="Whether use GPU for fine-tuning, input should be True or False")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint")
parser.add_argument("--save_interval", type=int, default=10, help="Save checkpoint every n epoch.")
args = parser.parse_args()
if __name__ == "__main__":
model = hub.Module(name='panns_cnn14', task='sound-cls', num_class=ESC50.num_class)
train_dataset = ESC50(mode='train')
dev_dataset = ESC50(mode='dev')
optimizer = paddle.optimizer.AdamW(learning_rate=args.learning_rate, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu)
trainer.train(
train_dataset,
epochs=args.num_epoch,
batch_size=args.batch_size,
eval_dataset=dev_dataset,
save_interval=args.save_interval,
)
```shell
$ hub install panns_cnn10==1.0.0
```
`panns_cnn10`是一个基于[Google Audioset](https://research.google.com/audioset/)数据集训练的声音分类/识别的模型。该模型主要包含8个卷积层和2个全连接层,模型参数为4.9M。经过预训练后,可以用于提取音频的embbedding,维度是512。
更多详情请参考论文:[PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/pdf/1912.10211.pdf)
## API
```python
def __init__(
task,
num_class=None,
label_map=None,
load_checkpoint=None,
**kwargs,
)
```
创建Module对象。
**参数**
* `task`: 任务名称,可为`sound-cls`或者`None``sound-cls`代表声音分类任务,可以对声音分类的数据集进行finetune;为`None`时可以获取预训练模型对音频进行分类/Tagging。
* `num_classes`:声音分类任务的类别数,finetune时需要指定,数值与具体使用的数据集类别数一致。
* `label_map`:预测时的类别映射表。
* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
* `**kwargs`:用户额外指定的关键字字典类型的参数。
```python
def predict(
data,
sample_rate,
batch_size=1,
feat_type='mel',
use_gpu=False
)
```
**参数**
* `data`: 待预测数据,格式为\[waveform1, wavwform2…,\],其中每个元素都是一个一维numpy列表,是音频的波形采样数值列表。
* `sample_rate`:音频文件的采样率。
* `feat_type`:音频特征的种类选取,当前支持`'mel'`(详情可查看[Mel-frequency cepstrum](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum))和原始波形特征`'raw'`
* `batch_size`:模型批处理大小。
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
* `results`:list类型,不同任务类型的返回结果如下
* 声音分类(task参数为`sound-cls`):列表里包含每个音频文件的分类标签。
* Tagging(task参数为`None`):列表里包含每个音频文件527个类别([Audioset标签](https://research.google.com/audioset/))的得分。
**代码示例**
- [ESC50](https://github.com/karolpiczak/ESC-50)声音分类预测
```python
import librosa
import paddlehub as hub
from paddlehub.datasets import ESC50
sr = 44100 # 音频文件的采样率
wav_file = '/data/cat.wav' # 用于预测的音频文件路径
checkpoint = 'model.pdparams' # 用于预测的模型参数
label_map = {idx: label for idx, label in enumerate(ESC50.label_list)}
model = hub.Module(
name='panns_cnn10',
version='1.0.0',
task='sound-cls',
num_class=ESC50.num_class,
label_map=label_map,
load_checkpoint=checkpoint)
data = [librosa.load(wav_file, sr=sr)[0]]
result = model.predict(
data,
sample_rate=sr,
batch_size=1,
feat_type='mel',
use_gpu=True)
print('File: {}\tLable: {}'.format(wav_file, result[0]))
```
- Audioset Tagging
```python
import librosa
import numpy as np
import paddlehub as hub
def show_topk(k, label_map, file, result):
"""
展示topk的分的类别和分数。
"""
result = np.asarray(result)
topk_idx = (-result).argsort()[:k]
msg = f'[{file}]\n'
for idx in topk_idx:
label, score = label_map[idx], result[idx]
msg += f'{label}: {score}\n'
print(msg)
sr = 44100 # 音频文件的采样率
wav_file = '/data/cat.wav' # 用于预测的音频文件路径
label_file = './audioset_labels.txt' # audioset标签文本文件
topk = 10 # 展示的topk数
label_map = {}
with open(label_file, 'r') as f:
for i, l in enumerate(f.readlines()):
label_map[i] = l.strip()
model = hub.Module(
name='panns_cnn10',
version='1.0.0',
task=None)
data = [librosa.load(wav_file, sr=sr)[0]]
result = model.predict(
data,
sample_rate=sr,
batch_size=1,
feat_type='mel',
use_gpu=True)
show_topk(topk, label_map, wav_file, result[0])
```
详情可参考PaddleHub示例:
- [AudioClassification](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0/demo/audio_classification)
## 查看代码
https://github.com/qiuqiangkong/audioset_tagging_cnn
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
## 更新历史
* 1.0.0
初始发布,动态图版本模型,支持声音分类`sound-cls`任务的fine-tune和基于Audioset Tagging预测。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from typing import Dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from panns_cnn10.network import CNN10
from paddlehub.env import MODULE_HOME
from paddlehub.module.audio_module import AudioClassifierModule
from paddlehub.module.module import moduleinfo
from paddlehub.utils.log import logger
@moduleinfo(name="panns_cnn10",
version="1.0.0",
summary="",
author="Baidu",
author_email="",
type="audio/sound_classification",
meta=AudioClassifierModule)
class PANN(nn.Layer):
def __init__(
self,
task: str,
num_class: int = None,
label_map: Dict = None,
load_checkpoint: str = None,
**kwargs,
):
super(PANN, self).__init__()
if label_map:
self.label_map = label_map
self.num_class = len(label_map)
else:
self.num_class = num_class
if task == 'sound-cls':
self.cnn10 = CNN10(extract_embedding=True,
checkpoint=os.path.join(MODULE_HOME, 'panns_cnn10', 'cnn10.pdparams'))
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.cnn10.emb_size, num_class)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
else:
self.cnn10 = CNN10(extract_embedding=False,
checkpoint=os.path.join(MODULE_HOME, 'panns_cnn10', 'cnn10.pdparams'))
self.task = task
if load_checkpoint is not None and os.path.isfile(load_checkpoint):
state_dict = paddle.load(load_checkpoint)
self.set_state_dict(state_dict)
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
def forward(self, feats, labels=None):
# feats: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
feats = feats.unsqueeze(1)
if self.task == 'sound-cls':
embeddings = self.cnn10(feats)
embeddings = self.dropout(embeddings)
logits = self.fc(embeddings)
probs = F.softmax(logits, axis=1)
if labels is not None:
loss = self.criterion(logits, labels)
correct = self.metric.compute(probs, labels)
acc = self.metric.update(correct)
return probs, loss, {'acc': acc}
return probs
else:
audioset_logits = self.cnn10(feats)
return audioset_logits
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlehub.utils.log import logger
class ConvBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.conv2 = nn.Conv2D(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
self.bn2 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class CNN10(nn.Layer):
emb_size = 512
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN10, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
print(f'Loaded CNN10 pretrained parameters from: {checkpoint}')
else:
print('No valid checkpoints for CNN10. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
```shell
$ hub install panns_cnn14==1.0.0
```
`panns_cnn14`是一个基于[Google Audioset](https://research.google.com/audioset/)数据集训练的声音分类/识别的模型。该模型主要包含12个卷积层和2个全连接层,模型参数为79.6M。经过预训练后,可以用于提取音频的embbedding,维度是2048。
更多详情请参考论文:[PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/pdf/1912.10211.pdf)
## API
```python
def __init__(
task,
num_class=None,
label_map=None,
load_checkpoint=None,
**kwargs,
)
```
创建Module对象。
**参数**
* `task`: 任务名称,可为`sound-cls`或者`None``sound-cls`代表声音分类任务,可以对声音分类的数据集进行finetune;为`None`时可以获取预训练模型对音频进行分类/Tagging。
* `num_classes`:声音分类任务的类别数,finetune时需要指定,数值与具体使用的数据集类别数一致。
* `label_map`:预测时的类别映射表。
* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
* `**kwargs`:用户额外指定的关键字字典类型的参数。
```python
def predict(
data,
sample_rate,
batch_size=1,
feat_type='mel',
use_gpu=False
)
```
**参数**
* `data`: 待预测数据,格式为\[waveform1, wavwform2…,\],其中每个元素都是一个一维numpy列表,是音频的波形采样数值列表。
* `sample_rate`:音频文件的采样率。
* `feat_type`:音频特征的种类选取,当前支持`'mel'`(详情可查看[Mel-frequency cepstrum](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum))和原始波形特征`'raw'`
* `batch_size`:模型批处理大小。
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
* `results`:list类型,不同任务类型的返回结果如下
* 声音分类(task参数为`sound-cls`):列表里包含每个音频文件的分类标签。
* Tagging(task参数为`None`):列表里包含每个音频文件527个类别([Audioset标签](https://research.google.com/audioset/))的得分。
**代码示例**
- [ESC50](https://github.com/karolpiczak/ESC-50)声音分类预测
```python
import librosa
import paddlehub as hub
from paddlehub.datasets import ESC50
sr = 44100 # 音频文件的采样率
wav_file = '/data/cat.wav' # 用于预测的音频文件路径
checkpoint = 'model.pdparams' # 用于预测的模型参数
label_map = {idx: label for idx, label in enumerate(ESC50.label_list)}
model = hub.Module(
name='panns_cnn14',
version='1.0.0',
task='sound-cls',
num_class=ESC50.num_class,
label_map=label_map,
load_checkpoint=checkpoint)
data = [librosa.load(wav_file, sr=sr)[0]]
result = model.predict(
data,
sample_rate=sr,
batch_size=1,
feat_type='mel',
use_gpu=True)
print('File: {}\tLable: {}'.format(wav_file, result[0]))
```
- Audioset Tagging
```python
import librosa
import numpy as np
import paddlehub as hub
def show_topk(k, label_map, file, result):
"""
展示topk的分的类别和分数。
"""
result = np.asarray(result)
topk_idx = (-result).argsort()[:k]
msg = f'[{file}]\n'
for idx in topk_idx:
label, score = label_map[idx], result[idx]
msg += f'{label}: {score}\n'
print(msg)
sr = 44100 # 音频文件的采样率
wav_file = '/data/cat.wav' # 用于预测的音频文件路径
label_file = './audioset_labels.txt' # audioset标签文本文件
topk = 10 # 展示的topk数
label_map = {}
with open(label_file, 'r') as f:
for i, l in enumerate(f.readlines()):
label_map[i] = l.strip()
model = hub.Module(
name='panns_cnn14',
version='1.0.0',
task=None)
data = [librosa.load(wav_file, sr=sr)[0]]
result = model.predict(
data,
sample_rate=sr,
batch_size=1,
feat_type='mel',
use_gpu=True)
show_topk(topk, label_map, wav_file, result[0])
```
详情可参考PaddleHub示例:
- [AudioClassification](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0/demo/audio_classification)
## 查看代码
https://github.com/qiuqiangkong/audioset_tagging_cnn
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
## 更新历史
* 1.0.0
初始发布,动态图版本模型,支持声音分类`sound-cls`任务的fine-tune和基于Audioset Tagging预测。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from typing import Dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from panns_cnn14.network import CNN14
from paddlehub.env import MODULE_HOME
from paddlehub.module.audio_module import AudioClassifierModule
from paddlehub.module.module import moduleinfo
from paddlehub.utils.log import logger
@moduleinfo(name="panns_cnn14",
version="1.0.0",
summary="",
author="Baidu",
author_email="",
type="audio/sound_classification",
meta=AudioClassifierModule)
class PANN(nn.Layer):
def __init__(
self,
task: str,
num_class: int = None,
label_map: Dict = None,
load_checkpoint: str = None,
**kwargs,
):
super(PANN, self).__init__()
if label_map:
self.label_map = label_map
self.num_class = len(label_map)
else:
self.num_class = num_class
if task == 'sound-cls':
self.cnn14 = CNN14(extract_embedding=True,
checkpoint=os.path.join(MODULE_HOME, 'panns_cnn14', 'cnn14.pdparams'))
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.cnn14.emb_size, num_class)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
else:
self.cnn14 = CNN14(extract_embedding=False,
checkpoint=os.path.join(MODULE_HOME, 'panns_cnn14', 'cnn14.pdparams'))
self.task = task
if load_checkpoint is not None and os.path.isfile(load_checkpoint):
state_dict = paddle.load(load_checkpoint)
self.set_state_dict(state_dict)
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
def forward(self, feats, labels=None):
# feats: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
feats = feats.unsqueeze(1)
if self.task == 'sound-cls':
embeddings = self.cnn14(feats)
embeddings = self.dropout(embeddings)
logits = self.fc(embeddings)
probs = F.softmax(logits, axis=1)
if labels is not None:
loss = self.criterion(logits, labels)
correct = self.metric.compute(probs, labels)
acc = self.metric.update(correct)
return probs, loss, {'acc': acc}
return probs
else:
audioset_logits = self.cnn14(feats)
return audioset_logits
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlehub.utils.log import logger
class ConvBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.conv2 = nn.Conv2D(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
self.bn2 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class CNN14(nn.Layer):
emb_size = 2048
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN14, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.fc1 = nn.Linear(2048, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
logger.info(f'Loaded CNN14 pretrained parameters from: {checkpoint}')
else:
logger.error('No valid checkpoints for CNN14. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
```shell
$ hub install panns_cnn6==1.0.0
```
`panns_cnn6`是一个基于[Google Audioset](https://research.google.com/audioset/)数据集训练的声音分类/识别的模型。该模型主要包含4个卷积层和2个全连接层,模型参数为4.5M。经过预训练后,可以用于提取音频的embbedding,维度是512。
更多详情请参考论文:[PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/pdf/1912.10211.pdf)
## API
```python
def __init__(
task,
num_class=None,
label_map=None,
load_checkpoint=None,
**kwargs,
)
```
创建Module对象。
**参数**
* `task`: 任务名称,可为`sound-cls`或者`None``sound-cls`代表声音分类任务,可以对声音分类的数据集进行finetune;为`None`时可以获取预训练模型对音频进行分类/Tagging。
* `num_classes`:声音分类任务的类别数,finetune时需要指定,数值与具体使用的数据集类别数一致。
* `label_map`:预测时的类别映射表。
* `load_checkpoint`:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
* `**kwargs`:用户额外指定的关键字字典类型的参数。
```python
def predict(
data,
sample_rate,
batch_size=1,
feat_type='mel',
use_gpu=False
)
```
**参数**
* `data`: 待预测数据,格式为\[waveform1, wavwform2…,\],其中每个元素都是一个一维numpy列表,是音频的波形采样数值列表。
* `sample_rate`:音频文件的采样率。
* `feat_type`:音频特征的种类选取,当前支持`'mel'`(详情可查看[Mel-frequency cepstrum](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum))和原始波形特征`'raw'`
* `batch_size`:模型批处理大小。
* `use_gpu`:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
* `results`:list类型,不同任务类型的返回结果如下
* 声音分类(task参数为`sound-cls`):列表里包含每个音频文件的分类标签。
* Tagging(task参数为`None`):列表里包含每个音频文件527个类别([Audioset标签](https://research.google.com/audioset/))的得分。
**代码示例**
- [ESC50](https://github.com/karolpiczak/ESC-50)声音分类预测
```python
import librosa
import paddlehub as hub
from paddlehub.datasets import ESC50
sr = 44100 # 音频文件的采样率
wav_file = '/data/cat.wav' # 用于预测的音频文件路径
checkpoint = 'model.pdparams' # 用于预测的模型参数
label_map = {idx: label for idx, label in enumerate(ESC50.label_list)}
model = hub.Module(
name='panns_cnn6',
version='1.0.0',
task='sound-cls',
num_class=ESC50.num_class,
label_map=label_map,
load_checkpoint=checkpoint)
data = [librosa.load(wav_file, sr=sr)[0]]
result = model.predict(
data,
sample_rate=sr,
batch_size=1,
feat_type='mel',
use_gpu=True)
print('File: {}\tLable: {}'.format(wav_file, result[0]))
```
- Audioset Tagging
```python
import librosa
import numpy as np
import paddlehub as hub
def show_topk(k, label_map, file, result):
"""
展示topk的分的类别和分数。
"""
result = np.asarray(result)
topk_idx = (-result).argsort()[:k]
msg = f'[{file}]\n'
for idx in topk_idx:
label, score = label_map[idx], result[idx]
msg += f'{label}: {score}\n'
print(msg)
sr = 44100 # 音频文件的采样率
wav_file = '/data/cat.wav' # 用于预测的音频文件路径
label_file = './audioset_labels.txt' # audioset标签文本文件
topk = 10 # 展示的topk数
label_map = {}
with open(label_file, 'r') as f:
for i, l in enumerate(f.readlines()):
label_map[i] = l.strip()
model = hub.Module(
name='panns_cnn6',
version='1.0.0',
task=None)
data = [librosa.load(wav_file, sr=sr)[0]]
result = model.predict(
data,
sample_rate=sr,
batch_size=1,
feat_type='mel',
use_gpu=True)
show_topk(topk, label_map, wav_file, result[0])
```
详情可参考PaddleHub示例:
- [AudioClassification](https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0/demo/audio_classification)
## 查看代码
https://github.com/qiuqiangkong/audioset_tagging_cnn
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
## 更新历史
* 1.0.0
初始发布,动态图版本模型,支持声音分类`sound-cls`任务的fine-tune和基于Audioset Tagging预测。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from typing import Dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from panns_cnn6.network import CNN6
from paddlehub.env import MODULE_HOME
from paddlehub.module.audio_module import AudioClassifierModule
from paddlehub.module.module import moduleinfo
from paddlehub.utils.log import logger
@moduleinfo(name="panns_cnn6",
version="1.0.0",
summary="",
author="Baidu",
author_email="",
type="audio/sound_classification",
meta=AudioClassifierModule)
class PANN(nn.Layer):
def __init__(
self,
task: str,
num_class: int = None,
label_map: Dict = None,
load_checkpoint: str = None,
**kwargs,
):
super(PANN, self).__init__()
if label_map:
self.label_map = label_map
self.num_class = len(label_map)
else:
self.num_class = num_class
if task == 'sound-cls':
self.cnn6 = CNN6(extract_embedding=True,
checkpoint=os.path.join(MODULE_HOME, 'panns_cnn6', 'cnn6.pdparams'))
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.cnn6.emb_size, num_class)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
else:
self.cnn6 = CNN6(extract_embedding=False,
checkpoint=os.path.join(MODULE_HOME, 'panns_cnn6', 'cnn6.pdparams'))
self.task = task
if load_checkpoint is not None and os.path.isfile(load_checkpoint):
state_dict = paddle.load(load_checkpoint)
self.set_state_dict(state_dict)
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
def forward(self, feats, labels=None):
# feats: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
feats = feats.unsqueeze(1)
if self.task == 'sound-cls':
embeddings = self.cnn6(feats)
embeddings = self.dropout(embeddings)
logits = self.fc(embeddings)
probs = F.softmax(logits, axis=1)
if labels is not None:
loss = self.criterion(logits, labels)
correct = self.metric.compute(probs, labels)
acc = self.metric.update(correct)
return probs, loss, {'acc': acc}
return probs
else:
audioset_logits = self.cnn6(feats)
return audioset_logits
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlehub.utils.log import logger
class ConvBlock5x5(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channels)
def forward(self, x, pool_size=(2, 2), pool_type='avg'):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
else:
raise Exception(
f'Pooling type of {pool_type} is not supported. It must be one of "max", "avg" and "avg+max".')
return x
class CNN6(nn.Layer):
emb_size = 512
def __init__(self, extract_embedding: bool = True, checkpoint: str = None):
super(CNN6, self).__init__()
self.bn0 = nn.BatchNorm2D(64)
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, self.emb_size)
self.fc_audioset = nn.Linear(self.emb_size, 527)
if checkpoint is not None and os.path.isfile(checkpoint):
state_dict = paddle.load(checkpoint)
self.set_state_dict(state_dict)
print(f'Loaded CNN6 pretrained parameters from: {checkpoint}')
else:
print('No valid checkpoints for CNN6. Start training from scratch.')
self.extract_embedding = extract_embedding
def forward(self, x):
x.stop_gradient = False
x = x.transpose([0, 3, 2, 1])
x = self.bn0(x)
x = x.transpose([0, 3, 2, 1])
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = x.mean(axis=3)
x = x.max(axis=2) + x.mean(axis=2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.fc1(x))
if self.extract_embedding:
output = F.dropout(x, p=0.5, training=self.training)
else:
output = F.sigmoid(self.fc_audioset(x))
return output
......@@ -13,9 +13,11 @@
# limitations under the License.
from paddlehub.datasets.canvas import Canvas
from paddlehub.datasets.chnsenticorp import ChnSentiCorp
from paddlehub.datasets.esc50 import ESC50
from paddlehub.datasets.flowers import Flowers
from paddlehub.datasets.lcqmc import LCQMC
from paddlehub.datasets.minicoco import MiniCOCO
from paddlehub.datasets.chnsenticorp import ChnSentiCorp
from paddlehub.datasets.msra_ner import MSRA_NER
from paddlehub.datasets.lcqmc import LCQMC
from paddlehub.datasets.base_seg_dataset import SegDataset
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import io
import os
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
from paddlehub.utils.utils import extract_melspectrogram
class InputExample(object):
"""
Input example of one audio sample.
"""
def __init__(self, guid: int, source: Union[list, str], label: Optional[str] = None):
self.guid = guid
self.source = source
self.label = label
class BaseAudioDataset(object):
"""
Base class of speech dataset.
"""
def __init__(self, base_path: str, data_file: str, mode: Optional[str] = "train"):
self.data_file = os.path.join(base_path, data_file)
self.mode = mode
def _read_file(self, input_file: str):
raise NotImplementedError
class AudioClassificationDataset(BaseAudioDataset, paddle.io.Dataset):
"""
Base class of audio classification dataset.
"""
_supported_features = ['raw', 'mel']
def __init__(self,
base_path: str,
data_file: str,
file_type: str = 'npz',
mode: str = 'train',
feat_type: str = 'mel',
feat_cfg: dict = None):
super(AudioClassificationDataset, self).__init__(base_path=base_path, mode=mode, data_file=data_file)
self.file_type = file_type
self.feat_type = feat_type
self.feat_cfg = feat_cfg
self.examples = self._read_file(self.data_file)
self.records = self._convert_examples_to_records(self.examples)
def _read_file(self, input_file: str) -> List[InputExample]:
if not os.path.exists(input_file):
raise RuntimeError("Data file: {} not found.".format(input_file))
examples = []
if self.file_type == 'npz':
dataset = np.load(os.path.join(self.data_file), allow_pickle=True)
audio_id = 0
for waveform, label in zip(dataset['waveforms'], dataset['labels']):
example = InputExample(guid=audio_id, source=waveform, label=label)
audio_id += 1
examples.append(example)
else:
raise NotImplementedError(f'Only soppurts npz file type, but got {self.file_type}')
return examples
def _convert_examples_to_records(self, examples: List[InputExample]) -> List[dict]:
records = []
for example in examples:
record = {}
if self.feat_type == 'raw':
record['feat'] = example.source
elif self.feat_type == 'mel':
record['feat'] = extract_melspectrogram(example.source,
sample_rate=self.feat_cfg['sample_rate'],
window_size=self.feat_cfg['window_size'],
hop_size=self.feat_cfg['hop_size'],
mel_bins=self.feat_cfg['mel_bins'],
fmin=self.feat_cfg['fmin'],
fmax=self.feat_cfg['fmax'],
window=self.feat_cfg['window'],
center=True,
pad_mode='reflect',
ref=1.0,
amin=1e-10,
top_db=None)
else:
raise RuntimeError(\
f"Unknown type of self.feat_type: {self.feat_type}, it must be one in {self._supported_features}")
record['label'] = example.label
records.append(record)
return records
def __getitem__(self, idx):
"""
Overload this method when doing extra feature processes or data augmentation.
"""
record = self.records[idx]
return np.array(record['feat']), np.array(record['label'], dtype=np.int64)
def __len__(self):
return len(self.records)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddlehub.datasets.base_audio_dataset import AudioClassificationDataset
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data
@download_data(url="https://bj.bcebos.com/paddlehub-dataset/esc50.tar.gz")
class ESC50(AudioClassificationDataset):
sample_rate = 44100
input_length = int(sample_rate * 5) # 5s
num_class = 50 # class num
label_list = [
# Animals
'Dog',
'Rooster',
'Pig',
'Cow',
'Frog',
'Cat',
'Hen',
'Insects (flying)',
'Sheep',
'Crow',
# Natural soundscapes & water sounds
'Rain',
'Sea waves',
'Crackling fire',
'Crickets',
'Chirping birds',
'Water drops',
'Wind',
'Pouring water',
'Toilet flush',
'Thunderstorm',
# Human, non-speech sounds
'Crying baby',
'Sneezing',
'Clapping',
'Breathing',
'Coughing',
'Footsteps',
'Laughing',
'Brushing teeth',
'Snoring',
'Drinking, sipping',
# Interior/domestic sounds
'Door knock',
'Mouse click',
'Keyboard typing',
'Door, wood creaks',
'Can opening',
'Washing machine',
'Vacuum cleaner',
'Clock alarm',
'Clock tick',
'Glass breaking',
# Exterior/urban noises
'Helicopter',
'Chainsaw',
'Siren',
'Car horn',
'Engine',
'Train',
'Church bells',
'Airplane',
'Fireworks',
'Hand saw',
]
def __init__(self, mode: str = 'train', feat_type: str = 'mel'):
base_path = os.path.join(DATA_HOME, "esc50")
if mode == 'train':
data_file = 'train.npz'
else:
data_file = 'dev.npz'
feat_cfg = dict(sample_rate=self.sample_rate,
window_size=1024,
hop_size=320,
mel_bins=64,
fmin=50,
fmax=14000,
window='hann')
super().__init__(base_path=base_path,
data_file=data_file,
file_type='npz',
mode=mode,
feat_type=feat_type,
feat_cfg=feat_cfg)
if __name__ == "__main__":
train_dataset = ESC50(mode='train')
dev_dataset = ESC50(mode='dev')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import List, Tuple
import numpy as np
import paddle
from paddlehub.module.module import RunModule, runnable, serving
from paddlehub.utils.utils import extract_melspectrogram
class AudioClassifierModule(RunModule):
"""
The base class for audio classifier models.
"""
_tasks_supported = [
'sound-cls',
]
def _batchify(self, data: List[List[float]], sample_rate: int, feat_type: str, batch_size: int):
examples = []
for waveform in data:
if feat_type == 'mel':
feat = extract_melspectrogram(waveform, sample_rate=sample_rate)
examples.append(feat)
else:
examples.append(waveform)
# Seperates data into some batches.
one_batch = []
for example in examples:
one_batch.append(example)
if len(one_batch) == batch_size:
yield one_batch
one_batch = []
if one_batch:
# The last batch whose size is less than the config batch_size setting.
yield one_batch
def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
if self.task == 'sound-cls':
predictions, avg_loss, metric = self(feats=batch[0], labels=batch[1])
else:
raise NotImplementedError
self.metric.reset()
return {'loss': avg_loss, 'metrics': metric}
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int):
if self.task == 'sound-cls':
predictions, avg_loss, metric = self(feats=batch[0], labels=batch[1])
else:
raise NotImplementedError
return {'metrics': metric}
def predict(self,
data: List[List[float]],
sample_rate: int,
batch_size: int = 1,
feat_type: str = 'mel',
use_gpu: bool = False):
if self.task not in self._tasks_supported \
and self.task is not None: # None for getting audioset tags
raise RuntimeError(f'Unknown task {self.task}, current tasks supported:\n'
'1. sound-cls: sound classification;\n'
'2. None: audioset tags')
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
batches = self._batchify(data, sample_rate, feat_type, batch_size)
results = []
self.eval()
for batch in batches:
feats = paddle.to_tensor(batch)
scores = self(feats)
for score in scores.numpy():
result = OrderedDict()
for i in (-score).argsort():
result[self.label_map[i]] = score[i]
results.append(result)
return results
......@@ -15,31 +15,31 @@
import base64
import contextlib
import cv2
import hashlib
import importlib
import math
import os
import requests
import socket
import sys
import time
import tempfile
import time
import traceback
import types
from typing import Generator, List
from urllib.parse import urlparse
import cv2
import numpy as np
import packaging.version
import requests
import paddlehub.env as hubenv
import paddlehub.utils as utils
from paddlehub.utils.log import logger
class Version(packaging.version.Version):
'''Extended implementation of packaging.version.Version'''
def match(self, condition: str) -> bool:
'''
Determine whether the given condition are met
......@@ -105,7 +105,6 @@ class Version(packaging.version.Version):
class Timer(object):
'''Calculate runing speed and estimated time of arrival(ETA)'''
def __init__(self, total_step: int):
self.total_step = total_step
self.last_start_step = 0
......@@ -303,7 +302,7 @@ def record_exception(msg: str) -> str:
'''Record the current exception infomation into the PaddleHub log file witch will be automatically stored according to date.'''
tb = traceback.format_exc()
file = record(tb)
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
def get_record_file() -> str:
......@@ -385,3 +384,41 @@ def trunc_sequence(ids: List[int], max_seq_len: int):
f'The input length {len(ids)} is less than max_seq_len {max_seq_len}. ' \
'Please check the input list and max_seq_len if you really want to truncate a sequence.'
return ids[:max_seq_len]
def extract_melspectrogram(y,
sample_rate: int = 32000,
window_size: int = 1024,
hop_size: int = 320,
mel_bins: int = 64,
fmin: int = 50,
fmax: int = 14000,
window: str = 'hann',
center: bool = True,
pad_mode: str = 'reflect',
ref: float = 1.0,
amin: float = 1e-10,
top_db: float = None):
'''
Extract Mel Spectrogram from a waveform.
'''
try:
import librosa
except Exception:
logger.error('Failed to import librosa. Please check that librosa and numba are correctly installed.')
raise
s = librosa.stft(y,
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode)
power = np.abs(s)**2
melW = librosa.filters.mel(sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax)
mel = np.matmul(melW, power)
db = librosa.power_to_db(mel, ref=ref, amin=amin, top_db=None)
db = db.transpose()
return db
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册