提交 c532fdab 编写于 作者: F fengjiayi

fix errors

上级 5b9dbbb9
......@@ -195,7 +195,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block)
def append_backward(loss, parameter_list=None, no_grad_dict=None):
def append_backward(loss, parameter_list=None, no_grad_set=None):
"""
Create and add gradient Operators in BlockDesc to compute
gradients of `loss` for parameters in parameter_list
......@@ -213,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None):
assert isinstance(loss, framework.Variable)
program = loss.block.program
if no_grad_dict is None:
no_grad_dict = dict()
if no_grad_set is None:
assert isinstance(program, framework.Program)
for block in program.blocks:
assert isinstance(block, framework.Block)
......@@ -224,8 +224,10 @@ def append_backward(loss, parameter_list=None, no_grad_dict=None):
if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_dict, set):
no_grad_dict = {0: no_grad_dict}
elif isinstance(no_grad_set, set):
no_grad_dict = {0: no_grad_set}
else:
raise ValueError("'no_grad_set' should be a set or None.")
grad_info_map = dict()
root_block = program.block(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册