Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b60b1dad
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b60b1dad
编写于
4月 08, 2022
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add KWS example.
上级
e01abc50
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
567 addition
and
0 deletion
+567
-0
examples/hey_snips/README.md
examples/hey_snips/README.md
+0
-0
examples/hey_snips/kws0/RESULTS.md
examples/hey_snips/kws0/RESULTS.md
+0
-0
examples/hey_snips/kws0/path.sh
examples/hey_snips/kws0/path.sh
+28
-0
examples/hey_snips/kws0/run.sh
examples/hey_snips/kws0/run.sh
+47
-0
paddlespeech/kws/exps/mdtc/compute_det.py
paddlespeech/kws/exps/mdtc/compute_det.py
+121
-0
paddlespeech/kws/exps/mdtc/plot_det_curve.py
paddlespeech/kws/exps/mdtc/plot_det_curve.py
+63
-0
paddlespeech/kws/exps/mdtc/score.py
paddlespeech/kws/exps/mdtc/score.py
+103
-0
paddlespeech/kws/exps/mdtc/train.py
paddlespeech/kws/exps/mdtc/train.py
+205
-0
未找到文件。
examples/hey_snips/README.md
0 → 100644
浏览文件 @
b60b1dad
examples/hey_snips/kws0/RESULTS.md
0 → 100644
浏览文件 @
b60b1dad
examples/hey_snips/kws0/path.sh
0 → 100755
浏览文件 @
b60b1dad
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export
MAIN_ROOT
=
`
realpath
${
PWD
}
/../../../
`
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
LC_ALL
=
C
export
PYTHONDONTWRITEBYTECODE
=
1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export
PYTHONIOENCODING
=
UTF-8
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
MODEL
=
mdtc
export
BIN_DIR
=
${
MAIN_ROOT
}
/paddlespeech/kws/exps/
${
MODEL
}
\ No newline at end of file
examples/hey_snips/kws0/run.sh
0 → 100755
浏览文件 @
b60b1dad
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
.
./path.sh
set
-e
stage
=
0
stop_stage
=
50
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
dir
=
data/
# data info directory
exp_dir
=
exp/ecapa-tdnn-vox12-big/
# experiment directory
conf_path
=
conf/mdtc.yaml
gpus
=
0,1,2,3
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
mkdir
-p
${
exp_dir
}
if
[
$stage
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
bash ./local/data.sh
${
dir
}
${
conf_path
}
||
exit
-1
;
fi
if
[
$stage
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
CUDA_VISIBLE_DEVICES
=
${
gpus
}
bash ./local/train.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
fi
if
[
$stage
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
CUDA_VISIBLE_DEVICES
=
0 bash ./local/test.sh
${
dir
}
${
exp_dir
}
${
conf_path
}
fi
paddlespeech/kws/exps/mdtc/compute_det.py
0 → 100644
浏览文件 @
b60b1dad
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
sys
from
tqdm
import
tqdm
def
load_label_and_score
(
keyword
,
label_file
,
score_file
):
# score_table: {uttid: [keywordlist]}
score_table
=
{}
with
open
(
score_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
arr
=
line
.
strip
().
split
()
key
=
arr
[
0
]
current_keyword
=
arr
[
1
]
str_list
=
arr
[
2
:]
if
int
(
current_keyword
)
==
keyword
:
scores
=
list
(
map
(
float
,
str_list
))
if
key
not
in
score_table
:
score_table
.
update
({
key
:
scores
})
keyword_table
=
{}
filler_table
=
{}
filler_duration
=
0.0
with
open
(
label_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
obj
=
json
.
loads
(
line
.
strip
())
assert
'key'
in
obj
assert
'txt'
in
obj
assert
'duration'
in
obj
key
=
obj
[
'key'
]
index
=
obj
[
'txt'
]
duration
=
obj
[
'duration'
]
assert
key
in
score_table
if
index
==
keyword
:
keyword_table
[
key
]
=
score_table
[
key
]
else
:
filler_table
[
key
]
=
score_table
[
key
]
filler_duration
+=
duration
return
keyword_table
,
filler_table
,
filler_duration
class
Args
:
def
__init__
(
self
):
self
.
test_data
=
'/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/data/test/data.list'
self
.
keyword
=
0
self
.
score_file
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'score.txt'
)
self
.
stats_file
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'stats.0.txt'
)
self
.
step
=
0.01
self
.
window_shift
=
50
args
=
Args
()
if
__name__
==
'__main__'
:
# parser = argparse.ArgumentParser(description='compute det curve')
# parser.add_argument('--test_data', required=True, help='label file')
# parser.add_argument('--keyword', type=int, default=0, help='keyword label')
# parser.add_argument('--score_file', required=True, help='score file')
# parser.add_argument('--step', type=float, default=0.01,
# help='threshold step')
# parser.add_argument('--window_shift', type=int, default=50,
# help='window_shift is used to skip the frames after triggered')
# parser.add_argument('--stats_file',
# required=True,
# help='false reject/alarm stats file')
# args = parser.parse_args()
window_shift
=
args
.
window_shift
keyword_table
,
filler_table
,
filler_duration
=
load_label_and_score
(
args
.
keyword
,
args
.
test_data
,
args
.
score_file
)
print
(
'Filler total duration Hours: {}'
.
format
(
filler_duration
/
3600.0
))
pbar
=
tqdm
(
total
=
int
(
1.0
/
args
.
step
))
with
open
(
args
.
stats_file
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
keyword_index
=
int
(
args
.
keyword
)
threshold
=
0.0
while
threshold
<=
1.0
:
num_false_reject
=
0
# transverse the all keyword_table
for
key
,
score_list
in
keyword_table
.
items
():
# computer positive test sample, use the max score of list.
score
=
max
(
score_list
)
if
float
(
score
)
<
threshold
:
num_false_reject
+=
1
num_false_alarm
=
0
# transverse the all filler_table
for
key
,
score_list
in
filler_table
.
items
():
i
=
0
while
i
<
len
(
score_list
):
if
score_list
[
i
]
>=
threshold
:
num_false_alarm
+=
1
i
+=
window_shift
else
:
i
+=
1
if
len
(
keyword_table
)
!=
0
:
false_reject_rate
=
num_false_reject
/
len
(
keyword_table
)
num_false_alarm
=
max
(
num_false_alarm
,
1e-6
)
if
filler_duration
!=
0
:
false_alarm_per_hour
=
num_false_alarm
/
\
(
filler_duration
/
3600.0
)
fout
.
write
(
'{:.6f} {:.6f} {:.6f}
\n
'
.
format
(
threshold
,
false_alarm_per_hour
,
false_reject_rate
))
threshold
+=
args
.
step
pbar
.
update
(
1
)
pbar
.
close
()
print
(
'DET saved to: {}'
.
format
(
args
.
stats_file
))
paddlespeech/kws/exps/mdtc/plot_det_curve.py
0 → 100644
浏览文件 @
b60b1dad
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
matplotlib.pyplot
as
plt
import
numpy
as
np
def
load_stats_file
(
stats_file
):
values
=
[]
with
open
(
stats_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
arr
=
line
.
strip
().
split
()
threshold
,
fa_per_hour
,
frr
=
arr
values
.
append
([
float
(
fa_per_hour
),
float
(
frr
)
*
100
])
values
.
reverse
()
return
np
.
array
(
values
)
def
plot_det_curve
(
keywords
,
stats_dir
,
figure_file
,
xlim
,
x_step
,
ylim
,
y_step
):
plt
.
figure
(
dpi
=
200
)
plt
.
rcParams
[
'xtick.direction'
]
=
'in'
plt
.
rcParams
[
'ytick.direction'
]
=
'in'
plt
.
rcParams
[
'font.size'
]
=
12
for
index
,
keyword
in
enumerate
(
keywords
):
stats_file
=
os
.
path
.
join
(
stats_dir
,
'stats.'
+
str
(
index
)
+
'.txt'
)
values
=
load_stats_file
(
stats_file
)
plt
.
plot
(
values
[:,
0
],
values
[:,
1
],
label
=
keyword
)
plt
.
xlim
([
0
,
xlim
])
plt
.
ylim
([
0
,
ylim
])
plt
.
xticks
(
range
(
0
,
xlim
+
x_step
,
x_step
))
plt
.
yticks
(
range
(
0
,
ylim
+
y_step
,
y_step
))
plt
.
xlabel
(
'False Alarm Per Hour'
)
plt
.
ylabel
(
'False Rejection Rate (
\\
%)'
)
plt
.
grid
(
linestyle
=
'--'
)
plt
.
legend
(
loc
=
'best'
,
fontsize
=
16
)
plt
.
savefig
(
figure_file
)
if
__name__
==
'__main__'
:
keywords
=
[
'Hey_Snips'
]
img_path
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'det.png'
)
plot_det_curve
(
keywords
,
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
img_path
,
10
,
2
,
10
,
2
)
print
(
'DET curve image saved to: {}'
.
format
(
img_path
))
paddlespeech/kws/exps/mdtc/score.py
0 → 100644
浏览文件 @
b60b1dad
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
time
import
paddle
from
mdtc
import
KWSModel
from
mdtc
import
MDTC
from
tqdm
import
tqdm
from
paddleaudio.datasets
import
HeySnips
def
collate_features
(
batch
):
# (key, feat, label) in one sample
collate_start
=
time
.
time
()
keys
=
[]
feats
=
[]
labels
=
[]
lengths
=
[]
for
sample
in
batch
:
keys
.
append
(
sample
[
0
])
feats
.
append
(
sample
[
1
])
labels
.
append
(
sample
[
2
])
lengths
.
append
(
sample
[
1
].
shape
[
0
])
max_length
=
max
(
lengths
)
for
i
in
range
(
len
(
feats
)):
feats
[
i
]
=
paddle
.
nn
.
functional
.
pad
(
feats
[
i
],
[
0
,
max_length
-
feats
[
i
].
shape
[
0
],
0
,
0
],
data_format
=
'NLC'
)
return
keys
,
paddle
.
stack
(
feats
),
paddle
.
to_tensor
(
labels
),
paddle
.
to_tensor
(
lengths
)
if
__name__
==
'__main__'
:
# Dataset
feat_conf
=
{
# 'n_mfcc': 80,
'n_mels'
:
80
,
'frame_shift'
:
10
,
'frame_length'
:
25
,
# 'dither': 1.0,
}
test_ds
=
HeySnips
(
mode
=
'test'
,
feat_type
=
'kaldi_fbank'
,
sample_rate
=
16000
,
**
feat_conf
)
test_sampler
=
paddle
.
io
.
BatchSampler
(
test_ds
,
batch_size
=
32
,
drop_last
=
False
)
test_loader
=
paddle
.
io
.
DataLoader
(
test_ds
,
batch_sampler
=
test_sampler
,
num_workers
=
16
,
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
# Model
backbone
=
MDTC
(
stack_num
=
3
,
stack_size
=
4
,
in_channels
=
80
,
res_channels
=
32
,
kernel_size
=
5
,
causal
=
True
,
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
1
)
model
=
paddle
.
DataParallel
(
model
)
# kws_checkpoint = '/ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddlespeech/kws/models/checkpoint/epoch_10_0.8903940343290826/model.pdparams'
kws_checkpoint
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'model.pdparams'
)
model
.
set_state_dict
(
paddle
.
load
(
kws_checkpoint
))
model
.
eval
()
score_abs_path
=
os
.
path
.
join
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]),
'score.txt'
)
with
paddle
.
no_grad
(),
open
(
score_abs_path
,
'w'
,
encoding
=
'utf8'
)
as
fout
:
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
test_loader
,
total
=
len
(
test_loader
))):
keys
,
feats
,
labels
,
lengths
=
batch
logits
=
model
(
feats
)
num_keywords
=
logits
.
shape
[
2
]
for
i
in
range
(
len
(
keys
)):
key
=
keys
[
i
]
score
=
logits
[
i
][:
lengths
[
i
]]
for
keyword_i
in
range
(
num_keywords
):
keyword_scores
=
score
[:,
keyword_i
]
score_frames
=
' '
.
join
(
[
'{:.6f}'
.
format
(
x
)
for
x
in
keyword_scores
.
tolist
()])
fout
.
write
(
'{} {} {}
\n
'
.
format
(
key
,
keyword_i
,
score_frames
))
print
(
'Scores saved to: {}'
.
format
(
score_abs_path
))
paddlespeech/kws/exps/mdtc/train.py
0 → 100644
浏览文件 @
b60b1dad
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
import
paddle
from
loss
import
max_pooling_loss
from
mdtc
import
KWSModel
from
mdtc
import
MDTC
from
paddleaudio.datasets
import
HeySnips
from
paddleaudio.utils
import
logger
from
paddleaudio.utils
import
Timer
def
collate_features
(
batch
):
# (key, feat, label)
collate_start
=
time
.
time
()
keys
=
[]
feats
=
[]
labels
=
[]
lengths
=
[]
for
sample
in
batch
:
keys
.
append
(
sample
[
0
])
feats
.
append
(
sample
[
1
])
labels
.
append
(
sample
[
2
])
lengths
.
append
(
sample
[
1
].
shape
[
0
])
max_length
=
max
(
lengths
)
for
i
in
range
(
len
(
feats
)):
feats
[
i
]
=
paddle
.
nn
.
functional
.
pad
(
feats
[
i
],
[
0
,
max_length
-
feats
[
i
].
shape
[
0
],
0
,
0
],
data_format
=
'NLC'
)
return
keys
,
paddle
.
stack
(
feats
),
paddle
.
to_tensor
(
labels
),
paddle
.
to_tensor
(
lengths
)
if
__name__
==
'__main__'
:
# Dataset
feat_conf
=
{
# 'n_mfcc': 80,
'n_mels'
:
80
,
'frame_shift'
:
10
,
'frame_length'
:
25
,
# 'dither': 1.0,
}
data_dir
=
'/ssd1/chenxiaojie06/datasets/hey_snips/hey_snips_research_6k_en_train_eval_clean_ter'
train_ds
=
HeySnips
(
data_dir
=
data_dir
,
mode
=
'train'
,
feat_type
=
'kaldi_fbank'
,
sample_rate
=
16000
,
**
feat_conf
)
dev_ds
=
HeySnips
(
data_dir
=
data_dir
,
mode
=
'dev'
,
feat_type
=
'kaldi_fbank'
,
sample_rate
=
16000
,
**
feat_conf
)
training_conf
=
{
'epochs'
:
100
,
'learning_rate'
:
0.001
,
'weight_decay'
:
0.00005
,
'num_workers'
:
16
,
'batch_size'
:
100
,
'checkpoint_dir'
:
'./checkpoint'
,
'save_freq'
:
10
,
'log_freq'
:
10
,
}
train_sampler
=
paddle
.
io
.
BatchSampler
(
train_ds
,
batch_size
=
training_conf
[
'batch_size'
],
shuffle
=
True
,
drop_last
=
False
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_ds
,
batch_sampler
=
train_sampler
,
num_workers
=
training_conf
[
'num_workers'
],
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
# Model
backbone
=
MDTC
(
stack_num
=
3
,
stack_size
=
4
,
in_channels
=
80
,
res_channels
=
32
,
kernel_size
=
5
,
causal
=
True
,
)
model
=
KWSModel
(
backbone
=
backbone
,
num_keywords
=
1
)
model
=
paddle
.
DataParallel
(
model
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
5.0
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
training_conf
[
'learning_rate'
],
weight_decay
=
training_conf
[
'weight_decay'
],
parameters
=
model
.
parameters
(),
grad_clip
=
clip
)
criterion
=
max_pooling_loss
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
training_conf
[
'epochs'
])
timer
.
start
()
for
epoch
in
range
(
1
,
training_conf
[
'epochs'
]
+
1
):
model
.
train
()
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
batch_start
=
time
.
time
()
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
# print('Fetch one batch: {:.4f}'.format(time.time()-batch_start))
keys
,
feats
,
labels
,
lengths
=
batch
logits
=
model
(
feats
)
loss
,
corrects
,
acc
=
criterion
(
logits
,
labels
,
lengths
)
loss
.
backward
()
optimizer
.
step
()
if
isinstance
(
optimizer
.
_learning_rate
,
paddle
.
optimizer
.
lr
.
LRScheduler
):
optimizer
.
_learning_rate
.
step
()
optimizer
.
clear_grad
()
# Calculate loss
avg_loss
+=
loss
.
numpy
()[
0
]
# Calculate metrics
num_corrects
+=
corrects
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
if
(
batch_idx
+
1
)
%
training_conf
[
'log_freq'
]
==
0
:
lr
=
optimizer
.
get_lr
()
avg_loss
/=
training_conf
[
'log_freq'
]
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Epoch={}/{}, Step={}/{}'
.
format
(
epoch
,
training_conf
[
'epochs'
],
batch_idx
+
1
,
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.6f} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
train
(
print_msg
)
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
batch_start
=
time
.
time
()
if
epoch
%
training_conf
[
'save_freq'
]
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
dev_sampler
=
paddle
.
io
.
BatchSampler
(
dev_ds
,
batch_size
=
training_conf
[
'batch_size'
],
shuffle
=
False
,
drop_last
=
False
)
dev_loader
=
paddle
.
io
.
DataLoader
(
dev_ds
,
batch_sampler
=
dev_sampler
,
num_workers
=
training_conf
[
'num_workers'
],
return_list
=
True
,
use_buffer_reader
=
True
,
collate_fn
=
collate_features
,
)
model
.
eval
()
num_corrects
=
0
num_samples
=
0
with
logger
.
processing
(
'Evaluation on validation dataset'
):
for
batch_idx
,
batch
in
enumerate
(
dev_loader
):
keys
,
feats
,
labels
,
lengths
=
batch
logits
=
model
(
feats
)
loss
,
corrects
,
acc
=
criterion
(
logits
,
labels
,
lengths
)
num_corrects
+=
corrects
num_samples
+=
feats
.
shape
[
0
]
eval_acc
=
num_corrects
/
num_samples
print_msg
=
'[Evaluation result]'
print_msg
+=
' dev_acc={:.4f}'
.
format
(
eval_acc
)
logger
.
eval
(
print_msg
)
# Save model
save_dir
=
os
.
path
.
join
(
training_conf
[
'checkpoint_dir'
],
'epoch_{}_{:.4f}'
.
format
(
epoch
,
eval_acc
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdopt'
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录