Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
abb15ac6
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
abb15ac6
编写于
4月 25, 2022
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update KWS example.
上级
2b44f374
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
258 addition
and
148 deletion
+258
-148
examples/hey_snips/kws0/conf/mdtc.yaml
examples/hey_snips/kws0/conf/mdtc.yaml
+45
-35
examples/hey_snips/kws0/local/plot.sh
examples/hey_snips/kws0/local/plot.sh
+24
-1
examples/hey_snips/kws0/local/score.sh
examples/hey_snips/kws0/local/score.sh
+24
-2
examples/hey_snips/kws0/local/train.sh
examples/hey_snips/kws0/local/train.sh
+20
-2
examples/hey_snips/kws0/run.sh
examples/hey_snips/kws0/run.sh
+8
-2
paddlespeech/kws/exps/mdtc/compute_det.py
paddlespeech/kws/exps/mdtc/compute_det.py
+42
-25
paddlespeech/kws/exps/mdtc/plot_det_curve.py
paddlespeech/kws/exps/mdtc/plot_det_curve.py
+6
-12
paddlespeech/kws/exps/mdtc/score.py
paddlespeech/kws/exps/mdtc/score.py
+41
-30
paddlespeech/kws/exps/mdtc/train.py
paddlespeech/kws/exps/mdtc/train.py
+48
-39
未找到文件。
examples/hey_snips/kws0/conf/mdtc.yaml
浏览文件 @
abb15ac6
data
:
# https://yaml.org/type/float.html
data_dir
:
'
/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
###########################################
dataset
:
'
paddleaudio.datasets:HeySnips'
# Data #
###########################################
dataset
:
'
paddleaudio.datasets:HeySnips'
data_dir
:
'
/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
model
:
############################################
num_keywords
:
1
# Network Architecture #
backbone
:
'
paddlespeech.kws.models:MDTC'
############################################
config
:
backbone
:
'
paddlespeech.kws.models:MDTC'
stack_num
:
3
num_keywords
:
1
stack_size
:
4
stack_num
:
3
in_channels
:
80
stack_size
:
4
res_channels
:
32
in_channels
:
80
kernel_size
:
5
res_channels
:
32
kernel_size
:
5
feature
:
###########################################
feat_type
:
'
kaldi_fbank'
# Feature #
sample_rate
:
16000
###########################################
frame_shift
:
10
feat_type
:
'
kaldi_fbank'
frame_length
:
25
sample_rate
:
16000
n_mels
:
80
frame_shift
:
10
frame_length
:
25
n_mels
:
80
training
:
###########################################
epochs
:
100
# Training #
num_workers
:
16
###########################################
batch_size
:
100
epochs
:
100
checkpoint_dir
:
'
./checkpoint'
num_workers
:
16
save_freq
:
10
batch_size
:
100
log_freq
:
10
checkpoint_dir
:
'
./checkpoint'
learning_rate
:
0.001
save_freq
:
10
weight_decay
:
0.00005
log_freq
:
10
grad_clip
:
5.0
learning_rate
:
0.001
weight_decay
:
0.00005
grad_clip
:
5.0
scoring
:
###########################################
batch_size
:
100
# Scoring #
num_workers
:
16
###########################################
checkpoint
:
'
./checkpoint/epoch_100/model.pdparams'
batch_size
:
100
score_file
:
'
./scores.txt'
num_workers
:
16
stats_file
:
'
./stats.0.txt'
checkpoint
:
'
./checkpoint/epoch_100/model.pdparams'
img_file
:
'
./det.png'
score_file
:
'
./scores.txt'
\ No newline at end of file
stats_file
:
'
./stats.0.txt'
img_file
:
'
./det.png'
\ No newline at end of file
examples/hey_snips/kws0/local/plot.sh
浏览文件 @
abb15ac6
#!/bin/bash
#!/bin/bash
python3
${
BIN_DIR
}
/plot_det_curve.py
--cfg_path
=
$1
--keyword
HeySnips
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path checkpoint output_file"
exit
-1
fi
keyword
=
$1
stats_file
=
$2
img_file
=
$3
python3
${
BIN_DIR
}
/plot_det_curve.py
--keyword_label
${
keyword
}
--stats_file
${
stats_file
}
--img_file
${
img_file
}
examples/hey_snips/kws0/local/score.sh
浏览文件 @
abb15ac6
#!/bin/bash
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python3
${
BIN_DIR
}
/score.py
--cfg_path
=
$1
if
[
$#
!=
4
]
;
then
echo
"usage:
${
0
}
checkpoint score_file stats_file"
exit
-1
fi
python3
${
BIN_DIR
}
/compute_det.py
--cfg_path
=
$1
cfg_path
=
$1
ckpt
=
$2
score_file
=
$3
stats_file
=
$4
python3
${
BIN_DIR
}
/score.py
--config
${
cfg_path
}
--ckpt
${
ckpt
}
--score_file
${
score_file
}
||
exit
-1
python3
${
BIN_DIR
}
/compute_det.py
--config
${
cfg_path
}
--score_file
${
score_file
}
--stats_file
${
stats_file
}
||
exit
-1
examples/hey_snips/kws0/local/train.sh
浏览文件 @
abb15ac6
#!/bin/bash
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
num_gpus config_path"
exit
-1
fi
ngpu
=
$1
ngpu
=
$1
cfg_path
=
$2
cfg_path
=
$2
if
[
${
ngpu
}
-gt
0
]
;
then
if
[
${
ngpu
}
-gt
0
]
;
then
python3
-m
paddle.distributed.launch
--gpus
$CUDA_VISIBLE_DEVICES
${
BIN_DIR
}
/train.py
\
python3
-m
paddle.distributed.launch
--gpus
$CUDA_VISIBLE_DEVICES
${
BIN_DIR
}
/train.py
\
--c
fg_path
${
cfg_path
}
--c
onfig
${
cfg_path
}
else
else
echo
"set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
echo
"set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
python3
${
BIN_DIR
}
/train.py
\
python3
${
BIN_DIR
}
/train.py
\
--c
fg_path
${
cfg_path
}
--c
onfig
${
cfg_path
}
fi
fi
examples/hey_snips/kws0/run.sh
浏览文件 @
abb15ac6
...
@@ -32,10 +32,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
...
@@ -32,10 +32,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/train.sh
${
ngpu
}
${
cfg_path
}
||
exit
-1
./local/train.sh
${
ngpu
}
${
cfg_path
}
||
exit
-1
fi
fi
ckpt
=
./checkpoint/epoch_100/model.pdparams
score_file
=
./scores.txt
stats_file
=
./stats.0.txt
img_file
=
./det.png
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
./local/score.sh
${
cfg_path
}
||
exit
-1
./local/score.sh
${
cfg_path
}
${
ckpt
}
${
score_file
}
${
stats_file
}
||
exit
-1
fi
fi
keyword
=
HeySnips
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
./local/plot.sh
${
cfg_path
}
||
exit
-1
./local/plot.sh
${
keyword
}
${
stats_file
}
${
img_file
}
||
exit
-1
fi
fi
\ No newline at end of file
paddlespeech/kws/exps/mdtc/compute_det.py
浏览文件 @
abb15ac6
...
@@ -12,24 +12,15 @@
...
@@ -12,24 +12,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
os
import
paddle
import
paddle
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--keyword_index'
,
type
=
int
,
default
=
0
,
help
=
'keyword index'
)
parser
.
add_argument
(
'--step'
,
type
=
float
,
default
=
0.01
,
help
=
'threshold step of trigger score'
)
parser
.
add_argument
(
'--window_shift'
,
type
=
int
,
default
=
50
,
help
=
'window_shift is used to skip the frames after triggered'
)
args
=
parser
.
parse_args
()
# yapf: enable
def
load_label_and_score
(
keyword_index
:
int
,
def
load_label_and_score
(
keyword_index
:
int
,
ds
:
paddle
.
io
.
Dataset
,
ds
:
paddle
.
io
.
Dataset
,
...
@@ -61,26 +52,52 @@ def load_label_and_score(keyword_index: int,
...
@@ -61,26 +52,52 @@ def load_label_and_score(keyword_index: int,
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
parser
=
default_argument_parser
()
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
parser
.
add_argument
(
config
=
yaml
.
safe_load
(
f
)
'--keyword_index'
,
type
=
int
,
default
=
0
,
help
=
'keyword index'
)
parser
.
add_argument
(
'--step'
,
type
=
float
,
default
=
0.01
,
help
=
'threshold step of trigger score'
)
parser
.
add_argument
(
'--window_shift'
,
type
=
int
,
default
=
50
,
help
=
'window_shift is used to skip the frames after triggered'
)
parser
.
add_argument
(
"--score_file"
,
type
=
str
,
required
=
True
,
help
=
'output file of trigger scores'
)
parser
.
add_argument
(
'--stats_file'
,
type
=
str
,
default
=
'./stats.0.txt'
,
help
=
'output file of detection error tradeoff'
)
args
=
parser
.
parse_args
()
data_conf
=
config
[
'data'
]
# https://yaml.org/type/float.html
feat_conf
=
config
[
'feature'
]
config
=
CfgNode
(
new_allowed
=
True
)
scoring_conf
=
config
[
'scoring'
]
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
# Dataset
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
ds_class
=
dynamic_import
(
config
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'test'
,
**
feat_conf
)
test_ds
=
ds_class
(
data_dir
=
config
[
'data_dir'
],
score_file
=
os
.
path
.
abspath
(
scoring_conf
[
'score_file'
])
mode
=
'test'
,
stats_file
=
os
.
path
.
abspath
(
scoring_conf
[
'stats_file'
])
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
keyword_table
,
filler_table
,
filler_duration
=
load_label_and_score
(
keyword_table
,
filler_table
,
filler_duration
=
load_label_and_score
(
args
.
keyword
,
test_ds
,
score_file
)
args
.
keyword
_index
,
test_ds
,
args
.
score_file
)
print
(
'Filler total duration Hours: {}'
.
format
(
filler_duration
/
3600.0
))
print
(
'Filler total duration Hours: {}'
.
format
(
filler_duration
/
3600.0
))
pbar
=
tqdm
(
total
=
int
(
1.0
/
args
.
step
))
pbar
=
tqdm
(
total
=
int
(
1.0
/
args
.
step
))
with
open
(
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
with
open
(
args
.
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
keyword_index
=
args
.
keyword_index
keyword_index
=
args
.
keyword_index
threshold
=
0.0
threshold
=
0.0
while
threshold
<=
1.0
:
while
threshold
<=
1.0
:
...
@@ -113,4 +130,4 @@ if __name__ == '__main__':
...
@@ -113,4 +130,4 @@ if __name__ == '__main__':
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
close
()
pbar
.
close
()
print
(
'DET saved to: {}'
.
format
(
stats_file
))
print
(
'DET saved to: {}'
.
format
(
args
.
stats_file
))
paddlespeech/kws/exps/mdtc/plot_det_curve.py
浏览文件 @
abb15ac6
...
@@ -17,12 +17,12 @@ import os
...
@@ -17,12 +17,12 @@ import os
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
yaml
# yapf: disable
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--keyword_label'
,
type
=
str
,
required
=
True
,
help
=
'keyword string shown on image'
)
parser
.
add_argument
(
"--keyword"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--stats_file'
,
type
=
str
,
required
=
True
,
help
=
'output file of detection error tradeoff'
)
parser
.
add_argument
(
'--img_file'
,
type
=
str
,
default
=
'./det.png'
,
help
=
'output det image'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# yapf: enable
# yapf: enable
...
@@ -61,14 +61,8 @@ def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
...
@@ -61,14 +61,8 @@ def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
img_file
=
os
.
path
.
abspath
(
args
.
img_file
)
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
stats_file
=
os
.
path
.
abspath
(
args
.
stats_file
)
config
=
yaml
.
safe_load
(
f
)
plot_det_curve
([
args
.
keyword_label
],
stats_file
,
img_file
,
10
,
2
,
10
,
2
)
scoring_conf
=
config
[
'scoring'
]
img_file
=
os
.
path
.
abspath
(
scoring_conf
[
'img_file'
])
stats_file
=
os
.
path
.
abspath
(
scoring_conf
[
'stats_file'
])
keywords
=
[
args
.
keyword
]
plot_det_curve
(
keywords
,
stats_file
,
img_file
,
10
,
2
,
10
,
2
)
print
(
'DET curve image saved to: {}'
.
format
(
img_file
))
print
(
'DET curve image saved to: {}'
.
format
(
img_file
))
paddlespeech/kws/exps/mdtc/score.py
浏览文件 @
abb15ac6
...
@@ -12,55 +12,67 @@
...
@@ -12,55 +12,67 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
paddle
import
paddle
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
parser
=
default_argument_parser
()
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
parser
.
add_argument
(
config
=
yaml
.
safe_load
(
f
)
"--ckpt"
,
type
=
str
,
required
=
True
,
help
=
'model checkpoint for evaluation.'
)
parser
.
add_argument
(
"--score_file"
,
type
=
str
,
default
=
'./scores.txt'
,
help
=
'output file of trigger scores'
)
args
=
parser
.
parse_args
()
model_conf
=
config
[
'model'
]
# https://yaml.org/type/float.html
data_conf
=
config
[
'data'
]
config
=
CfgNode
(
new_allowed
=
True
)
feat_conf
=
config
[
'feature'
]
if
args
.
config
:
scoring_conf
=
config
[
'scoring'
]
config
.
merge_from_file
(
args
.
config
)
# Dataset
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
ds_class
=
dynamic_import
(
config
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'test'
,
**
feat_conf
)
test_ds
=
ds_class
(
data_dir
=
config
[
'data_dir'
],
mode
=
'test'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
test_sampler
=
paddle
.
io
.
BatchSampler
(
test_sampler
=
paddle
.
io
.
BatchSampler
(
test_ds
,
batch_size
=
scoring_conf
[
'batch_size'
],
drop_last
=
False
)
test_ds
,
batch_size
=
config
[
'batch_size'
],
drop_last
=
False
)
test_loader
=
paddle
.
io
.
DataLoader
(
test_loader
=
paddle
.
io
.
DataLoader
(
test_ds
,
test_ds
,
batch_sampler
=
test_sampler
,
batch_sampler
=
test_sampler
,
num_workers
=
scoring_conf
[
'num_workers'
],
num_workers
=
config
[
'num_workers'
],
return_list
=
True
,
return_list
=
True
,
use_buffer_reader
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
collate_fn
=
collate_features
,
)
# Model
# Model
backbone_class
=
dynamic_import
(
model_conf
[
'backbone'
])
backbone_class
=
dynamic_import
(
config
[
'backbone'
])
backbone
=
backbone_class
(
**
model_conf
[
'config'
])
backbone
=
backbone_class
(
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
model_conf
[
'num_keywords'
])
stack_num
=
config
[
'stack_num'
],
model
.
set_state_dict
(
paddle
.
load
(
scoring_conf
[
'checkpoint'
]))
stack_size
=
config
[
'stack_size'
],
in_channels
=
config
[
'in_channels'
],
res_channels
=
config
[
'res_channels'
],
kernel_size
=
config
[
'kernel_size'
],
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
config
[
'num_keywords'
])
model
.
set_state_dict
(
paddle
.
load
(
args
.
ckpt
))
model
.
eval
()
model
.
eval
()
with
paddle
.
no_grad
(),
open
(
with
paddle
.
no_grad
(),
open
(
args
.
score_file
,
'w'
,
encoding
=
'utf8'
)
as
f
:
scoring_conf
[
'score_file'
],
'w'
,
encoding
=
'utf8'
)
as
fout
:
for
batch_idx
,
batch
in
enumerate
(
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
test_loader
,
total
=
len
(
test_loader
))):
tqdm
(
test_loader
,
total
=
len
(
test_loader
))):
keys
,
feats
,
labels
,
lengths
=
batch
keys
,
feats
,
labels
,
lengths
=
batch
...
@@ -73,7 +85,6 @@ if __name__ == '__main__':
...
@@ -73,7 +85,6 @@ if __name__ == '__main__':
keyword_scores
=
score
[:,
keyword_i
]
keyword_scores
=
score
[:,
keyword_i
]
score_frames
=
' '
.
join
(
score_frames
=
' '
.
join
(
[
'{:.6f}'
.
format
(
x
)
for
x
in
keyword_scores
.
tolist
()])
[
'{:.6f}'
.
format
(
x
)
for
x
in
keyword_scores
.
tolist
()])
fout
.
write
(
f
.
write
(
'{} {} {}
\n
'
.
format
(
key
,
keyword_i
,
score_frames
))
'{} {} {}
\n
'
.
format
(
key
,
keyword_i
,
score_frames
))
print
(
'Result saved to: {}'
.
format
(
scoring_conf
[
'score_file'
]
))
print
(
'Result saved to: {}'
.
format
(
args
.
score_file
))
paddlespeech/kws/exps/mdtc/train.py
浏览文件 @
abb15ac6
...
@@ -11,77 +11,88 @@
...
@@ -11,77 +11,88 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
os
import
os
import
paddle
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
paddleaudio.utils
import
logger
from
paddleaudio.utils
import
logger
from
paddleaudio.utils
import
Timer
from
paddleaudio.utils
import
Timer
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.models.loss
import
max_pooling_loss
from
paddlespeech.kws.models.loss
import
max_pooling_loss
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
default_argument_parser
()
args
=
parser
.
parse_args
()
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
nranks
=
paddle
.
distributed
.
get_world_size
()
nranks
=
paddle
.
distributed
.
get_world_size
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
local_rank
=
paddle
.
distributed
.
get_rank
()
local_rank
=
paddle
.
distributed
.
get_rank
()
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
model_conf
=
config
[
'model'
]
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
training_conf
=
config
[
'training'
]
# Dataset
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
ds_class
=
dynamic_import
(
config
[
'dataset'
])
train_ds
=
ds_class
(
train_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'train'
,
**
feat_conf
)
data_dir
=
config
[
'data_dir'
],
dev_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'dev'
,
**
feat_conf
)
mode
=
'train'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
dev_ds
=
ds_class
(
data_dir
=
config
[
'data_dir'
],
mode
=
'dev'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_ds
,
train_ds
,
batch_size
=
training_conf
[
'batch_size'
],
batch_size
=
config
[
'batch_size'
],
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
)
drop_last
=
False
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_loader
=
paddle
.
io
.
DataLoader
(
train_ds
,
train_ds
,
batch_sampler
=
train_sampler
,
batch_sampler
=
train_sampler
,
num_workers
=
training_conf
[
'num_workers'
],
num_workers
=
config
[
'num_workers'
],
return_list
=
True
,
return_list
=
True
,
use_buffer_reader
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
collate_fn
=
collate_features
,
)
# Model
# Model
backbone_class
=
dynamic_import
(
model_conf
[
'backbone'
])
backbone_class
=
dynamic_import
(
config
[
'backbone'
])
backbone
=
backbone_class
(
**
model_conf
[
'config'
])
backbone
=
backbone_class
(
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
model_conf
[
'num_keywords'
])
stack_num
=
config
[
'stack_num'
],
stack_size
=
config
[
'stack_size'
],
in_channels
=
config
[
'in_channels'
],
res_channels
=
config
[
'res_channels'
],
kernel_size
=
config
[
'kernel_size'
],
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
config
[
'num_keywords'
])
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
training_conf
[
'grad_clip'
])
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
config
[
'grad_clip'
])
optimizer
=
paddle
.
optimizer
.
Adam
(
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
training_conf
[
'learning_rate'
],
learning_rate
=
config
[
'learning_rate'
],
weight_decay
=
training_conf
[
'weight_decay'
],
weight_decay
=
config
[
'weight_decay'
],
parameters
=
model
.
parameters
(),
parameters
=
model
.
parameters
(),
grad_clip
=
clip
)
grad_clip
=
clip
)
criterion
=
max_pooling_loss
criterion
=
max_pooling_loss
steps_per_epoch
=
len
(
train_sampler
)
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
training_conf
[
'epochs'
])
timer
=
Timer
(
steps_per_epoch
*
config
[
'epochs'
])
timer
.
start
()
timer
.
start
()
for
epoch
in
range
(
1
,
training_conf
[
'epochs'
]
+
1
):
for
epoch
in
range
(
1
,
config
[
'epochs'
]
+
1
):
model
.
train
()
model
.
train
()
avg_loss
=
0
avg_loss
=
0
...
@@ -107,15 +118,13 @@ if __name__ == '__main__':
...
@@ -107,15 +118,13 @@ if __name__ == '__main__':
timer
.
count
()
timer
.
count
()
if
(
batch_idx
+
1
if
(
batch_idx
+
1
)
%
config
[
'log_freq'
]
==
0
and
local_rank
==
0
:
)
%
training_conf
[
'log_freq'
]
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
avg_loss
/=
training_conf
[
'log_freq'
]
avg_loss
/=
config
[
'log_freq'
]
avg_acc
=
num_corrects
/
num_samples
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Epoch={}/{}, Step={}/{}'
.
format
(
print_msg
=
'Epoch={}/{}, Step={}/{}'
.
format
(
epoch
,
training_conf
[
'epochs'
],
batch_idx
+
1
,
epoch
,
config
[
'epochs'
],
batch_idx
+
1
,
steps_per_epoch
)
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.6f} step/sec={:.2f} | ETA {}'
.
format
(
print_msg
+=
' lr={:.6f} step/sec={:.2f} | ETA {}'
.
format
(
...
@@ -126,17 +135,17 @@ if __name__ == '__main__':
...
@@ -126,17 +135,17 @@ if __name__ == '__main__':
num_corrects
=
0
num_corrects
=
0
num_samples
=
0
num_samples
=
0
if
epoch
%
training_conf
[
if
epoch
%
config
[
'save_freq'
]
==
0
and
batch_idx
+
1
==
steps_per_epoch
and
local_rank
==
0
:
'save_freq'
]
==
0
and
batch_idx
+
1
==
steps_per_epoch
and
local_rank
==
0
:
dev_sampler
=
paddle
.
io
.
BatchSampler
(
dev_sampler
=
paddle
.
io
.
BatchSampler
(
dev_ds
,
dev_ds
,
batch_size
=
training_conf
[
'batch_size'
],
batch_size
=
config
[
'batch_size'
],
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
)
drop_last
=
False
)
dev_loader
=
paddle
.
io
.
DataLoader
(
dev_loader
=
paddle
.
io
.
DataLoader
(
dev_ds
,
dev_ds
,
batch_sampler
=
dev_sampler
,
batch_sampler
=
dev_sampler
,
num_workers
=
training_conf
[
'num_workers'
],
num_workers
=
config
[
'num_workers'
],
return_list
=
True
,
return_list
=
True
,
use_buffer_reader
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
collate_fn
=
collate_features
,
)
...
@@ -159,7 +168,7 @@ if __name__ == '__main__':
...
@@ -159,7 +168,7 @@ if __name__ == '__main__':
logger
.
eval
(
print_msg
)
logger
.
eval
(
print_msg
)
# Save model
# Save model
save_dir
=
os
.
path
.
join
(
training_conf
[
'checkpoint_dir'
],
save_dir
=
os
.
path
.
join
(
config
[
'checkpoint_dir'
],
'epoch_{}'
.
format
(
epoch
))
'epoch_{}'
.
format
(
epoch
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
paddle
.
save
(
model
.
state_dict
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录