diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 9cac2908ba4ef385bd20fd9f820e21bcaaa6d1c9..44895ff3393d5af8be9e8dd32e63fc8bc88a8a20 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -276,6 +276,7 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', '13b1cdcb01f5ffdc26591ff9a2ec4669')) paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb')) paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '42d5155374f69786300d90d751956998')) +paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303')) paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '07cb0d95a646dba1b9cc7cdce89e59f0')) paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '11bb8e62cc9256958eff3991fe4834da')) paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '18bc95c62d3300456c3c7da5278b47bb')) diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d5129f8d603088551ca2f7dcf01edf7e5b0ffc4 --- /dev/null +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -0,0 +1,188 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/prroi_pool_op.h" +#include + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "the input of PRROIPoolOp. " + "The format of input tensor is NCHW. Where N is the batch size, " + "C is the number of input channels, " + "H is the height of the input feature map, and " + "W is the width."); + AddInput("ROIs", + "(LoDTensor), " + "ROIs (Regions of Interest) to pool over. " + "should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]. " + "where (x1, y1) is the top left coordinates, and " + "(x2, y2) is the bottom right coordinates. " + "The roi batch index can be calculated from LoD."); + AddOutput("Out", + "(Tensor), " + "the output of PRROIPoolOp is a 4-D Tensor with shape " + "(num_rois, output_channels, pooled_h, pooled_w)."); + AddAttr( + "output_channels", + "(int), " + "the number of channels of the output feature map. " + "For a task of C classes of objects, output_channels should be " + "(C + 1) for classification only."); + AddAttr("spatial_scale", + "(float, default 1.0), " + "Multiplicative spatial scale factor " + "to translate ROI coords from their input scale " + "to the scale used when pooling.") + .SetDefault(1.0); + AddAttr("pooled_height", + "(int, default 1), " + "the pooled output height.") + .SetDefault(1); + AddAttr("pooled_width", + "(int, default 1), " + "the pooled output width.") + .SetDefault(1); + AddComment(R"Doc( +**PRROIPool Operator** + +Precise region of interest pooling (also known as PRROIPooling) is to perform + bilinear interpolation average pooling method for RoI Pooling. + +Please refer to https://arxiv.org/abs/1807.11590 for more details. + + )Doc"); + } +}; + +class PRROIPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of op(PRROIPool) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true, + "Input(ROIs) of op(PRROIPool) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of op(PRROIPool) should not be null."); + auto input_dims = ctx->GetInputDim("X"); + auto rois_dims = ctx->GetInputDim("ROIs"); + + PADDLE_ENFORCE_EQ(input_dims.size(), 4, + "The format of input tensor is NCHW"); + PADDLE_ENFORCE_EQ(rois_dims.size(), 2, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]"); + PADDLE_ENFORCE_EQ(rois_dims[1], 4, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]"); + + int pooled_height = ctx->Attrs().Get("pooled_height"); + int pooled_width = ctx->Attrs().Get("pooled_width"); + int output_channels = ctx->Attrs().Get("output_channels"); + float spatial_scale = ctx->Attrs().Get("spatial_scale"); + + PADDLE_ENFORCE_EQ( + input_dims[1], output_channels * pooled_height * pooled_width, + "the channel of X(%d) should be equal to the product of " + "output_channels(%d), pooled_height(%d) and pooled_width(%d)", + input_dims[1], output_channels, pooled_height, pooled_width); + + PADDLE_ENFORCE_GT(pooled_height, 0, + "The pooled output height must be greater than 0"); + PADDLE_ENFORCE_GT(pooled_width, 0, + "The pooled output width must be greater than 0"); + PADDLE_ENFORCE_GT(output_channels, 1, + "The pooled output channels must greater than 1"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0."); + + auto out_dims = input_dims; + out_dims[0] = rois_dims[0]; + out_dims[1] = + output_channels; // input_dims[1] / (pooled_height * pooled_width); + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class PRROIPoolGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + "The gradient of Out should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + "The gradient of X should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class PRROIPoolGradDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("prroi_pool_grad"); + op->SetInput("X", Input("X")); + op->SetInput("ROIs", Input("ROIs")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(prroi_pool, ops::PRROIPoolOp, ops::PRROIPoolOpMaker, + ops::PRROIPoolGradDescMaker); +REGISTER_OPERATOR(prroi_pool_grad, ops::PRROIPoolGradOp); +REGISTER_OP_CPU_KERNEL( + prroi_pool, + ops::CPUPRROIPoolOpKernel, + ops::CPUPRROIPoolOpKernel); +REGISTER_OP_CPU_KERNEL( + prroi_pool_grad, + ops::CPUPRROIPoolGradOpKernel, + ops::CPUPRROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/prroi_pool_op.cu b/paddle/fluid/operators/prroi_pool_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..915e3daae538f00434edb26a354252895520a21f --- /dev/null +++ b/paddle/fluid/operators/prroi_pool_op.cu @@ -0,0 +1,309 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/prroi_pool_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff, + const int h, const int w, + const int height, const int width, + const T coeff) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + if (!overflow) { + paddle::platform::CudaAtomicAdd(diff + h * width + w, top_diff * coeff); + } +} + +template +__global__ void GPUPRROIPoolForward( + const int nthreads, const T* input_data, const T* input_rois, + const float spatial_scale, const int input_channels, const int height, + const int width, const int output_channels, const int pooled_height, + const int pooled_width, const int* rois_batch_id_data, T* output_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(offset_input_rois[0]) * spatial_scale; + T roi_start_h = static_cast(offset_input_rois[1]) * spatial_scale; + T roi_end_w = static_cast(offset_input_rois[2]) * spatial_scale; + T roi_end_h = static_cast(offset_input_rois[3]) * spatial_scale; + + T roi_width = max(roi_end_w - roi_start_w, static_cast(0.0)); + T roi_height = max(roi_end_h - roi_start_h, static_cast(0.0)); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T win_start_w = roi_start_w + bin_size_w * pw; + T win_start_h = roi_start_h + bin_size_h * ph; + T win_end_w = win_start_w + bin_size_w; + T win_end_h = win_start_h + bin_size_h; + + T win_size = max(static_cast(0.0), bin_size_w * bin_size_h); + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + const T* offset_input_data = + input_data + + (roi_batch_id * input_channels + input_channel) * height * width; + + if (win_size > static_cast(0.0)) { + int s_w = floor(win_start_w); + int e_w = ceil(win_end_w); + int s_h = floor(win_start_h); + int e_h = ceil(win_end_h); + T sum_out = 0; + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) { + for (int h_iter = s_h; h_iter < e_h; ++h_iter) { + sum_out += PrRoIPoolingMatCalculation( + offset_input_data, h_iter, w_iter, h_iter + 1, w_iter + 1, + max(win_start_h, static_cast(h_iter)), + max(win_start_w, static_cast(w_iter)), + min(win_end_h, static_cast(h_iter) + static_cast(1.0)), + min(win_end_w, static_cast(w_iter) + static_cast(1.0)), + height, width); + } + } + output_data[i] = sum_out / win_size; + } else { + output_data[i] = 0.; + } + } +} + +template +__global__ void GPUPRROIPoolBackward( + const int nthreads, const T* input_rois, const T* output_grad_data, + const float spatial_scale, const int input_channels, const int height, + const int width, const int output_channels, const int pooled_height, + const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_offset = + (roi_batch_id * input_channels + input_channel) * height * width; + T* offset_input_grad_data = input_grad_data + input_offset; + const T* offset_output_grad_data = output_grad_data + i; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(offset_input_rois[0]) * spatial_scale; + T roi_start_h = static_cast(offset_input_rois[1]) * spatial_scale; + T roi_end_w = static_cast(offset_input_rois[2]) * spatial_scale; + T roi_end_h = static_cast(offset_input_rois[3]) * spatial_scale; + + T roi_width = max(roi_end_w - roi_start_w, static_cast(0.0)); + T roi_height = max(roi_end_h - roi_start_h, static_cast(0.0)); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T win_start_w = roi_start_w + bin_size_w * pw; + T win_start_h = roi_start_h + bin_size_h * ph; + T win_end_w = win_start_w + bin_size_w; + T win_end_h = win_start_h + bin_size_h; + + T win_size = max(static_cast(0.0), bin_size_w * bin_size_h); + int s_w = floor(win_start_w); + int e_w = ceil(win_end_w); + int s_h = floor(win_start_h); + int e_h = ceil(win_end_h); + + T sum_out = win_size == static_cast(0.) + ? static_cast(0.) + : *offset_output_grad_data / win_size; + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) { + for (int h_iter = s_h; h_iter < e_h; ++h_iter) { + PrRoIPoolingMatDistributeDiff( + offset_input_grad_data, sum_out, h_iter, w_iter, h_iter + 1, + w_iter + 1, max(win_start_h, static_cast(h_iter)), + max(win_start_w, static_cast(w_iter)), + min(win_end_h, static_cast(h_iter) + static_cast(1.0)), + min(win_end_w, static_cast(w_iter) + static_cast(1.0)), + height, width, PrRoIPoolingDistributeDiffCUDA); + } + } + } +} + +template +class GPUPRROIPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto output_channels = ctx.Attr("output_channels"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + + PADDLE_ENFORCE_EQ(input_channels, + output_channels * pooled_height * pooled_width, + "the channels of input X should equal the product of " + "output_channels x pooled_height x pooled_width"); + + int rois_num = rois->dims()[0]; + if (rois_num == 0) return; + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "The rois_batch_size and input(X) batch_size must be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, + "The rois_num from input and lod must be the same."); + + // set rois batch id + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(platform::CPUPlace()); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + framework::Tensor rois_batch_id_list_gpu; + framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), + ctx.device_context(), &rois_batch_id_list_gpu); + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; + + // call cuda kernel function + GPUPRROIPoolForward< + T><<>>( + output_size, in->data(), rois->data(), spatial_scale, + input_channels, height, width, output_channels, pooled_height, + pooled_width, rois_batch_id_list_gpu.data(), + out->mutable_data(ctx.GetPlace())); + } +}; + +template +class GPUPRROIPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto output_channels = ctx.Attr("output_channels"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + int rois_num = rois->dims()[0]; + int input_channels = in->dims()[1]; + int height = in->dims()[2]; + int width = in->dims()[3]; + + if (input_grad) { + // set roi batch id + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(platform::CPUPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + framework::Tensor rois_batch_id_list_gpu; + framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), + ctx.device_context(), &rois_batch_id_list_gpu); + + input_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.cuda_device_context(), input_grad, static_cast(0)); + + int output_grad_size = output_grad->numel(); + int blocks = NumBlocks(output_grad_size); + int threads = kNumCUDAThreads; + + if (output_grad_size > 0) { + GPUPRROIPoolBackward< + T><<>>( + output_grad_size, rois->data(), output_grad->data(), + spatial_scale, input_channels, height, width, output_channels, + pooled_height, pooled_width, rois_batch_id_list_gpu.data(), + input_grad->mutable_data(ctx.GetPlace())); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(prroi_pool, ops::GPUPRROIPoolOpKernel, + ops::GPUPRROIPoolOpKernel); +REGISTER_OP_CUDA_KERNEL( + prroi_pool_grad, + ops::GPUPRROIPoolGradOpKernel, + ops::GPUPRROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/prroi_pool_op.h b/paddle/fluid/operators/prroi_pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..621e543fab5539df15bab65ab7552ae7cf2f2196 --- /dev/null +++ b/paddle/fluid/operators/prroi_pool_op.h @@ -0,0 +1,364 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +HOSTDEVICE T PrRoIPoolingGetData(const T* data, const int h, const int w, + const int height, const int width) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + T retVal = overflow ? 0.0f : data[h * width + w]; + return retVal; +} + +template +HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, const int s_h, + const int s_w, const int e_h, + const int e_w, const T y0, const T x0, + const T y1, const T x1, const int h0, + const int w0) { + T alpha, beta, lim_alpha, lim_beta, tmp; + T sum_out = 0; + + alpha = x0 - static_cast(s_w); + beta = y0 - static_cast(s_h); + lim_alpha = x1 - static_cast(s_w); + lim_beta = y1 - static_cast(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp; + + alpha = static_cast(e_w) - x1; + lim_alpha = static_cast(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp; + + alpha = x0 - static_cast(s_w); + beta = static_cast(e_h) - y1; + lim_alpha = x1 - static_cast(s_w); + lim_beta = static_cast(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp; + + alpha = static_cast(e_w) - x1; + lim_alpha = static_cast(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp; + + return sum_out; +} + +template +HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, + const int h, const int w, + const int height, const int width, + const T coeff) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + if (!overflow) { + *(diff + h * width + w) = top_diff * coeff; + } +} + +template +HOSTDEVICE void PrRoIPoolingMatDistributeDiff( + T* diff, const T top_diff, const int s_h, const int s_w, const int e_h, + const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0, + const int w0, Functor functor) { + T alpha, beta, lim_alpha, lim_beta, tmp; + + alpha = x0 - static_cast(s_w); + beta = y0 - static_cast(s_h); + lim_alpha = x1 - static_cast(s_w); + lim_beta = y1 - static_cast(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + functor(diff, top_diff, s_h, s_w, h0, w0, tmp); + + alpha = static_cast(e_w) - x1; + lim_alpha = static_cast(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + functor(diff, top_diff, s_h, e_w, h0, w0, tmp); + + alpha = x0 - static_cast(s_w); + beta = static_cast(e_h) - y1; + lim_alpha = x1 - static_cast(s_w); + lim_beta = static_cast(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + functor(diff, top_diff, e_h, s_w, h0, w0, tmp); + + alpha = static_cast(e_w) - x1; + lim_alpha = static_cast(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + functor(diff, top_diff, e_h, e_w, h0, w0, tmp); +} + +template +class CPUPRROIPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto output_channels = ctx.Attr("output_channels"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + + auto in_stride = framework::stride(in_dims); + auto out_stride = framework::stride(out->dims()); + + const T* input_data = in->data(); + + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(ctx.GetPlace()); + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "the rois_batch_size and input(X) batch_size should be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num, + "the rois_num from input and lod must be the same"); + + PADDLE_ENFORCE_EQ(input_channels, + output_channels * pooled_height * pooled_width, + "the channels of input X should equal the product of " + "output_channels x pooled_height x pooled_width"); + + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* input_rois = rois->data(); + + // calculate prroipooling, parallel processing can be implemented per ROI + for (int n = 0; n < rois_num; ++n) { + // set roi batch id + int roi_batch_id = rois_batch_id_data[n]; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(offset_input_rois[0]) * spatial_scale; + T roi_start_h = static_cast(offset_input_rois[1]) * spatial_scale; + T roi_end_w = static_cast(offset_input_rois[2]) * spatial_scale; + T roi_end_h = static_cast(offset_input_rois[3]) * spatial_scale; + + T roi_width = std::max(roi_end_w - roi_start_w, static_cast(0.0)); + T roi_height = std::max(roi_end_h - roi_start_h, static_cast(0.0)); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + T win_size = std::max(static_cast(0.0), bin_size_w * bin_size_h); + + // calculate each pixel of the output feature map. + int out_roi_offset = n * out_stride[0]; + for (int c = 0; c < output_channels; ++c) { + // per category + int out_plane_offset = out_roi_offset + c * out_stride[1]; + for (int ph = 0; ph < pooled_height; ++ph) { + int out_row_offset = out_plane_offset + ph * out_stride[2]; + for (int pw = 0; pw < pooled_width; ++pw) { + // calculate w and h at input feature map + T win_start_h = static_cast(ph) * bin_size_h + roi_start_h; + T win_start_w = static_cast(pw) * bin_size_w + roi_start_w; + T win_end_h = win_start_h + bin_size_h; + T win_end_w = win_start_w + bin_size_w; + // Add roi offsets and clip to input boundaries + int s_w = std::floor(win_start_w); + int e_w = std::ceil(win_end_w); + int s_h = std::floor(win_start_h); + int e_h = std::ceil(win_end_h); + + int output_index = out_row_offset + pw; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_plane_offset = + roi_batch_id * in_stride[0] + input_channel * in_stride[1]; + const T* offset_input_data = input_data + input_plane_offset; + T sum_out = 0.; + + if (win_size > static_cast(0.0)) { + for (int w_iter = s_w; w_iter < e_w; ++w_iter) { + for (int h_iter = s_h; h_iter < e_h; ++h_iter) { + sum_out += PrRoIPoolingMatCalculation( + offset_input_data, h_iter, w_iter, h_iter + 1, w_iter + 1, + std::max(win_start_h, static_cast(h_iter)), + std::max(win_start_w, static_cast(w_iter)), + std::min(win_end_h, + static_cast(h_iter) + static_cast(1.0)), + std::min(win_end_w, + static_cast(w_iter) + static_cast(1.0)), + height, width); + } + } + + output_data[output_index] = sum_out / win_size; + } else { + output_data[output_index] = 0.; + } + } + } + } + } + } +}; + +template +class CPUPRROIPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto output_channels = ctx.Attr("output_channels"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + if (input_grad) { + auto in_dims = in->dims(); + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + + // set roi batch id + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(ctx.GetPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + const T* input_rois = rois->data(); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + + // set gradient of X to be 0. before backpropagate. + math::SetConstant set_zero; + set_zero(ctx.template device_context(), input_grad, + static_cast(0)); + + // backpropagate gradient per output pixel + int output_grad_size = output_grad->numel(); + for (int i = 0; i < output_grad_size; ++i) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_offset = + (roi_batch_id * input_channels + input_channel) * height * width; + T* offset_input_grad_data = input_grad_data + input_offset; + const T* offset_output_grad_data = output_grad_data + i; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(offset_input_rois[0]) * spatial_scale; + T roi_start_h = static_cast(offset_input_rois[1]) * spatial_scale; + T roi_end_w = static_cast(offset_input_rois[2]) * spatial_scale; + T roi_end_h = static_cast(offset_input_rois[3]) * spatial_scale; + + T roi_width = std::max(roi_end_w - roi_start_w, static_cast(0.0)); + T roi_height = std::max(roi_end_h - roi_start_h, static_cast(0.0)); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T win_start_w = roi_start_w + bin_size_w * pw; + T win_start_h = roi_start_h + bin_size_h * ph; + T win_end_w = win_start_w + bin_size_w; + T win_end_h = win_start_h + bin_size_h; + + T win_size = std::max(static_cast(0.0), bin_size_w * bin_size_h); + + T sum_out = win_size == static_cast(0.) + ? static_cast(0.) + : *offset_output_grad_data / win_size; + + int s_w = std::floor(win_start_w); + int e_w = std::ceil(win_end_w); + int s_h = std::floor(win_start_h); + int e_h = std::ceil(win_end_h); + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) { + for (int h_iter = s_h; h_iter < e_h; ++h_iter) { + PrRoIPoolingMatDistributeDiff( + offset_input_grad_data, sum_out, h_iter, w_iter, h_iter + 1, + w_iter + 1, std::max(win_start_h, static_cast(h_iter)), + std::max(win_start_w, static_cast(w_iter)), + std::min(win_end_h, + static_cast(h_iter) + static_cast(1.0)), + std::min(win_end_w, + static_cast(w_iter) + static_cast(1.0)), + height, width, PrRoIPoolingDistributeDiff); + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 15ab6610c6b186d1da2ce522024eceb2d4d18317..c3cfbdba78a601a9d7e80ae50c54642d8a1cf230 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -203,6 +203,7 @@ __all__ = [ 'temporal_shift', 'py_func', 'psroi_pool', + 'prroi_pool', 'teacher_student_sigmoid_loss', 'huber_loss', 'kldiv_loss', @@ -12716,6 +12717,70 @@ def psroi_pool(input, return out +@templatedoc() +def prroi_pool(input, + rois, + output_channels, + spatial_scale=1.0, + pooled_height=1, + pooled_width=1, + name=None): + """ + The precise roi pooling implementation for paddle?https://arxiv.org/pdf/1807.11590.pdf + + Args: + input (Variable):The input of Deformable PSROIPooling.The shape of input tensor is + [N,C,H,W]. Where N is batch size,C is number of input channels,H + is height of the feature, and W is the width of the feature. + rois (Variable): ROIs (Regions of Interest) to pool over.It should be + a 2-D LoDTensor of shape (num_rois, 4), the lod level + is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is + the top left coordinates, and (x2, y2) is the bottom + right coordinates. + output_channels (integer): The output's channel. + spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width). + Equals the reciprocal of total stride in convolutional layers, Default: 1.0. + pooled_height (integer): The pooled output height. Default: 1. + pooled_width (integer): The pooled output width. Default: 1. + name (str, default None): The name of this operation. + + Returns: + Variable(Tensor): The shape of the returned Tensor is (num_rois, output_channels, pooled_h, pooled_w), with value type float32,float16.. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + x = fluid.layers.data(name='x', shape=[490, 28, 28], dtype='float32') + rois = fluid.layers.data(name='rois', shape=[4], lod_level=1, dtype='float32') + pool_out = fluid.layers.prroi_pool(x, rois, 10, 1.0, 7, 7) + """ + helper = LayerHelper('prroi_pool', **locals()) + # check attrs + if not isinstance(output_channels, int): + raise TypeError("output_channels must be int type") + if not isinstance(spatial_scale, float): + raise TypeError("spatial_scale must be float type") + if not isinstance(pooled_height, int): + raise TypeError("pooled_height must be int type") + if not isinstance(pooled_width, int): + raise TypeError("pooled_width must be int type") + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='prroi_pool', + inputs={'X': input, + 'ROIs': rois}, + outputs={'Out': out}, + attrs={ + 'output_channels': output_channels, + 'spatial_scale': spatial_scale, + 'pooled_height': pooled_height, + 'pooled_width': pooled_width + }) + return out + + def huber_loss(input, label, delta): """ Huber loss is a loss function used in robust. diff --git a/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py b/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..618ffbdf9fc690b08ade81443a6515b1a74ebc12 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/py_precise_roi_pool.py @@ -0,0 +1,151 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np + + +class PyPrRoIPool(object): + def __init__(self): + pass + + def _PrRoIPoolingGetData(self, data, h, w, height, width): + overflow = (h < 0) or (w < 0) or (h >= height) or (w >= width) + if overflow: + return 0.0 + else: + return data[h][w] + + def _PrRoIPoolingMatCalculation(self, this_data, s_h, s_w, e_h, e_w, y0, x0, + y1, x1, h0, w0): + sum_out = 0.0 + alpha = x0 - float(s_w) + beta = y0 - float(s_h) + lim_alpha = x1 - float(s_w) + lim_beta = y1 - float(s_h) + tmp = ( + lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha * + alpha) * ( + lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta) + sum_out += self._PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp + + alpha = float(e_w) - x1 + lim_alpha = float(e_w) - x0 + tmp = ( + lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha * + alpha) * ( + lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta) + sum_out += self._PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp + + alpha = x0 - float(s_w) + beta = float(e_h) - y1 + lim_alpha = x1 - float(s_w) + lim_beta = float(e_h) - y0 + tmp = ( + lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha * + alpha) * ( + lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta) + sum_out += self._PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp + + alpha = float(e_w) - x1 + lim_alpha = float(e_w) - x0 + tmp = ( + lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha * + alpha) * ( + lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta) + sum_out += self._PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp + + return sum_out + + def compute(self, + x, + rois, + output_channels, + spatial_scale=0.1, + pooled_height=1, + pooled_width=1): + ''' + calculate the precise roi pooling values + Note: This function is implements as pure python without any paddle concept involved + :param x (array): array[N, C, H, W] + :param rois (array): ROIs[id, x1, y1, x2, y2] (Regions of Interest) to pool over. + :param output_channels (Integer): Expected output channels + :param spatial_scale (float): spatial scale, default = 0.1 + :param pooled_height (Integer): Expected output height, default = 1 + :param pooled_width (Integer): Expected output width, default = 1 + :return: array[len(rois), output_channels, pooled_height, pooled_width] + ''' + if not isinstance(output_channels, int): + raise TypeError("output_channels must be int type") + if not isinstance(spatial_scale, float): + raise TypeError("spatial_scale must be float type") + if not isinstance(pooled_height, int): + raise TypeError("pooled_height must be int type") + if not isinstance(pooled_width, int): + raise TypeError("pooled_width must be int type") + + (batch_size, channels, height, width) = np.array(x).shape + rois_num = len(rois) + output_shape = (rois_num, output_channels, pooled_height, pooled_width) + out_data = np.zeros(output_shape) + for i in range(rois_num): + roi = rois[i] + roi_batch_id = int(roi[0]) + roi_start_w = roi[1] * spatial_scale + roi_start_h = roi[2] * spatial_scale + roi_end_w = roi[3] * spatial_scale + roi_end_h = roi[4] * spatial_scale + + roi_width = max(roi_end_w - roi_start_w, 0.0) + roi_height = max(roi_end_h - roi_start_h, 0.0) + bin_size_h = roi_height / float(pooled_height) + bin_size_w = roi_width / float(pooled_width) + + x_i = x[roi_batch_id] + + for c in range(output_channels): + for ph in range(pooled_height): + for pw in range(pooled_width): + win_start_w = roi_start_w + bin_size_w * pw + win_start_h = roi_start_h + bin_size_h * ph + win_end_w = win_start_w + bin_size_w + win_end_h = win_start_h + bin_size_h + + win_size = max(0.0, bin_size_w * bin_size_h) + if win_size == 0.0: + out_data[i, c, ph, pw] = 0.0 + else: + sum_out = 0 + + s_w = math.floor(win_start_w) + e_w = math.ceil(win_end_w) + s_h = math.floor(win_start_h) + e_h = math.ceil(win_end_h) + + c_in = (c * pooled_height + ph) * pooled_width + pw + + for w_iter in range(int(s_w), int(e_w)): + for h_iter in range(int(s_h), int(e_h)): + sum_out += self._PrRoIPoolingMatCalculation( + x_i[c_in], h_iter, w_iter, h_iter + 1, + w_iter + 1, + max(win_start_h, float(h_iter)), + max(win_start_w, float(w_iter)), + min(win_end_h, float(h_iter) + 1.0), + min(win_end_w, float(w_iter + 1.0)), + height, width) + + out_data[i, c, ph, pw] = sum_out / win_size + + return out_data diff --git a/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py new file mode 100644 index 0000000000000000000000000000000000000000..49aab6ddfc0158b11d73aac027746ef02edc6d89 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_prroi_pool_op.py @@ -0,0 +1,138 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +from py_precise_roi_pool import PyPrRoIPool +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestPRROIPoolOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + self.prRoIPool = PyPrRoIPool() + self.outs = self.prRoIPool.compute( + self.x, self.rois, self.output_channels, self.spatial_scale, + self.pooled_height, self.pooled_width).astype('float32') + self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)} + self.attrs = { + 'output_channels': self.output_channels, + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width + } + self.outputs = {'Out': self.outs} + + def init_test_case(self): + self.batch_size = 3 + self.channels = 3 * 2 * 2 + self.height = 6 + self.width = 4 + + self.x_dim = [self.batch_size, self.channels, self.height, self.width] + + self.spatial_scale = 1.0 / 4.0 + self.output_channels = 3 + self.pooled_height = 2 + self.pooled_width = 2 + + self.x = np.random.random(self.x_dim).astype('float32') + + def make_rois(self): + rois = [] + self.rois_lod = [[]] + for bno in range(self.batch_size): + self.rois_lod[0].append(bno + 1) + for i in range(bno + 1): + x1 = np.random.random_integers( + 0, self.width // self.spatial_scale - self.pooled_width) + y1 = np.random.random_integers( + 0, self.height // self.spatial_scale - self.pooled_height) + + x2 = np.random.random_integers(x1 + self.pooled_width, + self.width // self.spatial_scale) + y2 = np.random.random_integers( + y1 + self.pooled_height, self.height // self.spatial_scale) + roi = [bno, x1, y1, x2, y2] + rois.append(roi) + self.rois_num = len(rois) + self.rois = np.array(rois).astype('float32') + + def setUp(self): + self.op_type = 'prroi_pool' + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_backward(self): + for place in self._get_places(): + self._get_gradient(['X'], place, ["Out"], None) + + def run_net(self, place): + with program_guard(Program(), Program()): + x = fluid.layers.data( + name="X", + shape=[self.channels, self.height, self.width], + dtype="float32") + rois = fluid.layers.data( + name="ROIs", shape=[4], dtype="float32", lod_level=1) + output = fluid.layers.prroi_pool(x, rois, self.output_channels, + 0.25, 2, 2) + loss = fluid.layers.mean(output) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + input_x = fluid.create_lod_tensor(self.x, [], place) + input_rois = fluid.create_lod_tensor(self.rois[:, 1:5], + self.rois_lod, place) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + exe.run(fluid.default_main_program(), + {'X': input_x, + "ROIs": input_rois}) + + def test_net(self): + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.run_net(place) + + def test_errors(self): + with program_guard(Program(), Program()): + x = fluid.layers.data( + name="x", shape=[245, 30, 30], dtype="float32") + rois = fluid.layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + # channel must be int type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.5, + 0.25, 7, 7) + # spatial_scale must be float type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, 2, + 7, 7) + # pooled_height must be int type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, + 0.25, 0.7, 7) + # pooled_width must be int type + self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, + 0.25, 7, 0.7) + + +if __name__ == '__main__': + unittest.main()