提交 51398ab9 编写于 作者: H HydrogenSulfate 提交者: chajchaj

remove hard labels check

上级 7ddfec00
...@@ -1466,34 +1466,6 @@ class TestCrossEntropyFAPIError(unittest.TestCase): ...@@ -1466,34 +1466,6 @@ class TestCrossEntropyFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_WeightLength_NotEqual) self.assertRaises(ValueError, test_WeightLength_NotEqual)
def test_LabelValue_ExceedMax():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64") # hard label
label_data[0] = 100
weight_data = paddle.rand([100]) # provide weight
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(IndexError, test_LabelValue_ExceedMax)
def test_LabelValue_ExceedMin():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64") # hard label
label_data[0] = -1
weight_data = paddle.rand([100]) # provide weight
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(IndexError, test_LabelValue_ExceedMin)
def static_test_WeightLength_NotEqual(): def static_test_WeightLength_NotEqual():
input_np = np.random.random([2, 4]).astype('float32') input_np = np.random.random([2, 4]).astype('float32')
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
......
...@@ -1705,23 +1705,6 @@ def cross_entropy(input, ...@@ -1705,23 +1705,6 @@ def cross_entropy(input,
valid_label = paddle.where(label == ignore_index, valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label) paddle.zeros_like(label), label)
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise IndexError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise IndexError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
ignore_weight_mask = paddle.cast((label != ignore_index), ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype) out.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册