diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3e3357a09341435e312e2c12314e1e85b30cff53 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md @@ -0,0 +1,98 @@ +# kwmlp_speech_commands + +|模型名称|kwmlp_speech_commands| +| :--- | :---: | +|类别|语音-语言识别| +|网络|Keyword-MLP| +|数据集|Google Speech Commands V2| +|是否支持Fine-tuning|否| +|模型大小|1.6MB| +|最新更新日期|2022-01-04| +|数据指标|ACC 97.56%| + +## 一、模型基本信息 + +### 模型介绍 + +kwmlp_speech_commands采用了 [Keyword-MLP](https://arxiv.org/pdf/2110.07749v1.pdf) 的轻量级模型结构,并在 [Google Speech Commands V2](https://arxiv.org/abs/1804.03209) 数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。 + +

+
+

+ + +更多详情请参考 +- [Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition](https://arxiv.org/abs/1804.03209) +- [ATTENTION-FREE KEYWORD SPOTTING](https://arxiv.org/pdf/2110.07749v1.pdf) +- [Keyword-MLP](https://github.com/AI-Research-BD/Keyword-MLP) + + +## 二、安装 + +- ### 1、环境依赖 + + - paddlepaddle >= 2.2.0 + + - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) + +- ### 2、安装 + + - ```shell + $ hub install kwmlp_speech_commands + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + + +## 三、模型API预测 + +- ### 1、预测代码示例 + + ```python + import paddlehub as hub + + model = hub.Module( + name='kwmlp_speech_commands', + version='1.0.0') + + # 通过下列链接可下载示例音频 + # https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav + + # Keyword spotting + score, label = model.keyword_recognize('no.wav') + print(score, label) + # [0.89498246] no + score, label = model.keyword_recognize('go.wav') + print(score, label) + # [0.8997176] go + score, label = model.keyword_recognize('one.wav') + print(score, label) + # [0.88598305] one + ``` + +- ### 2、API + - ```python + def keyword_recognize( + wav: os.PathLike, + ) + ``` + - 检测音频中包含的关键词。 + + - **参数** + + - `wav`:输入的包含关键词的音频文件,格式为`*.wav`。 + + - **返回** + + - 输出结果的得分和对应的关键词标签。 + + +## 四、更新历史 + +* 1.0.0 + + 初始发布 + + ```shell + $ hub install kwmlp_speech_commands + ``` diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py new file mode 100644 index 0000000000000000000000000000000000000000..900a2eab26e4414b487d6d7858381ee302a107e8 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import paddle +import paddleaudio + + +def create_dct(n_mfcc: int, n_mels: int, norm: str = 'ortho'): + n = paddle.arange(float(n_mels)) + k = paddle.arange(float(n_mfcc)).unsqueeze(1) + dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) + if norm is None: + dct *= 2.0 + else: + assert norm == "ortho" + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.t() + + +def compute_mfcc( + x: paddle.Tensor, + sr: int = 16000, + n_mels: int = 40, + n_fft: int = 480, + win_length: int = 480, + hop_length: int = 160, + f_min: float = 0.0, + f_max: float = None, + center: bool = False, + top_db: float = 80.0, + norm: str = 'ortho', +): + fbank = paddleaudio.features.spectrum.MelSpectrogram( + sr=sr, + n_mels=n_mels, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + f_min=0.0, + f_max=f_max, + center=center)(x) # waveforms batch ~ (B, T) + log_fbank = paddleaudio.features.spectrum.power_to_db(fbank, top_db=top_db) + dct_matrix = create_dct(n_mfcc=n_mels, n_mels=n_mels, norm=norm) + mfcc = paddle.matmul(log_fbank.transpose((0, 2, 1)), dct_matrix).transpose((0, 2, 1)) # (B, n_mels, L) + return mfcc diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..df8c37e6fb14d1f5c43b0080410d3288406cfa77 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class Residual(nn.Layer): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class PreNorm(nn.Layer): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class PostNorm(nn.Layer): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.norm(self.fn(x, **kwargs)) + + +class SpatialGatingUnit(nn.Layer): + def __init__(self, dim, dim_seq, act=nn.Identity(), init_eps=1e-3): + super().__init__() + dim_out = dim // 2 + + self.norm = nn.LayerNorm(dim_out) + self.proj = nn.Conv1D(dim_seq, dim_seq, 1) + + self.act = act + + init_eps /= dim_seq + + def forward(self, x): + res, gate = x.split(2, axis=-1) + gate = self.norm(gate) + + weight, bias = self.proj.weight, self.proj.bias + gate = F.conv1d(gate, weight, bias) + + return self.act(gate) * res + + +class gMLPBlock(nn.Layer): + def __init__(self, *, dim, dim_ff, seq_len, act=nn.Identity()): + super().__init__() + self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU()) + + self.sgu = SpatialGatingUnit(dim_ff, seq_len, act) + self.proj_out = nn.Linear(dim_ff // 2, dim) + + def forward(self, x): + x = self.proj_in(x) + x = self.sgu(x) + x = self.proj_out(x) + return x + + +class Rearrange(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.transpose([0, 1, 3, 2]).squeeze(1) + return x + + +class Reduce(nn.Layer): + def __init__(self, axis=1): + super().__init__() + self.axis = axis + + def forward(self, x): + x = x.mean(axis=self.axis, keepdim=False) + return x + + +class KW_MLP(nn.Layer): + """Keyword-MLP.""" + + def __init__(self, + input_res=[40, 98], + patch_res=[40, 1], + num_classes=35, + dim=64, + depth=12, + ff_mult=4, + channels=1, + prob_survival=0.9, + pre_norm=False, + **kwargs): + super().__init__() + image_height, image_width = input_res + patch_height, patch_width = patch_res + assert (image_height % patch_height) == 0 and ( + image_width % patch_width) == 0, 'image height and width must be divisible by patch size' + num_patches = (image_height // patch_height) * (image_width // patch_width) + + P_Norm = PreNorm if pre_norm else PostNorm + + dim_ff = dim * ff_mult + + self.to_patch_embed = nn.Sequential(Rearrange(), nn.Linear(channels * patch_height * patch_width, dim)) + + self.prob_survival = prob_survival + + self.layers = nn.LayerList( + [Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)]) + + self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce(axis=1), nn.Linear(dim, num_classes)) + + def forward(self, x): + x = self.to_patch_embed(x) + layers = self.layers + x = nn.Sequential(*layers)(x) + return self.to_logits(x) diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py new file mode 100644 index 0000000000000000000000000000000000000000..34342de360f2927236429baaa41789993038bd5a --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import paddle +import paddleaudio + +from .feature import compute_mfcc +from .kwmlp import KW_MLP +from paddlehub.module.module import moduleinfo +from paddlehub.utils.log import logger + + +@moduleinfo( + name="kwmlp_speech_commands", + version="1.0.0", + summary="", + author="paddlepaddle", + author_email="", + type="audio/language_identification") +class KWS(paddle.nn.Layer): + def __init__(self): + super(KWS, self).__init__() + ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams') + label_path = os.path.join(self.directory, 'assets', 'label.txt') + + self.label_list = [] + with open(label_path, 'r') as f: + for l in f: + self.label_list.append(l.strip()) + + self.sr = 16000 + model_conf = { + 'input_res': [40, 98], + 'patch_res': [40, 1], + 'num_classes': 35, + 'channels': 1, + 'dim': 64, + 'depth': 12, + 'pre_norm': False, + 'prob_survival': 0.9, + } + self.model = KW_MLP(**model_conf) + self.model.set_state_dict(paddle.load(ckpt_path)) + self.model.eval() + + def load_audio(self, wav): + wav = os.path.abspath(os.path.expanduser(wav)) + assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav) + waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False) + return waveform + + def keyword_recognize(self, wav): + waveform = self.load_audio(wav) + + # fix_length to 1s + if len(waveform) > self.sr: + waveform = waveform[:self.sr] + else: + waveform = np.pad(waveform, (0, self.sr - len(waveform))) + + logits = self(paddle.to_tensor(waveform)).reshape([-1]) + probs = paddle.nn.functional.softmax(logits) + idx = paddle.argmax(probs) + return probs[idx].numpy(), self.label_list[idx] + + def forward(self, x): + if len(x.shape) == 1: # x: waveform tensors with (B, T) shape + x = x.unsqueeze(0) + + mfcc = compute_mfcc(x).unsqueeze(1) # (B, C, n_mels, L) + logits = self.model(mfcc).squeeze(1) + + return logits diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..defe617fa36bc5ab7b72438034c785ee2b3ac3c9 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt @@ -0,0 +1 @@ +paddleaudio==0.1.0