提交 46e856c7 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Remove the labels range check under the dynamic graph

上级 7b860a23
......@@ -1465,34 +1465,6 @@ class TestCrossEntropyFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_WeightLength_NotEqual)
def test_LabelValue_ExceedMax():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64")
label_data[0] = 100
weight_data = paddle.rand([100])
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(ValueError, test_LabelValue_ExceedMax)
def test_LabelValue_ExceedMin():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64")
label_data[0] = -1
weight_data = paddle.rand([100])
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(ValueError, test_LabelValue_ExceedMin)
def static_test_WeightLength_NotEqual():
input_np = np.random.random([2, 4]).astype('float32')
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
......
......@@ -1665,26 +1665,6 @@ def cross_entropy(input,
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode():
if soft_label == False:
valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label)
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
if core.is_compiled_with_npu():
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
......@@ -1716,6 +1696,25 @@ def cross_entropy(input,
out = _C_ops.elementwise_mul(out, weight_gather_reshape)
else:
valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label)
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
if input.shape[axis] != weight.shape[-1]:
raise ValueError(
"input's class_dimension({}) must equal to "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册