提交 400eb9d8 编写于 作者: H huangjun12 提交者: chajchaj

fix ce bug in label value, test=develop

上级 1eb59ef0
......@@ -1363,5 +1363,19 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(dy_ret_value, expected))
class TestCrossEntropyFAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_LabelValue():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(0, 100, shape=[5, 1], dtype="int64")
label_data[0] = 255
paddle.nn.functional.cross_entropy(
input=input_data, label=label_data)
self.assertRaises(ValueError, test_LabelValue)
if __name__ == "__main__":
unittest.main()
......@@ -1411,6 +1411,11 @@ def cross_entropy(input,
out = core.ops.elementwise_mul(out, weight_gather_reshape)
else:
for label_val in label:
if label_val < 0 or label_val >= input.shape[-1]:
raise ValueError(
'Expected 0 <= label_value < class_dimension({}), but got label_value {}'.
format(input.shape[-1], label_val.numpy()))
weight_gather = core.ops.gather_nd(weight, label)
input_shape = list(label.shape)
weight_gather_reshape = reshape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册