未验证 提交 fab9464f 编写于 作者: A Aurelius84 提交者: GitHub

API(fluid.gridents) error message enhancement (#23450)

* API(fluid.gridents) error message enhancement test=develop

* fix unitest failed test=develop
上级 979fb35e
......@@ -24,6 +24,7 @@ from .. import compat as cpt
from . import unique_name
from . import log_helper
import paddle.fluid
from .data_feeder import check_type
__all__ = [
'append_backward',
'gradients',
......@@ -1709,5 +1710,14 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
z = fluid.gradients([y], x)
print(z)
"""
check_type(targets, 'targets', (framework.Variable, list),
'fluid.backward.gradients')
check_type(inputs, 'inputs', (framework.Variable, list),
'fluid.backward.gradients')
check_type(target_gradients, 'target_gradients', (
framework.Variable, list, type(None)), 'fluid.backward.gradients')
check_type(no_grad_set, 'no_grad_set', (set, type(None)),
'fluid.backward.gradients')
outs = calc_gradient(targets, inputs, target_gradients, no_grad_set)
return _as_list(outs)
......@@ -241,6 +241,26 @@ class TestSimpleNet(TestBackward):
self._check_all(self.net)
class TestGradientsError(unittest.TestCase):
def test_error(self):
x = fluid.data(name='x', shape=[None, 2, 8, 8], dtype='float32')
x.stop_gradient = False
conv = fluid.layers.conv2d(x, 4, 1, bias_attr=False)
y = fluid.layers.relu(conv)
with self.assertRaises(TypeError):
x_grad = fluid.gradients(y.name, x)
with self.assertRaises(TypeError):
x_grad = fluid.gradients(y, x.name)
with self.assertRaises(TypeError):
x_grad = fluid.gradients([y], [x], target_gradients=x.name)
with self.assertRaises(TypeError):
x_grad = fluid.gradients([y], x, no_grad_set=conv)
class TestSimpleNetWithErrorParamList(TestBackward):
def test_parameter_list_type_error(self):
self.global_block_idx = 0
......
......@@ -441,7 +441,7 @@ def get_discriminator_loss(image_real, label_org, label_trg, generator,
d_loss = d_loss_real + d_loss_fake + d_loss_cls
d_loss_gp = gradient_penalty(discriminator, image_real, fake_img,
discriminator.parameters(), cfg)
set(discriminator.parameters()), cfg)
if d_loss_gp is not None:
d_loss += cfg.lambda_gp * d_loss_gp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册