Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
23493590
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
23493590
编写于
4月 06, 2021
作者:
K
KP
提交者:
GitHub
4月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add audio classification module and ESC50 dataset.
上级
c849198a
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
1777 addition
and
7 deletion
+1777
-7
demo/audio_classification/README.md
demo/audio_classification/README.md
+181
-0
demo/audio_classification/audioset_predict.py
demo/audio_classification/audioset_predict.py
+47
-0
demo/audio_classification/predict.py
demo/audio_classification/predict.py
+51
-0
demo/audio_classification/train.py
demo/audio_classification/train.py
+50
-0
modules/audio/audio_classification/PANNs/cnn10/README.md
modules/audio/audio_classification/PANNs/cnn10/README.md
+152
-0
modules/audio/audio_classification/PANNs/cnn10/__init__.py
modules/audio/audio_classification/PANNs/cnn10/__init__.py
+0
-0
modules/audio/audio_classification/PANNs/cnn10/module.py
modules/audio/audio_classification/PANNs/cnn10/module.py
+89
-0
modules/audio/audio_classification/PANNs/cnn10/network.py
modules/audio/audio_classification/PANNs/cnn10/network.py
+116
-0
modules/audio/audio_classification/PANNs/cnn14/README.md
modules/audio/audio_classification/PANNs/cnn14/README.md
+152
-0
modules/audio/audio_classification/PANNs/cnn14/__init__.py
modules/audio/audio_classification/PANNs/cnn14/__init__.py
+0
-0
modules/audio/audio_classification/PANNs/cnn14/module.py
modules/audio/audio_classification/PANNs/cnn14/module.py
+89
-0
modules/audio/audio_classification/PANNs/cnn14/network.py
modules/audio/audio_classification/PANNs/cnn14/network.py
+124
-0
modules/audio/audio_classification/PANNs/cnn6/README.md
modules/audio/audio_classification/PANNs/cnn6/README.md
+152
-0
modules/audio/audio_classification/PANNs/cnn6/__init__.py
modules/audio/audio_classification/PANNs/cnn6/__init__.py
+0
-0
modules/audio/audio_classification/PANNs/cnn6/module.py
modules/audio/audio_classification/PANNs/cnn6/module.py
+89
-0
modules/audio/audio_classification/PANNs/cnn6/network.py
modules/audio/audio_classification/PANNs/cnn6/network.py
+105
-0
paddlehub/datasets/__init__.py
paddlehub/datasets/__init__.py
+3
-1
paddlehub/datasets/base_audio_dataset.py
paddlehub/datasets/base_audio_dataset.py
+125
-0
paddlehub/datasets/esc50.py
paddlehub/datasets/esc50.py
+112
-0
paddlehub/module/audio_module.py
paddlehub/module/audio_module.py
+97
-0
paddlehub/utils/utils.py
paddlehub/utils/utils.py
+43
-6
未找到文件。
demo/audio_classification/README.md
0 → 100644
浏览文件 @
23493590
# PaddleHub 声音分类
本示例展示如何使用PaddleHub Fine-tune API以及CNN14等预训练模型完成声音分类和Tagging的任务。
CNN14等预训练模型的详情,请参考论文
[
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
](
https://arxiv.org/pdf/1912.10211.pdf
)
和代码
[
audioset_tagging_cnn
](
https://github.com/qiuqiangkong/audioset_tagging_cnn
)
。
## 如何开始Fine-tune
我们以环境声音分类公开数据集
[
ESC50
](
https://github.com/karolpiczak/ESC-50
)
为示例数据集,可以运行下面的命令,在训练集(train.npz)上进行模型训练,并在开发集(dev.npz)验证。通过如下命令,即可启动训练。
```
python
# 设置使用的GPU卡号
export
CUDA_VISIBLE_DEVICES
=
0
python
train
.
py
```
## 代码步骤
使用PaddleHub Fine-tune API进行Fine-tune可以分为4个步骤。
### Step1: 选择模型
```
python
import
paddle
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
model
=
hub
.
Module
(
name
=
'panns_cnn14'
,
version
=
'1.0.0'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
)
```
其中,参数:
-
`name`
: 模型名称,可以选择
`panns_cnn14`
、
`panns_cnn10`
和
`panns_cnn6`
,具体的模型参数信息可见下表。
-
`version`
: module版本号
-
`task`
:模型的执行任务。
`sound-cls`
表示声音分类任务;
`None`
表示Audio Tagging任务。
-
`num_classes`
:表示当前声音分类任务的类别数,根据具体使用的数据集确定。
目前可选用的预训练模型:
模型名 | PaddleHub Module
-----------| :------:
CNN14 |
`hub.Module(name='panns_cnn14')`
CNN10 |
`hub.Module(name='panns_cnn10')`
CNN6 |
`hub.Module(name='panns_cnn6')`
### Step2: 加载数据集
```
python
train_dataset
=
ESC50
(
mode
=
'train'
)
dev_dataset
=
ESC50
(
mode
=
'dev'
)
```
### Step3: 选择优化策略和运行配置
```
python
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
5e-5
,
parameters
=
model
.
parameters
())
trainer
=
hub
.
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'./'
,
use_gpu
=
True
)
```
#### 优化策略
Paddle2.0提供了多种优化器选择,如
`SGD`
,
`AdamW`
,
`Adamax`
等,详细参见
[
策略
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html
)
。
其中
`AdamW`
:
-
`learning_rate`
: 全局学习率。默认为1e-3;
-
`parameters`
: 待优化模型参数。
其余可配置参数请参考
[
AdamW
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/adamw/AdamW_cn.html#cn-api-paddle-optimizer-adamw
)
。
#### 运行配置
`Trainer`
主要控制Fine-tune的训练,包含以下可控制的参数:
-
`model`
: 被优化模型;
-
`optimizer`
: 优化器选择;
-
`use_vdl`
: 是否使用vdl可视化训练过程;
-
`checkpoint_dir`
: 保存模型参数的地址;
-
`compare_metrics`
: 保存最优模型的衡量指标;
### Step4: 执行训练和模型评估
```
python
trainer
.
train
(
train_dataset
,
epochs
=
50
,
batch_size
=
16
,
eval_dataset
=
dev_dataset
,
save_interval
=
10
,
)
trainer
.
evaluate
(
dev_dataset
,
batch_size
=
16
)
```
`trainer.train`
执行模型的训练,其参数可以控制具体的训练过程,主要的参数包含:
-
`train_dataset`
: 训练时所用的数据集;
-
`epochs`
: 训练轮数;
-
`batch_size`
: 训练时每一步用到的样本数目,如果使用GPU,请根据实际情况调整batch_size;
-
`num_workers`
: works的数量,默认为0;
-
`eval_dataset`
: 验证集;
-
`log_interval`
: 打印日志的间隔, 单位为执行批训练的次数。
-
`save_interval`
: 保存模型的间隔频次,单位为执行训练的轮数。
`trainer.evaluate`
执行模型的评估,主要的参数包含:
-
`eval_dataset`
: 模型评估时所用的数据集;
-
`batch_size`
: 模型评估时每一步用到的样本数目,如果使用GPU,请根据实际情况调整batch_size
## 模型预测
当完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在
`${CHECKPOINT_DIR}/best_model`
目录下,其中
`${CHECKPOINT_DIR}`
目录为Fine-tune时所选择的保存checkpoint的目录。
以下代码将本地的音频文件
`./cat.wav`
作为预测数据,使用训好的模型对它进行分类,输出结果。
```
python
import
os
import
librosa
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
wav
=
'./cat.wav'
# 存储在本地的需要预测的wav文件
sr
=
44100
# 音频文件的采样率
checkpoint
=
'./best_model/model.pdparams'
# 模型checkpoint
label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
ESC50
.
label_list
)}
model
=
hub
.
Module
(
name
=
'panns_cnn14'
,
version
=
'1.0.0'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
,
label_map
=
label_map
,
load_checkpoint
=
checkpoint
)
data
=
[
librosa
.
load
(
wav
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
print
(
result
[
0
])
# result[0]包含音频文件属于各类别的概率值
```
## Audio Tagging
当前使用的模型是基于
[
Audioset数据集
](
https://research.google.com/audioset/
)
的预训练模型,除了以上的针对特定声音分类数据集的finetune任务,模型还支持基于Audioset 527个标签的Tagging功能。
以下代码将本地的音频文件
`./cat.wav`
作为预测数据,使用预训练模型对它进行打分,输出top 10的标签和对应的得分。
```
python
import
os
import
librosa
import
numpy
as
np
import
paddlehub
as
hub
from
paddlehub.env
import
MODULE_HOME
wav
=
'./cat.wav'
# 存储在本地的需要预测的wav文件
sr
=
44100
# 音频文件的采样率
topk
=
10
# 展示音频得分前10的标签和分数
model
=
hub
.
Module
(
name
=
'panns_cnn14'
,
version
=
'1.0.0'
,
task
=
None
)
# 读取audioset数据集的label文件
label_file
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn14'
,
'audioset_labels.txt'
)
label_map
=
{}
with
open
(
label_file
,
'r'
)
as
f
:
for
i
,
l
in
enumerate
(
f
.
readlines
()):
label_map
[
i
]
=
l
.
strip
()
data
=
[
librosa
.
load
(
wav
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
# 打印topk的类别和对应得分
for
label
,
score
in
list
(
result
[
0
].
items
())[:
topk
]:
msg
+=
f
'
{
label
}
:
{
score
}
\n
'
print
(
msg
)
```
demo/audio_classification/audioset_predict.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
librosa
import
numpy
as
np
import
paddlehub
as
hub
from
paddlehub.env
import
MODULE_HOME
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--wav"
,
type
=
str
,
required
=
True
,
help
=
"Audio file to infer."
)
parser
.
add_argument
(
"--sr"
,
type
=
int
,
default
=
32000
,
help
=
"Sample rate of inference audio."
)
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
'panns_cnn14'
,
help
=
"Select model to to inference."
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
10
,
help
=
"Show top k results of audioset labels."
)
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
label_file
=
os
.
path
.
join
(
MODULE_HOME
,
args
.
model_type
,
'audioset_labels.txt'
)
label_map
=
{}
with
open
(
label_file
,
'r'
)
as
f
:
for
i
,
l
in
enumerate
(
f
.
readlines
()):
label_map
[
i
]
=
l
.
strip
()
model
=
hub
.
Module
(
name
=
args
.
model_type
,
version
=
'1.0.0'
,
task
=
None
,
label_map
=
label_map
)
data
=
[
librosa
.
load
(
args
.
wav
,
sr
=
args
.
sr
)[
0
]]
# (t, num_mel_bins)
result
=
model
.
predict
(
data
,
sample_rate
=
args
.
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
msg
=
f
'[
{
args
.
wav
}
]
\n
'
for
label
,
score
in
list
(
result
[
0
].
items
())[:
args
.
topk
]:
msg
+=
f
'
{
label
}
:
{
score
}
\n
'
print
(
msg
)
demo/audio_classification/predict.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
librosa
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--wav"
,
type
=
str
,
required
=
True
,
help
=
"Audio file to infer."
)
parser
.
add_argument
(
"--sr"
,
type
=
int
,
default
=
44100
,
help
=
"Sample rate of inference audio."
)
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
'panns_cnn14'
,
help
=
"Select model to to inference."
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
1
,
help
=
"Show top k results of prediction labels."
)
parser
.
add_argument
(
"--checkpoint"
,
type
=
str
,
default
=
'./checkpoint/best_model/model.pdparams'
,
help
=
"Checkpoint of model."
)
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
ESC50
.
label_list
)}
model
=
hub
.
Module
(
name
=
args
.
model_type
,
version
=
'1.0.0'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
,
label_map
=
label_map
,
load_checkpoint
=
args
.
checkpoint
)
data
=
[
librosa
.
load
(
args
.
wav
,
sr
=
args
.
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
args
.
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
msg
=
f
'[
{
args
.
wav
}
]
\n
'
for
label
,
score
in
list
(
result
[
0
].
items
())[:
args
.
topk
]:
msg
+=
f
'
{
label
}
:
{
score
}
\n
'
print
(
msg
)
demo/audio_classification/train.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
paddle
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
50
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"Whether use GPU for fine-tuning, input should be True or False"
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-5
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
'./checkpoint'
,
help
=
"Directory to model checkpoint"
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
10
,
help
=
"Save checkpoint every n epoch."
)
args
=
parser
.
parse_args
()
if
__name__
==
"__main__"
:
model
=
hub
.
Module
(
name
=
'panns_cnn14'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
)
train_dataset
=
ESC50
(
mode
=
'train'
)
dev_dataset
=
ESC50
(
mode
=
'dev'
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
args
.
learning_rate
,
parameters
=
model
.
parameters
())
trainer
=
hub
.
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
args
.
checkpoint_dir
,
use_gpu
=
args
.
use_gpu
)
trainer
.
train
(
train_dataset
,
epochs
=
args
.
num_epoch
,
batch_size
=
args
.
batch_size
,
eval_dataset
=
dev_dataset
,
save_interval
=
args
.
save_interval
,
)
modules/audio/audio_classification/PANNs/cnn10/README.md
0 → 100644
浏览文件 @
23493590
```
shell
$
hub
install
panns_cnn10
==
1.0.0
```
`panns_cnn10`
是一个基于
[
Google Audioset
](
https://research.google.com/audioset/
)
数据集训练的声音分类/识别的模型。该模型主要包含8个卷积层和2个全连接层,模型参数为4.9M。经过预训练后,可以用于提取音频的embbedding,维度是512。
更多详情请参考论文:
[
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
](
https://arxiv.org/pdf/1912.10211.pdf
)
## API
```
python
def
__init__
(
task
,
num_class
=
None
,
label_map
=
None
,
load_checkpoint
=
None
,
**
kwargs
,
)
```
创建Module对象。
**参数**
*
`task`
: 任务名称,可为
`sound-cls`
或者
`None`
。
`sound-cls`
代表声音分类任务,可以对声音分类的数据集进行finetune;为
`None`
时可以获取预训练模型对音频进行分类/Tagging。
*
`num_classes`
:声音分类任务的类别数,finetune时需要指定,数值与具体使用的数据集类别数一致。
*
`label_map`
:预测时的类别映射表。
*
`load_checkpoint`
:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
*
`**kwargs`
:用户额外指定的关键字字典类型的参数。
```
python
def
predict
(
data
,
sample_rate
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
False
)
```
**参数**
*
`data`
: 待预测数据,格式为
\[
waveform1, wavwform2…,
\]
,其中每个元素都是一个一维numpy列表,是音频的波形采样数值列表。
*
`sample_rate`
:音频文件的采样率。
*
`feat_type`
:音频特征的种类选取,当前支持
`'mel'`
(详情可查看
[
Mel-frequency cepstrum
](
https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
)
)和原始波形特征
`'raw'`
。
*
`batch_size`
:模型批处理大小。
*
`use_gpu`
:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
*
`results`
:list类型,不同任务类型的返回结果如下
*
声音分类(task参数为
`sound-cls`
):列表里包含每个音频文件的分类标签。
*
Tagging(task参数为
`None`
):列表里包含每个音频文件527个类别(
[
Audioset标签
](
https://research.google.com/audioset/
)
)的得分。
**代码示例**
-
[
ESC50
](
https://github.com/karolpiczak/ESC-50
)
声音分类预测
```
python
import
librosa
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
sr
=
44100
# 音频文件的采样率
wav_file
=
'/data/cat.wav'
# 用于预测的音频文件路径
checkpoint
=
'model.pdparams'
# 用于预测的模型参数
label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
ESC50
.
label_list
)}
model
=
hub
.
Module
(
name
=
'panns_cnn10'
,
version
=
'1.0.0'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
,
label_map
=
label_map
,
load_checkpoint
=
checkpoint
)
data
=
[
librosa
.
load
(
wav_file
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
print
(
'File: {}
\t
Lable: {}'
.
format
(
wav_file
,
result
[
0
]))
```
-
Audioset Tagging
```
python
import
librosa
import
numpy
as
np
import
paddlehub
as
hub
def
show_topk
(
k
,
label_map
,
file
,
result
):
"""
展示topk的分的类别和分数。
"""
result
=
np
.
asarray
(
result
)
topk_idx
=
(
-
result
).
argsort
()[:
k
]
msg
=
f
'[
{
file
}
]
\n
'
for
idx
in
topk_idx
:
label
,
score
=
label_map
[
idx
],
result
[
idx
]
msg
+=
f
'
{
label
}
:
{
score
}
\n
'
print
(
msg
)
sr
=
44100
# 音频文件的采样率
wav_file
=
'/data/cat.wav'
# 用于预测的音频文件路径
label_file
=
'./audioset_labels.txt'
# audioset标签文本文件
topk
=
10
# 展示的topk数
label_map
=
{}
with
open
(
label_file
,
'r'
)
as
f
:
for
i
,
l
in
enumerate
(
f
.
readlines
()):
label_map
[
i
]
=
l
.
strip
()
model
=
hub
.
Module
(
name
=
'panns_cnn10'
,
version
=
'1.0.0'
,
task
=
None
)
data
=
[
librosa
.
load
(
wav_file
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
show_topk
(
topk
,
label_map
,
wav_file
,
result
[
0
])
```
详情可参考PaddleHub示例:
-
[
AudioClassification
](
https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0/demo/audio_classification
)
## 查看代码
https://github.com/qiuqiangkong/audioset_tagging_cnn
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
## 更新历史
*
1.0.0
初始发布,动态图版本模型,支持声音分类
`sound-cls`
任务的fine-tune和基于Audioset Tagging预测。
modules/audio/audio_classification/PANNs/cnn10/__init__.py
0 → 100644
浏览文件 @
23493590
modules/audio/audio_classification/PANNs/cnn10/module.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
from
typing
import
Dict
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
panns_cnn10.network
import
CNN10
from
paddlehub.env
import
MODULE_HOME
from
paddlehub.module.audio_module
import
AudioClassifierModule
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"panns_cnn10"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"Baidu"
,
author_email
=
""
,
type
=
"audio/sound_classification"
,
meta
=
AudioClassifierModule
)
class
PANN
(
nn
.
Layer
):
def
__init__
(
self
,
task
:
str
,
num_class
:
int
=
None
,
label_map
:
Dict
=
None
,
load_checkpoint
:
str
=
None
,
**
kwargs
,
):
super
(
PANN
,
self
).
__init__
()
if
label_map
:
self
.
label_map
=
label_map
self
.
num_class
=
len
(
label_map
)
else
:
self
.
num_class
=
num_class
if
task
==
'sound-cls'
:
self
.
cnn10
=
CNN10
(
extract_embedding
=
True
,
checkpoint
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn10'
,
'cnn10.pdparams'
))
self
.
dropout
=
nn
.
Dropout
(
0.1
)
self
.
fc
=
nn
.
Linear
(
self
.
cnn10
.
emb_size
,
num_class
)
self
.
criterion
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
()
self
.
metric
=
paddle
.
metric
.
Accuracy
()
else
:
self
.
cnn10
=
CNN10
(
extract_embedding
=
False
,
checkpoint
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn10'
,
'cnn10.pdparams'
))
self
.
task
=
task
if
load_checkpoint
is
not
None
and
os
.
path
.
isfile
(
load_checkpoint
):
state_dict
=
paddle
.
load
(
load_checkpoint
)
self
.
set_state_dict
(
state_dict
)
logger
.
info
(
'Loaded parameters from %s'
%
os
.
path
.
abspath
(
load_checkpoint
))
def
forward
(
self
,
feats
,
labels
=
None
):
# feats: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
feats
=
feats
.
unsqueeze
(
1
)
if
self
.
task
==
'sound-cls'
:
embeddings
=
self
.
cnn10
(
feats
)
embeddings
=
self
.
dropout
(
embeddings
)
logits
=
self
.
fc
(
embeddings
)
probs
=
F
.
softmax
(
logits
,
axis
=
1
)
if
labels
is
not
None
:
loss
=
self
.
criterion
(
logits
,
labels
)
correct
=
self
.
metric
.
compute
(
probs
,
labels
)
acc
=
self
.
metric
.
update
(
correct
)
return
probs
,
loss
,
{
'acc'
:
acc
}
return
probs
else
:
audioset_logits
=
self
.
cnn10
(
feats
)
return
audioset_logits
modules/audio/audio_classification/PANNs/cnn10/network.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlehub.utils.log
import
logger
class
ConvBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
ConvBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
),
bias_attr
=
False
)
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
),
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2D
(
out_channels
)
def
forward
(
self
,
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
F
.
relu
(
x
)
if
pool_type
==
'max'
:
x
=
F
.
max_pool2d
(
x
,
kernel_size
=
pool_size
)
elif
pool_type
==
'avg'
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
pool_size
)
elif
pool_type
==
'avg+max'
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
pool_size
)
+
F
.
max_pool2d
(
x
,
kernel_size
=
pool_size
)
else
:
raise
Exception
(
f
'Pooling type of
{
pool_type
}
is not supported. It must be one of "max", "avg" and "avg+max".'
)
return
x
class
CNN10
(
nn
.
Layer
):
emb_size
=
512
def
__init__
(
self
,
extract_embedding
:
bool
=
True
,
checkpoint
:
str
=
None
):
super
(
CNN10
,
self
).
__init__
()
self
.
bn0
=
nn
.
BatchNorm2D
(
64
)
self
.
conv_block1
=
ConvBlock
(
in_channels
=
1
,
out_channels
=
64
)
self
.
conv_block2
=
ConvBlock
(
in_channels
=
64
,
out_channels
=
128
)
self
.
conv_block3
=
ConvBlock
(
in_channels
=
128
,
out_channels
=
256
)
self
.
conv_block4
=
ConvBlock
(
in_channels
=
256
,
out_channels
=
512
)
self
.
fc1
=
nn
.
Linear
(
512
,
self
.
emb_size
)
self
.
fc_audioset
=
nn
.
Linear
(
self
.
emb_size
,
527
)
if
checkpoint
is
not
None
and
os
.
path
.
isfile
(
checkpoint
):
state_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_state_dict
(
state_dict
)
print
(
f
'Loaded CNN10 pretrained parameters from:
{
checkpoint
}
'
)
else
:
print
(
'No valid checkpoints for CNN10. Start training from scratch.'
)
self
.
extract_embedding
=
extract_embedding
def
forward
(
self
,
x
):
x
.
stop_gradient
=
False
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
self
.
bn0
(
x
)
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
self
.
conv_block1
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block2
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block3
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block4
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
x
.
mean
(
axis
=
3
)
x
=
x
.
max
(
axis
=
2
)
+
x
.
mean
(
axis
=
2
)
x
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
if
self
.
extract_embedding
:
output
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
else
:
output
=
F
.
sigmoid
(
self
.
fc_audioset
(
x
))
return
output
modules/audio/audio_classification/PANNs/cnn14/README.md
0 → 100644
浏览文件 @
23493590
```
shell
$
hub
install
panns_cnn14
==
1.0.0
```
`panns_cnn14`
是一个基于
[
Google Audioset
](
https://research.google.com/audioset/
)
数据集训练的声音分类/识别的模型。该模型主要包含12个卷积层和2个全连接层,模型参数为79.6M。经过预训练后,可以用于提取音频的embbedding,维度是2048。
更多详情请参考论文:
[
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
](
https://arxiv.org/pdf/1912.10211.pdf
)
## API
```
python
def
__init__
(
task
,
num_class
=
None
,
label_map
=
None
,
load_checkpoint
=
None
,
**
kwargs
,
)
```
创建Module对象。
**参数**
*
`task`
: 任务名称,可为
`sound-cls`
或者
`None`
。
`sound-cls`
代表声音分类任务,可以对声音分类的数据集进行finetune;为
`None`
时可以获取预训练模型对音频进行分类/Tagging。
*
`num_classes`
:声音分类任务的类别数,finetune时需要指定,数值与具体使用的数据集类别数一致。
*
`label_map`
:预测时的类别映射表。
*
`load_checkpoint`
:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
*
`**kwargs`
:用户额外指定的关键字字典类型的参数。
```
python
def
predict
(
data
,
sample_rate
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
False
)
```
**参数**
*
`data`
: 待预测数据,格式为
\[
waveform1, wavwform2…,
\]
,其中每个元素都是一个一维numpy列表,是音频的波形采样数值列表。
*
`sample_rate`
:音频文件的采样率。
*
`feat_type`
:音频特征的种类选取,当前支持
`'mel'`
(详情可查看
[
Mel-frequency cepstrum
](
https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
)
)和原始波形特征
`'raw'`
。
*
`batch_size`
:模型批处理大小。
*
`use_gpu`
:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
*
`results`
:list类型,不同任务类型的返回结果如下
*
声音分类(task参数为
`sound-cls`
):列表里包含每个音频文件的分类标签。
*
Tagging(task参数为
`None`
):列表里包含每个音频文件527个类别(
[
Audioset标签
](
https://research.google.com/audioset/
)
)的得分。
**代码示例**
-
[
ESC50
](
https://github.com/karolpiczak/ESC-50
)
声音分类预测
```
python
import
librosa
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
sr
=
44100
# 音频文件的采样率
wav_file
=
'/data/cat.wav'
# 用于预测的音频文件路径
checkpoint
=
'model.pdparams'
# 用于预测的模型参数
label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
ESC50
.
label_list
)}
model
=
hub
.
Module
(
name
=
'panns_cnn14'
,
version
=
'1.0.0'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
,
label_map
=
label_map
,
load_checkpoint
=
checkpoint
)
data
=
[
librosa
.
load
(
wav_file
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
print
(
'File: {}
\t
Lable: {}'
.
format
(
wav_file
,
result
[
0
]))
```
-
Audioset Tagging
```
python
import
librosa
import
numpy
as
np
import
paddlehub
as
hub
def
show_topk
(
k
,
label_map
,
file
,
result
):
"""
展示topk的分的类别和分数。
"""
result
=
np
.
asarray
(
result
)
topk_idx
=
(
-
result
).
argsort
()[:
k
]
msg
=
f
'[
{
file
}
]
\n
'
for
idx
in
topk_idx
:
label
,
score
=
label_map
[
idx
],
result
[
idx
]
msg
+=
f
'
{
label
}
:
{
score
}
\n
'
print
(
msg
)
sr
=
44100
# 音频文件的采样率
wav_file
=
'/data/cat.wav'
# 用于预测的音频文件路径
label_file
=
'./audioset_labels.txt'
# audioset标签文本文件
topk
=
10
# 展示的topk数
label_map
=
{}
with
open
(
label_file
,
'r'
)
as
f
:
for
i
,
l
in
enumerate
(
f
.
readlines
()):
label_map
[
i
]
=
l
.
strip
()
model
=
hub
.
Module
(
name
=
'panns_cnn14'
,
version
=
'1.0.0'
,
task
=
None
)
data
=
[
librosa
.
load
(
wav_file
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
show_topk
(
topk
,
label_map
,
wav_file
,
result
[
0
])
```
详情可参考PaddleHub示例:
-
[
AudioClassification
](
https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0/demo/audio_classification
)
## 查看代码
https://github.com/qiuqiangkong/audioset_tagging_cnn
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
## 更新历史
*
1.0.0
初始发布,动态图版本模型,支持声音分类
`sound-cls`
任务的fine-tune和基于Audioset Tagging预测。
modules/audio/audio_classification/PANNs/cnn14/__init__.py
0 → 100644
浏览文件 @
23493590
modules/audio/audio_classification/PANNs/cnn14/module.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
from
typing
import
Dict
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
panns_cnn14.network
import
CNN14
from
paddlehub.env
import
MODULE_HOME
from
paddlehub.module.audio_module
import
AudioClassifierModule
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"panns_cnn14"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"Baidu"
,
author_email
=
""
,
type
=
"audio/sound_classification"
,
meta
=
AudioClassifierModule
)
class
PANN
(
nn
.
Layer
):
def
__init__
(
self
,
task
:
str
,
num_class
:
int
=
None
,
label_map
:
Dict
=
None
,
load_checkpoint
:
str
=
None
,
**
kwargs
,
):
super
(
PANN
,
self
).
__init__
()
if
label_map
:
self
.
label_map
=
label_map
self
.
num_class
=
len
(
label_map
)
else
:
self
.
num_class
=
num_class
if
task
==
'sound-cls'
:
self
.
cnn14
=
CNN14
(
extract_embedding
=
True
,
checkpoint
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn14'
,
'cnn14.pdparams'
))
self
.
dropout
=
nn
.
Dropout
(
0.1
)
self
.
fc
=
nn
.
Linear
(
self
.
cnn14
.
emb_size
,
num_class
)
self
.
criterion
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
()
self
.
metric
=
paddle
.
metric
.
Accuracy
()
else
:
self
.
cnn14
=
CNN14
(
extract_embedding
=
False
,
checkpoint
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn14'
,
'cnn14.pdparams'
))
self
.
task
=
task
if
load_checkpoint
is
not
None
and
os
.
path
.
isfile
(
load_checkpoint
):
state_dict
=
paddle
.
load
(
load_checkpoint
)
self
.
set_state_dict
(
state_dict
)
logger
.
info
(
'Loaded parameters from %s'
%
os
.
path
.
abspath
(
load_checkpoint
))
def
forward
(
self
,
feats
,
labels
=
None
):
# feats: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
feats
=
feats
.
unsqueeze
(
1
)
if
self
.
task
==
'sound-cls'
:
embeddings
=
self
.
cnn14
(
feats
)
embeddings
=
self
.
dropout
(
embeddings
)
logits
=
self
.
fc
(
embeddings
)
probs
=
F
.
softmax
(
logits
,
axis
=
1
)
if
labels
is
not
None
:
loss
=
self
.
criterion
(
logits
,
labels
)
correct
=
self
.
metric
.
compute
(
probs
,
labels
)
acc
=
self
.
metric
.
update
(
correct
)
return
probs
,
loss
,
{
'acc'
:
acc
}
return
probs
else
:
audioset_logits
=
self
.
cnn14
(
feats
)
return
audioset_logits
modules/audio/audio_classification/PANNs/cnn14/network.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlehub.utils.log
import
logger
class
ConvBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
ConvBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
),
bias_attr
=
False
)
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
3
,
3
),
stride
=
(
1
,
1
),
padding
=
(
1
,
1
),
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2D
(
out_channels
)
def
forward
(
self
,
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
F
.
relu
(
x
)
if
pool_type
==
'max'
:
x
=
F
.
max_pool2d
(
x
,
kernel_size
=
pool_size
)
elif
pool_type
==
'avg'
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
pool_size
)
elif
pool_type
==
'avg+max'
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
pool_size
)
+
F
.
max_pool2d
(
x
,
kernel_size
=
pool_size
)
else
:
raise
Exception
(
f
'Pooling type of
{
pool_type
}
is not supported. It must be one of "max", "avg" and "avg+max".'
)
return
x
class
CNN14
(
nn
.
Layer
):
emb_size
=
2048
def
__init__
(
self
,
extract_embedding
:
bool
=
True
,
checkpoint
:
str
=
None
):
super
(
CNN14
,
self
).
__init__
()
self
.
bn0
=
nn
.
BatchNorm2D
(
64
)
self
.
conv_block1
=
ConvBlock
(
in_channels
=
1
,
out_channels
=
64
)
self
.
conv_block2
=
ConvBlock
(
in_channels
=
64
,
out_channels
=
128
)
self
.
conv_block3
=
ConvBlock
(
in_channels
=
128
,
out_channels
=
256
)
self
.
conv_block4
=
ConvBlock
(
in_channels
=
256
,
out_channels
=
512
)
self
.
conv_block5
=
ConvBlock
(
in_channels
=
512
,
out_channels
=
1024
)
self
.
conv_block6
=
ConvBlock
(
in_channels
=
1024
,
out_channels
=
2048
)
self
.
fc1
=
nn
.
Linear
(
2048
,
self
.
emb_size
)
self
.
fc_audioset
=
nn
.
Linear
(
self
.
emb_size
,
527
)
if
checkpoint
is
not
None
and
os
.
path
.
isfile
(
checkpoint
):
state_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_state_dict
(
state_dict
)
logger
.
info
(
f
'Loaded CNN14 pretrained parameters from:
{
checkpoint
}
'
)
else
:
logger
.
error
(
'No valid checkpoints for CNN14. Start training from scratch.'
)
self
.
extract_embedding
=
extract_embedding
def
forward
(
self
,
x
):
x
.
stop_gradient
=
False
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
self
.
bn0
(
x
)
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
self
.
conv_block1
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block2
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block3
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block4
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block5
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block6
(
x
,
pool_size
=
(
1
,
1
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
x
.
mean
(
axis
=
3
)
x
=
x
.
max
(
axis
=
2
)
+
x
.
mean
(
axis
=
2
)
x
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
if
self
.
extract_embedding
:
output
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
else
:
output
=
F
.
sigmoid
(
self
.
fc_audioset
(
x
))
return
output
modules/audio/audio_classification/PANNs/cnn6/README.md
0 → 100644
浏览文件 @
23493590
```
shell
$
hub
install
panns_cnn6
==
1.0.0
```
`panns_cnn6`
是一个基于
[
Google Audioset
](
https://research.google.com/audioset/
)
数据集训练的声音分类/识别的模型。该模型主要包含4个卷积层和2个全连接层,模型参数为4.5M。经过预训练后,可以用于提取音频的embbedding,维度是512。
更多详情请参考论文:
[
PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
](
https://arxiv.org/pdf/1912.10211.pdf
)
## API
```
python
def
__init__
(
task
,
num_class
=
None
,
label_map
=
None
,
load_checkpoint
=
None
,
**
kwargs
,
)
```
创建Module对象。
**参数**
*
`task`
: 任务名称,可为
`sound-cls`
或者
`None`
。
`sound-cls`
代表声音分类任务,可以对声音分类的数据集进行finetune;为
`None`
时可以获取预训练模型对音频进行分类/Tagging。
*
`num_classes`
:声音分类任务的类别数,finetune时需要指定,数值与具体使用的数据集类别数一致。
*
`label_map`
:预测时的类别映射表。
*
`load_checkpoint`
:使用PaddleHub Fine-tune api训练保存的模型参数文件路径。
*
`**kwargs`
:用户额外指定的关键字字典类型的参数。
```
python
def
predict
(
data
,
sample_rate
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
False
)
```
**参数**
*
`data`
: 待预测数据,格式为
\[
waveform1, wavwform2…,
\]
,其中每个元素都是一个一维numpy列表,是音频的波形采样数值列表。
*
`sample_rate`
:音频文件的采样率。
*
`feat_type`
:音频特征的种类选取,当前支持
`'mel'`
(详情可查看
[
Mel-frequency cepstrum
](
https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
)
)和原始波形特征
`'raw'`
。
*
`batch_size`
:模型批处理大小。
*
`use_gpu`
:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。
**返回**
*
`results`
:list类型,不同任务类型的返回结果如下
*
声音分类(task参数为
`sound-cls`
):列表里包含每个音频文件的分类标签。
*
Tagging(task参数为
`None`
):列表里包含每个音频文件527个类别(
[
Audioset标签
](
https://research.google.com/audioset/
)
)的得分。
**代码示例**
-
[
ESC50
](
https://github.com/karolpiczak/ESC-50
)
声音分类预测
```
python
import
librosa
import
paddlehub
as
hub
from
paddlehub.datasets
import
ESC50
sr
=
44100
# 音频文件的采样率
wav_file
=
'/data/cat.wav'
# 用于预测的音频文件路径
checkpoint
=
'model.pdparams'
# 用于预测的模型参数
label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
ESC50
.
label_list
)}
model
=
hub
.
Module
(
name
=
'panns_cnn6'
,
version
=
'1.0.0'
,
task
=
'sound-cls'
,
num_class
=
ESC50
.
num_class
,
label_map
=
label_map
,
load_checkpoint
=
checkpoint
)
data
=
[
librosa
.
load
(
wav_file
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
print
(
'File: {}
\t
Lable: {}'
.
format
(
wav_file
,
result
[
0
]))
```
-
Audioset Tagging
```
python
import
librosa
import
numpy
as
np
import
paddlehub
as
hub
def
show_topk
(
k
,
label_map
,
file
,
result
):
"""
展示topk的分的类别和分数。
"""
result
=
np
.
asarray
(
result
)
topk_idx
=
(
-
result
).
argsort
()[:
k
]
msg
=
f
'[
{
file
}
]
\n
'
for
idx
in
topk_idx
:
label
,
score
=
label_map
[
idx
],
result
[
idx
]
msg
+=
f
'
{
label
}
:
{
score
}
\n
'
print
(
msg
)
sr
=
44100
# 音频文件的采样率
wav_file
=
'/data/cat.wav'
# 用于预测的音频文件路径
label_file
=
'./audioset_labels.txt'
# audioset标签文本文件
topk
=
10
# 展示的topk数
label_map
=
{}
with
open
(
label_file
,
'r'
)
as
f
:
for
i
,
l
in
enumerate
(
f
.
readlines
()):
label_map
[
i
]
=
l
.
strip
()
model
=
hub
.
Module
(
name
=
'panns_cnn6'
,
version
=
'1.0.0'
,
task
=
None
)
data
=
[
librosa
.
load
(
wav_file
,
sr
=
sr
)[
0
]]
result
=
model
.
predict
(
data
,
sample_rate
=
sr
,
batch_size
=
1
,
feat_type
=
'mel'
,
use_gpu
=
True
)
show_topk
(
topk
,
label_map
,
wav_file
,
result
[
0
])
```
详情可参考PaddleHub示例:
-
[
AudioClassification
](
https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0/demo/audio_classification
)
## 查看代码
https://github.com/qiuqiangkong/audioset_tagging_cnn
## 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
## 更新历史
*
1.0.0
初始发布,动态图版本模型,支持声音分类
`sound-cls`
任务的fine-tune和基于Audioset Tagging预测。
modules/audio/audio_classification/PANNs/cnn6/__init__.py
0 → 100644
浏览文件 @
23493590
modules/audio/audio_classification/PANNs/cnn6/module.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
from
typing
import
Dict
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
panns_cnn6.network
import
CNN6
from
paddlehub.env
import
MODULE_HOME
from
paddlehub.module.audio_module
import
AudioClassifierModule
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"panns_cnn6"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"Baidu"
,
author_email
=
""
,
type
=
"audio/sound_classification"
,
meta
=
AudioClassifierModule
)
class
PANN
(
nn
.
Layer
):
def
__init__
(
self
,
task
:
str
,
num_class
:
int
=
None
,
label_map
:
Dict
=
None
,
load_checkpoint
:
str
=
None
,
**
kwargs
,
):
super
(
PANN
,
self
).
__init__
()
if
label_map
:
self
.
label_map
=
label_map
self
.
num_class
=
len
(
label_map
)
else
:
self
.
num_class
=
num_class
if
task
==
'sound-cls'
:
self
.
cnn6
=
CNN6
(
extract_embedding
=
True
,
checkpoint
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn6'
,
'cnn6.pdparams'
))
self
.
dropout
=
nn
.
Dropout
(
0.1
)
self
.
fc
=
nn
.
Linear
(
self
.
cnn6
.
emb_size
,
num_class
)
self
.
criterion
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
()
self
.
metric
=
paddle
.
metric
.
Accuracy
()
else
:
self
.
cnn6
=
CNN6
(
extract_embedding
=
False
,
checkpoint
=
os
.
path
.
join
(
MODULE_HOME
,
'panns_cnn6'
,
'cnn6.pdparams'
))
self
.
task
=
task
if
load_checkpoint
is
not
None
and
os
.
path
.
isfile
(
load_checkpoint
):
state_dict
=
paddle
.
load
(
load_checkpoint
)
self
.
set_state_dict
(
state_dict
)
logger
.
info
(
'Loaded parameters from %s'
%
os
.
path
.
abspath
(
load_checkpoint
))
def
forward
(
self
,
feats
,
labels
=
None
):
# feats: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
feats
=
feats
.
unsqueeze
(
1
)
if
self
.
task
==
'sound-cls'
:
embeddings
=
self
.
cnn6
(
feats
)
embeddings
=
self
.
dropout
(
embeddings
)
logits
=
self
.
fc
(
embeddings
)
probs
=
F
.
softmax
(
logits
,
axis
=
1
)
if
labels
is
not
None
:
loss
=
self
.
criterion
(
logits
,
labels
)
correct
=
self
.
metric
.
compute
(
probs
,
labels
)
acc
=
self
.
metric
.
update
(
correct
)
return
probs
,
loss
,
{
'acc'
:
acc
}
return
probs
else
:
audioset_logits
=
self
.
cnn6
(
feats
)
return
audioset_logits
modules/audio/audio_classification/PANNs/cnn6/network.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlehub.utils.log
import
logger
class
ConvBlock5x5
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
ConvBlock5x5
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
5
,
5
),
stride
=
(
1
,
1
),
padding
=
(
2
,
2
),
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
out_channels
)
def
forward
(
self
,
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
if
pool_type
==
'max'
:
x
=
F
.
max_pool2d
(
x
,
kernel_size
=
pool_size
)
elif
pool_type
==
'avg'
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
pool_size
)
elif
pool_type
==
'avg+max'
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
pool_size
)
+
F
.
max_pool2d
(
x
,
kernel_size
=
pool_size
)
else
:
raise
Exception
(
f
'Pooling type of
{
pool_type
}
is not supported. It must be one of "max", "avg" and "avg+max".'
)
return
x
class
CNN6
(
nn
.
Layer
):
emb_size
=
512
def
__init__
(
self
,
extract_embedding
:
bool
=
True
,
checkpoint
:
str
=
None
):
super
(
CNN6
,
self
).
__init__
()
self
.
bn0
=
nn
.
BatchNorm2D
(
64
)
self
.
conv_block1
=
ConvBlock5x5
(
in_channels
=
1
,
out_channels
=
64
)
self
.
conv_block2
=
ConvBlock5x5
(
in_channels
=
64
,
out_channels
=
128
)
self
.
conv_block3
=
ConvBlock5x5
(
in_channels
=
128
,
out_channels
=
256
)
self
.
conv_block4
=
ConvBlock5x5
(
in_channels
=
256
,
out_channels
=
512
)
self
.
fc1
=
nn
.
Linear
(
512
,
self
.
emb_size
)
self
.
fc_audioset
=
nn
.
Linear
(
self
.
emb_size
,
527
)
if
checkpoint
is
not
None
and
os
.
path
.
isfile
(
checkpoint
):
state_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_state_dict
(
state_dict
)
print
(
f
'Loaded CNN6 pretrained parameters from:
{
checkpoint
}
'
)
else
:
print
(
'No valid checkpoints for CNN6. Start training from scratch.'
)
self
.
extract_embedding
=
extract_embedding
def
forward
(
self
,
x
):
x
.
stop_gradient
=
False
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
self
.
bn0
(
x
)
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
self
.
conv_block1
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block2
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block3
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
self
.
conv_block4
(
x
,
pool_size
=
(
2
,
2
),
pool_type
=
'avg'
)
x
=
F
.
dropout
(
x
,
p
=
0.2
,
training
=
self
.
training
)
x
=
x
.
mean
(
axis
=
3
)
x
=
x
.
max
(
axis
=
2
)
+
x
.
mean
(
axis
=
2
)
x
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
if
self
.
extract_embedding
:
output
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
else
:
output
=
F
.
sigmoid
(
self
.
fc_audioset
(
x
))
return
output
paddlehub/datasets/__init__.py
浏览文件 @
23493590
...
...
@@ -13,9 +13,11 @@
# limitations under the License.
from
paddlehub.datasets.canvas
import
Canvas
from
paddlehub.datasets.chnsenticorp
import
ChnSentiCorp
from
paddlehub.datasets.esc50
import
ESC50
from
paddlehub.datasets.flowers
import
Flowers
from
paddlehub.datasets.lcqmc
import
LCQMC
from
paddlehub.datasets.minicoco
import
MiniCOCO
from
paddlehub.datasets.chnsenticorp
import
ChnSentiCorp
from
paddlehub.datasets.msra_ner
import
MSRA_NER
from
paddlehub.datasets.lcqmc
import
LCQMC
from
paddlehub.datasets.base_seg_dataset
import
SegDataset
...
...
paddlehub/datasets/base_audio_dataset.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
io
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
paddle
from
paddlehub.utils.utils
import
extract_melspectrogram
class
InputExample
(
object
):
"""
Input example of one audio sample.
"""
def
__init__
(
self
,
guid
:
int
,
source
:
Union
[
list
,
str
],
label
:
Optional
[
str
]
=
None
):
self
.
guid
=
guid
self
.
source
=
source
self
.
label
=
label
class
BaseAudioDataset
(
object
):
"""
Base class of speech dataset.
"""
def
__init__
(
self
,
base_path
:
str
,
data_file
:
str
,
mode
:
Optional
[
str
]
=
"train"
):
self
.
data_file
=
os
.
path
.
join
(
base_path
,
data_file
)
self
.
mode
=
mode
def
_read_file
(
self
,
input_file
:
str
):
raise
NotImplementedError
class
AudioClassificationDataset
(
BaseAudioDataset
,
paddle
.
io
.
Dataset
):
"""
Base class of audio classification dataset.
"""
_supported_features
=
[
'raw'
,
'mel'
]
def
__init__
(
self
,
base_path
:
str
,
data_file
:
str
,
file_type
:
str
=
'npz'
,
mode
:
str
=
'train'
,
feat_type
:
str
=
'mel'
,
feat_cfg
:
dict
=
None
):
super
(
AudioClassificationDataset
,
self
).
__init__
(
base_path
=
base_path
,
mode
=
mode
,
data_file
=
data_file
)
self
.
file_type
=
file_type
self
.
feat_type
=
feat_type
self
.
feat_cfg
=
feat_cfg
self
.
examples
=
self
.
_read_file
(
self
.
data_file
)
self
.
records
=
self
.
_convert_examples_to_records
(
self
.
examples
)
def
_read_file
(
self
,
input_file
:
str
)
->
List
[
InputExample
]:
if
not
os
.
path
.
exists
(
input_file
):
raise
RuntimeError
(
"Data file: {} not found."
.
format
(
input_file
))
examples
=
[]
if
self
.
file_type
==
'npz'
:
dataset
=
np
.
load
(
os
.
path
.
join
(
self
.
data_file
),
allow_pickle
=
True
)
audio_id
=
0
for
waveform
,
label
in
zip
(
dataset
[
'waveforms'
],
dataset
[
'labels'
]):
example
=
InputExample
(
guid
=
audio_id
,
source
=
waveform
,
label
=
label
)
audio_id
+=
1
examples
.
append
(
example
)
else
:
raise
NotImplementedError
(
f
'Only soppurts npz file type, but got
{
self
.
file_type
}
'
)
return
examples
def
_convert_examples_to_records
(
self
,
examples
:
List
[
InputExample
])
->
List
[
dict
]:
records
=
[]
for
example
in
examples
:
record
=
{}
if
self
.
feat_type
==
'raw'
:
record
[
'feat'
]
=
example
.
source
elif
self
.
feat_type
==
'mel'
:
record
[
'feat'
]
=
extract_melspectrogram
(
example
.
source
,
sample_rate
=
self
.
feat_cfg
[
'sample_rate'
],
window_size
=
self
.
feat_cfg
[
'window_size'
],
hop_size
=
self
.
feat_cfg
[
'hop_size'
],
mel_bins
=
self
.
feat_cfg
[
'mel_bins'
],
fmin
=
self
.
feat_cfg
[
'fmin'
],
fmax
=
self
.
feat_cfg
[
'fmax'
],
window
=
self
.
feat_cfg
[
'window'
],
center
=
True
,
pad_mode
=
'reflect'
,
ref
=
1.0
,
amin
=
1e-10
,
top_db
=
None
)
else
:
raise
RuntimeError
(
\
f
"Unknown type of self.feat_type:
{
self
.
feat_type
}
, it must be one in
{
self
.
_supported_features
}
"
)
record
[
'label'
]
=
example
.
label
records
.
append
(
record
)
return
records
def
__getitem__
(
self
,
idx
):
"""
Overload this method when doing extra feature processes or data augmentation.
"""
record
=
self
.
records
[
idx
]
return
np
.
array
(
record
[
'feat'
]),
np
.
array
(
record
[
'label'
],
dtype
=
np
.
int64
)
def
__len__
(
self
):
return
len
(
self
.
records
)
paddlehub/datasets/esc50.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
paddlehub.datasets.base_audio_dataset
import
AudioClassificationDataset
from
paddlehub.env
import
DATA_HOME
from
paddlehub.utils.download
import
download_data
@
download_data
(
url
=
"https://bj.bcebos.com/paddlehub-dataset/esc50.tar.gz"
)
class
ESC50
(
AudioClassificationDataset
):
sample_rate
=
44100
input_length
=
int
(
sample_rate
*
5
)
# 5s
num_class
=
50
# class num
label_list
=
[
# Animals
'Dog'
,
'Rooster'
,
'Pig'
,
'Cow'
,
'Frog'
,
'Cat'
,
'Hen'
,
'Insects (flying)'
,
'Sheep'
,
'Crow'
,
# Natural soundscapes & water sounds
'Rain'
,
'Sea waves'
,
'Crackling fire'
,
'Crickets'
,
'Chirping birds'
,
'Water drops'
,
'Wind'
,
'Pouring water'
,
'Toilet flush'
,
'Thunderstorm'
,
# Human, non-speech sounds
'Crying baby'
,
'Sneezing'
,
'Clapping'
,
'Breathing'
,
'Coughing'
,
'Footsteps'
,
'Laughing'
,
'Brushing teeth'
,
'Snoring'
,
'Drinking, sipping'
,
# Interior/domestic sounds
'Door knock'
,
'Mouse click'
,
'Keyboard typing'
,
'Door, wood creaks'
,
'Can opening'
,
'Washing machine'
,
'Vacuum cleaner'
,
'Clock alarm'
,
'Clock tick'
,
'Glass breaking'
,
# Exterior/urban noises
'Helicopter'
,
'Chainsaw'
,
'Siren'
,
'Car horn'
,
'Engine'
,
'Train'
,
'Church bells'
,
'Airplane'
,
'Fireworks'
,
'Hand saw'
,
]
def
__init__
(
self
,
mode
:
str
=
'train'
,
feat_type
:
str
=
'mel'
):
base_path
=
os
.
path
.
join
(
DATA_HOME
,
"esc50"
)
if
mode
==
'train'
:
data_file
=
'train.npz'
else
:
data_file
=
'dev.npz'
feat_cfg
=
dict
(
sample_rate
=
self
.
sample_rate
,
window_size
=
1024
,
hop_size
=
320
,
mel_bins
=
64
,
fmin
=
50
,
fmax
=
14000
,
window
=
'hann'
)
super
().
__init__
(
base_path
=
base_path
,
data_file
=
data_file
,
file_type
=
'npz'
,
mode
=
mode
,
feat_type
=
feat_type
,
feat_cfg
=
feat_cfg
)
if
__name__
==
"__main__"
:
train_dataset
=
ESC50
(
mode
=
'train'
)
dev_dataset
=
ESC50
(
mode
=
'dev'
)
paddlehub/module/audio_module.py
0 → 100644
浏览文件 @
23493590
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
typing
import
List
,
Tuple
import
numpy
as
np
import
paddle
from
paddlehub.module.module
import
RunModule
,
runnable
,
serving
from
paddlehub.utils.utils
import
extract_melspectrogram
class
AudioClassifierModule
(
RunModule
):
"""
The base class for audio classifier models.
"""
_tasks_supported
=
[
'sound-cls'
,
]
def
_batchify
(
self
,
data
:
List
[
List
[
float
]],
sample_rate
:
int
,
feat_type
:
str
,
batch_size
:
int
):
examples
=
[]
for
waveform
in
data
:
if
feat_type
==
'mel'
:
feat
=
extract_melspectrogram
(
waveform
,
sample_rate
=
sample_rate
)
examples
.
append
(
feat
)
else
:
examples
.
append
(
waveform
)
# Seperates data into some batches.
one_batch
=
[]
for
example
in
examples
:
one_batch
.
append
(
example
)
if
len
(
one_batch
)
==
batch_size
:
yield
one_batch
one_batch
=
[]
if
one_batch
:
# The last batch whose size is less than the config batch_size setting.
yield
one_batch
def
training_step
(
self
,
batch
:
List
[
paddle
.
Tensor
],
batch_idx
:
int
):
if
self
.
task
==
'sound-cls'
:
predictions
,
avg_loss
,
metric
=
self
(
feats
=
batch
[
0
],
labels
=
batch
[
1
])
else
:
raise
NotImplementedError
self
.
metric
.
reset
()
return
{
'loss'
:
avg_loss
,
'metrics'
:
metric
}
def
validation_step
(
self
,
batch
:
List
[
paddle
.
Tensor
],
batch_idx
:
int
):
if
self
.
task
==
'sound-cls'
:
predictions
,
avg_loss
,
metric
=
self
(
feats
=
batch
[
0
],
labels
=
batch
[
1
])
else
:
raise
NotImplementedError
return
{
'metrics'
:
metric
}
def
predict
(
self
,
data
:
List
[
List
[
float
]],
sample_rate
:
int
,
batch_size
:
int
=
1
,
feat_type
:
str
=
'mel'
,
use_gpu
:
bool
=
False
):
if
self
.
task
not
in
self
.
_tasks_supported
\
and
self
.
task
is
not
None
:
# None for getting audioset tags
raise
RuntimeError
(
f
'Unknown task
{
self
.
task
}
, current tasks supported:
\n
'
'1. sound-cls: sound classification;
\n
'
'2. None: audioset tags'
)
paddle
.
set_device
(
'gpu'
)
if
use_gpu
else
paddle
.
set_device
(
'cpu'
)
batches
=
self
.
_batchify
(
data
,
sample_rate
,
feat_type
,
batch_size
)
results
=
[]
self
.
eval
()
for
batch
in
batches
:
feats
=
paddle
.
to_tensor
(
batch
)
scores
=
self
(
feats
)
for
score
in
scores
.
numpy
():
result
=
OrderedDict
()
for
i
in
(
-
score
).
argsort
():
result
[
self
.
label_map
[
i
]]
=
score
[
i
]
results
.
append
(
result
)
return
results
paddlehub/utils/utils.py
浏览文件 @
23493590
...
...
@@ -15,31 +15,31 @@
import
base64
import
contextlib
import
cv2
import
hashlib
import
importlib
import
math
import
os
import
requests
import
socket
import
sys
import
time
import
tempfile
import
time
import
traceback
import
types
from
typing
import
Generator
,
List
from
urllib.parse
import
urlparse
import
cv2
import
numpy
as
np
import
packaging.version
import
requests
import
paddlehub.env
as
hubenv
import
paddlehub.utils
as
utils
from
paddlehub.utils.log
import
logger
class
Version
(
packaging
.
version
.
Version
):
'''Extended implementation of packaging.version.Version'''
def
match
(
self
,
condition
:
str
)
->
bool
:
'''
Determine whether the given condition are met
...
...
@@ -105,7 +105,6 @@ class Version(packaging.version.Version):
class
Timer
(
object
):
'''Calculate runing speed and estimated time of arrival(ETA)'''
def
__init__
(
self
,
total_step
:
int
):
self
.
total_step
=
total_step
self
.
last_start_step
=
0
...
...
@@ -303,7 +302,7 @@ def record_exception(msg: str) -> str:
'''Record the current exception infomation into the PaddleHub log file witch will be automatically stored according to date.'''
tb
=
traceback
.
format_exc
()
file
=
record
(
tb
)
utils
.
log
.
logger
.
warning
(
'{}. Detailed error information can be found in the {}.'
.
format
(
msg
,
file
))
logger
.
warning
(
'{}. Detailed error information can be found in the {}.'
.
format
(
msg
,
file
))
def
get_record_file
()
->
str
:
...
...
@@ -385,3 +384,41 @@ def trunc_sequence(ids: List[int], max_seq_len: int):
f
'The input length
{
len
(
ids
)
}
is less than max_seq_len
{
max_seq_len
}
. '
\
'Please check the input list and max_seq_len if you really want to truncate a sequence.'
return
ids
[:
max_seq_len
]
def
extract_melspectrogram
(
y
,
sample_rate
:
int
=
32000
,
window_size
:
int
=
1024
,
hop_size
:
int
=
320
,
mel_bins
:
int
=
64
,
fmin
:
int
=
50
,
fmax
:
int
=
14000
,
window
:
str
=
'hann'
,
center
:
bool
=
True
,
pad_mode
:
str
=
'reflect'
,
ref
:
float
=
1.0
,
amin
:
float
=
1e-10
,
top_db
:
float
=
None
):
'''
Extract Mel Spectrogram from a waveform.
'''
try
:
import
librosa
except
Exception
:
logger
.
error
(
'Failed to import librosa. Please check that librosa and numba are correctly installed.'
)
raise
s
=
librosa
.
stft
(
y
,
n_fft
=
window_size
,
hop_length
=
hop_size
,
win_length
=
window_size
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
)
power
=
np
.
abs
(
s
)
**
2
melW
=
librosa
.
filters
.
mel
(
sr
=
sample_rate
,
n_fft
=
window_size
,
n_mels
=
mel_bins
,
fmin
=
fmin
,
fmax
=
fmax
)
mel
=
np
.
matmul
(
melW
,
power
)
db
=
librosa
.
power_to_db
(
mel
,
ref
=
ref
,
amin
=
amin
,
top_db
=
None
)
db
=
db
.
transpose
()
return
db
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录