diff --git a/paddle/phi/kernels/cpu/accuracy_kernel.cc b/paddle/phi/kernels/cpu/accuracy_kernel.cc index 17246de35db22c079a0bcad3598b172abe6ea808..2c9312e63ac89994ed19d1dc77dd36bc20c3e7be 100644 --- a/paddle/phi/kernels/cpu/accuracy_kernel.cc +++ b/paddle/phi/kernels/cpu/accuracy_kernel.cc @@ -35,10 +35,33 @@ void AccuracyRawKernel(const Context& dev_ctx, const int64_t* indices_data = indices.data(); const int64_t* label_data = label.data(); + PADDLE_ENFORCE_EQ( + inference.dims().size(), + 2, + phi::errors::InvalidArgument( + "Rank(Input) of AccuracyOp must be 2, with shape " + "[sample_number, class_dim], But received rank(Input) is %d", + inference.dims().size())); + size_t num_samples = inference.dims()[0]; size_t class_dim = inference.dims()[1]; *accuracy_data = 0.0f; + PADDLE_ENFORCE_GT(label.dims().size(), + 0, + phi::errors::InvalidArgument( + "Rank(Label) of AccuracyOp must greater than 0, " + "But received rank(Label) is %d", + label.dims().size())); + + PADDLE_ENFORCE_GE( + label.dims()[0], + inference.dims()[0], + phi::errors::InvalidArgument("num_samples(%d) of Label should less than " + "or equal to num_samples(%d) of Input", + label.dims()[0], + num_samples)); + if (num_samples == 0) { return; } diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu index 8a4aa2a6397c91dd4258032f358288e11be8450b..6cdad23bfd5e180ecd943e1462de111c2bf318c9 100644 --- a/paddle/phi/kernels/gpu/accuracy_kernel.cu +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -82,6 +82,14 @@ void AccuracyRawKernel(const Context& dev_ctx, const int64_t* indices_data = indices.data(); const int64_t* label_data = label.data(); + PADDLE_ENFORCE_EQ( + inference.dims().size(), + 2, + phi::errors::InvalidArgument( + "Rank(Input) of AccuracyOp must be 2, with shape " + "[sample_number, class_dim], But received rank(Input) is %d", + inference.dims().size())); + int* correct_data = dev_ctx.template Alloc(correct); int* total_data = dev_ctx.template Alloc(total); T* accuracy_data = dev_ctx.template Alloc(accuracy); @@ -91,6 +99,21 @@ void AccuracyRawKernel(const Context& dev_ctx, auto stream = dev_ctx.stream(); phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(T), stream); + PADDLE_ENFORCE_GT(label.dims().size(), + 0, + phi::errors::InvalidArgument( + "Rank(Label) of AccuracyOp must greater than 0, " + "But received rank(Label) is %d", + label.dims().size())); + + PADDLE_ENFORCE_GE( + label.dims()[0], + inference.dims()[0], + phi::errors::InvalidArgument("num_samples(%d) of Label should less than " + "or equal to num_samples(%d) of Input", + label.dims()[0], + num_samples)); + if (num_samples == 0) { return; } diff --git a/python/paddle/fluid/tests/unittests/test_accuracy_op.py b/python/paddle/fluid/tests/unittests/test_accuracy_op.py index 431d8b24bcee2f7581c25a542a71fc9cb1a6340f..5627ead0a6b59e62a62bfb518842a50592453e9a 100755 --- a/python/paddle/fluid/tests/unittests/test_accuracy_op.py +++ b/python/paddle/fluid/tests/unittests/test_accuracy_op.py @@ -60,7 +60,7 @@ class TestAccuracyOpFp16(TestAccuracyOp): class TestAccuracyOpError(unittest.TestCase): - def test_errors(self): + def test_type_errors(self): with program_guard(Program(), Program()): # The input type of accuracy_op must be Variable. x1 = fluid.create_lod_tensor( @@ -75,12 +75,27 @@ class TestAccuracyOpError(unittest.TestCase): x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="int32") self.assertRaises(TypeError, paddle.static.accuracy, x2, label) self.assertRaises(TypeError, paddle.metric.accuracy, x2, label) + x3 = paddle.static.data( name='input', shape=[-1, 2], dtype="float16" ) paddle.static.accuracy(input=x3, label=label) paddle.metric.accuracy(input=x3, label=label) + def test_value_errors(self): + with program_guard(Program(), Program()): + paddle.disable_static() + + # The input rank of accuracy_op must be 2. + with self.assertRaises(ValueError): + x3 = paddle.to_tensor([0.1], dtype='float32') + label3 = paddle.to_tensor( + np.reshape([0], [1, 1]), dtype='int32' + ) + paddle.metric.accuracy(x3, label3) + + paddle.enable_static() + class TestAccuracyAPI1(unittest.TestCase): def setUp(self):