diff --git a/paddle/fluid/operators/roi_align_op_xpu.cc b/paddle/fluid/operators/roi_align_op_xpu.cc index 2c3bfdbc16b4d421385abe5613f7e63bb67ea02b..75bd94142e6b7ff0a36220b594ba14380f65053b 100644 --- a/paddle/fluid/operators/roi_align_op_xpu.cc +++ b/paddle/fluid/operators/roi_align_op_xpu.cc @@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel { int width = in_dims[3]; int rois_num = rois->dims()[0]; const T* input_data = in->data(); - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The rois_batch_size and imgs batch_size of roi_align_xpu OP must " - "be the same. But received rois_batch_size %d , batch_size %d", - rois_batch_size, batch_size)); + + framework::Tensor _roi_batch_list; + _roi_batch_list.Resize({rois_num}); + int* rois_lod = _roi_batch_list.mutable_data(ctx.GetPlace()); + int rois_batch_size = 1; + if (ctx.HasInput("RoisNum")) { + auto* rois_num_t = ctx.Input("RoisNum"); + rois_batch_size = rois_num_t->numel(); + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + platform::errors::InvalidArgument( + "The batch size of rois and the batch size of images " + " must be the same. But received the batch size of rois is %d, " + "and the batch size of images is %d", + rois_batch_size, batch_size)); + auto* rois_num_data = rois_num_t->data(); + rois_lod[0] = 0; + for (int n = 0; n < rois_batch_size; ++n) { + rois_lod[n + 1] = rois_lod[n] + rois_num_data[n]; + } + } else { + auto _rois_lod = rois->lod().back(); + rois_batch_size = _rois_lod.size() - 1; + for (int n = 0; n < _rois_lod.size(); ++n) { + rois_lod[n] = _rois_lod[n]; + } + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + platform::errors::InvalidArgument( + "The rois_batch_size and imgs batch_size of roi_align_xpu OP " + "must " + "be the same. But received rois_batch_size %d , batch_size %d", + rois_batch_size, batch_size)); + } int rois_num_with_lod = rois_lod[rois_batch_size]; PADDLE_ENFORCE_EQ( rois_num, rois_num_with_lod, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py index 813bbffefcb34848acab6f5ceaaa7f0318d76340..70f03edb6bac6eaee82ca795dc673e017898f853 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_roi_align_op_xpu.py @@ -179,5 +179,29 @@ class TestROIAlignOp(OpTest): self.check_output_with_place(place) +class TestROIAlignInLodOp(TestROIAlignOp): + def set_data(self): + self.init_test_case() + self.make_rois() + self.calc_roi_align() + + seq_len = self.rois_lod[0] + + self.inputs = { + 'X': self.x, + 'ROIs': (self.rois[:, 1:5], self.rois_lod), + 'RoisNum': np.asarray(seq_len).astype('int32') + } + + self.attrs = { + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width, + 'sampling_ratio': self.sampling_ratio + } + + self.outputs = {'Out': self.out_data} + + if __name__ == '__main__': unittest.main()