Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
2504ccbf
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2504ccbf
编写于
4月 25, 2022
作者:
H
Hui Zhang
提交者:
GitHub
4月 25, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1783 from KPatr1ck/kws
[KWS]Update KWS example.
上级
2b44f374
abb15ac6
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
258 addition
and
148 deletion
+258
-148
examples/hey_snips/kws0/conf/mdtc.yaml
examples/hey_snips/kws0/conf/mdtc.yaml
+45
-35
examples/hey_snips/kws0/local/plot.sh
examples/hey_snips/kws0/local/plot.sh
+24
-1
examples/hey_snips/kws0/local/score.sh
examples/hey_snips/kws0/local/score.sh
+24
-2
examples/hey_snips/kws0/local/train.sh
examples/hey_snips/kws0/local/train.sh
+20
-2
examples/hey_snips/kws0/run.sh
examples/hey_snips/kws0/run.sh
+8
-2
paddlespeech/kws/exps/mdtc/compute_det.py
paddlespeech/kws/exps/mdtc/compute_det.py
+42
-25
paddlespeech/kws/exps/mdtc/plot_det_curve.py
paddlespeech/kws/exps/mdtc/plot_det_curve.py
+6
-12
paddlespeech/kws/exps/mdtc/score.py
paddlespeech/kws/exps/mdtc/score.py
+41
-30
paddlespeech/kws/exps/mdtc/train.py
paddlespeech/kws/exps/mdtc/train.py
+48
-39
未找到文件。
examples/hey_snips/kws0/conf/mdtc.yaml
浏览文件 @
2504ccbf
data
:
data_dir
:
'
/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
dataset
:
'
paddleaudio.datasets:HeySnips'
# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
dataset
:
'
paddleaudio.datasets:HeySnips'
data_dir
:
'
/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
model
:
num_keywords
:
1
backbone
:
'
paddlespeech.kws.models:MDTC'
config
:
stack_num
:
3
stack_size
:
4
in_channels
:
80
res_channels
:
32
kernel_size
:
5
############################################
# Network Architecture #
############################################
backbone
:
'
paddlespeech.kws.models:MDTC'
num_keywords
:
1
stack_num
:
3
stack_size
:
4
in_channels
:
80
res_channels
:
32
kernel_size
:
5
feature
:
feat_type
:
'
kaldi_fbank'
sample_rate
:
16000
frame_shift
:
10
frame_length
:
25
n_mels
:
80
###########################################
# Feature #
###########################################
feat_type
:
'
kaldi_fbank'
sample_rate
:
16000
frame_shift
:
10
frame_length
:
25
n_mels
:
80
training
:
epochs
:
100
num_workers
:
16
batch_size
:
100
checkpoint_dir
:
'
./checkpoint'
save_freq
:
10
log_freq
:
10
learning_rate
:
0.001
weight_decay
:
0.00005
grad_clip
:
5.0
###########################################
# Training #
###########################################
epochs
:
100
num_workers
:
16
batch_size
:
100
checkpoint_dir
:
'
./checkpoint'
save_freq
:
10
log_freq
:
10
learning_rate
:
0.001
weight_decay
:
0.00005
grad_clip
:
5.0
scoring
:
batch_size
:
100
num_workers
:
16
checkpoint
:
'
./checkpoint/epoch_100/model.pdparams'
score_file
:
'
./scores.txt'
stats_file
:
'
./stats.0.txt'
img_file
:
'
./det.png'
\ No newline at end of file
###########################################
# Scoring #
###########################################
batch_size
:
100
num_workers
:
16
checkpoint
:
'
./checkpoint/epoch_100/model.pdparams'
score_file
:
'
./scores.txt'
stats_file
:
'
./stats.0.txt'
img_file
:
'
./det.png'
\ No newline at end of file
examples/hey_snips/kws0/local/plot.sh
浏览文件 @
2504ccbf
#!/bin/bash
python3
${
BIN_DIR
}
/plot_det_curve.py
--cfg_path
=
$1
--keyword
HeySnips
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path checkpoint output_file"
exit
-1
fi
keyword
=
$1
stats_file
=
$2
img_file
=
$3
python3
${
BIN_DIR
}
/plot_det_curve.py
--keyword_label
${
keyword
}
--stats_file
${
stats_file
}
--img_file
${
img_file
}
examples/hey_snips/kws0/local/score.sh
浏览文件 @
2504ccbf
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python3
${
BIN_DIR
}
/score.py
--cfg_path
=
$1
if
[
$#
!=
4
]
;
then
echo
"usage:
${
0
}
checkpoint score_file stats_file"
exit
-1
fi
python3
${
BIN_DIR
}
/compute_det.py
--cfg_path
=
$1
cfg_path
=
$1
ckpt
=
$2
score_file
=
$3
stats_file
=
$4
python3
${
BIN_DIR
}
/score.py
--config
${
cfg_path
}
--ckpt
${
ckpt
}
--score_file
${
score_file
}
||
exit
-1
python3
${
BIN_DIR
}
/compute_det.py
--config
${
cfg_path
}
--score_file
${
score_file
}
--stats_file
${
stats_file
}
||
exit
-1
examples/hey_snips/kws0/local/train.sh
浏览文件 @
2504ccbf
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
num_gpus config_path"
exit
-1
fi
ngpu
=
$1
cfg_path
=
$2
if
[
${
ngpu
}
-gt
0
]
;
then
python3
-m
paddle.distributed.launch
--gpus
$CUDA_VISIBLE_DEVICES
${
BIN_DIR
}
/train.py
\
--c
fg_path
${
cfg_path
}
--c
onfig
${
cfg_path
}
else
echo
"set CUDA_VISIBLE_DEVICES to enable multi-gpus trainning."
python3
${
BIN_DIR
}
/train.py
\
--c
fg_path
${
cfg_path
}
--c
onfig
${
cfg_path
}
fi
examples/hey_snips/kws0/run.sh
浏览文件 @
2504ccbf
...
...
@@ -32,10 +32,16 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/train.sh
${
ngpu
}
${
cfg_path
}
||
exit
-1
fi
ckpt
=
./checkpoint/epoch_100/model.pdparams
score_file
=
./scores.txt
stats_file
=
./stats.0.txt
img_file
=
./det.png
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
./local/score.sh
${
cfg_path
}
||
exit
-1
./local/score.sh
${
cfg_path
}
${
ckpt
}
${
score_file
}
${
stats_file
}
||
exit
-1
fi
keyword
=
HeySnips
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
./local/plot.sh
${
cfg_path
}
||
exit
-1
./local/plot.sh
${
keyword
}
${
stats_file
}
${
img_file
}
||
exit
-1
fi
\ No newline at end of file
paddlespeech/kws/exps/mdtc/compute_det.py
浏览文件 @
2504ccbf
...
...
@@ -12,24 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
paddle
import
yaml
from
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--keyword_index'
,
type
=
int
,
default
=
0
,
help
=
'keyword index'
)
parser
.
add_argument
(
'--step'
,
type
=
float
,
default
=
0.01
,
help
=
'threshold step of trigger score'
)
parser
.
add_argument
(
'--window_shift'
,
type
=
int
,
default
=
50
,
help
=
'window_shift is used to skip the frames after triggered'
)
args
=
parser
.
parse_args
()
# yapf: enable
def
load_label_and_score
(
keyword_index
:
int
,
ds
:
paddle
.
io
.
Dataset
,
...
...
@@ -61,26 +52,52 @@ def load_label_and_score(keyword_index: int,
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
parser
=
default_argument_parser
()
parser
.
add_argument
(
'--keyword_index'
,
type
=
int
,
default
=
0
,
help
=
'keyword index'
)
parser
.
add_argument
(
'--step'
,
type
=
float
,
default
=
0.01
,
help
=
'threshold step of trigger score'
)
parser
.
add_argument
(
'--window_shift'
,
type
=
int
,
default
=
50
,
help
=
'window_shift is used to skip the frames after triggered'
)
parser
.
add_argument
(
"--score_file"
,
type
=
str
,
required
=
True
,
help
=
'output file of trigger scores'
)
parser
.
add_argument
(
'--stats_file'
,
type
=
str
,
default
=
'./stats.0.txt'
,
help
=
'output file of detection error tradeoff'
)
args
=
parser
.
parse_args
()
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
scoring_conf
=
config
[
'scoring'
]
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'test'
,
**
feat_conf
)
score_file
=
os
.
path
.
abspath
(
scoring_conf
[
'score_file'
])
stats_file
=
os
.
path
.
abspath
(
scoring_conf
[
'stats_file'
])
ds_class
=
dynamic_import
(
config
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
config
[
'data_dir'
],
mode
=
'test'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
keyword_table
,
filler_table
,
filler_duration
=
load_label_and_score
(
args
.
keyword
,
test_ds
,
score_file
)
args
.
keyword
_index
,
test_ds
,
args
.
score_file
)
print
(
'Filler total duration Hours: {}'
.
format
(
filler_duration
/
3600.0
))
pbar
=
tqdm
(
total
=
int
(
1.0
/
args
.
step
))
with
open
(
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
with
open
(
args
.
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
keyword_index
=
args
.
keyword_index
threshold
=
0.0
while
threshold
<=
1.0
:
...
...
@@ -113,4 +130,4 @@ if __name__ == '__main__':
pbar
.
update
(
1
)
pbar
.
close
()
print
(
'DET saved to: {}'
.
format
(
stats_file
))
print
(
'DET saved to: {}'
.
format
(
args
.
stats_file
))
paddlespeech/kws/exps/mdtc/plot_det_curve.py
浏览文件 @
2504ccbf
...
...
@@ -17,12 +17,12 @@ import os
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
yaml
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--keyword"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--keyword_label'
,
type
=
str
,
required
=
True
,
help
=
'keyword string shown on image'
)
parser
.
add_argument
(
'--stats_file'
,
type
=
str
,
required
=
True
,
help
=
'output file of detection error tradeoff'
)
parser
.
add_argument
(
'--img_file'
,
type
=
str
,
default
=
'./det.png'
,
help
=
'output det image'
)
args
=
parser
.
parse_args
()
# yapf: enable
...
...
@@ -61,14 +61,8 @@ def plot_det_curve(keywords, stats_file, figure_file, xlim, x_step, ylim,
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
scoring_conf
=
config
[
'scoring'
]
img_file
=
os
.
path
.
abspath
(
scoring_conf
[
'img_file'
])
stats_file
=
os
.
path
.
abspath
(
scoring_conf
[
'stats_file'
])
keywords
=
[
args
.
keyword
]
plot_det_curve
(
keywords
,
stats_file
,
img_file
,
10
,
2
,
10
,
2
)
img_file
=
os
.
path
.
abspath
(
args
.
img_file
)
stats_file
=
os
.
path
.
abspath
(
args
.
stats_file
)
plot_det_curve
([
args
.
keyword_label
],
stats_file
,
img_file
,
10
,
2
,
10
,
2
)
print
(
'DET curve image saved to: {}'
.
format
(
img_file
))
paddlespeech/kws/exps/mdtc/score.py
浏览文件 @
2504ccbf
...
...
@@ -12,55 +12,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import
argparse
import
os
import
paddle
import
yaml
from
tqdm
import
tqdm
from
yacs.config
import
CfgNode
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
if
__name__
==
'__main__'
:
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
required
=
True
,
help
=
'model checkpoint for evaluation.'
)
parser
.
add_argument
(
"--score_file"
,
type
=
str
,
default
=
'./scores.txt'
,
help
=
'output file of trigger scores'
)
args
=
parser
.
parse_args
()
model_conf
=
config
[
'model'
]
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
scoring_conf
=
config
[
'scoring'
]
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'test'
,
**
feat_conf
)
ds_class
=
dynamic_import
(
config
[
'dataset'
])
test_ds
=
ds_class
(
data_dir
=
config
[
'data_dir'
],
mode
=
'test'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
test_sampler
=
paddle
.
io
.
BatchSampler
(
test_ds
,
batch_size
=
scoring_conf
[
'batch_size'
],
drop_last
=
False
)
test_ds
,
batch_size
=
config
[
'batch_size'
],
drop_last
=
False
)
test_loader
=
paddle
.
io
.
DataLoader
(
test_ds
,
batch_sampler
=
test_sampler
,
num_workers
=
scoring_conf
[
'num_workers'
],
num_workers
=
config
[
'num_workers'
],
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
# Model
backbone_class
=
dynamic_import
(
model_conf
[
'backbone'
])
backbone
=
backbone_class
(
**
model_conf
[
'config'
])
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
model_conf
[
'num_keywords'
])
model
.
set_state_dict
(
paddle
.
load
(
scoring_conf
[
'checkpoint'
]))
backbone_class
=
dynamic_import
(
config
[
'backbone'
])
backbone
=
backbone_class
(
stack_num
=
config
[
'stack_num'
],
stack_size
=
config
[
'stack_size'
],
in_channels
=
config
[
'in_channels'
],
res_channels
=
config
[
'res_channels'
],
kernel_size
=
config
[
'kernel_size'
],
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
config
[
'num_keywords'
])
model
.
set_state_dict
(
paddle
.
load
(
args
.
ckpt
))
model
.
eval
()
with
paddle
.
no_grad
(),
open
(
scoring_conf
[
'score_file'
],
'w'
,
encoding
=
'utf8'
)
as
fout
:
with
paddle
.
no_grad
(),
open
(
args
.
score_file
,
'w'
,
encoding
=
'utf8'
)
as
f
:
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
test_loader
,
total
=
len
(
test_loader
))):
keys
,
feats
,
labels
,
lengths
=
batch
...
...
@@ -73,7 +85,6 @@ if __name__ == '__main__':
keyword_scores
=
score
[:,
keyword_i
]
score_frames
=
' '
.
join
(
[
'{:.6f}'
.
format
(
x
)
for
x
in
keyword_scores
.
tolist
()])
fout
.
write
(
'{} {} {}
\n
'
.
format
(
key
,
keyword_i
,
score_frames
))
f
.
write
(
'{} {} {}
\n
'
.
format
(
key
,
keyword_i
,
score_frames
))
print
(
'Result saved to: {}'
.
format
(
scoring_conf
[
'score_file'
]
))
print
(
'Result saved to: {}'
.
format
(
args
.
score_file
))
paddlespeech/kws/exps/mdtc/train.py
浏览文件 @
2504ccbf
...
...
@@ -11,77 +11,88 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
paddleaudio.utils
import
logger
from
paddleaudio.utils
import
Timer
from
paddlespeech.kws.exps.mdtc.collate
import
collate_features
from
paddlespeech.kws.models.loss
import
max_pooling_loss
from
paddlespeech.kws.models.mdtc
import
KWSModel
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--cfg_path"
,
type
=
str
,
required
=
True
)
args
=
parser
.
parse_args
()
# yapf: enable
if
__name__
==
'__main__'
:
parser
=
default_argument_parser
()
args
=
parser
.
parse_args
()
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
nranks
=
paddle
.
distributed
.
get_world_size
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
paddle
.
distributed
.
init_parallel_env
()
local_rank
=
paddle
.
distributed
.
get_rank
()
args
.
cfg_path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
cfg_path
))
with
open
(
args
.
cfg_path
,
'r'
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
model_conf
=
config
[
'model'
]
data_conf
=
config
[
'data'
]
feat_conf
=
config
[
'feature'
]
training_conf
=
config
[
'training'
]
# Dataset
ds_class
=
dynamic_import
(
data_conf
[
'dataset'
])
ds_class
=
dynamic_import
(
config
[
'dataset'
])
train_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'train'
,
**
feat_conf
)
dev_ds
=
ds_class
(
data_dir
=
data_conf
[
'data_dir'
],
mode
=
'dev'
,
**
feat_conf
)
data_dir
=
config
[
'data_dir'
],
mode
=
'train'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
dev_ds
=
ds_class
(
data_dir
=
config
[
'data_dir'
],
mode
=
'dev'
,
feat_type
=
config
[
'feat_type'
],
sample_rate
=
config
[
'sample_rate'
],
frame_shift
=
config
[
'frame_shift'
],
frame_length
=
config
[
'frame_length'
],
n_mels
=
config
[
'n_mels'
],
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_ds
,
batch_size
=
training_conf
[
'batch_size'
],
batch_size
=
config
[
'batch_size'
],
shuffle
=
True
,
drop_last
=
False
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_ds
,
batch_sampler
=
train_sampler
,
num_workers
=
training_conf
[
'num_workers'
],
num_workers
=
config
[
'num_workers'
],
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
# Model
backbone_class
=
dynamic_import
(
model_conf
[
'backbone'
])
backbone
=
backbone_class
(
**
model_conf
[
'config'
])
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
model_conf
[
'num_keywords'
])
backbone_class
=
dynamic_import
(
config
[
'backbone'
])
backbone
=
backbone_class
(
stack_num
=
config
[
'stack_num'
],
stack_size
=
config
[
'stack_size'
],
in_channels
=
config
[
'in_channels'
],
res_channels
=
config
[
'res_channels'
],
kernel_size
=
config
[
'kernel_size'
],
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
config
[
'num_keywords'
])
model
=
paddle
.
DataParallel
(
model
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
training_conf
[
'grad_clip'
])
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
config
[
'grad_clip'
])
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
training_conf
[
'learning_rate'
],
weight_decay
=
training_conf
[
'weight_decay'
],
learning_rate
=
config
[
'learning_rate'
],
weight_decay
=
config
[
'weight_decay'
],
parameters
=
model
.
parameters
(),
grad_clip
=
clip
)
criterion
=
max_pooling_loss
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
training_conf
[
'epochs'
])
timer
=
Timer
(
steps_per_epoch
*
config
[
'epochs'
])
timer
.
start
()
for
epoch
in
range
(
1
,
training_conf
[
'epochs'
]
+
1
):
for
epoch
in
range
(
1
,
config
[
'epochs'
]
+
1
):
model
.
train
()
avg_loss
=
0
...
...
@@ -107,15 +118,13 @@ if __name__ == '__main__':
timer
.
count
()
if
(
batch_idx
+
1
)
%
training_conf
[
'log_freq'
]
==
0
and
local_rank
==
0
:
if
(
batch_idx
+
1
)
%
config
[
'log_freq'
]
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
avg_loss
/=
training_conf
[
'log_freq'
]
avg_loss
/=
config
[
'log_freq'
]
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Epoch={}/{}, Step={}/{}'
.
format
(
epoch
,
training_conf
[
'epochs'
],
batch_idx
+
1
,
steps_per_epoch
)
epoch
,
config
[
'epochs'
],
batch_idx
+
1
,
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.6f} step/sec={:.2f} | ETA {}'
.
format
(
...
...
@@ -126,17 +135,17 @@ if __name__ == '__main__':
num_corrects
=
0
num_samples
=
0
if
epoch
%
training_conf
[
if
epoch
%
config
[
'save_freq'
]
==
0
and
batch_idx
+
1
==
steps_per_epoch
and
local_rank
==
0
:
dev_sampler
=
paddle
.
io
.
BatchSampler
(
dev_ds
,
batch_size
=
training_conf
[
'batch_size'
],
batch_size
=
config
[
'batch_size'
],
shuffle
=
False
,
drop_last
=
False
)
dev_loader
=
paddle
.
io
.
DataLoader
(
dev_ds
,
batch_sampler
=
dev_sampler
,
num_workers
=
training_conf
[
'num_workers'
],
num_workers
=
config
[
'num_workers'
],
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
...
...
@@ -159,7 +168,7 @@ if __name__ == '__main__':
logger
.
eval
(
print_msg
)
# Save model
save_dir
=
os
.
path
.
join
(
training_conf
[
'checkpoint_dir'
],
save_dir
=
os
.
path
.
join
(
config
[
'checkpoint_dir'
],
'epoch_{}'
.
format
(
epoch
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录