未验证 提交 be84cac7 编写于 作者: R RichardWooSJTU 提交者: GitHub

[vision.ops.nms] Fix return order error and duplicate results with specific...

[vision.ops.nms] Fix return order error and duplicate results with specific inputs (#46148) (#46193)

* fix return order error and duplicate results with specific inputs
上级 ad8beaaf
...@@ -2035,8 +2035,8 @@ void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) { ...@@ -2035,8 +2035,8 @@ void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) {
"whose shape must be [N, 4] " "whose shape must be [N, 4] "
"N is the number of boxes " "N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2]. ")); "in last dimension in format [x1, x2, y1, y2]. "));
auto num_boxes = boxes_dim[0]; out->set_dims(phi::make_ddim({-1}));
out->set_dims(phi::make_ddim({num_boxes})); out->set_dtype(DataType::INT64);
} }
void NormInferMeta(const MetaTensor& x, void NormInferMeta(const MetaTensor& x,
......
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/diagonal.h" #include "paddle/phi/kernels/funcs/diagonal.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi { namespace phi {
template <typename T> template <typename T>
static void NMS(const T* boxes_data, static int64_t NMS(const T* boxes_data,
int64_t* output_data, int64_t* output_data,
float threshold, float threshold,
int64_t num_boxes) { int64_t num_boxes) {
...@@ -54,9 +55,13 @@ static void NMS(const T* boxes_data, ...@@ -54,9 +55,13 @@ static void NMS(const T* boxes_data,
output_data[output_data_idx++] = i; output_data[output_data_idx++] = i;
} }
int64_t num_keep_boxes = output_data_idx;
for (; output_data_idx < num_boxes; ++output_data_idx) { for (; output_data_idx < num_boxes; ++output_data_idx) {
output_data[output_data_idx] = 0; output_data[output_data_idx] = 0;
} }
return num_keep_boxes;
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -64,8 +69,15 @@ void NMSKernel(const Context& dev_ctx, ...@@ -64,8 +69,15 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes, const DenseTensor& boxes,
float threshold, float threshold,
DenseTensor* output) { DenseTensor* output) {
auto output_data = dev_ctx.template Alloc<int64_t>(output); int64_t num_boxes = boxes.dims()[0];
NMS<T>(boxes.data<T>(), output_data, threshold, boxes.dims()[0]); DenseTensor output_tmp;
output_tmp.Resize(phi::make_ddim({num_boxes}));
auto output_tmp_data = dev_ctx.template Alloc<int64_t>(&output_tmp);
int64_t num_keep_boxes =
NMS<T>(boxes.data<T>(), output_tmp_data, threshold, num_boxes);
auto slice_out = output_tmp.Slice(0, num_keep_boxes);
phi::Copy(dev_ctx, slice_out, dev_ctx.GetPlace(), false, output);
} }
} // namespace phi } // namespace phi
......
...@@ -59,7 +59,6 @@ void NMSKernel(const Context& dev_ctx, ...@@ -59,7 +59,6 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes, const DenseTensor& boxes,
float threshold, float threshold,
DenseTensor* output) { DenseTensor* output) {
auto* output_data = dev_ctx.template Alloc<int64_t>(output);
const int64_t num_boxes = boxes.dims()[0]; const int64_t num_boxes = boxes.dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
dim3 block(threadsPerBlock); dim3 block(threadsPerBlock);
...@@ -93,11 +92,13 @@ void NMSKernel(const Context& dev_ctx, ...@@ -93,11 +92,13 @@ void NMSKernel(const Context& dev_ctx,
} }
} }
} }
output->Resize(phi::make_ddim({last_box_num}));
auto* output_data = dev_ctx.template Alloc<int64_t>(output);
paddle::memory::Copy(dev_ctx.GetPlace(), paddle::memory::Copy(dev_ctx.GetPlace(),
output_data, output_data,
phi::CPUPlace(), phi::CPUPlace(),
output_host, output_host,
sizeof(int64_t) * num_boxes, sizeof(int64_t) * last_box_num,
dev_ctx.stream()); dev_ctx.stream());
} }
} // namespace phi } // namespace phi
......
...@@ -65,7 +65,7 @@ def nms(boxes, nms_threshold): ...@@ -65,7 +65,7 @@ def nms(boxes, nms_threshold):
else: else:
continue continue
return selected_indices return selected_indices[:cnt]
class TestNMSOp(OpTest): class TestNMSOp(OpTest):
......
...@@ -1611,7 +1611,9 @@ def nms(boxes, ...@@ -1611,7 +1611,9 @@ def nms(boxes,
import paddle import paddle
if category_idxs is None: if category_idxs is None:
sorted_global_indices = paddle.argsort(scores, descending=True) sorted_global_indices = paddle.argsort(scores, descending=True)
return _nms(boxes[sorted_global_indices], iou_threshold) sorted_keep_boxes_indices = _nms(boxes[sorted_global_indices],
iou_threshold)
return sorted_global_indices[sorted_keep_boxes_indices]
if top_k is not None: if top_k is not None:
assert top_k <= scores.shape[ assert top_k <= scores.shape[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册