Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bb0427ad
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bb0427ad
编写于
12月 28, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add comments for functions in backward.py
上级
18311767
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
67 addition
and
10 deletion
+67
-10
python/paddle/v2/fluid/backward.py
python/paddle/v2/fluid/backward.py
+67
-10
未找到文件。
python/paddle/v2/fluid/backward.py
浏览文件 @
bb0427ad
...
@@ -5,14 +5,17 @@ import collections
...
@@ -5,14 +5,17 @@ import collections
__all__
=
[
'append_backward'
]
__all__
=
[
'append_backward'
]
def
_rename_arg_
(
op_desc_list
,
old_name
,
new_name
,
begin_idx
=
None
,
def
_rename_arg_
(
op_descs
,
old_name
,
new_name
,
begin_idx
=
None
,
end_idx
=
None
):
end_idx
=
None
):
"""
Traverse all ops in op_descs[begin_idx : end_idx],
if any op has inputs/outputs named "old_name", rename it as 'new_name'
"""
if
begin_idx
is
None
:
if
begin_idx
is
None
:
begin_idx
=
0
begin_idx
=
0
if
end_idx
is
None
:
if
end_idx
is
None
:
end_idx
=
len
(
op_desc
_list
)
end_idx
=
len
(
op_desc
s
)
for
i
in
range
(
begin_idx
,
end_idx
):
for
i
in
range
(
begin_idx
,
end_idx
):
op_desc
=
op_desc
_list
[
i
]
op_desc
=
op_desc
s
[
i
]
if
isinstance
(
op_desc
,
tuple
):
if
isinstance
(
op_desc
,
tuple
):
op_desc
=
op_desc
[
0
]
op_desc
=
op_desc
[
0
]
op_desc
.
rename_input
(
old_name
,
new_name
)
op_desc
.
rename_input
(
old_name
,
new_name
)
...
@@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
...
@@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
def
_create_op_desc_
(
op_type
,
inputs
,
outputs
,
attrs
):
def
_create_op_desc_
(
op_type
,
inputs
,
outputs
,
attrs
):
"""
Create a C++ OpDesc object with specified inputs, outputs and attributes.
"""
op_desc
=
core
.
OpDesc
()
op_desc
=
core
.
OpDesc
()
op_desc
.
set_type
(
op_type
)
op_desc
.
set_type
(
op_type
)
for
para
,
args
in
inputs
.
iteritems
():
for
para
,
args
in
inputs
.
iteritems
():
...
@@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
...
@@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return
op_desc
return
op_desc
def
_infer_var_data_type_
(
var_name
,
block
):
def
_infer_var_data_type_
(
grad_var_name
,
block
):
grad_var
=
block
.
desc
.
find_var
(
var_name
.
encode
(
"ascii"
))
"""
fwd_name
=
_strip_grad_suffix_
(
var_name
.
encode
(
"ascii"
))
Infer the data type of given grad variable
"""
grad_var
=
block
.
desc
.
find_var
(
grad_var_name
.
encode
(
"ascii"
))
fwd_name
=
_strip_grad_suffix_
(
grad_var_name
.
encode
(
"ascii"
))
if
block
.
desc
.
has_var_recursive
(
fwd_name
):
if
block
.
desc
.
has_var_recursive
(
fwd_name
):
fwd_var
=
block
.
desc
.
find_var_recursive
(
fwd_name
.
encode
(
"ascii"
))
fwd_var
=
block
.
desc
.
find_var_recursive
(
fwd_name
.
encode
(
"ascii"
))
grad_var
.
set_dtype
(
fwd_var
.
dtype
())
grad_var
.
set_dtype
(
fwd_var
.
dtype
())
...
@@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block):
...
@@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block):
def
_all_in_set_
(
cands
,
s
):
def
_all_in_set_
(
cands
,
s
):
"""
Test if all elements of 'cands' are in set 's'
"""
for
c
in
cands
:
for
c
in
cands
:
if
not
c
in
s
:
if
not
c
in
s
:
return
False
return
False
...
@@ -52,18 +64,29 @@ def _all_in_set_(cands, s):
...
@@ -52,18 +64,29 @@ def _all_in_set_(cands, s):
def
_strip_grad_suffix_
(
name
):
def
_strip_grad_suffix_
(
name
):
"""
Strip the grad suffix from the given varibale name
e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y
"""
pos
=
name
.
find
(
core
.
grad_var_suffix
())
pos
=
name
.
find
(
core
.
grad_var_suffix
())
return
name
[:
pos
]
if
pos
!=
-
1
else
name
return
name
[:
pos
]
if
pos
!=
-
1
else
name
def
_append_grad_suffix_
(
name
):
def
_append_grad_suffix_
(
name
):
"""
Append grad suffix to the given variable name
e.g. x ==> x@GRAD
"""
return
name
+
core
.
grad_var_suffix
()
return
name
+
core
.
grad_var_suffix
()
def
_addup_repetitive_outputs_
(
op_descs
):
def
_addup_repetitive_outputs_
(
op_descs
):
# In backward part, an variable my be the output of more than one ops.
"""
# In this case, the variable should be the accumulation of all the outputs.
In backward part, an variable may be the output of more than one ops.
# We adopt adding `sum_op`s to implement the accumulate.
In this case, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate.
"""
pending_sum_ops
=
[]
pending_sum_ops
=
[]
var_rename_count
=
collections
.
defaultdict
(
int
)
var_rename_count
=
collections
.
defaultdict
(
int
)
renamed_vars
=
collections
.
defaultdict
(
list
)
renamed_vars
=
collections
.
defaultdict
(
list
)
...
@@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs):
...
@@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs):
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
):
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
"""
# Remove ops whose outputs are all in no_grad_dict
# Remove ops whose outputs are all in no_grad_dict
op_descs
=
filter
(
op_descs
=
filter
(
lambda
op_desc
:
not
_all_in_set_
(
op_desc
.
output_arg_names
(),
no_grad_set
),
lambda
op_desc
:
not
_all_in_set_
(
op_desc
.
output_arg_names
(),
no_grad_set
),
...
@@ -133,6 +162,20 @@ def _append_backward_ops_(target,
...
@@ -133,6 +162,20 @@ def _append_backward_ops_(target,
no_grad_dict
,
no_grad_dict
,
grad_to_var
,
grad_to_var
,
callback
=
None
):
callback
=
None
):
"""
Create all grad ops, and insert them into given block
Args:
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
"""
grad_op_descs
=
[]
grad_op_descs
=
[]
program
=
block
.
program
program
=
block
.
program
for
op
in
reversed
(
block
.
ops
):
for
op
in
reversed
(
block
.
ops
):
...
@@ -170,6 +213,20 @@ def _append_backward_ops_(target,
...
@@ -170,6 +213,20 @@ def _append_backward_ops_(target,
def
_append_backward_vars_
(
block
,
start_op_idx
,
grad_to_var
,
grad_info_map
):
def
_append_backward_vars_
(
block
,
start_op_idx
,
grad_to_var
,
grad_info_map
):
"""
Create new variables required by backward pass.
Args:
block(Block): the block where new variables will be created
start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
grad_to_var(dict):
key(str): grad variable name
val(str): corresponding forward variable name
In most cases, this dict is generated by _append_backward_ops_()
grad_info_map(dict)(output argument):
key(str): forward variable name
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
"""
for
op_idx
in
range
(
start_op_idx
,
block
.
desc
.
op_size
()):
for
op_idx
in
range
(
start_op_idx
,
block
.
desc
.
op_size
()):
op_desc
=
block
.
desc
.
op
(
op_idx
)
op_desc
=
block
.
desc
.
op
(
op_idx
)
if
op_desc
.
has_attr
(
"sub_block"
):
if
op_desc
.
has_attr
(
"sub_block"
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录