提交 de9430f8 编写于 作者: K KP

Add kwmlp_speech_commands.

上级 f6f5826e
# kwmlp_speech_commands
|模型名称|kwmlp_speech_commands|
| :--- | :---: |
|类别|语音-语言识别|
|网络|Keyword-MLP|
|数据集|Google Speech Commands V2|
|是否支持Fine-tuning|否|
|模型大小|1.6MB|
|最新更新日期|2022-01-04|
|数据指标|ACC 97.56%|
## 一、模型基本信息
### 模型介绍
kwmlp_speech_commands采用了 [Keyword-MLP](https://arxiv.org/pdf/2110.07749v1.pdf) 的轻量级模型结构,并在 [Google Speech Commands V2](https://arxiv.org/abs/1804.03209) 数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。
<p align="center">
<img src="https://d3i71xaburhd42.cloudfront.net/fa690a97f76ba119ca08fb02fa524a546c47f031/2-Figure1-1.png" hspace='10' height="550"/> <br />
</p>
更多详情请参考
- [Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition](https://arxiv.org/abs/1804.03209)
- [ATTENTION-FREE KEYWORD SPOTTING](https://arxiv.org/pdf/2110.07749v1.pdf)
- [Keyword-MLP](https://github.com/AI-Research-BD/Keyword-MLP)
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.2.0
- paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install kwmlp_speech_commands
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、预测代码示例
```python
import paddlehub as hub
model = hub.Module(
name='kwmlp_speech_commands',
version='1.0.0')
# 通过下列链接可下载示例音频
# https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav
# Keyword spotting
score, label = model.keyword_recognize('no.wav')
print(score, label)
# [0.89498246] no
score, label = model.keyword_recognize('go.wav')
print(score, label)
# [0.8997176] go
score, label = model.keyword_recognize('one.wav')
print(score, label)
# [0.88598305] one
```
- ### 2、API
- ```python
def keyword_recognize(
wav: os.PathLike,
)
```
- 检测音频中包含的关键词。
- **参数**
- `wav`:输入的包含关键词的音频文件,格式为`*.wav`。
- **返回**
- 输出结果的得分和对应的关键词标签。
## 四、更新历史
* 1.0.0
初始发布
```shell
$ hub install kwmlp_speech_commands
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
import paddleaudio
def create_dct(n_mfcc: int, n_mels: int, norm: str = 'ortho'):
n = paddle.arange(float(n_mels))
k = paddle.arange(float(n_mfcc)).unsqueeze(1)
dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
if norm is None:
dct *= 2.0
else:
assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels))
return dct.t()
def compute_mfcc(
x: paddle.Tensor,
sr: int = 16000,
n_mels: int = 40,
n_fft: int = 480,
win_length: int = 480,
hop_length: int = 160,
f_min: float = 0.0,
f_max: float = None,
center: bool = False,
top_db: float = 80.0,
norm: str = 'ortho',
):
fbank = paddleaudio.features.spectrum.MelSpectrogram(
sr=sr,
n_mels=n_mels,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=0.0,
f_max=f_max,
center=center)(x) # waveforms batch ~ (B, T)
log_fbank = paddleaudio.features.spectrum.power_to_db(fbank, top_db=top_db)
dct_matrix = create_dct(n_mfcc=n_mels, n_mels=n_mels, norm=norm)
mfcc = paddle.matmul(log_fbank.transpose((0, 2, 1)), dct_matrix).transpose((0, 2, 1)) # (B, n_mels, L)
return mfcc
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class Residual(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class PreNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class PostNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.norm(self.fn(x, **kwargs))
class SpatialGatingUnit(nn.Layer):
def __init__(self, dim, dim_seq, act=nn.Identity(), init_eps=1e-3):
super().__init__()
dim_out = dim // 2
self.norm = nn.LayerNorm(dim_out)
self.proj = nn.Conv1D(dim_seq, dim_seq, 1)
self.act = act
init_eps /= dim_seq
def forward(self, x):
res, gate = x.split(2, axis=-1)
gate = self.norm(gate)
weight, bias = self.proj.weight, self.proj.bias
gate = F.conv1d(gate, weight, bias)
return self.act(gate) * res
class gMLPBlock(nn.Layer):
def __init__(self, *, dim, dim_ff, seq_len, act=nn.Identity()):
super().__init__()
self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU())
self.sgu = SpatialGatingUnit(dim_ff, seq_len, act)
self.proj_out = nn.Linear(dim_ff // 2, dim)
def forward(self, x):
x = self.proj_in(x)
x = self.sgu(x)
x = self.proj_out(x)
return x
class Rearrange(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
x = x.transpose([0, 1, 3, 2]).squeeze(1)
return x
class Reduce(nn.Layer):
def __init__(self, axis=1):
super().__init__()
self.axis = axis
def forward(self, x):
x = x.mean(axis=self.axis, keepdim=False)
return x
class KW_MLP(nn.Layer):
"""Keyword-MLP."""
def __init__(self,
input_res=[40, 98],
patch_res=[40, 1],
num_classes=35,
dim=64,
depth=12,
ff_mult=4,
channels=1,
prob_survival=0.9,
pre_norm=False,
**kwargs):
super().__init__()
image_height, image_width = input_res
patch_height, patch_width = patch_res
assert (image_height % patch_height) == 0 and (
image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width)
P_Norm = PreNorm if pre_norm else PostNorm
dim_ff = dim * ff_mult
self.to_patch_embed = nn.Sequential(Rearrange(), nn.Linear(channels * patch_height * patch_width, dim))
self.prob_survival = prob_survival
self.layers = nn.LayerList(
[Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)])
self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce(axis=1), nn.Linear(dim, num_classes))
def forward(self, x):
x = self.to_patch_embed(x)
layers = self.layers
x = nn.Sequential(*layers)(x)
return self.to_logits(x)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddleaudio
from .feature import compute_mfcc
from .kwmlp import KW_MLP
from paddlehub.module.module import moduleinfo
from paddlehub.utils.log import logger
@moduleinfo(
name="kwmlp_speech_commands",
version="1.0.0",
summary="",
author="paddlepaddle",
author_email="",
type="audio/language_identification")
class KWS(paddle.nn.Layer):
def __init__(self):
super(KWS, self).__init__()
ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams')
label_path = os.path.join(self.directory, 'assets', 'label.txt')
self.label_list = []
with open(label_path, 'r') as f:
for l in f:
self.label_list.append(l.strip())
self.sr = 16000
model_conf = {
'input_res': [40, 98],
'patch_res': [40, 1],
'num_classes': 35,
'channels': 1,
'dim': 64,
'depth': 12,
'pre_norm': False,
'prob_survival': 0.9,
}
self.model = KW_MLP(**model_conf)
self.model.set_state_dict(paddle.load(ckpt_path))
self.model.eval()
def load_audio(self, wav):
wav = os.path.abspath(os.path.expanduser(wav))
assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav)
waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False)
return waveform
def keyword_recognize(self, wav):
waveform = self.load_audio(wav)
# fix_length to 1s
if len(waveform) > self.sr:
waveform = waveform[:self.sr]
else:
waveform = np.pad(waveform, (0, self.sr - len(waveform)))
logits = self(paddle.to_tensor(waveform)).reshape([-1])
probs = paddle.nn.functional.softmax(logits)
idx = paddle.argmax(probs)
return probs[idx].numpy(), self.label_list[idx]
def forward(self, x):
if len(x.shape) == 1: # x: waveform tensors with (B, T) shape
x = x.unsqueeze(0)
mfcc = compute_mfcc(x).unsqueeze(1) # (B, C, n_mels, L)
logits = self.model(mfcc).squeeze(1)
return logits
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册