diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3e3357a09341435e312e2c12314e1e85b30cff53
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
@@ -0,0 +1,98 @@
+# kwmlp_speech_commands
+
+|模型名称|kwmlp_speech_commands|
+| :--- | :---: |
+|类别|语音-语言识别|
+|网络|Keyword-MLP|
+|数据集|Google Speech Commands V2|
+|是否支持Fine-tuning|否|
+|模型大小|1.6MB|
+|最新更新日期|2022-01-04|
+|数据指标|ACC 97.56%|
+
+## 一、模型基本信息
+
+### 模型介绍
+
+kwmlp_speech_commands采用了 [Keyword-MLP](https://arxiv.org/pdf/2110.07749v1.pdf) 的轻量级模型结构,并在 [Google Speech Commands V2](https://arxiv.org/abs/1804.03209) 数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。
+
+
+
+
+
+
+更多详情请参考
+- [Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition](https://arxiv.org/abs/1804.03209)
+- [ATTENTION-FREE KEYWORD SPOTTING](https://arxiv.org/pdf/2110.07749v1.pdf)
+- [Keyword-MLP](https://github.com/AI-Research-BD/Keyword-MLP)
+
+
+## 二、安装
+
+- ### 1、环境依赖
+
+ - paddlepaddle >= 2.2.0
+
+ - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
+
+- ### 2、安装
+
+ - ```shell
+ $ hub install kwmlp_speech_commands
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+
+## 三、模型API预测
+
+- ### 1、预测代码示例
+
+ ```python
+ import paddlehub as hub
+
+ model = hub.Module(
+ name='kwmlp_speech_commands',
+ version='1.0.0')
+
+ # 通过下列链接可下载示例音频
+ # https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav
+
+ # Keyword spotting
+ score, label = model.keyword_recognize('no.wav')
+ print(score, label)
+ # [0.89498246] no
+ score, label = model.keyword_recognize('go.wav')
+ print(score, label)
+ # [0.8997176] go
+ score, label = model.keyword_recognize('one.wav')
+ print(score, label)
+ # [0.88598305] one
+ ```
+
+- ### 2、API
+ - ```python
+ def keyword_recognize(
+ wav: os.PathLike,
+ )
+ ```
+ - 检测音频中包含的关键词。
+
+ - **参数**
+
+ - `wav`:输入的包含关键词的音频文件,格式为`*.wav`。
+
+ - **返回**
+
+ - 输出结果的得分和对应的关键词标签。
+
+
+## 四、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ ```shell
+ $ hub install kwmlp_speech_commands
+ ```
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..900a2eab26e4414b487d6d7858381ee302a107e8
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import paddle
+import paddleaudio
+
+
+def create_dct(n_mfcc: int, n_mels: int, norm: str = 'ortho'):
+ n = paddle.arange(float(n_mels))
+ k = paddle.arange(float(n_mfcc)).unsqueeze(1)
+ dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
+ if norm is None:
+ dct *= 2.0
+ else:
+ assert norm == "ortho"
+ dct[0] *= 1.0 / math.sqrt(2.0)
+ dct *= math.sqrt(2.0 / float(n_mels))
+ return dct.t()
+
+
+def compute_mfcc(
+ x: paddle.Tensor,
+ sr: int = 16000,
+ n_mels: int = 40,
+ n_fft: int = 480,
+ win_length: int = 480,
+ hop_length: int = 160,
+ f_min: float = 0.0,
+ f_max: float = None,
+ center: bool = False,
+ top_db: float = 80.0,
+ norm: str = 'ortho',
+):
+ fbank = paddleaudio.features.spectrum.MelSpectrogram(
+ sr=sr,
+ n_mels=n_mels,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ f_min=0.0,
+ f_max=f_max,
+ center=center)(x) # waveforms batch ~ (B, T)
+ log_fbank = paddleaudio.features.spectrum.power_to_db(fbank, top_db=top_db)
+ dct_matrix = create_dct(n_mfcc=n_mels, n_mels=n_mels, norm=norm)
+ mfcc = paddle.matmul(log_fbank.transpose((0, 2, 1)), dct_matrix).transpose((0, 2, 1)) # (B, n_mels, L)
+ return mfcc
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..df8c37e6fb14d1f5c43b0080410d3288406cfa77
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class Residual(nn.Layer):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x):
+ return self.fn(x) + x
+
+
+class PreNorm(nn.Layer):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.fn = fn
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x, **kwargs):
+ x = self.norm(x)
+ return self.fn(x, **kwargs)
+
+
+class PostNorm(nn.Layer):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.norm(self.fn(x, **kwargs))
+
+
+class SpatialGatingUnit(nn.Layer):
+ def __init__(self, dim, dim_seq, act=nn.Identity(), init_eps=1e-3):
+ super().__init__()
+ dim_out = dim // 2
+
+ self.norm = nn.LayerNorm(dim_out)
+ self.proj = nn.Conv1D(dim_seq, dim_seq, 1)
+
+ self.act = act
+
+ init_eps /= dim_seq
+
+ def forward(self, x):
+ res, gate = x.split(2, axis=-1)
+ gate = self.norm(gate)
+
+ weight, bias = self.proj.weight, self.proj.bias
+ gate = F.conv1d(gate, weight, bias)
+
+ return self.act(gate) * res
+
+
+class gMLPBlock(nn.Layer):
+ def __init__(self, *, dim, dim_ff, seq_len, act=nn.Identity()):
+ super().__init__()
+ self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU())
+
+ self.sgu = SpatialGatingUnit(dim_ff, seq_len, act)
+ self.proj_out = nn.Linear(dim_ff // 2, dim)
+
+ def forward(self, x):
+ x = self.proj_in(x)
+ x = self.sgu(x)
+ x = self.proj_out(x)
+ return x
+
+
+class Rearrange(nn.Layer):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ x = x.transpose([0, 1, 3, 2]).squeeze(1)
+ return x
+
+
+class Reduce(nn.Layer):
+ def __init__(self, axis=1):
+ super().__init__()
+ self.axis = axis
+
+ def forward(self, x):
+ x = x.mean(axis=self.axis, keepdim=False)
+ return x
+
+
+class KW_MLP(nn.Layer):
+ """Keyword-MLP."""
+
+ def __init__(self,
+ input_res=[40, 98],
+ patch_res=[40, 1],
+ num_classes=35,
+ dim=64,
+ depth=12,
+ ff_mult=4,
+ channels=1,
+ prob_survival=0.9,
+ pre_norm=False,
+ **kwargs):
+ super().__init__()
+ image_height, image_width = input_res
+ patch_height, patch_width = patch_res
+ assert (image_height % patch_height) == 0 and (
+ image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
+ num_patches = (image_height // patch_height) * (image_width // patch_width)
+
+ P_Norm = PreNorm if pre_norm else PostNorm
+
+ dim_ff = dim * ff_mult
+
+ self.to_patch_embed = nn.Sequential(Rearrange(), nn.Linear(channels * patch_height * patch_width, dim))
+
+ self.prob_survival = prob_survival
+
+ self.layers = nn.LayerList(
+ [Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)])
+
+ self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce(axis=1), nn.Linear(dim, num_classes))
+
+ def forward(self, x):
+ x = self.to_patch_embed(x)
+ layers = self.layers
+ x = nn.Sequential(*layers)(x)
+ return self.to_logits(x)
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..34342de360f2927236429baaa41789993038bd5a
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import numpy as np
+import paddle
+import paddleaudio
+
+from .feature import compute_mfcc
+from .kwmlp import KW_MLP
+from paddlehub.module.module import moduleinfo
+from paddlehub.utils.log import logger
+
+
+@moduleinfo(
+ name="kwmlp_speech_commands",
+ version="1.0.0",
+ summary="",
+ author="paddlepaddle",
+ author_email="",
+ type="audio/language_identification")
+class KWS(paddle.nn.Layer):
+ def __init__(self):
+ super(KWS, self).__init__()
+ ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams')
+ label_path = os.path.join(self.directory, 'assets', 'label.txt')
+
+ self.label_list = []
+ with open(label_path, 'r') as f:
+ for l in f:
+ self.label_list.append(l.strip())
+
+ self.sr = 16000
+ model_conf = {
+ 'input_res': [40, 98],
+ 'patch_res': [40, 1],
+ 'num_classes': 35,
+ 'channels': 1,
+ 'dim': 64,
+ 'depth': 12,
+ 'pre_norm': False,
+ 'prob_survival': 0.9,
+ }
+ self.model = KW_MLP(**model_conf)
+ self.model.set_state_dict(paddle.load(ckpt_path))
+ self.model.eval()
+
+ def load_audio(self, wav):
+ wav = os.path.abspath(os.path.expanduser(wav))
+ assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav)
+ waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False)
+ return waveform
+
+ def keyword_recognize(self, wav):
+ waveform = self.load_audio(wav)
+
+ # fix_length to 1s
+ if len(waveform) > self.sr:
+ waveform = waveform[:self.sr]
+ else:
+ waveform = np.pad(waveform, (0, self.sr - len(waveform)))
+
+ logits = self(paddle.to_tensor(waveform)).reshape([-1])
+ probs = paddle.nn.functional.softmax(logits)
+ idx = paddle.argmax(probs)
+ return probs[idx].numpy(), self.label_list[idx]
+
+ def forward(self, x):
+ if len(x.shape) == 1: # x: waveform tensors with (B, T) shape
+ x = x.unsqueeze(0)
+
+ mfcc = compute_mfcc(x).unsqueeze(1) # (B, C, n_mels, L)
+ logits = self.model(mfcc).squeeze(1)
+
+ return logits
diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..defe617fa36bc5ab7b72438034c785ee2b3ac3c9
--- /dev/null
+++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
@@ -0,0 +1 @@
+paddleaudio==0.1.0