未验证 提交 957cbe68 编写于 作者: H huangjun12 提交者: GitHub

fix ce error message, test=release/2.1 (#32758)

上级 f54fb1ee
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import unittest import unittest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
from test_softmax_with_cross_entropy_op import cross_entropy from test_softmax_with_cross_entropy_op import cross_entropy
from paddle.fluid import Program, program_guard
def stable_softmax(x): def stable_softmax(x):
...@@ -1363,5 +1364,37 @@ class CrossEntropyLoss(unittest.TestCase): ...@@ -1363,5 +1364,37 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(dy_ret_value, expected)) self.assertTrue(np.allclose(dy_ret_value, expected))
class TestCrossEntropyFAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_LabelValue():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64")
label_data[0] = 255
weight_data = paddle.rand([100])
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=255)
self.assertRaises(ValueError, test_LabelValue)
def test_LabelValueNeg():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64")
label_data[0] = -1
weight_data = paddle.rand([100])
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-1)
self.assertRaises(ValueError, test_LabelValueNeg)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1411,6 +1411,13 @@ def cross_entropy(input, ...@@ -1411,6 +1411,13 @@ def cross_entropy(input,
out = core.ops.elementwise_mul(out, weight_gather_reshape) out = core.ops.elementwise_mul(out, weight_gather_reshape)
else: else:
label_min = paddle.min(label)
label_max = paddle.max(label)
if label_min < 0 or label_max >= input.shape[-1]:
raise ValueError(
'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '.
format(input.shape[-1],
label_min.numpy(), label_max.numpy()))
weight_gather = core.ops.gather_nd(weight, label) weight_gather = core.ops.gather_nd(weight, label)
input_shape = list(label.shape) input_shape = list(label.shape)
weight_gather_reshape = reshape( weight_gather_reshape = reshape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册