未验证 提交 97411214 编写于 作者: R RedContritio 提交者: GitHub

Fix 堆栈溢出 (stack overflow) of case3: paddle.metric.accuracy (#49984)

* add input check for accuracyOp

* add input check for gpu/accuracyOp

* add unittest

* use rank instead of dimensions in message

* update unittest

* update unittest
上级 85490f70
......@@ -35,10 +35,33 @@ void AccuracyRawKernel(const Context& dev_ctx,
const int64_t* indices_data = indices.data<int64_t>();
const int64_t* label_data = label.data<int64_t>();
PADDLE_ENFORCE_EQ(
inference.dims().size(),
2,
phi::errors::InvalidArgument(
"Rank(Input) of AccuracyOp must be 2, with shape "
"[sample_number, class_dim], But received rank(Input) is %d",
inference.dims().size()));
size_t num_samples = inference.dims()[0];
size_t class_dim = inference.dims()[1];
*accuracy_data = 0.0f;
PADDLE_ENFORCE_GT(label.dims().size(),
0,
phi::errors::InvalidArgument(
"Rank(Label) of AccuracyOp must greater than 0, "
"But received rank(Label) is %d",
label.dims().size()));
PADDLE_ENFORCE_GE(
label.dims()[0],
inference.dims()[0],
phi::errors::InvalidArgument("num_samples(%d) of Label should less than "
"or equal to num_samples(%d) of Input",
label.dims()[0],
num_samples));
if (num_samples == 0) {
return;
}
......
......@@ -82,6 +82,14 @@ void AccuracyRawKernel(const Context& dev_ctx,
const int64_t* indices_data = indices.data<int64_t>();
const int64_t* label_data = label.data<int64_t>();
PADDLE_ENFORCE_EQ(
inference.dims().size(),
2,
phi::errors::InvalidArgument(
"Rank(Input) of AccuracyOp must be 2, with shape "
"[sample_number, class_dim], But received rank(Input) is %d",
inference.dims().size()));
int* correct_data = dev_ctx.template Alloc<int>(correct);
int* total_data = dev_ctx.template Alloc<int>(total);
T* accuracy_data = dev_ctx.template Alloc<T>(accuracy);
......@@ -91,6 +99,21 @@ void AccuracyRawKernel(const Context& dev_ctx,
auto stream = dev_ctx.stream();
phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(T), stream);
PADDLE_ENFORCE_GT(label.dims().size(),
0,
phi::errors::InvalidArgument(
"Rank(Label) of AccuracyOp must greater than 0, "
"But received rank(Label) is %d",
label.dims().size()));
PADDLE_ENFORCE_GE(
label.dims()[0],
inference.dims()[0],
phi::errors::InvalidArgument("num_samples(%d) of Label should less than "
"or equal to num_samples(%d) of Input",
label.dims()[0],
num_samples));
if (num_samples == 0) {
return;
}
......
......@@ -60,7 +60,7 @@ class TestAccuracyOpFp16(TestAccuracyOp):
class TestAccuracyOpError(unittest.TestCase):
def test_errors(self):
def test_type_errors(self):
with program_guard(Program(), Program()):
# The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor(
......@@ -75,12 +75,27 @@ class TestAccuracyOpError(unittest.TestCase):
x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="int32")
self.assertRaises(TypeError, paddle.static.accuracy, x2, label)
self.assertRaises(TypeError, paddle.metric.accuracy, x2, label)
x3 = paddle.static.data(
name='input', shape=[-1, 2], dtype="float16"
)
paddle.static.accuracy(input=x3, label=label)
paddle.metric.accuracy(input=x3, label=label)
def test_value_errors(self):
with program_guard(Program(), Program()):
paddle.disable_static()
# The input rank of accuracy_op must be 2.
with self.assertRaises(ValueError):
x3 = paddle.to_tensor([0.1], dtype='float32')
label3 = paddle.to_tensor(
np.reshape([0], [1, 1]), dtype='int32'
)
paddle.metric.accuracy(x3, label3)
paddle.enable_static()
class TestAccuracyAPI1(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册