提交 d49daff0 编写于 作者: H HydrogenSulfate 提交者: chajchaj

restore test for min,max labels

上级 1e3e17df
...@@ -1465,6 +1465,34 @@ class TestCrossEntropyFAPIError(unittest.TestCase): ...@@ -1465,6 +1465,34 @@ class TestCrossEntropyFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_WeightLength_NotEqual) self.assertRaises(ValueError, test_WeightLength_NotEqual)
def test_LabelValue_ExceedMax():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64") # hard label
label_data[0] = 100
weight_data = paddle.rand([100]) # provide weight
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(ValueError, test_LabelValue_ExceedMax)
def test_LabelValue_ExceedMin():
input_data = paddle.rand(shape=[20, 100])
label_data = paddle.randint(
0, 100, shape=[20, 1], dtype="int64") # hard label
label_data[0] = -1
weight_data = paddle.rand([100]) # provide weight
paddle.nn.functional.cross_entropy(
input=input_data,
label=label_data,
weight=weight_data,
ignore_index=-100)
self.assertRaises(ValueError, test_LabelValue_ExceedMin)
def static_test_WeightLength_NotEqual(): def static_test_WeightLength_NotEqual():
input_np = np.random.random([2, 4]).astype('float32') input_np = np.random.random([2, 4]).astype('float32')
label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册