未验证 提交 0a21836d 编写于 作者: F Feng Ni 提交者: GitHub

fix roi_align roi_pool to static num 0 (#55342)

上级 c80bf368
......@@ -192,6 +192,11 @@ void RoiAlignKernel(const Context& dev_ctx,
int width = in_dims[3];
int rois_num = boxes.dims()[0];
if (rois_num == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
auto in_stride = phi::stride(in_dims);
auto roi_stride = phi::stride(boxes.dims());
auto out_stride = phi::stride(out->dims());
......
......@@ -37,6 +37,11 @@ void RoiPoolKernel(const Context& dev_ctx,
int width = x_dims[3];
int rois_num = boxes.dims()[0];
if (rois_num == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
auto in_stride = phi::stride(x_dims);
auto arg_max_stride = phi::stride(arg_max->dims());
auto box_stride = phi::stride(boxes.dims());
......
......@@ -153,7 +153,10 @@ void RoiAlignKernel(const Context& dev_ctx,
int rois_num = boxes.dims()[0];
if (rois_num == 0) return;
if (rois_num == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
int output_size = out->numel();
int blocks = NumBlocks(output_size);
......
......@@ -118,7 +118,10 @@ void RoiPoolKernel(const Context& dev_ctx,
int rois_num = boxes.dims()[0];
if (rois_num == 0) return;
if (rois_num == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
int output_size = out->numel();
int blocks = NumBlocks(output_size);
......
......@@ -40,7 +40,10 @@ void RoiAlignKernel(const Context& dev_ctx,
int rois_num = boxes.dims()[0];
if (rois_num == 0) return;
if (rois_num == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
DenseTensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册