提交 09d4a3a4 编写于 作者: H HydrogenSulfate 提交者: chajchaj

add static label check

上级 b4eec5d5
......@@ -1465,6 +1465,34 @@ class TestCrossEntropyFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_WeightLength_NotEqual)
def test_LabelValue_ExceedMax():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64")
label_data[0] = 100
weight_data = paddle.rand([100])
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(ValueError, test_LabelValue_ExceedMax)
def test_LabelValue_ExceedMin():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64")
label_data[0] = -1
weight_data = paddle.rand([100])
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(ValueError, test_LabelValue_ExceedMin)
def static_test_WeightLength_NotEqual():
input_np = np.random.random([2, 4]).astype('float32')
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
......
......@@ -1710,6 +1710,22 @@ def cross_entropy(input,
ignore_weight_mask,
dtype=label.dtype) * label # ignored position will be 0
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
ignore_weight_mask = paddle.cast(
ignore_weight_mask, out.dtype) # convert from 0 to 0.0
......@@ -1833,6 +1849,22 @@ def cross_entropy(input,
ignore_weight_mask,
dtype=label.dtype) * label # ignored position will be 0
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
ignore_weight_mask = paddle.cast(ignore_weight_mask,
out.dtype) # convert from 0 to 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册