Created by: Aurelius84
gradients API输入类型检查(仅涉及python端)
测试样例:
x = fluid.data(name='x', shape=[None, 2, 8, 8], dtype='float32')
x.stop_gradient = False
conv = fluid.layers.conv2d(x, 4, 1, bias_attr=False)
y = fluid.layers.relu(conv)
list/Variable
1. targets参数必须是x_grad = fluid.gradients(y.name, x)
- 优化前,错误报在子函数
calc_gradient
具体代码执行处,对用户不明朗
File "backward.py", line 1578, in calc_gradient
block = targets[0].block
AttributeError: 'str' object has no attribute 'block'
- 优化后,在gradients接口处直接报错
TypeError: The type of 'targets' in fluid.backward.gradients must be (<class 'paddle.fluid.framework.Variable'>, <class 'list'>), but received <class 'str'>.
list/Variable
2. inputs参数必须是x_grad = fluid.gradients(y, x.name)
- 优化前,错误报在子函数
calc_gradient
具体代码执行处,对用户不明朗
File "backward.py", line 1638, in calc_gradient
if input.block.program != prog:
AttributeError: 'str' object has no attribute 'block'
- 优化后,在gradients接口处直接报错
TypeError: The type of 'inputs' in fluid.backward.gradients must be (<class 'paddle.fluid.framework.Variable'>, <class 'list'>), but received <class 'str'>.
Variable或None
3. target_gradients参数必须是x_grad = fluid.gradients([y], [x], target_gradients=x.name)
- 优化前,错误报在子函数
calc_gradient
具体代码执行处,对用户不明朗
File "backward.py", line 1625, in calc_gradient
if target.shape != grad.shape:
AttributeError: 'str' object has no attribute 'shape'
- 优化后,在gradients接口处直接报错
TypeError: The type of 'target_gradients' in fluid.backward.gradients must be (<class 'paddle.fluid.framework.Variable'>, <class 'list'>, <class 'NoneType'>), but received <class 'str'>.
set或None
4. no_grad_set参数必须是x_grad = fluid.gradients([y], x, no_grad_set=conv)
- 优化前,错误报在子函数
_get_no_grad_set_name
具体代码执行处,对用户不明朗
File "backward.py", line 1141, in _get_no_grad_set_name
format(type(no_grad_set)))
TypeError: The type of no_grad_set should be set or list or tuple, but received <class 'paddle.fluid.framework.Variable'>
- 优化后,在gradients接口处直接报错
TypeError: The type of 'no_grad_set' in fluid.backward.gradients must be (<class 'set'>, <class 'NoneType'>), but received <class 'paddle.fluid.framework.Variable'>.