Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
fbd27aab
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“8d3c8d674c446d66f5539814a17d5aabc1ea72b0”上不存在“mobile/src/operators/kernel/conv_relu_kernel.h”
提交
fbd27aab
编写于
4月 17, 2023
作者:
Z
zxcd
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add amp for U2 conformer.
上级
d3d86f59
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
40 addition
and
8 deletion
+40
-8
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+35
-7
paddlespeech/s2t/training/trainer.py
paddlespeech/s2t/training/trainer.py
+5
-1
未找到文件。
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
fbd27aab
...
...
@@ -23,6 +23,7 @@ import jsonlines
import
numpy
as
np
import
paddle
from
paddle
import
distributed
as
dist
from
paddle.nn.utils
import
clip_grad_norm_
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
...
...
@@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch_data
,
scaler
,
msg
):
train_conf
=
self
.
config
start
=
time
.
time
()
# forward
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
with
paddle
.
amp
.
auto_cast
(
level
=
self
.
amp_level
,
enable
=
True
if
scaler
else
False
):
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
...
...
@@ -77,12 +80,24 @@ class U2Trainer(Trainer):
# processes.
context
=
nullcontext
with
context
():
loss
.
backward
()
if
scaler
:
scaler
.
scale
(
loss
).
backward
()
else
:
loss
.
backward
()
layer_tools
.
print_grads
(
self
.
model
,
print_func
=
None
)
# optimizer step
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
==
0
:
self
.
optimizer
.
step
()
# do global grad clip
if
train_conf
.
global_grad_clip
!=
0
:
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_
(
self
.
model
.
parameters
(),
train_conf
.
global_grad_clip
)
if
scaler
:
scaler
.
step
(
self
.
optimizer
)
scaler
.
update
()
else
:
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
self
.
lr_scheduler
.
step
()
self
.
iteration
+=
1
...
...
@@ -173,7 +188,8 @@ class U2Trainer(Trainer):
report
(
"epoch"
,
self
.
epoch
)
report
(
'step'
,
self
.
iteration
)
report
(
"lr"
,
self
.
lr_scheduler
())
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
train_batch
(
batch_index
,
batch
,
self
.
scaler
,
msg
)
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
if
not
self
.
use_streamdata
:
...
...
@@ -253,6 +269,19 @@ class U2Trainer(Trainer):
model_conf
.
output_dim
=
self
.
test_loader
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
# For Mixed Precision Training
self
.
use_amp
=
self
.
config
.
get
(
"use_amp"
,
True
)
self
.
amp_level
=
self
.
config
.
get
(
"amp_level"
,
"O1"
)
if
self
.
train
and
self
.
use_amp
:
self
.
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
self
.
config
.
get
(
"scale_loss"
,
32768.0
))
#amp default num 32768.0
#Set amp_level
if
self
.
amp_level
==
'O2'
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
self
.
amp_level
)
else
:
self
.
scaler
=
None
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
...
...
@@ -290,7 +319,6 @@ class U2Trainer(Trainer):
scheduler_type
=
train_config
.
scheduler
scheduler_conf
=
train_config
.
scheduler_conf
return
{
"grad_clip"
:
train_config
.
global_grad_clip
,
"weight_decay"
:
optim_conf
.
weight_decay
,
"learning_rate"
:
lr_scheduler
if
lr_scheduler
else
optim_conf
.
lr
,
...
...
paddlespeech/s2t/training/trainer.py
浏览文件 @
fbd27aab
...
...
@@ -110,6 +110,7 @@ class Trainer():
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
_train
=
True
self
.
scaler
=
None
# print deps version
all_version
()
...
...
@@ -187,7 +188,8 @@ class Trainer():
infos
.
update
({
"step"
:
self
.
iteration
,
"epoch"
:
self
.
epoch
,
"lr"
:
self
.
optimizer
.
get_lr
()
"lr"
:
self
.
optimizer
.
get_lr
(),
"scaler"
:
self
.
scaler
})
self
.
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
if
tag
is
None
else
tag
,
self
.
model
,
...
...
@@ -211,6 +213,8 @@ class Trainer():
# lr will resotre from optimizer ckpt
self
.
iteration
=
infos
[
"step"
]
self
.
epoch
=
infos
[
"epoch"
]
self
.
scaler
=
paddle
.
amp
.
GradScaler
()
self
.
scaler
.
load_state_dict
(
infos
[
"scaler"
])
scratch
=
False
logger
.
info
(
f
"Restore ckpt: epoch
{
self
.
epoch
}
, step
{
self
.
iteration
}
!"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录