Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
db937b5a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db937b5a
编写于
8月 23, 2022
作者:
X
xiongkun
提交者:
GitHub
8月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix multi-targets bugs which this is common case in dy2static (#45277)
上级
229befc8
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
18 addition
and
4 deletion
+18
-4
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+18
-4
未找到文件。
python/paddle/fluid/backward.py
浏览文件 @
db937b5a
...
...
@@ -642,12 +642,16 @@ def _addup_repetitive_outputs_(op_descs,
return
op_descs
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
,
grad_op_id_to_fwd_op
=
None
):
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
,
grad_op_id_to_fwd_op
=
None
,
target_vars
=
[]):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. all grad inputs of the grad op are in 'no_grad_set'
NOTE: we will skip target_vars's grad name.
"""
def
_op_can_be_removed_
(
op_desc
,
no_grad_set
):
...
...
@@ -658,11 +662,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None):
name
for
name
in
op_desc
.
input_arg_names
()
if
name
.
find
(
core
.
grad_var_suffix
())
!=
-
1
],
no_grad_set
):
no_grad_set
.
update
(
out_arg
_names
)
no_grad_set
.
update
(
set
(
out_arg_names
)
-
target_grad_var
_names
)
return
True
return
False
# Remove ops whose outputs are all in no_grad_dict
target_grad_var_names
=
set
(
[
var
.
name
+
core
.
grad_var_suffix
()
for
var
in
target_vars
])
op_descs
=
[
op_desc
for
op_desc
in
op_descs
if
not
_op_can_be_removed_
(
op_desc
,
no_grad_set
)
...
...
@@ -824,6 +830,7 @@ def serialize_op_decs(op_desc):
def
_append_backward_ops_with_checkpoints_
(
block
,
ops
,
target_vars
,
target_block
,
no_grad_dict
,
grad_to_var
,
...
...
@@ -835,6 +842,7 @@ def _append_backward_ops_with_checkpoints_(block,
Args:
block(Block): the block where forward ops are
ops(Op): the forward operators whose forward recomputation backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
...
...
@@ -1070,7 +1078,7 @@ def _append_backward_ops_with_checkpoints_(block,
# 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs
=
_remove_no_grad_branch_
(
grad_op_descs
,
no_grad_dict
[
block
.
idx
],
grad_op_id_to_fwd_op
)
grad_op_id_to_fwd_op
,
target_vars
)
added_descs
=
_add_descs_to_block
(
grad_op_descs
,
target_block
,
grad_op_id_to_fwd_op
)
return
program_stat
,
checkpoints_name
,
vars_should_be_hold
,
recompute_segments
...
...
@@ -1140,6 +1148,7 @@ def _rename_grad_name_(name, grad_order):
def
_append_backward_ops_
(
block
,
ops
,
target_vars
,
target_block
,
no_grad_dict
,
grad_to_var
,
...
...
@@ -1155,6 +1164,7 @@ def _append_backward_ops_(block,
Args:
block(Block): the block where forward ops are
ops(Op): the forward operators whose backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
...
...
@@ -1212,6 +1222,7 @@ def _append_backward_ops_(block,
sub_block_path
=
op_path_dict
[
op
.
_block_attr_id
(
"sub_block"
)]
_append_backward_ops_
(
sub_block
,
sub_block_path
,
target_vars
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
...
...
@@ -1330,7 +1341,7 @@ def _append_backward_ops_(block,
# if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs
=
_remove_no_grad_branch_
(
grad_op_descs
,
no_grad_dict
[
block
.
idx
],
grad_op_id_to_fwd_op
)
grad_op_id_to_fwd_op
,
target_vars
)
# remove some backward ops
not_need_ops
=
_find_not_need_ops
(
grad_op_descs
,
ops
,
input_grad_names_set
)
...
...
@@ -1765,6 +1776,7 @@ def append_backward(loss,
_append_backward_ops_with_checkpoints_
(
root_block
,
op_path
,
[
loss
],
root_block
,
no_grad_dict
,
grad_to_var
,
...
...
@@ -1774,6 +1786,7 @@ def append_backward(loss,
_append_backward_ops_
(
block
,
# the block where forward ops are in
op_path
,
[
loss
],
target_grad_block
,
no_grad_dict
,
grad_to_var
,
...
...
@@ -2135,6 +2148,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
grad_info_map
=
dict
()
_append_backward_ops_
(
block
,
op_path
,
targets
,
block
,
no_grad_dict
,
grad_to_var
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录