From de9430f8bd560929d762e5262ae457641072b13a Mon Sep 17 00:00:00 2001
From: KP <109694228@qq.com>
Date: Sat, 1 Jan 2022 21:55:26 +0800
Subject: [PATCH] Add kwmlp_speech_commands.
---
 .../kwmlp_speech_commands/README.md           |  98 ++++++++++++
 .../kwmlp_speech_commands/__init__.py         |  13 ++
 .../kwmlp_speech_commands/feature.py          |  59 ++++++++
 .../kwmlp_speech_commands/kwmlp.py            | 143 ++++++++++++++++++
 .../kwmlp_speech_commands/module.py           |  86 +++++++++++
 .../kwmlp_speech_commands/requirements.txt    |   1 +
 6 files changed, 400 insertions(+)
 create mode 100644 modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
 create mode 100644 modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
 create mode 100644 modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
 create mode 100644 modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
 create mode 100644 modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
 create mode 100644 modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
new file mode 100644
index 00000000..3e3357a0
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
@@ -0,0 +1,98 @@
+# kwmlp_speech_commands
+
+|模型名称|kwmlp_speech_commands|
+| :--- | :---: |
+|类别|语音-语言识别|
+|网络|Keyword-MLP|
+|数据集|Google Speech Commands V2|
+|是否支持Fine-tuning|否|
+|模型大小|1.6MB|
+|最新更新日期|2022-01-04|
+|数据指标|ACC 97.56%|
+
+## 一、模型基本信息
+
+### 模型介绍
+
+kwmlp_speech_commands采用了 [Keyword-MLP](https://arxiv.org/pdf/2110.07749v1.pdf) 的轻量级模型结构,并在 [Google Speech Commands V2](https://arxiv.org/abs/1804.03209) 数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。
+
+
+ 
 
+
+
+
+更多详情请参考
+- [Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition](https://arxiv.org/abs/1804.03209)
+- [ATTENTION-FREE KEYWORD SPOTTING](https://arxiv.org/pdf/2110.07749v1.pdf)
+- [Keyword-MLP](https://github.com/AI-Research-BD/Keyword-MLP)
+
+
+## 二、安装
+
+- ### 1、环境依赖
+
+  - paddlepaddle >= 2.2.0
+
+  - paddlehub >= 2.2.0    | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
+
+- ### 2、安装
+
+  - ```shell
+    $ hub install kwmlp_speech_commands
+    ```
+  - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+
+## 三、模型API预测  
+
+- ### 1、预测代码示例
+
+    ```python
+    import paddlehub as hub
+
+    model = hub.Module(
+        name='kwmlp_speech_commands',
+        version='1.0.0')
+
+    # 通过下列链接可下载示例音频
+    # https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav
+
+    # Keyword spotting
+    score, label = model.keyword_recognize('no.wav')
+    print(score, label)
+    # [0.89498246] no
+    score, label = model.keyword_recognize('go.wav')
+    print(score, label)
+    # [0.8997176] go
+    score, label = model.keyword_recognize('one.wav')
+    print(score, label)
+    # [0.88598305] one
+    ```
+
+- ### 2、API
+  - ```python
+    def keyword_recognize(
+        wav: os.PathLike,
+    )
+    ```
+    - 检测音频中包含的关键词。
+
+    - **参数**
+
+      - `wav`:输入的包含关键词的音频文件,格式为`*.wav`。
+
+    - **返回**
+
+      - 输出结果的得分和对应的关键词标签。
+
+
+## 四、更新历史
+
+* 1.0.0
+
+  初始发布
+
+  ```shell
+  $ hub install kwmlp_speech_commands
+  ```
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
new file mode 100644
index 00000000..185a92b8
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
new file mode 100644
index 00000000..900a2eab
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import paddle
+import paddleaudio
+
+
+def create_dct(n_mfcc: int, n_mels: int, norm: str = 'ortho'):
+    n = paddle.arange(float(n_mels))
+    k = paddle.arange(float(n_mfcc)).unsqueeze(1)
+    dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
+    if norm is None:
+        dct *= 2.0
+    else:
+        assert norm == "ortho"
+        dct[0] *= 1.0 / math.sqrt(2.0)
+        dct *= math.sqrt(2.0 / float(n_mels))
+    return dct.t()
+
+
+def compute_mfcc(
+        x: paddle.Tensor,
+        sr: int = 16000,
+        n_mels: int = 40,
+        n_fft: int = 480,
+        win_length: int = 480,
+        hop_length: int = 160,
+        f_min: float = 0.0,
+        f_max: float = None,
+        center: bool = False,
+        top_db: float = 80.0,
+        norm: str = 'ortho',
+):
+    fbank = paddleaudio.features.spectrum.MelSpectrogram(
+        sr=sr,
+        n_mels=n_mels,
+        n_fft=n_fft,
+        win_length=win_length,
+        hop_length=hop_length,
+        f_min=0.0,
+        f_max=f_max,
+        center=center)(x)  # waveforms batch ~ (B, T)
+    log_fbank = paddleaudio.features.spectrum.power_to_db(fbank, top_db=top_db)
+    dct_matrix = create_dct(n_mfcc=n_mels, n_mels=n_mels, norm=norm)
+    mfcc = paddle.matmul(log_fbank.transpose((0, 2, 1)), dct_matrix).transpose((0, 2, 1))  # (B, n_mels, L)
+    return mfcc
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
new file mode 100644
index 00000000..df8c37e6
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class Residual(nn.Layer):
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, x):
+        return self.fn(x) + x
+
+
+class PreNorm(nn.Layer):
+    def __init__(self, dim, fn):
+        super().__init__()
+        self.fn = fn
+        self.norm = nn.LayerNorm(dim)
+
+    def forward(self, x, **kwargs):
+        x = self.norm(x)
+        return self.fn(x, **kwargs)
+
+
+class PostNorm(nn.Layer):
+    def __init__(self, dim, fn):
+        super().__init__()
+        self.norm = nn.LayerNorm(dim)
+        self.fn = fn
+
+    def forward(self, x, **kwargs):
+        return self.norm(self.fn(x, **kwargs))
+
+
+class SpatialGatingUnit(nn.Layer):
+    def __init__(self, dim, dim_seq, act=nn.Identity(), init_eps=1e-3):
+        super().__init__()
+        dim_out = dim // 2
+
+        self.norm = nn.LayerNorm(dim_out)
+        self.proj = nn.Conv1D(dim_seq, dim_seq, 1)
+
+        self.act = act
+
+        init_eps /= dim_seq
+
+    def forward(self, x):
+        res, gate = x.split(2, axis=-1)
+        gate = self.norm(gate)
+
+        weight, bias = self.proj.weight, self.proj.bias
+        gate = F.conv1d(gate, weight, bias)
+
+        return self.act(gate) * res
+
+
+class gMLPBlock(nn.Layer):
+    def __init__(self, *, dim, dim_ff, seq_len, act=nn.Identity()):
+        super().__init__()
+        self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU())
+
+        self.sgu = SpatialGatingUnit(dim_ff, seq_len, act)
+        self.proj_out = nn.Linear(dim_ff // 2, dim)
+
+    def forward(self, x):
+        x = self.proj_in(x)
+        x = self.sgu(x)
+        x = self.proj_out(x)
+        return x
+
+
+class Rearrange(nn.Layer):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x.transpose([0, 1, 3, 2]).squeeze(1)
+        return x
+
+
+class Reduce(nn.Layer):
+    def __init__(self, axis=1):
+        super().__init__()
+        self.axis = axis
+
+    def forward(self, x):
+        x = x.mean(axis=self.axis, keepdim=False)
+        return x
+
+
+class KW_MLP(nn.Layer):
+    """Keyword-MLP."""
+
+    def __init__(self,
+                 input_res=[40, 98],
+                 patch_res=[40, 1],
+                 num_classes=35,
+                 dim=64,
+                 depth=12,
+                 ff_mult=4,
+                 channels=1,
+                 prob_survival=0.9,
+                 pre_norm=False,
+                 **kwargs):
+        super().__init__()
+        image_height, image_width = input_res
+        patch_height, patch_width = patch_res
+        assert (image_height % patch_height) == 0 and (
+            image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
+        num_patches = (image_height // patch_height) * (image_width // patch_width)
+
+        P_Norm = PreNorm if pre_norm else PostNorm
+
+        dim_ff = dim * ff_mult
+
+        self.to_patch_embed = nn.Sequential(Rearrange(), nn.Linear(channels * patch_height * patch_width, dim))
+
+        self.prob_survival = prob_survival
+
+        self.layers = nn.LayerList(
+            [Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)])
+
+        self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce(axis=1), nn.Linear(dim, num_classes))
+
+    def forward(self, x):
+        x = self.to_patch_embed(x)
+        layers = self.layers
+        x = nn.Sequential(*layers)(x)
+        return self.to_logits(x)
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
new file mode 100644
index 00000000..34342de3
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import numpy as np
+import paddle
+import paddleaudio
+
+from .feature import compute_mfcc
+from .kwmlp import KW_MLP
+from paddlehub.module.module import moduleinfo
+from paddlehub.utils.log import logger
+
+
+@moduleinfo(
+    name="kwmlp_speech_commands",
+    version="1.0.0",
+    summary="",
+    author="paddlepaddle",
+    author_email="",
+    type="audio/language_identification")
+class KWS(paddle.nn.Layer):
+    def __init__(self):
+        super(KWS, self).__init__()
+        ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams')
+        label_path = os.path.join(self.directory, 'assets', 'label.txt')
+
+        self.label_list = []
+        with open(label_path, 'r') as f:
+            for l in f:
+                self.label_list.append(l.strip())
+
+        self.sr = 16000
+        model_conf = {
+            'input_res': [40, 98],
+            'patch_res': [40, 1],
+            'num_classes': 35,
+            'channels': 1,
+            'dim': 64,
+            'depth': 12,
+            'pre_norm': False,
+            'prob_survival': 0.9,
+        }
+        self.model = KW_MLP(**model_conf)
+        self.model.set_state_dict(paddle.load(ckpt_path))
+        self.model.eval()
+
+    def load_audio(self, wav):
+        wav = os.path.abspath(os.path.expanduser(wav))
+        assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav)
+        waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False)
+        return waveform
+
+    def keyword_recognize(self, wav):
+        waveform = self.load_audio(wav)
+
+        # fix_length to 1s
+        if len(waveform) > self.sr:
+            waveform = waveform[:self.sr]
+        else:
+            waveform = np.pad(waveform, (0, self.sr - len(waveform)))
+
+        logits = self(paddle.to_tensor(waveform)).reshape([-1])
+        probs = paddle.nn.functional.softmax(logits)
+        idx = paddle.argmax(probs)
+        return probs[idx].numpy(), self.label_list[idx]
+
+    def forward(self, x):
+        if len(x.shape) == 1:  # x: waveform tensors with (B, T) shape
+            x = x.unsqueeze(0)
+
+        mfcc = compute_mfcc(x).unsqueeze(1)  # (B, C, n_mels, L)
+        logits = self.model(mfcc).squeeze(1)
+
+        return logits
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
new file mode 100644
index 00000000..defe617f
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
@@ -0,0 +1 @@
+paddleaudio==0.1.0
-- 
GitLab