From 6ea3809143f86cf8de3e0c85dde39f36b8ff9d41 Mon Sep 17 00:00:00 2001 From: Double_V Date: Wed, 8 Jan 2020 11:27:06 +0800 Subject: [PATCH] Support prroi_pool_op with Tensor and LoDTensor rois (#20649) 1. Add a new input named batch_roi_nums for prroi_pool_op. batch_roi_nums includes the number of roi for each image in batch when rois is Tensor. This information is saved in rois's lod when rois is LoDTensor. 2. add grad check to prroi_pool_op and solve unnormal X grad diff in CPU. --- paddle/fluid/operators/prroi_pool_op.cc | 60 +++++-- paddle/fluid/operators/prroi_pool_op.cu | 120 ++++++++++---- paddle/fluid/operators/prroi_pool_op.h | 99 ++++++++---- python/paddle/fluid/layers/nn.py | 38 +++-- .../tests/unittests/test_prroi_pool_op.py | 150 ++++++++++++++++-- 5 files changed, 361 insertions(+), 106 deletions(-) diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index 85e15d0ed1b..b301b3a926f 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -39,6 +39,11 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { "where (x1, y1) is the top left coordinates, and " "(x2, y2) is the bottom right coordinates. " "The roi batch index can be calculated from LoD."); + AddInput("BatchRoINums", + "(Tensor), " + "1-D tensor with shape [N], the number of" + " rois for each image in batch, where N is the batch size") + .AsDispensable(); AddOutput("Out", "(Tensor), " "the output of PRROIPoolOp is a 4-D Tensor with shape " @@ -75,39 +80,57 @@ class PRROIPoolOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - "Input(X) of op(PRROIPool) should not be null."); + platform::errors::NotFound( + "Input(X) of op(PRROIPool) should not be null.")); PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true, - "Input(ROIs) of op(PRROIPool) should not be null."); + platform::errors::NotFound( + "Input(ROIs) of op(PRROIPool) should not be null.")); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - "Output(Out) of op(PRROIPool) should not be null."); + platform::errors::NotFound( + "Output(Out) of op(PRROIPool) should not be null.")); auto input_dims = ctx->GetInputDim("X"); auto rois_dims = ctx->GetInputDim("ROIs"); PADDLE_ENFORCE_EQ(input_dims.size(), 4, - "The format of input tensor is NCHW"); - PADDLE_ENFORCE_EQ(rois_dims.size(), 2, - "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " - "given as [(x1, y1, x2, y2), ...]"); - PADDLE_ENFORCE_EQ(rois_dims[1], 4, - "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " - "given as [(x1, y1, x2, y2), ...]"); - + platform::errors::InvalidArgument( + "The format of input tensor is NCHW")); + PADDLE_ENFORCE_EQ( + rois_dims.size(), 2, + platform::errors::InvalidArgument( + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]")); + PADDLE_ENFORCE_EQ( + rois_dims[1], 4, + platform::errors::InvalidArgument( + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]")); int pooled_height = ctx->Attrs().Get("pooled_height"); int pooled_width = ctx->Attrs().Get("pooled_width"); float spatial_scale = ctx->Attrs().Get("spatial_scale"); PADDLE_ENFORCE_GT(pooled_height, 0, - "The pooled output height must be greater than 0"); + platform::errors::InvalidArgument( + "The pooled output height must be greater than 0")); PADDLE_ENFORCE_GT(pooled_width, 0, - "The pooled output width must be greater than 0"); + platform::errors::InvalidArgument( + "The pooled output width must be greater than 0")); PADDLE_ENFORCE_GT(spatial_scale, 0.0f, - "The spatial scale must greater than 0."); + platform::errors::InvalidArgument( + "The spatial scale must greater than 0.")); auto out_dims = input_dims; out_dims[0] = rois_dims[0]; out_dims[1] = input_dims[1]; out_dims[2] = pooled_height; out_dims[3] = pooled_width; + + if (ctx->HasInput("BatchRoINums")) { + auto rois_batch_index = ctx->GetInputDim("BatchRoINums"); + PADDLE_ENFORCE_EQ(rois_batch_index[0], input_dims[0], + platform::errors::InvalidArgument( + "The length of BatchRoINums should equal to " + "first dim of inputs(X)")); + } ctx->SetOutputDim("Out", out_dims); } @@ -154,6 +177,7 @@ class PRROIPoolGradMaker : public framework::SingleGradOpMaker { op->SetInput("X", this->Input("X")); op->SetInput("Out", this->Output("Out")); op->SetInput("ROIs", this->Input("ROIs")); + op->SetInput("BatchRoINums", this->Input("BatchRoINums")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("ROIs"), this->InputGrad("ROIs")); @@ -172,8 +196,12 @@ REGISTER_OPERATOR(prroi_pool_grad, ops::PRROIPoolGradOp); REGISTER_OP_CPU_KERNEL( prroi_pool, ops::CPUPRROIPoolOpKernel, - ops::CPUPRROIPoolOpKernel); + ops::CPUPRROIPoolOpKernel, + ops::CPUPRROIPoolOpKernel, + ops::CPUPRROIPoolOpKernel); REGISTER_OP_CPU_KERNEL( prroi_pool_grad, ops::CPUPRROIPoolGradOpKernel, - ops::CPUPRROIPoolGradOpKernel); + ops::CPUPRROIPoolGradOpKernel, + ops::CPUPRROIPoolGradOpKernel, + ops::CPUPRROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/prroi_pool_op.cu b/paddle/fluid/operators/prroi_pool_op.cu index 35180dc91ae..caf6892a987 100644 --- a/paddle/fluid/operators/prroi_pool_op.cu +++ b/paddle/fluid/operators/prroi_pool_op.cu @@ -185,8 +185,8 @@ __global__ void GPUPRROIPoolBackward( PrRoIPoolingCoorBackward( s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale, - offset_in_data, offset_out_data, offset_input_grad_data, - offset_input_roi_grad_data, GPUAccumulateRois, + offset_in_data, offset_out_data, offset_input_roi_grad_data, + offset_output_grad_data, GPUAccumulateRois, [](const T x, const T y) { return max(x, y); }, [](const T x, const T y) { return min(x, y); }); } @@ -214,41 +214,66 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel { int rois_num = rois->dims()[0]; if (rois_num == 0) return; - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - "The rois_batch_size and input(X) batch_size must be the same."); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, - "The rois_num from input and lod must be the same."); - // set rois batch id framework::Tensor rois_batch_id_list; rois_batch_id_list.Resize({rois_num}); int* rois_batch_id_data = rois_batch_id_list.mutable_data(platform::CPUPlace()); - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; + + if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) { + auto* batchroinum = ctx.Input("BatchRoINums"); + framework::Tensor batch_index_cpu; + framework::TensorCopySync(*batchroinum, platform::CPUPlace(), + &batch_index_cpu); + + int rois_batch_size = batchroinum->dims()[0]; + auto* batch_index = batch_index_cpu.data(); + size_t c = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int64_t k = 0; k < batch_index[n]; ++k) { + rois_batch_id_data[c] = n; + c = c + 1; + } } - } - framework::Tensor rois_batch_id_list_gpu; - framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), - ctx.device_context(), &rois_batch_id_list_gpu); + } else { + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + platform::errors::InvalidArgument( + "The rois_batch_size and input(X) batch_size must be the same.")); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, rois_num_with_lod, + platform::errors::InvalidArgument( + "The rois_num from input and lod must be the same.")); + + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + } int output_size = out->numel(); int blocks = NumBlocks(output_size); int threads = kNumCUDAThreads; + auto cplace = platform::CPUPlace(); + auto& dev_ctx = ctx.cuda_device_context(); + int bytes = rois_batch_id_list.numel() * sizeof(int); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); + int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); + const auto gplace = boost::get(ctx.GetPlace()); + memory::Copy(gplace, roi_id_data, cplace, rois_batch_id_data, bytes, + dev_ctx.stream()); + // call cuda kernel function - GPUPRROIPoolForward< - T><<>>( + GPUPRROIPoolForward<<>>( output_size, in->data(), rois->data(), spatial_scale, input_channels, height, width, output_channels, pooled_height, - pooled_width, rois_batch_id_list_gpu.data(), - out->mutable_data(ctx.GetPlace())); + pooled_width, roi_id_data, out->mutable_data(ctx.GetPlace())); } }; @@ -275,23 +300,50 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel { int height = in->dims()[2]; int width = in->dims()[3]; - if (input_grad) { + if (input_grad || input_roi_grad) { // set roi batch id framework::Tensor rois_batch_id_list; rois_batch_id_list.Resize({rois_num}); int* rois_batch_id_data = rois_batch_id_list.mutable_data(platform::CPUPlace()); - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; + + if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) { + auto* batchroinum = ctx.Input("BatchRoINums"); + framework::Tensor batch_index_cpu; + framework::TensorCopySync(*batchroinum, platform::CPUPlace(), + &batch_index_cpu); + + int rois_batch_size = batchroinum->dims()[0]; + auto* batch_index = batch_index_cpu.data(); + size_t c = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int64_t k = 0; k < batch_index[n]; ++k) { + rois_batch_id_data[c] = n; + c = c + 1; + } + } + } else { + PADDLE_ENFORCE_EQ(rois->lod().empty(), false, + platform::errors::InvalidArgument( + "the lod of Input ROIs shoule not be empty when " + "BatchRoINums is None!")); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } } } - framework::Tensor rois_batch_id_list_gpu; - framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), - ctx.device_context(), &rois_batch_id_list_gpu); + auto cplace = platform::CPUPlace(); + auto& dev_ctx = ctx.cuda_device_context(); + int bytes = rois_batch_id_list.numel() * sizeof(int); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); + int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); + const auto gplace = boost::get(ctx.GetPlace()); + memory::Copy(gplace, roi_id_data, cplace, rois_batch_id_data, bytes, + dev_ctx.stream()); input_grad->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; @@ -304,12 +356,10 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel { int threads = kNumCUDAThreads; if (output_grad_size > 0) { - GPUPRROIPoolBackward< - T><<>>( + GPUPRROIPoolBackward<<>>( output_grad_size, in->data(), rois->data(), output_grad->data(), spatial_scale, input_channels, height, - width, output_channels, pooled_height, pooled_width, - rois_batch_id_list_gpu.data(), + width, output_channels, pooled_height, pooled_width, roi_id_data, input_grad->mutable_data(ctx.GetPlace()), out->data(), input_roi_grad->mutable_data(ctx.GetPlace())); } diff --git a/paddle/fluid/operators/prroi_pool_op.h b/paddle/fluid/operators/prroi_pool_op.h index 641309c730f..25f45d0b2c9 100644 --- a/paddle/fluid/operators/prroi_pool_op.h +++ b/paddle/fluid/operators/prroi_pool_op.h @@ -81,7 +81,7 @@ inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, const T coeff) { bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); if (!overflow) { - *(diff + h * width + w) = top_diff * coeff; + *(diff + h * width + w) += top_diff * coeff; } } @@ -179,7 +179,7 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( T win_start_h, T win_end_w, T win_end_h, int pw, int ph, const int pooled_width, const int pooled_height, T win_size, const float spatial_scale, const T* this_bottom_data, - const T* this_top_data, T* this_data_grad, T* this_out_grad, + const T* this_top_data, T* this_data_grad, const T* this_out_grad, Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) { T g_x1_y = 0.f; T g_x2_y = 0.f; @@ -232,20 +232,19 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( partial_y1 = partial_y1 / win_size * spatial_scale; partial_y2 = partial_y2 / win_size * spatial_scale; - this_data_grad[0] = 0; - functor(this_data_grad + 1, + functor(this_data_grad + 0, (partial_x1 * (1.0 - static_cast(pw) / pooled_width) + partial_x2 * (1.0 - static_cast(pw + 1) / pooled_width)) * (*this_out_grad)); - functor(this_data_grad + 2, + functor(this_data_grad + 1, (partial_y1 * (1.0 - static_cast(ph) / pooled_height) + partial_y2 * (1.0 - static_cast(ph + 1) / pooled_height)) * (*this_out_grad)); - functor(this_data_grad + 3, + functor(this_data_grad + 2, (partial_x2 * static_cast(pw + 1) / pooled_width + partial_x1 * static_cast(pw) / pooled_width) * (*this_out_grad)); - functor(this_data_grad + 4, + functor(this_data_grad + 3, (partial_y2 * static_cast(ph + 1) / pooled_height + partial_y1 * static_cast(ph) / pooled_height) * (*this_out_grad)); @@ -262,7 +261,6 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel { auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); auto spatial_scale = ctx.Attr("spatial_scale"); - auto in_dims = in->dims(); int batch_size = in_dims[0]; int input_channels = in_dims[1]; @@ -270,6 +268,7 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel { int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; + if (rois_num == 0) return; auto in_stride = framework::stride(in_dims); auto out_stride = framework::stride(out->dims()); @@ -280,26 +279,44 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel { rois_batch_id_list.Resize({rois_num}); int* rois_batch_id_data = rois_batch_id_list.mutable_data(ctx.GetPlace()); + if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) { + auto* batchroinum = ctx.Input("BatchRoINums"); + auto* batch_index = batchroinum->data(); + int rois_batch_size = batchroinum->dims()[0]; + size_t c = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int64_t k = 0; k < batch_index[n]; ++k) { + rois_batch_id_data[c] = n; + c = c + 1; + } + } + } else { + PADDLE_ENFORCE_EQ(rois->lod().empty(), false, + platform::errors::InvalidArgument( + "the lod of Input ROIs shoule not be empty when " + "BatchRoINums is None!")); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + platform::errors::InvalidArgument("the rois_batch_size and input(X) " + "batch_size should be the same.")); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num_with_lod, rois_num, + platform::errors::InvalidArgument( + "the rois_num from input and lod must be the same")); - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - "the rois_batch_size and input(X) batch_size should be the same."); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num, - "the rois_num from input and lod must be the same"); - - // calculate batch id index for each roi according to LoD - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } } } T* output_data = out->mutable_data(ctx.GetPlace()); const T* input_rois = rois->data(); - // calculate prroipooling, parallel processing can be implemented per ROI for (int n = 0; n < rois_num; ++n) { // set roi batch id @@ -390,7 +407,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { auto pooled_width = ctx.Attr("pooled_width"); auto spatial_scale = ctx.Attr("spatial_scale"); - if (input_grad && input_roi_grad) { + if (input_grad || input_roi_grad) { auto in_dims = in->dims(); auto* in_data = in->data(); auto* out_data = out->data(); @@ -406,24 +423,42 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { rois_batch_id_list.Resize({rois_num}); int* rois_batch_id_data = rois_batch_id_list.mutable_data(ctx.GetPlace()); - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - // calculate batch id index for each roi according to LoD - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - rois_batch_id_data[i] = n; + if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) { + auto* batchroinum = ctx.Input("BatchRoINums"); + auto* batch_index = batchroinum->data(); + int rois_batch_size = batchroinum->dims()[0]; + size_t c = 0; + for (int n = 0; n < rois_batch_size; ++n) { + for (int64_t k = 0; k < batch_index[n]; ++k) { + rois_batch_id_data[c] = n; + c = c + 1; + } + } + } else { + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } } } const T* input_rois = rois->data(); const T* output_grad_data = output_grad->data(); - T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - T* input_roi_grad_data = input_roi_grad->mutable_data(ctx.GetPlace()); + input_grad->mutable_data(ctx.GetPlace()); + input_roi_grad->mutable_data(ctx.GetPlace()); // set gradient of X to be 0. before backpropagate. math::SetConstant set_zero; set_zero(ctx.template device_context(), input_grad, static_cast(0)); + set_zero(ctx.template device_context(), input_roi_grad, + static_cast(0)); + + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + T* input_roi_grad_data = input_roi_grad->mutable_data(ctx.GetPlace()); // backpropagate gradient per output pixel int output_grad_size = output_grad->numel(); @@ -493,7 +528,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale, offset_in_data, offset_out_data, - offset_input_grad_data, offset_input_roi_grad_data, + offset_input_roi_grad_data, offset_output_grad_data, CPUAccumulateRois, [](const T x, const T y) { return std::max(x, y); }, [](const T x, const T y) { return std::min(x, y); }); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c1fdf12df73..f61a4d04163 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12655,35 +12655,53 @@ def prroi_pool(input, spatial_scale=1.0, pooled_height=1, pooled_width=1, + batch_roi_nums=None, name=None): """ - The precise roi pooling implementation for paddle?https://arxiv.org/pdf/1807.11590.pdf + The precise roi pooling implementation for paddle. Reference: https://arxiv.org/pdf/1807.11590.pdf Args: - input (Variable):The input of Deformable PSROIPooling.The shape of input tensor is + input (Variable):The input of precise roi pooliing.The shape of input tensor is [N,C,H,W]. Where N is batch size,C is number of input channels,H is height of the feature, and W is the width of the feature. rois (Variable): ROIs (Regions of Interest) to pool over.It should be - a 2-D LoDTensor of shape (num_rois, 4), the lod level - is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is + a 2-D LoDTensor or Tensor of shape (num_rois, 4), the lod level + is 1 when it is LoDTensor. The LoD include the rois's batch index + information. If rois is Tensor, its batch index information should + be provided by batch_index. + Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates. spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width). Equals the reciprocal of total stride in convolutional layers, Default: 1.0. pooled_height (integer): The pooled output height. Default: 1. pooled_width (integer): The pooled output width. Default: 1. + batch_roi_nums (Variable): The number of roi for each image in batch. It + shoule be 1-D Tensor, with shape [N] and dtype int64, + where N is the batch size. Default: None. Be note: The lod of input should be + empty when batch_roi_nums has values; name (str, default None): The name of this operation. Returns: - Variable(Tensor): The shape of the returned Tensor is (num_rois, output_channels, pooled_h, pooled_w), with value type float32,float16.. + Variable(Tensor):The shape of the returned Tensor is (N, C, pooled_height, pooled_width), with value type float32,float16. N, C denote batch_size and channels of input respectively. Examples: .. code-block:: python + ## prroi_pool without batch_roi_num import paddle.fluid as fluid - x = fluid.layers.data(name='x', shape=[490, 28, 28], dtype='float32') - rois = fluid.layers.data(name='rois', shape=[4], lod_level=1, dtype='float32') + x = fluid.data(name='x', shape=[None, 490, 28, 28], dtype='float32') + rois = fluid.data(name='rois', shape=[None, 4], lod_level=1, dtype='float32') pool_out = fluid.layers.prroi_pool(x, rois, 1.0, 7, 7) + + ## prroi_pool with batch_roi_num + batchsize=4 + x2 = fluid.data(name='x2', shape=[batchsize, 490, 28, 28], dtype='float32') + rois2 = fluid.data(name='rois2', shape=[batchsize, 4], dtype='float32') + batch_rois_num = fluid.data(name='rois_nums', shape=[batchsize], dtype='int64') + pool_out2 = fluid.layers.prroi_pool(x2, rois2, 1.0, 7, 7, batch_roi_nums=batch_rois_num) + + """ helper = LayerHelper('prroi_pool', **locals()) # check attrs @@ -12695,10 +12713,12 @@ def prroi_pool(input, raise TypeError("pooled_width must be int type") dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) + inputs_op = {'X': input, 'ROIs': rois} + if batch_roi_nums is not None: + inputs_op['BatchRoINums'] = batch_roi_nums helper.append_op( type='prroi_pool', - inputs={'X': input, - 'ROIs': rois}, + inputs=inputs_op, outputs={'Out': out}, attrs={ 'spatial_scale': spatial_scale, diff --git a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py index e3bfa062a3e..cf9d69247f1 100644 --- a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py @@ -19,6 +19,7 @@ import unittest from py_precise_roi_pool import PyPrRoIPool from op_test import OpTest import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid import compiler, Program, program_guard @@ -29,7 +30,7 @@ class TestPRROIPoolOp(OpTest): self.prRoIPool = PyPrRoIPool() self.outs = self.prRoIPool.compute( self.x, self.rois, self.output_channels, self.spatial_scale, - self.pooled_height, self.pooled_width).astype('float32') + self.pooled_height, self.pooled_width).astype('float64') self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)} self.attrs = { 'output_channels': self.output_channels, @@ -42,17 +43,17 @@ class TestPRROIPoolOp(OpTest): def init_test_case(self): self.batch_size = 3 self.channels = 3 * 2 * 2 - self.height = 6 - self.width = 4 + self.height = 12 + self.width = 16 self.x_dim = [self.batch_size, self.channels, self.height, self.width] - self.spatial_scale = 1.0 / 4.0 + self.spatial_scale = 1.0 / 2.0 self.output_channels = self.channels - self.pooled_height = 2 - self.pooled_width = 2 + self.pooled_height = 4 + self.pooled_width = 4 - self.x = np.random.random(self.x_dim).astype('float32') + self.x = np.random.random(self.x_dim).astype('float64') def make_rois(self): rois = [] @@ -72,7 +73,7 @@ class TestPRROIPoolOp(OpTest): roi = [bno, x1, y1, x2, y2] rois.append(roi) self.rois_num = len(rois) - self.rois = np.array(rois).astype('float32') + self.rois = np.array(rois).astype('float64') def setUp(self): self.op_type = 'prroi_pool' @@ -82,17 +83,20 @@ class TestPRROIPoolOp(OpTest): self.check_output() def test_backward(self): - for place in self._get_places(): - self._get_gradient(['X'], place, ["Out"], None) + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.check_grad_with_place(place, ['X'], 'Out') def run_net(self, place): with program_guard(Program(), Program()): x = fluid.layers.data( name="X", shape=[self.channels, self.height, self.width], - dtype="float32") + dtype="float64") rois = fluid.layers.data( - name="ROIs", shape=[4], dtype="float32", lod_level=1) + name="ROIs", shape=[4], dtype="float64", lod_level=1) output = fluid.layers.prroi_pool(x, rois, 0.25, 2, 2) loss = fluid.layers.mean(output) optimizer = fluid.optimizer.SGD(learning_rate=1e-3) @@ -116,9 +120,127 @@ class TestPRROIPoolOp(OpTest): def test_errors(self): with program_guard(Program(), Program()): x = fluid.layers.data( - name="x", shape=[245, 30, 30], dtype="float32") + name="x", shape=[245, 30, 30], dtype="float64") + rois = fluid.layers.data( + name="rois", shape=[4], dtype="float64", lod_level=1) + # spatial_scale must be float type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7, + 7) + # pooled_height must be int type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25, + 0.7, 7) + # pooled_width must be int type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25, + 7, 0.7) + + +class TestPRROIPoolOpTensorRoIs(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + self.prRoIPool = PyPrRoIPool() + self.outs = self.prRoIPool.compute( + self.x, self.rois, self.output_channels, self.spatial_scale, + self.pooled_height, self.pooled_width).astype('float64') + + self.rois_index = np.array(self.rois_lod).reshape([-1]).astype(np.int64) + self.inputs = { + 'X': self.x, + 'ROIs': self.rois[:, 1:5], + 'BatchRoINums': self.rois_index + } + self.attrs = { + 'output_channels': self.output_channels, + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width + } + self.outputs = {'Out': self.outs} + + def init_test_case(self): + self.batch_size = 3 + self.channels = 3 * 2 * 2 + self.height = 12 + self.width = 16 + + self.x_dim = [self.batch_size, self.channels, self.height, self.width] + + self.spatial_scale = 1.0 / 2.0 + self.output_channels = self.channels + self.pooled_height = 4 + self.pooled_width = 4 + + self.x = np.random.random(self.x_dim).astype('float64') + + def make_rois(self): + rois = [] + self.rois_lod = [] + for bno in range(self.batch_size): + self.rois_lod.append(bno + 1) + for i in range(bno + 1): + x1 = np.random.uniform( + 0, self.width // self.spatial_scale - self.pooled_width) + y1 = np.random.uniform( + 0, self.height // self.spatial_scale - self.pooled_height) + + x2 = np.random.uniform(x1 + self.pooled_width, + self.width // self.spatial_scale) + y2 = np.random.uniform(y1 + self.pooled_height, + self.height // self.spatial_scale) + roi = [bno, x1, y1, x2, y2] + rois.append(roi) + self.rois_num = len(rois) + self.rois = np.array(rois).astype('float64') + + def setUp(self): + self.op_type = 'prroi_pool' + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_backward(self): + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.check_grad_with_place(place, ['X'], 'Out') + + def run_net(self, place): + with program_guard(Program(), Program()): + x = fluid.layers.data( + name="X", + shape=[self.channels, self.height, self.width], + dtype="float64") + rois = fluid.layers.data(name="ROIs", shape=[4], dtype="float64") + rois_index = fluid.layers.data( + name='rois_idx', shape=[], dtype="int64") + output = fluid.layers.prroi_pool( + x, rois, 0.25, 2, 2, batch_roi_nums=rois_index) + loss = fluid.layers.mean(output) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + exe.run(fluid.default_main_program(), { + 'X': self.x, + "ROIs": self.rois[:, 1:5], + "rois_idx": self.rois_index + }) + + def test_net(self): + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.run_net(place) + + def test_errors(self): + with program_guard(Program(), Program()): + x = fluid.layers.data( + name="x", shape=[245, 30, 30], dtype="float64") rois = fluid.layers.data( - name="rois", shape=[4], dtype="float32", lod_level=1) + name="rois", shape=[4], dtype="float64", lod_level=1) # spatial_scale must be float type self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7, 7) -- GitLab