未验证 提交 ae31faaa 编写于 作者: C chengduo 提交者: GitHub

refine optimier function (#19886)

test=developt
上级 93364b45
...@@ -491,6 +491,10 @@ class Optimizer(object): ...@@ -491,6 +491,10 @@ class Optimizer(object):
else: else:
assert (isinstance(callbacks, list)) assert (isinstance(callbacks, list))
program = loss.block.program program = loss.block.program
assert len(loss.shape) == 1 and loss.shape[0] == 1, \
"The loss.shape should be (1L,), but the current loss.shape is {}. " \
"Maybe that you should call fluid.layers.mean to process the current loss.".format(
loss.shape)
with program_guard(program, startup_program): with program_guard(program, startup_program):
params_grads = append_backward(loss, parameter_list, params_grads = append_backward(loss, parameter_list,
no_grad_set, callbacks) no_grad_set, callbacks)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册