Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f9761d53
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f9761d53
编写于
4月 19, 2022
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add KWS example.
上级
b60b1dad
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
354 addition
and
225 deletion
+354
-225
audio/paddleaudio/datasets/hey_snips.py
audio/paddleaudio/datasets/hey_snips.py
+3
-1
examples/hey_snips/README.md
examples/hey_snips/README.md
+22
-0
examples/hey_snips/RESULTS.md
examples/hey_snips/RESULTS.md
+8
-0
examples/hey_snips/kws0/RESULTS.md
examples/hey_snips/kws0/RESULTS.md
+0
-0
examples/hey_snips/kws0/conf/mdtc.yaml
examples/hey_snips/kws0/conf/mdtc.yaml
+39
-0
examples/hey_snips/kws0/local/plot.sh
examples/hey_snips/kws0/local/plot.sh
+2
-0
examples/hey_snips/kws0/local/score.sh
examples/hey_snips/kws0/local/score.sh
+5
-0
examples/hey_snips/kws0/local/train.sh
examples/hey_snips/kws0/local/train.sh
+12
-0
examples/hey_snips/kws0/run.sh
examples/hey_snips/kws0/run.sh
+12
-23
paddlespeech/kws/exps/mdtc/collate.py
paddlespeech/kws/exps/mdtc/collate.py
+39
-0
paddlespeech/kws/exps/mdtc/compute_det.py
paddlespeech/kws/exps/mdtc/compute_det.py
+41
-49
paddlespeech/kws/exps/mdtc/plot_det_curve.py
paddlespeech/kws/exps/mdtc/plot_det_curve.py
+20
-9
paddlespeech/kws/exps/mdtc/score.py
paddlespeech/kws/exps/mdtc/score.py
+30
-54
paddlespeech/kws/exps/mdtc/train.py
paddlespeech/kws/exps/mdtc/train.py
+37
-74
paddlespeech/kws/models/__init__.py
paddlespeech/kws/models/__init__.py
+2
-0
paddlespeech/kws/models/loss.py
paddlespeech/kws/models/loss.py
+80
-0
paddlespeech/kws/models/mdtc.py
paddlespeech/kws/models/mdtc.py
+2
-15
未找到文件。
audio/paddleaudio/datasets/hey_snips.py
浏览文件 @
f9761d53
...
...
@@ -63,10 +63,12 @@ class HeySnips(AudioClassificationDataset):
files
=
[]
labels
=
[]
self
.
keys
=
[]
self
.
durations
=
[]
for
sample
in
meta_info
:
key
,
target
,
_
,
wav
=
sample
key
,
target
,
duration
,
wav
=
sample
files
.
append
(
wav
)
labels
.
append
(
int
(
target
))
self
.
keys
.
append
(
key
)
self
.
durations
.
append
(
float
(
duration
))
return
files
,
labels
examples/hey_snips/README.md
浏览文件 @
f9761d53
# MDTC Keyword Spotting with HeySnips Dataset
## Dataset
Before running scripts, you
**MUST**
follow this instruction to download the dataset: https://github.com/sonos/keyword-spotting-research-datasets
After you download and decompress the dataset archive, you should
**REPLACE**
the value of
`data_dir`
in
`conf/*.yaml`
to complete dataset config.
## Get Started
In this section, we will train the
[
MDTC
](
https://arxiv.org/pdf/2102.13552.pdf
)
model and evaluate on "Hey Snips" dataset.
```
sh
CUDA_VISIBLE_DEVICES
=
0,1 ./run.sh conf/mdtc.yaml
```
This script contains training and scoring steps. You can just set the
`CUDA_VISIBLE_DEVICES`
environment var to run on single gpu or multi-gpus.
The vars
`stage`
and
`stop_stage`
in
`./run.sh`
controls the running steps:
-
stage 1: Training from scratch.
-
stage 2: Evaluating model on test dataset and computing detection error tradeoff(DET) of all trigger thresholds.
-
stage 3: Plotting the DET cruve for visualizaiton.
examples/hey_snips/RESULTS.md
0 → 100644
浏览文件 @
f9761d53
## Metrics
We mesure FRRs with fixing false alarms in one hour:
|Model|False Alarm| False Reject Rate|
|--|--|--|
|MDTC| 1| 0.003559 |
examples/hey_snips/kws0/RESULTS.md
已删除
100644 → 0
浏览文件 @
b60b1dad
examples/hey_snips/kws0/conf/mdtc.yaml
0 → 100644
浏览文件 @
f9761d53
data
:
data_dir
:
'
/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
dataset
:
'
paddleaudio.datasets:HeySnips'
model
:
num_keywords
:
1
backbone
:
'
paddlespeech.kws.models:MDTC'
config
:
stack_num
:
3
stack_size
:
4
in_channels
:
80
res_channels
:
32
kernel_size
:
5
feature
:
feat_type
:
'
kaldi_fbank'
sample_rate
:
16000
frame_shift
:
10
frame_length
:
25
n_mels
:
80
training
:
epochs
:
100
num_workers
:
16
batch_size
:
100
checkpoint_dir
:
'
./checkpoint'
save_freq
:
10
log_freq
:
10
learning_rate
:
0.001
weight_decay
:
0.00005
grad_clip
:
5.0
scoring
:
batch_size
:
100
num_workers
:
16
checkpoint
:
'
./checkpoint/epoch_100/model.pdparams'
score_file
:
'
./scores.txt'
stats_file
:
'
./stats.0.txt'
img_file
:
'
./det.png'
\ No newline at end of file
examples/hey_snips/kws0/local/plot.sh
0 → 100755
浏览文件 @
f9761d53
#!/bin/bash
python3
${
BIN_DIR
}
/plot_det_curve.py
--cfg_path
=
$1
--keyword
HeySnips
examples/hey_snips/kws0/local/score.sh
0 → 100755
浏览文件 @
f9761d53
#!/bin/bash
python3
${
BIN_DIR
}
/score.py
--cfg_path
=
$1
python3
${
BIN_DIR
}
/compute_det.py
--cfg_path
=
$1
examples/hey_snips/kws0/local/train.sh
0 → 100755
浏览文件 @
f9761d53
#!/bin/bash
ngpu
=
$1
cfg_path
=
$2
if
[
${
ngpu
}
-gt
0
]
;
then
python3
-m
paddle.distributed.launch
--gpus
$CUDA_VISIBLE_DEVICES
${
BIN_DIR
}
/train.py
\
--cfg_path
${
cfg_path
}
else
python3
${
BIN_DIR
}
/train.py
\
--cfg_path
${
cfg_path
}
fi
examples/hey_snips/kws0/run.sh
浏览文件 @
f9761d53
...
...
@@ -13,35 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
.
./path.sh
set
-e
source
path.sh
stage
=
0
stop_stage
=
50
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
dir
=
data/
# data info directory
stage
=
1
stop_stage
=
3
exp_dir
=
exp/ecapa-tdnn-vox12-big/
# experiment directory
conf_path
=
conf/mdtc.yaml
gpus
=
0,1,2,3
cfg_path
=
$1
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
mkdir
-p
${
exp_dir
}
if
[
$stage
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
bash ./local/data.sh
${
dir
}
${
conf_path
}
||
exit
-1
;
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
./local/train.sh
${
ngpu
}
${
cfg_path
}
||
exit
-1
fi
if
[
$
stage
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
bash ./local/train.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
if
[
$
{
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
./local/score.sh
${
cfg_path
}
||
exit
-1
fi
if
[
$
stage
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
CUDA_VISIBLE_DEVICES
=
0 bash ./local/test.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
fi
if
[
$
{
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
./local/plot.sh
${
cfg_path
}
||
exit
-1
fi
\ No newline at end of file
paddlespeech/kws/exps/mdtc/collate.py
0 → 100644
浏览文件 @
f9761d53
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
paddle
def
collate_features
(
batch
):
# (key, feat, label)
collate_start
=
time
.
time
()
keys
=
[]
feats
=
[]
labels
=
[]
lengths
=
[]
for
sample
in
batch
:
keys
.
append
(
sample
[
0
])
feats
.
append
(
sample
[
1
])
labels
.
append
(
sample
[
2
])
lengths
.
append
(
sample
[
1
].
shape
[
0
])
max_length
=
max
(
lengths
)
for
i
in
range
(
len
(
feats
)):
feats
[
i
]
=
paddle
.
nn
.
functional
.
pad
(
feats
[
i
],
[
0
,
max_length
-
feats
[
i
].
shape
[
0
],
0
,
0
],
data_format
=
'NLC'
)
return
keys
,
paddle
.
stack
(
feats
),
paddle
.
to_tensor
(
labels
),
paddle
.
to_tensor
(
lengths
)
paddlespeech/kws/exps/mdtc/compute_det.py
浏览文件 @
f9761d53
...
...
@@ -11,15 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
sys
import
yaml
from
tqdm
import
tqdm
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
def
load_label_and_score
(
keyword
,
label_file
,
score_file
):
# score_table: {uttid: [keywordlist]}
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--keyword'
,
type
=
int
,
default
=
0
,
help
=
'keyword label'
)
parser
.
add_argument
(
'--step'
,
type
=
float
,
default
=
0.01
,
help
=
'threshold step'
)
parser
.
add_argument
(
'--window_shift'
,
type
=
int
,
default
=
50
,
help
=
'window_shift is used to skip the frames after triggered'
)
args
=
parser
.
parse_args
()
# yapf: enable
def
load_label_and_score
(
keyword
,
ds
,
score_file
):
score_table
=
{}
with
open
(
score_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
...
...
@@ -34,59 +45,40 @@ def load_label_and_score(keyword, label_file, score_file):
keyword_table
=
{}
filler_table
=
{}
filler_duration
=
0.0
with
open
(
label_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
obj
=
json
.
loads
(
line
.
strip
())
assert
'key'
in
obj
assert
'txt'
in
obj
assert
'duration'
in
obj
key
=
obj
[
'key'
]
index
=
obj
[
'txt'
]
duration
=
obj
[
'duration'
]
assert
key
in
score_table
if
index
==
keyword
:
keyword_table
[
key
]
=
score_table
[
key
]
else
:
filler_table
[
key
]
=
score_table
[
key
]
filler_duration
+=
duration
for
key
,
index
,
duration
in
zip
(
ds
.
keys
,
ds
.
labels
,
ds
.
durations
):
assert
key
in
score_table
if
index
==
keyword
:
keyword_table
[
key
]
=
score_table
[
key
]
else
:
filler_table
[
key
]
=
score_table
[
key
]
filler_duration
+=
duration
return
keyword_table
,
filler_table
,
filler_duration
class
Args
:
def
__init__
(
self
):
self
.
test_data
=
'/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/data/test/data.list'
self
.
keyword
=
0
self
.
score_file
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'score.txt'
)
self
.
stats_file
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'stats.0.txt'
)
self
.
step
=
0.01
self
.
window_shift
=
50
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
scoring_conf
=
config
[
'scoring'
]
args
=
Args
()
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'test'
,
**
feat_conf
)
if
__name__
==
'__main__'
:
# parser = argparse.ArgumentParser(description='compute det curve')
# parser.add_argument('--test_data', required=True, help='label file')
# parser.add_argument('--keyword', type=int, default=0, help='keyword label')
# parser.add_argument('--score_file', required=True, help='score file')
# parser.add_argument('--step', type=float, default=0.01,
# help='threshold step')
# parser.add_argument('--window_shift', type=int, default=50,
# help='window_shift is used to skip the frames after triggered')
# parser.add_argument('--stats_file',
# required=True,
# help='false reject/alarm stats file')
# args = parser.parse_args()
score_file
=
os
.
path
.
abspath
(
scoring_conf
[
'score_file'
])
stats_file
=
os
.
path
.
abspath
(
scoring_conf
[
'stats_file'
])
window_shift
=
args
.
window_shift
keyword_table
,
filler_table
,
filler_duration
=
load_label_and_score
(
args
.
keyword
,
args
.
test_data
,
args
.
score_file
)
args
.
keyword
,
test_ds
,
score_file
)
print
(
'Filler total duration Hours: {}'
.
format
(
filler_duration
/
3600.0
))
pbar
=
tqdm
(
total
=
int
(
1.0
/
args
.
step
))
with
open
(
args
.
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
keyword_index
=
int
(
args
.
keyword
)
with
open
(
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
keyword_index
=
args
.
keyword
threshold
=
0.0
while
threshold
<=
1.0
:
num_false_reject
=
0
...
...
@@ -103,7 +95,7 @@ if __name__ == '__main__':
while
i
<
len
(
score_list
):
if
score_list
[
i
]
>=
threshold
:
num_false_alarm
+=
1
i
+=
window_shift
i
+=
args
.
window_shift
else
:
i
+=
1
if
len
(
keyword_table
)
!=
0
:
...
...
@@ -118,4 +110,4 @@ if __name__ == '__main__':
pbar
.
update
(
1
)
pbar
.
close
()
print
(
'DET saved to: {}'
.
format
(
args
.
stats_file
))
print
(
'DET saved to: {}'
.
format
(
stats_file
))
paddlespeech/kws/exps/mdtc/plot_det_curve.py
浏览文件 @
f9761d53
...
...
@@ -11,11 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
sys
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
yaml
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--keyword"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
def
load_stats_file
(
stats_file
):
...
...
@@ -29,7 +38,7 @@ def load_stats_file(stats_file):
return
np
.
array
(
values
)
def
plot_det_curve
(
keywords
,
stats_
dir
,
figure_file
,
xlim
,
x_step
,
ylim
,
def
plot_det_curve
(
keywords
,
stats_
file
,
figure_file
,
xlim
,
x_step
,
ylim
,
y_step
):
plt
.
figure
(
dpi
=
200
)
plt
.
rcParams
[
'xtick.direction'
]
=
'in'
...
...
@@ -37,7 +46,6 @@ def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim,
plt
.
rcParams
[
'font.size'
]
=
12
for
index
,
keyword
in
enumerate
(
keywords
):
stats_file
=
os
.
path
.
join
(
stats_dir
,
'stats.'
+
str
(
index
)
+
'.txt'
)
values
=
load_stats_file
(
stats_file
)
plt
.
plot
(
values
[:,
0
],
values
[:,
1
],
label
=
keyword
)
...
...
@@ -53,11 +61,14 @@ def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim,
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
keywords
=
[
'Hey_Snips
'
]
img_
path
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'det.png'
)
plot_det_curve
(
keywords
,
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
img_path
,
10
,
2
,
10
,
2
)
scoring_conf
=
config
[
'scoring
'
]
img_
file
=
os
.
path
.
abspath
(
scoring_conf
[
'img_file'
]
)
stats_file
=
os
.
path
.
abspath
(
scoring_conf
[
'stats_file'
])
keywords
=
[
args
.
keyword
]
plot_det_curve
(
keywords
,
stats_file
,
img_file
,
10
,
2
,
10
,
2
)
print
(
'DET curve image saved to: {}'
.
format
(
img_
path
))
print
(
'DET curve image saved to: {}'
.
format
(
img_
file
))
paddlespeech/kws/exps/mdtc/score.py
浏览文件 @
f9761d53
...
...
@@ -11,80 +11,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
sys
import
time
import
paddle
from
mdtc
import
KWSModel
from
mdtc
import
MDTC
import
yaml
from
tqdm
import
tqdm
from
paddleaudio.datasets
import
HeySnips
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
def
collate_features
(
batch
):
# (key, feat, label) in one sample
collate_start
=
time
.
time
()
keys
=
[]
feats
=
[]
labels
=
[]
lengths
=
[]
for
sample
in
batch
:
keys
.
append
(
sample
[
0
])
feats
.
append
(
sample
[
1
])
labels
.
append
(
sample
[
2
])
lengths
.
append
(
sample
[
1
].
shape
[
0
])
max_length
=
max
(
lengths
)
for
i
in
range
(
len
(
feats
)):
feats
[
i
]
=
paddle
.
nn
.
functional
.
pad
(
feats
[
i
],
[
0
,
max_length
-
feats
[
i
].
shape
[
0
],
0
,
0
],
data_format
=
'NLC'
)
return
keys
,
paddle
.
stack
(
feats
),
paddle
.
to_tensor
(
labels
),
paddle
.
to_tensor
(
lengths
)
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
model_conf
=
config
[
'model'
]
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
scoring_conf
=
config
[
'scoring'
]
if
__name__
==
'__main__'
:
# Dataset
feat_conf
=
{
# 'n_mfcc': 80,
'n_mels'
:
80
,
'frame_shift'
:
10
,
'frame_length'
:
25
,
# 'dither': 1.0,
}
test_ds
=
HeySnips
(
mode
=
'test'
,
feat_type
=
'kaldi_fbank'
,
sample_rate
=
16000
,
**
feat_conf
)
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'test'
,
**
feat_conf
)
test_sampler
=
paddle
.
io
.
BatchSampler
(
test_ds
,
batch_size
=
32
,
drop_last
=
False
)
test_ds
,
batch_size
=
scoring_conf
[
'batch_size'
]
,
drop_last
=
False
)
test_loader
=
paddle
.
io
.
DataLoader
(
test_ds
,
batch_sampler
=
test_sampler
,
num_workers
=
16
,
num_workers
=
scoring_conf
[
'num_workers'
]
,
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
# Model
backbone
=
MDTC
(
stack_num
=
3
,
stack_size
=
4
,
in_channels
=
80
,
res_channels
=
32
,
kernel_size
=
5
,
causal
=
True
,
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
1
)
model
=
paddle
.
DataParallel
(
model
)
# kws_checkpoint = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/checkpoint/epoch_10_0.8903940343290826/model.pdparams'
kws_checkpoint
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'model.pdparams'
)
model
.
set_state_dict
(
paddle
.
load
(
kws_checkpoint
))
backbone_class
=
dynamic_import
(
model_conf
[
'backbone'
])
backbone
=
backbone_class
(
**
model_conf
[
'config'
])
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
model_conf
[
'num_keywords'
])
model
.
set_state_dict
(
paddle
.
load
(
scoring_conf
[
'checkpoint'
]))
model
.
eval
()
score_abs_path
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'score.txt'
)
with
paddle
.
no_grad
(),
open
(
score_abs_path
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
with
paddle
.
no_grad
(),
open
(
scoring_conf
[
'score_file'
]
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
test_loader
,
total
=
len
(
test_loader
))):
keys
,
feats
,
labels
,
lengths
=
batch
...
...
@@ -100,4 +76,4 @@ if __name__ == '__main__':
fout
.
write
(
'{} {} {}
\n
'
.
format
(
key
,
keyword_i
,
score_frames
))
print
(
'
Scores saved to: {}'
.
format
(
score_abs_path
))
print
(
'
Result saved to: {}'
.
format
(
scoring_conf
[
'score_file'
]
))
paddlespeech/kws/exps/mdtc/train.py
浏览文件 @
f9761d53
...
...
@@ -11,77 +11,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
time
import
paddle
from
loss
import
max_pooling_loss
from
mdtc
import
KWSModel
from
mdtc
import
MDTC
import
yaml
from
paddleaudio.datasets
import
HeySnips
from
paddleaudio.utils
import
logger
from
paddleaudio.utils
import
Timer
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.models.loss
import
max_pooling_loss
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
def
collate_features
(
batch
):
# (key, feat, label)
collate_start
=
time
.
time
()
keys
=
[]
feats
=
[]
labels
=
[]
lengths
=
[]
for
sample
in
batch
:
keys
.
append
(
sample
[
0
])
feats
.
append
(
sample
[
1
])
labels
.
append
(
sample
[
2
])
lengths
.
append
(
sample
[
1
].
shape
[
0
])
max_length
=
max
(
lengths
)
for
i
in
range
(
len
(
feats
)):
feats
[
i
]
=
paddle
.
nn
.
functional
.
pad
(
feats
[
i
],
[
0
,
max_length
-
feats
[
i
].
shape
[
0
],
0
,
0
],
data_format
=
'NLC'
)
if
__name__
==
'__main__'
:
nranks
=
paddle
.
distributed
.
get_world_size
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
paddle
.
distributed
.
init_parallel_env
()
local_rank
=
paddle
.
distributed
.
get_rank
()
return
keys
,
paddle
.
stack
(
feats
),
paddle
.
to_tensor
(
labels
),
paddle
.
to_tensor
(
lengths
)
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
model_conf
=
config
[
'model'
]
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
training_conf
=
config
[
'training'
]
if
__name__
==
'__main__'
:
# Dataset
feat_conf
=
{
# 'n_mfcc': 80,
'n_mels'
:
80
,
'frame_shift'
:
10
,
'frame_length'
:
25
,
# 'dither': 1.0,
}
data_dir
=
'/ssd1/chenxiaojie06/datasets/hey_snips/hey_snips_research_6k_en_train_eval_clean_ter'
train_ds
=
HeySnips
(
data_dir
=
data_dir
,
mode
=
'train'
,
feat_type
=
'kaldi_fbank'
,
sample_rate
=
16000
,
**
feat_conf
)
dev_ds
=
HeySnips
(
data_dir
=
data_dir
,
mode
=
'dev'
,
feat_type
=
'kaldi_fbank'
,
sample_rate
=
16000
,
**
feat_conf
)
training_conf
=
{
'epochs'
:
100
,
'learning_rate'
:
0.001
,
'weight_decay'
:
0.00005
,
'num_workers'
:
16
,
'batch_size'
:
100
,
'checkpoint_dir'
:
'./checkpoint'
,
'save_freq'
:
10
,
'log_freq'
:
10
,
}
train_sampler
=
paddle
.
io
.
BatchSampler
(
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
train_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'train'
,
**
feat_conf
)
dev_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'dev'
,
**
feat_conf
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_ds
,
batch_size
=
training_conf
[
'batch_size'
],
shuffle
=
True
,
...
...
@@ -95,16 +65,11 @@ if __name__ == '__main__':
collate_fn
=
collate_features
,
)
# Model
backbone
=
MDTC
(
stack_num
=
3
,
stack_size
=
4
,
in_channels
=
80
,
res_channels
=
32
,
kernel_size
=
5
,
causal
=
True
,
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
1
)
backbone_class
=
dynamic_import
(
model_conf
[
'backbone'
])
backbone
=
backbone_class
(
**
model_conf
[
'config'
])
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
model_conf
[
'num_keywords'
])
model
=
paddle
.
DataParallel
(
model
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
5.0
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
training_conf
[
'grad_clip'
]
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
training_conf
[
'learning_rate'
],
weight_decay
=
training_conf
[
'weight_decay'
],
...
...
@@ -122,9 +87,7 @@ if __name__ == '__main__':
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
batch_start
=
time
.
time
()
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
# print('Fetch one batch: {:.4f}'.format(time.time()-batch_start))
keys
,
feats
,
labels
,
lengths
=
batch
logits
=
model
(
feats
)
loss
,
corrects
,
acc
=
criterion
(
logits
,
labels
,
lengths
)
...
...
@@ -144,7 +107,8 @@ if __name__ == '__main__':
timer
.
count
()
if
(
batch_idx
+
1
)
%
training_conf
[
'log_freq'
]
==
0
:
if
(
batch_idx
+
1
)
%
training_conf
[
'log_freq'
]
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
avg_loss
/=
training_conf
[
'log_freq'
]
avg_acc
=
num_corrects
/
num_samples
...
...
@@ -161,10 +125,9 @@ if __name__ == '__main__':
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
batch_start
=
time
.
time
()
if
epoch
%
training_conf
[
'save_freq'
]
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
'save_freq'
]
==
0
and
batch_idx
+
1
==
steps_per_epoch
and
local_rank
==
0
:
dev_sampler
=
paddle
.
io
.
BatchSampler
(
dev_ds
,
batch_size
=
training_conf
[
'batch_size'
],
...
...
@@ -197,7 +160,7 @@ if __name__ == '__main__':
# Save model
save_dir
=
os
.
path
.
join
(
training_conf
[
'checkpoint_dir'
],
'epoch_{}
_{:.4f}'
.
format
(
epoch
,
eval_acc
))
'epoch_{}
'
.
format
(
epoch
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
...
...
paddlespeech/kws/models/__init__.py
浏览文件 @
f9761d53
...
...
@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.mdtc
import
KWSModel
from
.mdtc
import
MDTC
paddlespeech/kws/models/loss.py
0 → 100644
浏览文件 @
f9761d53
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
paddle
def
fill_mask_elements
(
condition
,
value
,
x
):
assert
condition
.
shape
==
x
.
shape
values
=
paddle
.
ones_like
(
x
,
dtype
=
x
.
dtype
)
*
value
return
paddle
.
where
(
condition
,
values
,
x
)
def
max_pooling_loss
(
logits
:
paddle
.
Tensor
,
target
:
paddle
.
Tensor
,
lengths
:
paddle
.
Tensor
,
min_duration
:
int
=
0
):
mask
=
padding_mask
(
lengths
)
num_utts
=
logits
.
shape
[
0
]
num_keywords
=
logits
.
shape
[
2
]
loss
=
0.0
for
i
in
range
(
num_utts
):
for
j
in
range
(
num_keywords
):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if
target
[
i
]
==
j
:
# For the keyword, do max-polling
prob
=
logits
[
i
,
:,
j
]
m
=
mask
[
i
]
if
min_duration
>
0
:
m
[:
min_duration
]
=
True
prob
=
fill_mask_elements
(
m
,
0.0
,
prob
)
prob
=
paddle
.
clip
(
prob
,
1e-8
,
1.0
)
max_prob
=
prob
.
max
()
loss
+=
-
paddle
.
log
(
max_prob
)
else
:
# For other keywords or filler, do min-polling
prob
=
1
-
logits
[
i
,
:,
j
]
prob
=
fill_mask_elements
(
mask
[
i
],
1.0
,
prob
)
prob
=
paddle
.
clip
(
prob
,
1e-8
,
1.0
)
min_prob
=
prob
.
min
()
loss
+=
-
paddle
.
log
(
min_prob
)
loss
=
loss
/
num_utts
# Compute accuracy of current batch
mask
=
mask
.
unsqueeze
(
-
1
)
logits
=
fill_mask_elements
(
mask
,
0.0
,
logits
)
max_logits
=
logits
.
max
(
1
)
num_correct
=
0
for
i
in
range
(
num_utts
):
max_p
=
max_logits
[
i
].
max
(
0
).
item
()
idx
=
max_logits
[
i
].
argmax
(
0
).
item
()
# Predict correct as the i'th keyword
if
max_p
>
0.5
and
idx
==
target
[
i
].
item
():
num_correct
+=
1
# Predict correct as the filler, filler id < 0
if
max_p
<
0.5
and
target
[
i
].
item
()
<
0
:
num_correct
+=
1
acc
=
num_correct
/
num_utts
# acc = 0.0
return
loss
,
num_correct
,
acc
def
padding_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
batch_size
=
lengths
.
shape
[
0
]
max_len
=
int
(
lengths
.
max
().
item
())
seq
=
paddle
.
arange
(
max_len
,
dtype
=
paddle
.
int64
)
seq
=
seq
.
expand
((
batch_size
,
max_len
))
return
seq
>=
lengths
.
unsqueeze
(
1
)
paddlespeech/kws/models/mdtc.py
浏览文件 @
f9761d53
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
...
...
@@ -163,7 +164,7 @@ class MDTC(nn.Layer):
in_channels
:
int
,
res_channels
:
int
,
kernel_size
:
int
,
causal
:
bool
,
):
causal
:
bool
=
True
,
):
super
(
MDTC
,
self
).
__init__
()
assert
kernel_size
%
2
==
1
self
.
kernel_size
=
kernel_size
...
...
@@ -230,17 +231,3 @@ class KWSModel(nn.Layer):
outputs
=
self
.
backbone
(
x
)
outputs
=
self
.
linear
(
outputs
)
return
self
.
activation
(
outputs
)
if
__name__
==
'__main__'
:
paddle
.
set_device
(
'cpu'
)
from
paddleaudio.features
import
LogMelSpectrogram
mdtc
=
MDTC
(
3
,
4
,
80
,
32
,
5
,
causal
=
True
)
x
=
paddle
.
randn
(
shape
=
(
32
,
16000
*
5
))
feature_extractor
=
LogMelSpectrogram
(
sr
=
16000
,
n_fft
=
512
,
n_mels
=
80
)
feats
=
feature_extractor
(
x
).
transpose
([
0
,
2
,
1
])
print
(
feats
.
shape
)
res
,
_
=
mdtc
(
feats
)
print
(
res
.
shape
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录