未验证 提交 08bf1b49 编写于 作者: L Linjie Chen 提交者: GitHub

Add input check for NLLLoss (#49547)

* fix nll_loss

* fix nll_loss

* update

* update

* update

* fix
上级 203f9594
......@@ -1125,6 +1125,37 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
self.assertRaises(ValueError, test_x_dim_imperative_lt_2)
def test_x_shape_lt_1():
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
place = paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog):
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1, 0]), dtype='float32')
label = paddle.to_tensor(
np.reshape(array, [1, 0]), dtype='int64'
)
nll_loss = paddle.nn.loss.NLLLoss()
res = nll_loss(x, label)
self.assertRaises(ValueError, test_x_shape_lt_1)
def test_x_dim_and_label_dim():
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
place = paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog):
x_np = np.random.random(size=(5,)).astype(np.float64)
label_np = np.random.randint(0, 10, size=(5, 1)).astype(
np.int64
)
x = paddle.to_tensor(x_np)
label = paddle.to_tensor(label_np)
nll_loss = paddle.nn.loss.NLLLoss()
res = nll_loss(x, label)
self.assertRaises(ValueError, test_x_dim_and_label_dim)
def test_reduction_value_error(self):
def test_NLLLoss_reduction_not_sum_mean_none():
prog = paddle.static.Program()
......
......@@ -1372,10 +1372,29 @@ def nll_loss(
input_shape = list(input.shape)
input_dims = len(input_shape)
label_shape = list(label.shape)
label_dims = len(label_shape)
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
"Expected input_dims - 1 = label_dims or input_dims == label_dims\
(got input_dims{}, label_dims{})".format(
input_dims, label_dims
)
)
if input_dims < 2:
raise ValueError(
'Expected 2 or more dimensions (got {})'.format(input_dims)
)
if input_shape[1] < 1:
raise ValueError(
"Expected 1 or more classess (got num classes{})".format(
input_shape[1]
)
)
n = input_shape[0]
c = input_shape[1]
if in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册