提交 05c08214 编写于 作者: Y yangyaming

Bug fix when inserting fill_zeros_like_op.

上级 24341d3a
......@@ -7,7 +7,7 @@ __all__ = ['append_backward']
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
"""
Traverse all ops in op_descs[begin_idx : end_idx],
Traverse all ops in op_descs[begin_idx : end_idx],
if any op has inputs/outputs named "old_name", rename it as 'new_name'
"""
if begin_idx is None:
......@@ -162,7 +162,7 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((_create_op_desc_("fill_zeros_like", {
"X": [_strip_grad_suffix_(arg)]
}, {"Y": [arg]}, {}), idx))
}, {"Out": [arg]}, {}), idx))
map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))
......@@ -182,7 +182,7 @@ def _append_backward_ops_(target,
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
......@@ -276,8 +276,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
loss(Variable): The variable generated by cost function.
parameter_list(list): Parameters that need to be updated by optimizer.
If None, it means all parameters need to be updated.
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
contains all variables with `step_gradient=True` from all blocks.
Return:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册