From 3044a62f2ab820c58e62b32a5afd850e0dd56892 Mon Sep 17 00:00:00 2001 From: wopeizl Date: Fri, 11 Oct 2019 16:21:38 +0800 Subject: [PATCH] fix the precise roi poop op test=develop (#20126) * fix the precise roi poop op test=develop add roi backward implementation, fix the output-channel --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/prroi_pool_op.cc | 21 +- paddle/fluid/operators/prroi_pool_op.cu | 54 +++-- paddle/fluid/operators/prroi_pool_op.h | 184 ++++++++++++++++-- python/paddle/fluid/layers/nn.py | 7 +- .../tests/unittests/py_precise_roi_pool.py | 3 +- .../tests/unittests/test_prroi_pool_op.py | 32 ++- 7 files changed, 221 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3a438f2d009..20be1488ad3 100755 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -290,7 +290,7 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'd5945431cdcae3cda21914db5bbf383e')) paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '231f91231430f5dae2b757df22317c67')) paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9bf0cc6b0717010b8ceec5dc2541d566')) -paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303')) +paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '466be691ac4c1cd7b88fccb40846afce')) paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', 'b0e07aa41caae04b07a8e8217cc96020')) paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '9d93ee81f7a3e526d68bb280bc695d6c')) paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '45f3ebbcb766fca84cb2fe6307086573')) diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index 6d5129f8d60..5c559bda339 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -43,12 +43,6 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor), " "the output of PRROIPoolOp is a 4-D Tensor with shape " "(num_rois, output_channels, pooled_h, pooled_w)."); - AddAttr( - "output_channels", - "(int), " - "the number of channels of the output feature map. " - "For a task of C classes of objects, output_channels should be " - "(C + 1) for classification only."); AddAttr("spatial_scale", "(float, default 1.0), " "Multiplicative spatial scale factor " @@ -100,28 +94,18 @@ class PRROIPoolOp : public framework::OperatorWithKernel { int pooled_height = ctx->Attrs().Get("pooled_height"); int pooled_width = ctx->Attrs().Get("pooled_width"); - int output_channels = ctx->Attrs().Get("output_channels"); float spatial_scale = ctx->Attrs().Get("spatial_scale"); - PADDLE_ENFORCE_EQ( - input_dims[1], output_channels * pooled_height * pooled_width, - "the channel of X(%d) should be equal to the product of " - "output_channels(%d), pooled_height(%d) and pooled_width(%d)", - input_dims[1], output_channels, pooled_height, pooled_width); - PADDLE_ENFORCE_GT(pooled_height, 0, "The pooled output height must be greater than 0"); PADDLE_ENFORCE_GT(pooled_width, 0, "The pooled output width must be greater than 0"); - PADDLE_ENFORCE_GT(output_channels, 1, - "The pooled output channels must greater than 1"); PADDLE_ENFORCE_GT(spatial_scale, 0.0f, "The spatial scale must greater than 0."); auto out_dims = input_dims; out_dims[0] = rois_dims[0]; - out_dims[1] = - output_channels; // input_dims[1] / (pooled_height * pooled_width); + out_dims[1] = input_dims[1]; out_dims[2] = pooled_height; out_dims[3] = pooled_width; ctx->SetOutputDim("Out", out_dims); @@ -145,6 +129,7 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, "The gradient of X should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->SetOutputDim(framework::GradVarName("ROIs"), ctx->GetInputDim("ROIs")); } protected: @@ -164,9 +149,11 @@ class PRROIPoolGradDescMaker : public framework::SingleGradOpDescMaker { std::unique_ptr op(new framework::OpDesc()); op->SetType("prroi_pool_grad"); op->SetInput("X", Input("X")); + op->SetInput("Out", Output("Out")); op->SetInput("ROIs", Input("ROIs")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("ROIs"), InputGrad("ROIs")); op->SetAttrMap(Attrs()); return op; } diff --git a/paddle/fluid/operators/prroi_pool_op.cu b/paddle/fluid/operators/prroi_pool_op.cu index 915e3daae53..35180dc91ae 100644 --- a/paddle/fluid/operators/prroi_pool_op.cu +++ b/paddle/fluid/operators/prroi_pool_op.cu @@ -40,6 +40,11 @@ DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff, } } +template +DEVICE void GPUAccumulateRois(T* offset, T data) { + paddle::platform::CudaAtomicAdd(offset, data); +} + template __global__ void GPUPRROIPoolForward( const int nthreads, const T* input_data, const T* input_rois, @@ -78,7 +83,7 @@ __global__ void GPUPRROIPoolForward( T win_end_h = win_start_h + bin_size_h; T win_size = max(static_cast(0.0), bin_size_w * bin_size_h); - int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_channel = c; const T* offset_input_data = input_data + (roi_batch_id * input_channels + input_channel) * height * width; @@ -110,10 +115,12 @@ __global__ void GPUPRROIPoolForward( template __global__ void GPUPRROIPoolBackward( - const int nthreads, const T* input_rois, const T* output_grad_data, - const float spatial_scale, const int input_channels, const int height, - const int width, const int output_channels, const int pooled_height, - const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) { + const int nthreads, const T* in_data, const T* input_rois, + const T* output_grad_data, const float spatial_scale, + const int input_channels, const int height, const int width, + const int output_channels, const int pooled_height, const int pooled_width, + const int* rois_batch_id_data, T* input_grad_data, const T* out_data, + T* input_roi_grad_data) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; for (int i = index; i < nthreads; i += offset) { @@ -125,7 +132,7 @@ __global__ void GPUPRROIPoolBackward( // set roi_batch_id int roi_batch_id = rois_batch_id_data[n]; - int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_channel = c; int input_offset = (roi_batch_id * input_channels + input_channel) * height * width; T* offset_input_grad_data = input_grad_data + input_offset; @@ -137,6 +144,7 @@ __global__ void GPUPRROIPoolBackward( T roi_start_h = static_cast(offset_input_rois[1]) * spatial_scale; T roi_end_w = static_cast(offset_input_rois[2]) * spatial_scale; T roi_end_h = static_cast(offset_input_rois[3]) * spatial_scale; + T* offset_input_roi_grad_data = input_roi_grad_data + n * 4; T roi_width = max(roi_end_w - roi_start_w, static_cast(0.0)); T roi_height = max(roi_end_h - roi_start_h, static_cast(0.0)); @@ -171,6 +179,16 @@ __global__ void GPUPRROIPoolBackward( height, width, PrRoIPoolingDistributeDiffCUDA); } } + + const T* offset_out_data = out_data + i; + const T* offset_in_data = in_data + input_offset; + PrRoIPoolingCoorBackward( + s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w, + win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale, + offset_in_data, offset_out_data, offset_input_grad_data, + offset_input_roi_grad_data, GPUAccumulateRois, + [](const T x, const T y) { return max(x, y); }, + [](const T x, const T y) { return min(x, y); }); } } @@ -184,20 +202,15 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel { auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); - auto output_channels = ctx.Attr("output_channels"); auto spatial_scale = ctx.Attr("spatial_scale"); auto in_dims = in->dims(); int batch_size = in_dims[0]; int input_channels = in_dims[1]; + auto output_channels = input_channels; int height = in_dims[2]; int width = in_dims[3]; - PADDLE_ENFORCE_EQ(input_channels, - output_channels * pooled_height * pooled_width, - "the channels of input X should equal the product of " - "output_channels x pooled_height x pooled_width"); - int rois_num = rois->dims()[0]; if (rois_num == 0) return; @@ -245,17 +258,20 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Input("Out"); auto* output_grad = ctx.Input(framework::GradVarName("Out")); auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* input_roi_grad = + ctx.Output(framework::GradVarName("ROIs")); auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); - auto output_channels = ctx.Attr("output_channels"); auto spatial_scale = ctx.Attr("spatial_scale"); int rois_num = rois->dims()[0]; int input_channels = in->dims()[1]; + auto output_channels = input_channels; int height = in->dims()[2]; int width = in->dims()[3]; @@ -280,6 +296,8 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel { input_grad->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; set_zero(ctx.cuda_device_context(), input_grad, static_cast(0)); + input_roi_grad->mutable_data(ctx.GetPlace()); + set_zero(ctx.cuda_device_context(), input_roi_grad, static_cast(0)); int output_grad_size = output_grad->numel(); int blocks = NumBlocks(output_grad_size); @@ -288,10 +306,12 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel { if (output_grad_size > 0) { GPUPRROIPoolBackward< T><<>>( - output_grad_size, rois->data(), output_grad->data(), - spatial_scale, input_channels, height, width, output_channels, - pooled_height, pooled_width, rois_batch_id_list_gpu.data(), - input_grad->mutable_data(ctx.GetPlace())); + output_grad_size, in->data(), rois->data(), + output_grad->data(), spatial_scale, input_channels, height, + width, output_channels, pooled_height, pooled_width, + rois_batch_id_list_gpu.data(), + input_grad->mutable_data(ctx.GetPlace()), out->data(), + input_roi_grad->mutable_data(ctx.GetPlace())); } } } diff --git a/paddle/fluid/operators/prroi_pool_op.h b/paddle/fluid/operators/prroi_pool_op.h index 621e543fab5..641309c730f 100644 --- a/paddle/fluid/operators/prroi_pool_op.h +++ b/paddle/fluid/operators/prroi_pool_op.h @@ -21,19 +21,20 @@ namespace paddle { namespace operators { template -HOSTDEVICE T PrRoIPoolingGetData(const T* data, const int h, const int w, - const int height, const int width) { +inline HOSTDEVICE T PrRoIPoolingGetData(const T* data, const int h, const int w, + const int height, const int width) { bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); T retVal = overflow ? 0.0f : data[h * width + w]; return retVal; } template -HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, const int s_h, - const int s_w, const int e_h, - const int e_w, const T y0, const T x0, - const T y1, const T x1, const int h0, - const int w0) { +inline HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, + const int s_h, const int s_w, + const int e_h, const int e_w, + const T y0, const T x0, + const T y1, const T x1, + const int h0, const int w0) { T alpha, beta, lim_alpha, lim_beta, tmp; T sum_out = 0; @@ -73,10 +74,11 @@ HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, const int s_h, } template -HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, - const int h, const int w, - const int height, const int width, - const T coeff) { +inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, + const int h, const int w, + const int height, + const int width, + const T coeff) { bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); if (!overflow) { *(diff + h * width + w) = top_diff * coeff; @@ -123,6 +125,132 @@ HOSTDEVICE void PrRoIPoolingMatDistributeDiff( functor(diff, top_diff, e_h, e_w, h0, w0, tmp); } +template +inline HOSTDEVICE void CPUAccumulateRois(T* offset, T data) { + *offset += data; +} + +template +inline HOSTDEVICE static T PrRoIPoolingGetCoeff(T dh, T dw) { + dw = dw > 0 ? dw : -dw; + dh = dh > 0 ? dh : -dh; + return (1.0f - dh) * (1.0f - dw); +} + +template +inline HOSTDEVICE static T PrRoIPoolingInterpolation(const T* data, const H h, + const W w, + const int height, + const int width) { + T retVal = 0.0f; + int h1 = floorf(h); + int w1 = floorf(w); + retVal += + PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - static_cast(h1), w - static_cast(w1)); + h1 = floorf(h) + 1; + w1 = floorf(w); + retVal += + PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - static_cast(h1), w - static_cast(w1)); + h1 = floorf(h); + w1 = floorf(w) + 1; + retVal += + PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - static_cast(h1), w - static_cast(w1)); + h1 = floorf(h) + 1; + w1 = floorf(w) + 1; + retVal += + PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - static_cast(h1), w - static_cast(w1)); + return retVal; +} + +template +inline HOSTDEVICE T PrRoIPoolingSingleCoorIntegral(T s, T t, T c1, T c2) { + return 0.5f * (t * t - s * s) * c2 + + (t - 0.5f * t * t - s + 0.5f * s * s) * c1; +} + +template +inline HOSTDEVICE void PrRoIPoolingCoorBackward( + int s_w, int e_w, int s_h, int e_h, int width, int height, T win_start_w, + T win_start_h, T win_end_w, T win_end_h, int pw, int ph, + const int pooled_width, const int pooled_height, T win_size, + const float spatial_scale, const T* this_bottom_data, + const T* this_top_data, T* this_data_grad, T* this_out_grad, + Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) { + T g_x1_y = 0.f; + T g_x2_y = 0.f; + T g_x_y1 = 0.f; + T g_x_y2 = 0.f; + + for (int h_iter = s_h; h_iter < e_h; ++h_iter) { + g_x1_y += PrRoIPoolingSingleCoorIntegral( + maxFunctor(win_start_h, static_cast(h_iter)) - h_iter, + minFunctor(win_end_h, static_cast(h_iter + 1)) - h_iter, + PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height, + width), + PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w, + height, width)); + + g_x2_y += PrRoIPoolingSingleCoorIntegral( + maxFunctor(win_start_h, static_cast(h_iter)) - h_iter, + minFunctor(win_end_h, static_cast(h_iter + 1)) - h_iter, + PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height, + width), + PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w, + height, width)); + } + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) { + g_x_y1 += PrRoIPoolingSingleCoorIntegral( + maxFunctor(win_start_w, static_cast(w_iter)) - w_iter, + minFunctor(win_end_w, static_cast(w_iter + 1)) - w_iter, + PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height, + width), + PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1, + height, width)); + + g_x_y2 += PrRoIPoolingSingleCoorIntegral( + maxFunctor(win_start_w, static_cast(w_iter)) - w_iter, + minFunctor(win_end_w, static_cast(w_iter + 1)) - w_iter, + PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height, + width), + PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1, + height, width)); + } + + float partial_x1 = -g_x1_y + (win_end_h - win_start_h) * (*this_top_data); + float partial_y1 = -g_x_y1 + (win_end_w - win_start_w) * (*this_top_data); + float partial_x2 = g_x2_y - (win_end_h - win_start_h) * (*this_top_data); + float partial_y2 = g_x_y2 - (win_end_w - win_start_w) * (*this_top_data); + + partial_x1 = partial_x1 / win_size * spatial_scale; + partial_x2 = partial_x2 / win_size * spatial_scale; + partial_y1 = partial_y1 / win_size * spatial_scale; + partial_y2 = partial_y2 / win_size * spatial_scale; + + this_data_grad[0] = 0; + functor(this_data_grad + 1, + (partial_x1 * (1.0 - static_cast(pw) / pooled_width) + + partial_x2 * (1.0 - static_cast(pw + 1) / pooled_width)) * + (*this_out_grad)); + functor(this_data_grad + 2, + (partial_y1 * (1.0 - static_cast(ph) / pooled_height) + + partial_y2 * (1.0 - static_cast(ph + 1) / pooled_height)) * + (*this_out_grad)); + functor(this_data_grad + 3, + (partial_x2 * static_cast(pw + 1) / pooled_width + + partial_x1 * static_cast(pw) / pooled_width) * + (*this_out_grad)); + functor(this_data_grad + 4, + (partial_y2 * static_cast(ph + 1) / pooled_height + + partial_y1 * static_cast(ph) / pooled_height) * + (*this_out_grad)); +} + template class CPUPRROIPoolOpKernel : public framework::OpKernel { public: @@ -134,11 +262,11 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel { auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); auto spatial_scale = ctx.Attr("spatial_scale"); - auto output_channels = ctx.Attr("output_channels"); auto in_dims = in->dims(); int batch_size = in_dims[0]; int input_channels = in_dims[1]; + auto output_channels = input_channels; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; @@ -162,11 +290,6 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num, "the rois_num from input and lod must be the same"); - PADDLE_ENFORCE_EQ(input_channels, - output_channels * pooled_height * pooled_width, - "the channels of input X should equal the product of " - "output_channels x pooled_height x pooled_width"); - // calculate batch id index for each roi according to LoD for (int n = 0; n < rois_batch_size; ++n) { for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { @@ -217,7 +340,7 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel { int e_h = std::ceil(win_end_h); int output_index = out_row_offset + pw; - int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_channel = c; int input_plane_offset = roi_batch_id * in_stride[0] + input_channel * in_stride[1]; const T* offset_input_data = input_data + input_plane_offset; @@ -254,20 +377,26 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); + auto* out = ctx.Input("Out"); auto* rois = ctx.Input("ROIs"); auto* output_grad = ctx.Input(framework::GradVarName("Out")); auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* input_roi_grad = + ctx.Output(framework::GradVarName("ROIs")); auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); - auto output_channels = ctx.Attr("output_channels"); auto spatial_scale = ctx.Attr("spatial_scale"); - if (input_grad) { + if (input_grad && input_roi_grad) { auto in_dims = in->dims(); + auto* in_data = in->data(); + auto* out_data = out->data(); + int input_channels = in_dims[1]; + auto output_channels = input_channels; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; @@ -289,6 +418,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { const T* input_rois = rois->data(); const T* output_grad_data = output_grad->data(); T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + T* input_roi_grad_data = input_roi_grad->mutable_data(ctx.GetPlace()); // set gradient of X to be 0. before backpropagate. math::SetConstant set_zero; @@ -306,11 +436,12 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { // set roi_batch_id int roi_batch_id = rois_batch_id_data[n]; - int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_channel = c; int input_offset = (roi_batch_id * input_channels + input_channel) * height * width; T* offset_input_grad_data = input_grad_data + input_offset; const T* offset_output_grad_data = output_grad_data + i; + const T* offset_out_data = out_data + i; // [start, end) interval for spatial sampling const T* offset_input_rois = input_rois + n * 4; @@ -318,6 +449,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { T roi_start_h = static_cast(offset_input_rois[1]) * spatial_scale; T roi_end_w = static_cast(offset_input_rois[2]) * spatial_scale; T roi_end_h = static_cast(offset_input_rois[3]) * spatial_scale; + T* offset_input_roi_grad_data = input_roi_grad_data + n * 4; T roi_width = std::max(roi_end_w - roi_start_w, static_cast(0.0)); T roi_height = std::max(roi_end_h - roi_start_h, static_cast(0.0)); @@ -355,6 +487,16 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel { height, width, PrRoIPoolingDistributeDiff); } } + + const T* offset_in_data = in_data + input_offset; + PrRoIPoolingCoorBackward( + s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, + win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size, + spatial_scale, offset_in_data, offset_out_data, + offset_input_grad_data, offset_input_roi_grad_data, + CPUAccumulateRois, + [](const T x, const T y) { return std::max(x, y); }, + [](const T x, const T y) { return std::min(x, y); }); } } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4132618ab7b..e1c0c3117ad 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -15442,7 +15442,6 @@ def psroi_pool(input, @templatedoc() def prroi_pool(input, rois, - output_channels, spatial_scale=1.0, pooled_height=1, pooled_width=1, @@ -15459,7 +15458,6 @@ def prroi_pool(input, is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates. - output_channels (integer): The output's channel. spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width). Equals the reciprocal of total stride in convolutional layers, Default: 1.0. pooled_height (integer): The pooled output height. Default: 1. @@ -15475,12 +15473,10 @@ def prroi_pool(input, import paddle.fluid as fluid x = fluid.layers.data(name='x', shape=[490, 28, 28], dtype='float32') rois = fluid.layers.data(name='rois', shape=[4], lod_level=1, dtype='float32') - pool_out = fluid.layers.prroi_pool(x, rois, 10, 1.0, 7, 7) + pool_out = fluid.layers.prroi_pool(x, rois, 1.0, 7, 7) """ helper = LayerHelper('prroi_pool', **locals()) # check attrs - if not isinstance(output_channels, int): - raise TypeError("output_channels must be int type") if not isinstance(spatial_scale, float): raise TypeError("spatial_scale must be float type") if not isinstance(pooled_height, int): @@ -15495,7 +15491,6 @@ def prroi_pool(input, 'ROIs': rois}, outputs={'Out': out}, attrs={ - 'output_channels': output_channels, 'spatial_scale': spatial_scale, 'pooled_height': pooled_height, 'pooled_width': pooled_width diff --git a/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py b/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py index 618ffbdf9fc..aa7b8420f48 100644 --- a/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py +++ b/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py @@ -133,8 +133,7 @@ class PyPrRoIPool(object): s_h = math.floor(win_start_h) e_h = math.ceil(win_end_h) - c_in = (c * pooled_height + ph) * pooled_width + pw - + c_in = c for w_iter in range(int(s_w), int(e_w)): for h_iter in range(int(s_h), int(e_h)): sum_out += self._PrRoIPoolingMatCalculation( diff --git a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py index 49aab6ddfc0..e3bfa062a3e 100644 --- a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py @@ -48,7 +48,7 @@ class TestPRROIPoolOp(OpTest): self.x_dim = [self.batch_size, self.channels, self.height, self.width] self.spatial_scale = 1.0 / 4.0 - self.output_channels = 3 + self.output_channels = self.channels self.pooled_height = 2 self.pooled_width = 2 @@ -60,15 +60,15 @@ class TestPRROIPoolOp(OpTest): for bno in range(self.batch_size): self.rois_lod[0].append(bno + 1) for i in range(bno + 1): - x1 = np.random.random_integers( + x1 = np.random.uniform( 0, self.width // self.spatial_scale - self.pooled_width) - y1 = np.random.random_integers( + y1 = np.random.uniform( 0, self.height // self.spatial_scale - self.pooled_height) - x2 = np.random.random_integers(x1 + self.pooled_width, - self.width // self.spatial_scale) - y2 = np.random.random_integers( - y1 + self.pooled_height, self.height // self.spatial_scale) + x2 = np.random.uniform(x1 + self.pooled_width, + self.width // self.spatial_scale) + y2 = np.random.uniform(y1 + self.pooled_height, + self.height // self.spatial_scale) roi = [bno, x1, y1, x2, y2] rois.append(roi) self.rois_num = len(rois) @@ -93,8 +93,7 @@ class TestPRROIPoolOp(OpTest): dtype="float32") rois = fluid.layers.data( name="ROIs", shape=[4], dtype="float32", lod_level=1) - output = fluid.layers.prroi_pool(x, rois, self.output_channels, - 0.25, 2, 2) + output = fluid.layers.prroi_pool(x, rois, 0.25, 2, 2) loss = fluid.layers.mean(output) optimizer = fluid.optimizer.SGD(learning_rate=1e-3) optimizer.minimize(loss) @@ -120,18 +119,15 @@ class TestPRROIPoolOp(OpTest): name="x", shape=[245, 30, 30], dtype="float32") rois = fluid.layers.data( name="rois", shape=[4], dtype="float32", lod_level=1) - # channel must be int type - self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.5, - 0.25, 7, 7) # spatial_scale must be float type - self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, 2, - 7, 7) + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7, + 7) # pooled_height must be int type - self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, - 0.25, 0.7, 7) + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25, + 0.7, 7) # pooled_width must be int type - self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, - 0.25, 7, 0.7) + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25, + 7, 0.7) if __name__ == '__main__': -- GitLab