Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
86f65f0b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86f65f0b
编写于
10月 16, 2022
作者:
T
tianhao zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix wav2vec2 report loss bug
上级
49c0cf9e
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
11 addition
and
15 deletion
+11
-15
paddlespeech/s2t/exps/wav2vec2/model.py
paddlespeech/s2t/exps/wav2vec2/model.py
+11
-15
未找到文件。
paddlespeech/s2t/exps/wav2vec2/model.py
浏览文件 @
86f65f0b
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains wav2vec2 model."""
import
json
import
math
import
os
import
time
from
collections
import
defaultdict
...
...
@@ -46,25 +47,20 @@ logger = Log(__name__).getlog()
class
Wav2Vec2ASRTrainer
(
Trainer
):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
self
.
avg_train_loss
=
0
self
.
avg_train_loss
=
0
.0
def
update_average
(
self
,
batch_index
,
loss
,
avg_loss
):
def
update_average
(
self
,
batch_index
,
loss
):
"""Update running average of the loss.
Arguments
---------
batch_index : int
current batch index
loss : paddle.tensor
detached loss, a single float value.
avg_loss : float
current running average.
Returns
-------
avg_loss : float
The average loss.
"""
if
paddle
.
isfinite
(
loss
):
avg_loss
-=
avg_loss
/
(
batch_index
+
1
)
avg_loss
+=
float
(
loss
)
/
(
batch_index
+
1
)
return
avg_loss
if
math
.
isfinite
(
loss
):
self
.
avg_train_loss
-=
self
.
avg_train_loss
/
(
batch_index
+
1
)
self
.
avg_train_loss
+=
loss
/
(
batch_index
+
1
)
def
train_batch
(
self
,
batch_index
,
batch
,
msg
):
train_conf
=
self
.
config
...
...
@@ -80,8 +76,8 @@ class Wav2Vec2ASRTrainer(Trainer):
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
self
.
avg_train_loss
=
self
.
update_average
(
batch_index
,
loss
,
self
.
avg_train_loss
)
# update self.avg_train_loss
self
.
update_average
(
batch_index
,
float
(
loss
)
)
# loss backward
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
!=
0
:
...
...
@@ -106,7 +102,7 @@ class Wav2Vec2ASRTrainer(Trainer):
self
.
lr_scheduler
.
step
()
self
.
iteration
+=
1
losses_np
=
{
'loss'
:
float
(
self
.
avg_train_loss
)
*
train_conf
.
accum_grad
}
losses_np
=
{
'loss'
:
self
.
avg_train_loss
*
train_conf
.
accum_grad
}
iteration_time
=
time
.
time
()
-
start
for
k
,
v
in
losses_np
.
items
():
report
(
k
,
v
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录