From 4a5bfa89c342771688f5d62dc2156df85933af50 Mon Sep 17 00:00:00 2001 From: dyning Date: Fri, 27 Apr 2018 10:36:15 +0800 Subject: [PATCH] Modify RoI pooling op to use LoDTensor and expose it into Python API (#10208) * modify roi pool with lod and expose ROI Pooling into Python API * make lod code brief * make doc more clearly * make doc more clearly --- doc/fluid/api/layers.rst | 9 +++ paddle/fluid/operators/roi_pool_op.cc | 17 ++-- paddle/fluid/operators/roi_pool_op.cu | 79 ++++++++++++++----- paddle/fluid/operators/roi_pool_op.h | 63 ++++++++++----- python/paddle/fluid/layers/nn.py | 51 ++++++++++++ .../fluid/tests/unittests/test_layers.py | 10 +++ .../fluid/tests/unittests/test_roi_pool_op.py | 37 +++++---- 7 files changed, 201 insertions(+), 65 deletions(-) diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst index 3790f09c8..ff3c9346a 100644 --- a/doc/fluid/api/layers.rst +++ b/doc/fluid/api/layers.rst @@ -479,6 +479,13 @@ label_smooth .. autofunction:: paddle.fluid.layers.label_smooth :noindex: +roi_pool +--------- + +.. autofunction:: paddle.fluid.layers.roi_pool + :noindex: + + ops === @@ -820,3 +827,5 @@ topk .. autofunction:: paddle.fluid.layers.topk :noindex: + + diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index 224ec93d2..397e49ef2 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -18,8 +18,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; - -static constexpr int kROISize = 5; +using LoDTensor = framework::LoDTensor; class ROIPoolOp : public framework::OperatorWithKernel { public: @@ -40,11 +39,11 @@ class ROIPoolOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(input_dims.size() == 4, "The format of input tensor is NCHW."); PADDLE_ENFORCE(rois_dims.size() == 2, - "ROIs should be a 2-D tensor of shape (num_rois, 5)" - "given as [[batch_id, x1, y1, x2, y2], …]."); + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" + "given as [[x1, y1, x2, y2], …]."); PADDLE_ENFORCE(rois_dims[1] == kROISize, - "ROIs should be a 2-D tensor of shape (num_rois, 5)" - "given as [[batch_id, x1, y1, x2, y2], …]."); + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" + "given as [[x1, y1, x2, y2], …]."); int pooled_height = ctx->Attrs().Get("pooled_height"); int pooled_width = ctx->Attrs().Get("pooled_width"); @@ -109,10 +108,10 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { "H is the height of the feature, and " "W is the width of the feature."); AddInput("ROIs", - "(Tensor), " + "(LoDTensor), " "ROIs (Regions of Interest) to pool over. " - "should be a 2-D tensor of shape (num_rois, 5)" - "given as [[batch_id, x1, y1, x2, y2], …]. " + "should be a 2-D LoDTensor of shape (num_rois, 4)" + "given as [[x1, y1, x2, y2], …]. " "Where batch_id is the id of the data, " "(x1, y1) is the top left coordinates, and " "(x2, y2) is the bottom right coordinates."); diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index 1931629d1..0bdfee043 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -19,10 +19,10 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaxinumNumBlocks = 4096; -static constexpr int kROISize = 5; static inline int NumBlocks(const int N) { return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, @@ -30,13 +30,11 @@ static inline int NumBlocks(const int N) { } template -__global__ void GPUROIPoolForward(const int nthreads, const T* input_data, - const int64_t* input_rois, - const float spatial_scale, const int channels, - const int height, const int width, - const int pooled_height, - const int pooled_width, T* output_data, - int64_t* argmax_data) { +__global__ void GPUROIPoolForward( + const int nthreads, const T* input_data, const int64_t* input_rois, + const float spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + int* roi_batch_id_data, T* output_data, int64_t* argmax_data) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (size_t i = index; i < nthreads; i += offset) { @@ -46,11 +44,11 @@ __global__ void GPUROIPoolForward(const int nthreads, const T* input_data, int n = index / pooled_width / pooled_height / channels; const int64_t* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = offset_input_rois[0]; - int roi_start_w = round(offset_input_rois[1] * spatial_scale); - int roi_start_h = round(offset_input_rois[2] * spatial_scale); - int roi_end_w = round(offset_input_rois[3] * spatial_scale); - int roi_end_h = round(offset_input_rois[4] * spatial_scale); + int roi_batch_ind = roi_batch_id_data[n]; + int roi_start_w = round(offset_input_rois[0] * spatial_scale); + int roi_start_h = round(offset_input_rois[1] * spatial_scale); + int roi_end_w = round(offset_input_rois[2] * spatial_scale); + int roi_end_h = round(offset_input_rois[3] * spatial_scale); int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1); @@ -93,7 +91,8 @@ __global__ void GPUROIPoolBackward( const int nthreads, const int64_t* input_rois, const T* output_grad, const int64_t* argmax_data, const int num_rois, const float spatial_scale, const int channels, const int height, const int width, - const int pooled_height, const int pooled_width, T* input_grad) { + const int pooled_height, const int pooled_width, int* roi_batch_id_data, + T* input_grad) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (int i = index; i < nthreads; i += offset) { @@ -102,8 +101,7 @@ __global__ void GPUROIPoolBackward( int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const int64_t* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = offset_input_rois[0]; + int roi_batch_ind = roi_batch_id_data[n]; int input_offset = (roi_batch_ind * channels + c) * height * width; int output_offset = (n * channels + c) * pooled_height * pooled_width; const T* offset_output_grad = output_grad + output_offset; @@ -124,7 +122,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); + auto* rois = ctx.Input("ROIs"); auto* out = ctx.Output("Out"); auto* argmax = ctx.Output("Argmax"); @@ -133,23 +131,46 @@ class GPUROIPoolOpKernel : public framework::OpKernel { auto spatial_scale = ctx.Attr("spatial_scale"); auto in_dims = in->dims(); + int batch_size = in_dims[0]; auto in_stride = framework::stride(in_dims); int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; - size_t rois_num = rois->dims()[0]; + int rois_num = rois->dims()[0]; if (rois_num == 0) return; int output_size = out->numel(); int blocks = NumBlocks(output_size); int threads = kNumCUDAThreads; + framework::Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(platform::CPUPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "The rois_batch_size and imgs batch_size must be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, + "The rois_num from input and lod must be the same."); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + + framework::Tensor roi_batch_id_list_gpu; + framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(), + ctx.device_context(), &roi_batch_id_list_gpu); + GPUROIPoolForward< T><<>>( output_size, in->data(), rois->data(), spatial_scale, channels, height, width, pooled_height, pooled_width, - out->mutable_data(ctx.GetPlace()), + roi_batch_id_list_gpu.data(), out->mutable_data(ctx.GetPlace()), argmax->mutable_data(ctx.GetPlace())); } }; @@ -159,7 +180,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); + auto* rois = ctx.Input("ROIs"); auto* argmax = ctx.Input("Argmax"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); @@ -169,12 +190,27 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { auto pooled_width = ctx.Attr("pooled_width"); auto spatial_scale = ctx.Attr("spatial_scale"); - size_t rois_num = rois->dims()[0]; + int rois_num = rois->dims()[0]; int channels = in->dims()[1]; int height = in->dims()[2]; int width = in->dims()[3]; if (x_grad) { + framework::Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(platform::CPUPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + framework::Tensor roi_batch_id_list_gpu; + framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(), + ctx.device_context(), &roi_batch_id_list_gpu); + x_grad->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; set_zero(ctx.cuda_device_context(), x_grad, static_cast(0)); @@ -189,6 +225,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { output_grad_size, rois->data(), out_grad->data(), argmax->data(), rois_num, spatial_scale, channels, height, width, pooled_height, pooled_width, + roi_batch_id_list_gpu.data(), x_grad->mutable_data(ctx.GetPlace())); } } diff --git a/paddle/fluid/operators/roi_pool_op.h b/paddle/fluid/operators/roi_pool_op.h index 54e074903..c4f739b2c 100644 --- a/paddle/fluid/operators/roi_pool_op.h +++ b/paddle/fluid/operators/roi_pool_op.h @@ -21,12 +21,14 @@ limitations under the License. */ namespace paddle { namespace operators { +static constexpr int kROISize = 4; + template class CPUROIPoolOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); + auto* rois = ctx.Input("ROIs"); auto* out = ctx.Output("Out"); auto* argmax = ctx.Output("Argmax"); @@ -47,24 +49,36 @@ class CPUROIPoolOpKernel : public framework::OpKernel { auto out_stride = framework::stride(out->dims()); const T* input_data = in->data(); - const int64_t* rois_data = rois->data(); - T* output_data = out->mutable_data(ctx.GetPlace()); - int64_t* argmax_data = argmax->mutable_data(ctx.GetPlace()); - for (int n = 0; n < rois_num; ++n) { - int roi_batch_id = rois_data[0]; - PADDLE_ENFORCE_GE(roi_batch_id, 0); - PADDLE_ENFORCE_LT(roi_batch_id, batch_size); - rois_data += roi_stride[0]; + framework::Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(ctx.GetPlace()); + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "The rois_batch_size and imgs batch_size must be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, + "The rois_num from input and lod must be the same."); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } } - rois_data = rois->data(); + T* output_data = out->mutable_data(ctx.GetPlace()); + int64_t* argmax_data = argmax->mutable_data(ctx.GetPlace()); + + const int64_t* rois_data = rois->data(); for (int n = 0; n < rois_num; ++n) { - int roi_batch_id = rois_data[0]; - int roi_start_w = round(rois_data[1] * spatial_scale); - int roi_start_h = round(rois_data[2] * spatial_scale); - int roi_end_w = round(rois_data[3] * spatial_scale); - int roi_end_h = round(rois_data[4] * spatial_scale); + int roi_batch_id = roi_batch_id_data[n]; + int roi_start_w = round(rois_data[0] * spatial_scale); + int roi_start_h = round(rois_data[1] * spatial_scale); + int roi_end_w = round(rois_data[2] * spatial_scale); + int roi_end_h = round(rois_data[3] * spatial_scale); // Force malformed ROIs to be 1x1 int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); @@ -133,7 +147,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); + auto* rois = ctx.Input("ROIs"); auto* argmax = ctx.Input("Argmax"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); @@ -143,6 +157,20 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { auto pooled_width = ctx.Attr("pooled_width"); if (in_grad) { + int rois_num = rois->dims()[0]; + framework::Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(ctx.GetPlace()); + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + const int64_t* rois_data = rois->data(); const T* out_grad_data = out_grad->data(); const int64_t* argmax_data = argmax->data(); @@ -156,11 +184,10 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims()); - int rois_num = rois->dims()[0]; int channels = in->dims()[1]; for (int n = 0; n < rois_num; ++n) { - int roi_batch_idx = rois_data[0]; + int roi_batch_idx = roi_batch_id_data[n]; T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0]; for (int c = 0; c < channels; ++c) { for (int ph = 0; ph < pooled_height; ++ph) { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9a0c32803..7f16bf2a0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -79,6 +79,7 @@ __all__ = [ 'lrn', 'pad', 'label_smooth', + 'roi_pool', ] @@ -3759,3 +3760,53 @@ def label_smooth(label, outputs={"Out": smooth_label}, attrs={"epsilon": float(epsilon)}) return smooth_label + + +def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): + """ + Region of interest pooling (also known as RoI pooling) is to perform + is to perform max pooling on inputs of nonuniform sizes to obtain + fixed-size feature maps (e.g. 7*7). + The operator has three steps: + 1. Dividing each region proposal into equal-sized sections with + the pooled_width and pooled_height + 2. Finding the largest value in each section + 3. Copying these max values to the output buffer + + Args: + input (Variable): The input for ROI pooling. + rois (Variable): ROIs (Regions of Interest) to pool over. It should + be a 2-D one level LoTensor of shape [num_rois, 4]. + The layout is [x1, y1, x2, y2], where (x1, y1) + is the top left coordinates, and (x2, y2) is the + bottom right coordinates. The num_rois is the + total number of ROIs in this batch data. + pooled_height (integer): The pooled output height. Default: 1 + pooled_width (integer): The pooled output width. Default: 1 + spatial_scale (float): Multiplicative spatial scale factor. To + translate ROI coords from their input scale + to the scale used when pooling. Default: 1.0 + + Returns: + pool_out (Variable): The output is a 4-D tensor of the shape + (num_rois, channels, pooled_h, pooled_w). + + Examples: + pool_out = fluid.layers.roi_pool(input=x, rois=rois, 7, 7, 1.0) + """ + helper = LayerHelper('roi_pool', **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_tmp_variable(dtype) + argmaxes = helper.create_tmp_variable(dtype='int32') + helper.append_op( + type="roi_pool", + inputs={"X": input, + "ROIs": rois}, + outputs={"Out": pool_out, + "Argmax": argmaxes}, + attrs={ + "pooled_height": pooled_height, + "pooled_width": pooled_width, + "spatial_scale": spatial_scale + }) + return pool_out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 17d6afdee..c5414abf0 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -359,6 +359,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(indices) print(str(program)) + def test_roi_pool(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + output = layers.roi_pool(x, rois, 7, 7, 0.6) + self.assertIsNotNone(output) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py index e556d51b0..3d754aff3 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py @@ -25,7 +25,7 @@ class TestROIPoolOp(OpTest): self.make_rois() self.calc_roi_pool() - self.inputs = {'X': self.x, 'ROIs': self.rois} + self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)} self.attrs = { 'spatial_scale': self.spatial_scale, @@ -36,7 +36,7 @@ class TestROIPoolOp(OpTest): self.outputs = {'Out': self.outs, 'Argmax': self.argmaxes} def init_test_case(self): - self.batch_size = 5 + self.batch_size = 3 self.channels = 3 self.height = 6 self.width = 4 @@ -47,7 +47,6 @@ class TestROIPoolOp(OpTest): self.spatial_scale = 1.0 / 4.0 self.pooled_height = 2 self.pooled_width = 2 - self.rois_num = 2 self.x = np.random.random(self.x_dim).astype('float32') @@ -106,20 +105,24 @@ class TestROIPoolOp(OpTest): def make_rois(self): rois = [] - batch_ids = np.random.randint(0, self.batch_size, size=self.rois_num) - for i in range(self.rois_num): - x1 = np.random.random_integers( - 0, self.width / self.spatial_scale - self.pooled_width) - y1 = np.random.random_integers( - 0, self.height / self.spatial_scale - self.pooled_height) - - x2 = np.random.random_integers(x1 + self.pooled_width, - self.width / self.spatial_scale) - y2 = np.random.random_integers(y1 + self.pooled_height, - self.height / self.spatial_scale) - - roi = [batch_ids[i], x1, y1, x2, y2] - rois.append(roi) + self.rois_lod = [[]] + for bno in range(self.batch_size): + self.rois_lod[0].append(len(rois)) + for i in range(bno + 1): + x1 = np.random.random_integers( + 0, self.width / self.spatial_scale - self.pooled_width) + y1 = np.random.random_integers( + 0, self.height / self.spatial_scale - self.pooled_height) + + x2 = np.random.random_integers(x1 + self.pooled_width, + self.width / self.spatial_scale) + y2 = np.random.random_integers(y1 + self.pooled_height, + self.height / self.spatial_scale) + + roi = [bno, x1, y1, x2, y2] + rois.append(roi) + self.rois_lod[0].append(len(rois)) + self.rois_num = len(rois) self.rois = np.array(rois).astype("int64") def setUp(self): -- GitLab