提交 c532fdab 编写于 作者: F fengjiayi

fix errors

上级 5b9dbbb9
...@@ -195,7 +195,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -195,7 +195,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block) _infer_var_data_type_(arg, block)
def append_backward(loss, parameter_list=None, no_grad_dict=None): def append_backward(loss, parameter_list=None, no_grad_set=None):
""" """
Create and add gradient Operators in BlockDesc to compute Create and add gradient Operators in BlockDesc to compute
gradients of `loss` for parameters in parameter_list gradients of `loss` for parameters in parameter_list
...@@ -213,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None): ...@@ -213,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None):
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
program = loss.block.program program = loss.block.program
if no_grad_dict is None:
no_grad_dict = dict() no_grad_dict = dict()
if no_grad_set is None:
assert isinstance(program, framework.Program) assert isinstance(program, framework.Program)
for block in program.blocks: for block in program.blocks:
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -224,8 +224,10 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None): ...@@ -224,8 +224,10 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None):
if var.stop_gradient: if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name)) block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_dict, set): elif isinstance(no_grad_set, set):
no_grad_dict = {0: no_grad_dict} no_grad_dict = {0: no_grad_set}
else:
raise ValueError("'no_grad_set' should be a set or None.")
grad_info_map = dict() grad_info_map = dict()
root_block = program.block(0) root_block = program.block(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册