未验证 提交 4976153d 编写于 作者: R RedContritio 提交者: GitHub

add dims check for nms_kernel (#49993)

上级 3586e856
...@@ -69,6 +69,18 @@ void NMSKernel(const Context& dev_ctx, ...@@ -69,6 +69,18 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes, const DenseTensor& boxes,
float threshold, float threshold,
DenseTensor* output) { DenseTensor* output) {
PADDLE_ENFORCE_EQ(
boxes.dims().size(),
2,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));
PADDLE_ENFORCE_EQ(
boxes.dims()[1],
4,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));
int64_t num_boxes = boxes.dims()[0]; int64_t num_boxes = boxes.dims()[0];
DenseTensor output_tmp; DenseTensor output_tmp;
output_tmp.Resize(phi::make_ddim({num_boxes})); output_tmp.Resize(phi::make_ddim({num_boxes}));
......
...@@ -59,6 +59,18 @@ void NMSKernel(const Context& dev_ctx, ...@@ -59,6 +59,18 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes, const DenseTensor& boxes,
float threshold, float threshold,
DenseTensor* output) { DenseTensor* output) {
PADDLE_ENFORCE_EQ(
boxes.dims().size(),
2,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));
PADDLE_ENFORCE_EQ(
boxes.dims()[1],
4,
phi::errors::InvalidArgument("The shape [%s] of boxes must be (N, 4).",
boxes.dims()));
const int64_t num_boxes = boxes.dims()[0]; const int64_t num_boxes = boxes.dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
dim3 block(threadsPerBlock); dim3 block(threadsPerBlock);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册