未验证 提交 2b76db99 编写于 作者: R RichardWooSJTU 提交者: GitHub

[vision.ops.nms] Fix return order error and duplicate results with specific inputs (#46148)

* fix return order error and duplicate results with specific inputs
上级 76154c94
......@@ -2035,8 +2035,8 @@ void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) {
"whose shape must be [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2]. "));
auto num_boxes = boxes_dim[0];
out->set_dims(phi::make_ddim({num_boxes}));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(DataType::INT64);
}
void NormInferMeta(const MetaTensor& x,
......
......@@ -16,13 +16,14 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/diagonal.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
static void NMS(const T* boxes_data,
static int64_t NMS(const T* boxes_data,
int64_t* output_data,
float threshold,
int64_t num_boxes) {
......@@ -54,9 +55,13 @@ static void NMS(const T* boxes_data,
output_data[output_data_idx++] = i;
}
int64_t num_keep_boxes = output_data_idx;
for (; output_data_idx < num_boxes; ++output_data_idx) {
output_data[output_data_idx] = 0;
}
return num_keep_boxes;
}
template <typename T, typename Context>
......@@ -64,8 +69,15 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes,
float threshold,
DenseTensor* output) {
auto output_data = dev_ctx.template Alloc<int64_t>(output);
NMS<T>(boxes.data<T>(), output_data, threshold, boxes.dims()[0]);
int64_t num_boxes = boxes.dims()[0];
DenseTensor output_tmp;
output_tmp.Resize(phi::make_ddim({num_boxes}));
auto output_tmp_data = dev_ctx.template Alloc<int64_t>(&output_tmp);
int64_t num_keep_boxes =
NMS<T>(boxes.data<T>(), output_tmp_data, threshold, num_boxes);
auto slice_out = output_tmp.Slice(0, num_keep_boxes);
phi::Copy(dev_ctx, slice_out, dev_ctx.GetPlace(), false, output);
}
} // namespace phi
......
......@@ -59,7 +59,6 @@ void NMSKernel(const Context& dev_ctx,
const DenseTensor& boxes,
float threshold,
DenseTensor* output) {
auto* output_data = dev_ctx.template Alloc<int64_t>(output);
const int64_t num_boxes = boxes.dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
dim3 block(threadsPerBlock);
......@@ -93,11 +92,13 @@ void NMSKernel(const Context& dev_ctx,
}
}
}
output->Resize(phi::make_ddim({last_box_num}));
auto* output_data = dev_ctx.template Alloc<int64_t>(output);
paddle::memory::Copy(dev_ctx.GetPlace(),
output_data,
phi::CPUPlace(),
output_host,
sizeof(int64_t) * num_boxes,
sizeof(int64_t) * last_box_num,
dev_ctx.stream());
}
} // namespace phi
......
......@@ -65,7 +65,7 @@ def nms(boxes, nms_threshold):
else:
continue
return selected_indices
return selected_indices[:cnt]
class TestNMSOp(OpTest):
......
......@@ -1611,7 +1611,9 @@ def nms(boxes,
import paddle
if category_idxs is None:
sorted_global_indices = paddle.argsort(scores, descending=True)
return _nms(boxes[sorted_global_indices], iou_threshold)
sorted_keep_boxes_indices = _nms(boxes[sorted_global_indices],
iou_threshold)
return sorted_global_indices[sorted_keep_boxes_indices]
if top_k is not None:
assert top_k <= scores.shape[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册