Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3070dc8b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3070dc8b
编写于
9月 28, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
9月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Generalize Amp Pass (#46519)
* support input mask
上级
526d963e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
59 addition
and
13 deletion
+59
-13
python/paddle/distributed/auto_parallel/parallelizer.py
python/paddle/distributed/auto_parallel/parallelizer.py
+2
-0
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+2
-0
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
...dle/distributed/auto_parallel/tuner/optimization_tuner.py
+2
-0
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+42
-9
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+11
-4
未找到文件。
python/paddle/distributed/auto_parallel/parallelizer.py
浏览文件 @
3070dc8b
...
...
@@ -110,10 +110,12 @@ class AutoParallelizer:
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_amp_pass
.
get_loss
()
# apply recompute pass
if
self
.
_dist_strategy
.
recompute
:
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
3070dc8b
...
...
@@ -192,10 +192,12 @@ class Parallelizer:
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_amp_pass
.
get_loss
()
# apply recompute pass
# recompute is then train-only optimization
...
...
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
浏览文件 @
3070dc8b
...
...
@@ -271,10 +271,12 @@ class OptimizationTuner:
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
([
main_program
],
[
startup_program
],
pass_context
)
dist_context
.
serial_loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
([
main_program
],
[
startup_program
],
pass_context
)
dist_context
.
serial_loss
=
auto_parallel_amp_pass
.
get_loss
()
if
new_strategy
.
recompute
.
enable
:
config
=
copy
.
deepcopy
(
new_strategy
.
recompute
.
to_dict
())
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
3070dc8b
...
...
@@ -614,21 +614,17 @@ class AMPPass(PassBase):
loss_op
)
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
# cast loss here will change the effective loss tensor for the computation graph
# and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge),
# so we it is not allowed by now. fixed it in future.
raise
NotImplementedError
(
"Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list."
)
tmp_name
=
unique_name
.
generate
(
loss
.
name
+
".cast_fp32"
)
cast_loss
=
main_block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
cast_loss
=
main_block
.
create_var
(
name
=
tmp_name
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
)
loss_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
loss
)
ref_mesh
=
loss_op_dist_attr
.
process_mesh
self
.
dist_context
.
set_tensor_dist_attr_for_program
(
cast_loss
,
loss_dist_attr
)
# forward
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
cast_op
=
main_block
.
_insert_op
(
loss_op_idx
+
1
,
...
...
@@ -645,7 +641,34 @@ class AMPPass(PassBase):
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
loss
=
loss
.
astype
(
'float32'
)
# backward
first_backward_op
=
main_block
.
ops
[
loss_op_idx
+
2
]
assert
first_backward_op
.
type
==
"fill_constant"
and
int
(
first_backward_op
.
all_attrs
()[
OP_ROLE_KEY
])
==
257
cast_loss_grad
=
main_block
.
create_var
(
name
=
unique_name
.
generate
(
tmp_name
+
"@GRAD"
),
shape
=
loss
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
,
persistable
=
loss
.
persistable
)
set_var_dist_attr
(
self
.
dist_context
,
cast_loss_grad
,
[
-
1
],
ref_mesh
)
pre_grad_name
=
first_backward_op
.
output_arg_names
[
0
]
first_backward_op
.
_rename_output
(
pre_grad_name
,
cast_loss_grad
.
name
)
cast_grad_op
=
main_block
.
_insert_op
(
loss_op_idx
+
3
,
type
=
'cast'
,
inputs
=
{
'X'
:
[
cast_loss_grad
]},
outputs
=
{
'Out'
:
[
pre_grad_name
]},
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP16
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
,
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_grad_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
loss_op
=
cast_op
loss
=
cast_loss
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
...
...
@@ -718,7 +741,7 @@ class AMPPass(PassBase):
else
:
self
.
_scaled_loss
=
loss
self
.
_loss
=
loss
main_block
.
_sync_with_cpp
()
def
_update_loss_scaling
(
self
,
grads
,
found_inf
):
...
...
@@ -782,3 +805,13 @@ class AMPPass(PassBase):
self
.
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
main_block
.
_sync_with_cpp
()
def
get_loss
(
self
):
# the amp / fp16 might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp / fp16 pass.
if
self
.
_loss
:
return
self
.
_loss
else
:
return
self
.
get_attr
(
"loss"
)
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
3070dc8b
...
...
@@ -368,6 +368,10 @@ class FP16State(object):
for
cast_name
,
src_name
,
dst_dtype
,
src_dtype
,
slot_name
in
self
.
forward_input_cast_ops
[
forward_op_id
]:
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if
slot_name
not
in
op
.
input_names
:
continue
# rename input
assert
src_name
in
op
.
input
(
slot_name
),
"var: {} not in op's {}. {}"
.
format
(
...
...
@@ -379,12 +383,15 @@ class FP16State(object):
# create cast grad
grad_slot_name
=
slot_name
+
"@GRAD"
assert
grad_slot_name
in
op
.
output_names
assert
grad_slot_name
in
op
.
output_names
,
"[{}], Current Op: {}"
.
format
(
grad_slot_name
,
str
(
op
))
# some forward input maybe stop_gradient=True, e.g. input_mask
if
len
(
op
.
output
(
grad_slot_name
))
==
0
:
var
=
block
.
var
(
src_name
)
assert
var
.
stop_gradient
is
True
continue
assert
len
(
op
.
output
(
grad_slot_name
))
==
1
assert
len
(
op
.
output
(
grad_slot_name
))
==
1
,
"[{}], Current Op: {}"
.
format
(
grad_slot_name
,
str
(
op
))
grad_name
=
op
.
output
(
grad_slot_name
)[
0
]
grad
=
block
.
var
(
grad_name
)
grad_dist_attr
=
grad_op_attr
.
get_output_dist_attr
(
grad_name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录