Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
33b87902
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
33b87902
编写于
9月 23, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor avg_model; fix set_value not support start==end
上级
9d5eb740
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
37 addition
and
26 deletion
+37
-26
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+12
-2
utils/avg_model.py
utils/avg_model.py
+25
-24
未找到文件。
deepspeech/utils/tensor_utils.py
浏览文件 @
33b87902
...
@@ -94,9 +94,19 @@ def pad_sequence(sequences: List[paddle.Tensor],
...
@@ -94,9 +94,19 @@ def pad_sequence(sequences: List[paddle.Tensor],
length
=
tensor
.
shape
[
0
]
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
# use index notation to prevent duplicate references to the tensor
if
batch_first
:
if
batch_first
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[i, :length, ...] = tensor
if
length
!=
0
:
out_tensor
[
i
,
:
length
,
...]
=
tensor
out_tensor
[
i
,
:
length
,
...]
=
tensor
else
:
else
:
out_tensor
[
i
,
length
,
...]
=
tensor
else
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if
length
!=
0
:
out_tensor
[:
length
,
i
,
...]
=
tensor
out_tensor
[:
length
,
i
,
...]
=
tensor
else
:
out_tensor
[
length
,
i
,
...]
=
tensor
return
out_tensor
return
out_tensor
...
...
utils/avg_model.py
浏览文件 @
33b87902
...
@@ -27,8 +27,9 @@ def main(args):
...
@@ -27,8 +27,9 @@ def main(args):
val_scores
=
[]
val_scores
=
[]
beat_val_scores
=
[]
beat_val_scores
=
[]
selected_epochs
=
[]
selected_epochs
=
[]
if
args
.
val_best
:
jsons
=
glob
.
glob
(
f
'
{
args
.
ckpt_dir
}
/[!train]*.json'
)
jsons
=
glob
.
glob
(
f
'
{
args
.
ckpt_dir
}
/[!train]*.json'
)
jsons
=
sorted
(
jsons
,
key
=
os
.
path
.
getmtime
,
reverse
=
True
)
for
y
in
jsons
:
for
y
in
jsons
:
with
open
(
y
,
'r'
)
as
f
:
with
open
(
y
,
'r'
)
as
f
:
dic_json
=
json
.
load
(
f
)
dic_json
=
json
.
load
(
f
)
...
@@ -36,24 +37,23 @@ def main(args):
...
@@ -36,24 +37,23 @@ def main(args):
epoch
=
dic_json
[
'epoch'
]
epoch
=
dic_json
[
'epoch'
]
if
epoch
>=
args
.
min_epoch
and
epoch
<=
args
.
max_epoch
:
if
epoch
>=
args
.
min_epoch
and
epoch
<=
args
.
max_epoch
:
val_scores
.
append
((
epoch
,
loss
))
val_scores
.
append
((
epoch
,
loss
))
val_scores
=
np
.
array
(
val_scores
)
val_scores
=
np
.
array
(
val_scores
)
if
args
.
val_best
:
sort_idx
=
np
.
argsort
(
val_scores
[:,
1
])
sort_idx
=
np
.
argsort
(
val_scores
[:,
1
])
sorted_val_scores
=
val_scores
[
sort_idx
]
sorted_val_scores
=
val_scores
[
sort_idx
]
path_list
=
[
else
:
args
.
ckpt_dir
+
'/{}.pdparams'
.
format
(
int
(
epoch
))
sorted_val_scores
=
val_scores
for
epoch
in
sorted_val_scores
[:
args
.
num
,
0
]
]
beat_val_scores
=
sorted_val_scores
[:
args
.
num
,
1
]
beat_val_scores
=
sorted_val_scores
[:
args
.
num
,
1
]
selected_epochs
=
sorted_val_scores
[:
args
.
num
,
0
].
astype
(
np
.
int64
)
selected_epochs
=
sorted_val_scores
[:
args
.
num
,
0
].
astype
(
np
.
int64
)
print
(
"best
val scores = "
+
str
(
beat_val_scores
))
print
(
"selected
val scores = "
+
str
(
beat_val_scores
))
print
(
"selected epochs = "
+
str
(
selected_epochs
))
print
(
"selected epochs = "
+
str
(
selected_epochs
))
else
:
path_list
=
glob
.
glob
(
f
'
{
args
.
ckpt_dir
}
/[!avg][!final]*.pdparams'
)
path_list
=
sorted
(
path_list
,
key
=
os
.
path
.
getmtime
)
path_list
=
path_list
[
-
args
.
num
:]
path_list
=
[
args
.
ckpt_dir
+
'/{}.pdparams'
.
format
(
int
(
epoch
))
for
epoch
in
sorted_val_scores
[:
args
.
num
,
0
]
]
print
(
path_list
)
print
(
path_list
)
avg
=
None
avg
=
None
...
@@ -78,10 +78,11 @@ def main(args):
...
@@ -78,10 +78,11 @@ def main(args):
meta_path
=
os
.
path
.
splitext
(
args
.
dst_model
)[
0
]
+
'.avg.json'
meta_path
=
os
.
path
.
splitext
(
args
.
dst_model
)[
0
]
+
'.avg.json'
with
open
(
meta_path
,
'w'
)
as
f
:
with
open
(
meta_path
,
'w'
)
as
f
:
data
=
json
.
dumps
({
data
=
json
.
dumps
({
"mode"
:
'val_best'
if
args
.
val_best
else
'latest'
,
"avg_ckpt"
:
args
.
dst_model
,
"avg_ckpt"
:
args
.
dst_model
,
"ckpt"
:
path_list
,
"ckpt"
:
path_list
,
"epoch"
:
selected_epochs
,
"epoch"
:
selected_epochs
.
tolist
()
,
"val_loss"
:
beat_val_scores
,
"val_loss"
:
beat_val_scores
.
tolist
()
,
})
})
f
.
write
(
data
+
"
\n
"
)
f
.
write
(
data
+
"
\n
"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录