Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e2b924bf
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e2b924bf
编写于
8月 16, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
8月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] Prune D2H memcpy for fp16 pass (#45159)
* prune d2h memcpy for fp16 pass
上级
fa890092
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
54 addition
and
6 deletion
+54
-6
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+3
-2
python/paddle/distributed/auto_parallel/partitioner.py
python/paddle/distributed/auto_parallel/partitioner.py
+1
-1
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+50
-3
未找到文件。
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
e2b924bf
...
@@ -16,7 +16,7 @@ import abc
...
@@ -16,7 +16,7 @@ import abc
import
paddle
import
paddle
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
..dist_attribute
import
OperatorDistributedAttribute
from
..dist_attribute
import
OperatorDistributedAttribute
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
,
is_optimize_op
from
..process_group
import
new_process_group
from
..process_group
import
new_process_group
_g_distributed_operator_impl_containers
=
{}
_g_distributed_operator_impl_containers
=
{}
...
@@ -426,7 +426,8 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
...
@@ -426,7 +426,8 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
rank (int): global ranks index for current process.
rank (int): global ranks index for current process.
"""
"""
if
len
(
act_grad_names
)
==
0
or
len
(
out_grad_names
)
==
0
:
if
is_optimize_op
(
op
)
or
len
(
act_grad_names
)
==
0
or
len
(
out_grad_names
)
==
0
:
return
return
dp_group
=
get_data_parallel_group
(
dist_ctx
,
op
,
act_grad_names
,
rank
)
dp_group
=
get_data_parallel_group
(
dist_ctx
,
op
,
act_grad_names
,
rank
)
...
...
python/paddle/distributed/auto_parallel/partitioner.py
浏览文件 @
e2b924bf
...
@@ -279,7 +279,7 @@ class Partitioner(object):
...
@@ -279,7 +279,7 @@ class Partitioner(object):
dist_op_opt_impl
=
_get_dist_op_backward_implement
(
dist_op_opt_impl
=
_get_dist_op_backward_implement
(
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
dist_op_opt_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
dist_op_opt_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
**
koutputs
,
**
{
"grad_var_to_var"
:
{}}
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"partitioner only support forward and backward, optimize ops, but got {}"
"partitioner only support forward and backward, optimize ops, but got {}"
...
...
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
e2b924bf
...
@@ -491,6 +491,50 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
...
@@ -491,6 +491,50 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
def
_get_memcopy_idx
(
block
,
found_inf_var
):
# use reduce_any op for check_nan_inf as the anchor for now
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
'reduce_any'
and
op
.
output_arg_names
[
0
]
==
found_inf_var
.
name
:
return
idx
+
1
raise
RuntimeError
(
"not found the correct location for memcopy for found_inf_var."
)
def
_insert_memcopy
(
block
,
idx
,
src_var
,
dist_context
,
direction
=
"D2H"
):
src_name
=
src_var
.
name
output_var
=
block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
src_name
.
join
([
'memcopy_'
])),
dtype
=
src_var
.
dtype
,
shape
=
src_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
src_var
.
stop_gradient
)
set_var_dist_attr
(
dist_context
,
output_var
,
[
-
1
],
world_process_group
.
ranks
)
# TODO to support CUDAPinned/NPU/XPU Places
if
direction
==
"D2H"
:
dst_place_type
=
0
elif
direction
==
"D2H"
:
dst_place_type
=
1
else
:
raise
NotImplementedError
(
"direction [{}] is not supported yet."
.
format
(
direction
))
attrs
=
{
'dst_place_type'
:
dst_place_type
}
new_op
=
block
.
_insert_op_without_sync
(
index
=
idx
,
type
=
'memcpy'
,
inputs
=
{
'X'
:
[
src_var
]},
outputs
=
{
'Out'
:
[
output_var
]},
attrs
=
attrs
)
_set_op_dist_attr_with_ranks
(
new_op
,
world_process_group
.
ranks
,
block
,
dist_context
)
block
.
_sync_with_cpp
()
return
output_var
@
register_pass
(
"auto_parallel_fp16"
)
@
register_pass
(
"auto_parallel_fp16"
)
class
FP16Pass
(
AMPPass
):
class
FP16Pass
(
AMPPass
):
...
@@ -577,9 +621,12 @@ class FP16Pass(AMPPass):
...
@@ -577,9 +621,12 @@ class FP16Pass(AMPPass):
if
isinstance
(
if
isinstance
(
base_opt
,
base_opt
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)):
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)):
# with main_program._optimized_guard([]):
with
main_program
.
_optimized_guard
([]):
# found_inf = paddle.tensor.creation._memcpy(
# found_inf = paddle.tensor.creation._memcpy(
# found_inf, paddle.CPUPlace())
# found_inf, paddle.CPUPlace())
insert_idx
=
_get_memcopy_idx
(
block
,
found_inf
)
found_inf
=
_insert_memcopy
(
block
,
insert_idx
,
found_inf
,
self
.
dist_context
)
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
elif
hasattr
(
base_opt
,
"_set_auxiliary_var"
):
elif
hasattr
(
base_opt
,
"_set_auxiliary_var"
):
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录