diff --git a/paddle/operators/roi_pool_op.cc b/paddle/operators/roi_pool_op.cc new file mode 100755 index 0000000000000000000000000000000000000000..156db9358689c90293311b8f08a7576b680c9472 --- /dev/null +++ b/paddle/operators/roi_pool_op.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/roi_pool_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kROISize = 5; + +class ROIPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ROIPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ROIs"), + "Input(ROIs) of ROIPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ROIPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Argmax"), + "Output(Argmax) of ROIPoolOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + auto rois_dims = ctx->GetInputDim("ROIs"); + + PADDLE_ENFORCE(input_dims.size() == 4, + "The format of input tensor is NCHW."); + PADDLE_ENFORCE(rois_dims.size() == 2, + "ROIs should be a 2-D tensor of shape (num_rois, 5)" + "given as [[batch_id, x1, y1, x2, y2], …]."); + PADDLE_ENFORCE(rois_dims[1] == kROISize, + "ROIs should be a 2-D tensor of shape (num_rois, 5)" + "given as [[batch_id, x1, y1, x2, y2], …]."); + + int pooled_height = ctx->Attrs().Get("pooled_height"); + int pooled_width = ctx->Attrs().Get("pooled_width"); + float spatial_scale = ctx->Attrs().Get("spatial_scale"); + + PADDLE_ENFORCE_GT(pooled_height, 0, + "The pooled output height must greater than 0"); + PADDLE_ENFORCE_GT(pooled_width, 0, + "The pooled output width must greater than 0"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0"); + + auto out_dims = input_dims; + out_dims[0] = rois_dims[0]; + out_dims[1] = input_dims[1]; + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + + ctx->SetOutputDim("Out", out_dims); + ctx->SetOutputDim("Argmax", out_dims); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIPoolGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ROIPoolOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor), " + "the input of ROIPoolOp. " + "The format of input tensor is NCHW. Where N is batch size, " + "C is the number of input channels, " + "H is the height of the feature, and " + "W is the width of the feature."); + AddInput("ROIs", + "(Tensor), " + "ROIs (Regions of Interest) to pool over. " + "should be a 2-D tensor of shape (num_rois, 5)" + "given as [[batch_id, x1, y1, x2, y2], …]. " + "Where batch_id is the id of the data, " + "(x1, y1) is the top left coordinates, and " + "(x2, y2) is the bottom right coordinates."); + AddOutput("Out", + "(Tensor), " + "The output of ROIPoolOp is a 4-D tensor with shape " + "(num_rois, channels, pooled_h, pooled_w)."); + AddOutput("Argmax", + "(Tensor), " + "Argmaxes corresponding to indices in X used " + "for gradient computation. Only output " + "if arg “is_test” is false.").AsIntermediate(); + AddAttr("spatial_scale", + "(float, default 1.0), " + "Multiplicative spatial scale factor " + "to translate ROI coords from their input scale " + "to the scale used when pooling.") + .SetDefault(1.0); + AddAttr("pooled_height", + "(int, default 1), " + "The pooled output height.") + .SetDefault(1); + AddAttr("pooled_width", + "(int, default 1), " + "The pooled output width.") + .SetDefault(1); + AddComment(R"DOC( +ROIPool operator + +ROI Pooling for Faster-RCNN. The link below is a further introduction: +https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, + roi_pool_grad, ops::ROIPoolGradOp); +REGISTER_OP_CPU_KERNEL( + roi_pool, + ops::CPUROIPoolOpKernel, + ops::CPUROIPoolOpKernel); +REGISTER_OP_CPU_KERNEL( + roi_pool_grad, + ops::CPUROIPoolGradOpKernel, + ops::CPUROIPoolOpKernel); diff --git a/paddle/operators/roi_pool_op.cu b/paddle/operators/roi_pool_op.cu new file mode 100755 index 0000000000000000000000000000000000000000..97df45f1b5779d5e28e36814450a9577edf85135 --- /dev/null +++ b/paddle/operators/roi_pool_op.cu @@ -0,0 +1,232 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/roi_pool_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; +static constexpr int kROISize = 5; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + + template + __global__ void GPUROIPoolForward( + const int nthreads, const T* input_data, const int64_t* input_rois, + const float spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + T* output_data, int64_t* argmax_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const int64_t* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = offset_input_rois[0]; + int roi_start_w = round(offset_input_rois[1] * spatial_scale); + int roi_start_h = round(offset_input_rois[2] * spatial_scale); + int roi_end_w = round(offset_input_rois[3] * spatial_scale); + int roi_end_h = round(offset_input_rois[4] * spatial_scale); + + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + T maxval = is_empty ? 0 : -std::numeric_limits::max(); + int maxidx = -1; + const T* offset_input_data = + input_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_data_index = h * width + w; + if (offset_input_data[input_data_index] > maxval) { + maxval = offset_input_data[input_data_index]; + maxidx = input_data_index; + } + } + } + output_data[index] = maxval; + if (argmax_data) { + argmax_data[index] = maxidx; + } + } + } + +template +__global__ void GPUROIPoolBackward( + const int nthreads, + const int64_t* input_rois, + const T* output_grad, + const int64_t* argmax_data, + const int num_rois, + const float spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + T* input_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const int64_t* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = offset_input_rois[0]; + int input_offset = (roi_batch_ind * channels + c) * height * width; + int output_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_output_grad = output_grad + output_offset; + T* offset_input_grad = input_grad + input_offset; + const int64_t* offset_argmax_data = argmax_data + output_offset; + + int argmax = offset_argmax_data[ph * pooled_width + pw]; + if (argmax != -1) { + platform::CudaAtomicAdd(offset_input_grad + argmax, + static_cast(offset_output_grad[ph * pooled_width + pw])); + } + } + } + + +template +class GPUROIPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + auto* argmax = ctx.Output("Argmax"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + auto in_stride = framework::stride(in_dims); + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + + size_t rois_num = rois->dims()[0]; + if (rois_num== 0) return; + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; + + GPUROIPoolForward + <<>>( + output_size, + in->data(), + rois->data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + out->mutable_data(ctx.GetPlace()), + argmax->mutable_data(ctx.GetPlace())); + } +}; + +template +class GPUROIPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* argmax = ctx.Input("Argmax"); + + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + size_t rois_num = rois->dims()[0]; + int channels = in->dims()[1]; + int height = in->dims()[2]; + int width = in->dims()[3]; + + if (x_grad) { + x_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), x_grad, static_cast(0)); + + int output_grad_size = out_grad->numel(); + int blocks = NumBlocks(output_grad_size); + int threads = kNumCUDAThreads; + + if (output_grad_size > 0) { + GPUROIPoolBackward + <<>>( + output_grad_size, + rois->data(), + out_grad->data(), + argmax->data(), + rois_num, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + x_grad->mutable_data(ctx.GetPlace())); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + roi_pool, + ops::GPUROIPoolOpKernel, + ops::GPUROIPoolOpKernel); +REGISTER_OP_GPU_KERNEL( + roi_pool_grad, + ops::GPUROIPoolGradOpKernel, + ops::GPUROIPoolOpKernel); diff --git a/paddle/operators/roi_pool_op.h b/paddle/operators/roi_pool_op.h new file mode 100755 index 0000000000000000000000000000000000000000..bd7736d63125f1be57c8af5141208f66d0592adb --- /dev/null +++ b/paddle/operators/roi_pool_op.h @@ -0,0 +1,190 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class CPUROIPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + auto* argmax = ctx.Output("Argmax"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + + auto in_stride = framework::stride(in_dims); + auto argmax_stride = framework::stride(argmax->dims()); + auto roi_stride = framework::stride(rois->dims()); + auto out_stride = framework::stride(out->dims()); + + const T* input_data = in->data(); + const int64_t* rois_data = rois->data(); + T* output_data = out->mutable_data(ctx.GetPlace()); + int64_t* argmax_data = argmax->mutable_data(ctx.GetPlace()); + + for (int n = 0; n < rois_num; ++n) { + int roi_batch_id = rois_data[0]; + PADDLE_ENFORCE_GE(roi_batch_id, 0); + PADDLE_ENFORCE_LT(roi_batch_id, batch_size); + rois_data += roi_stride[0]; + } + + rois_data = rois->data(); + for (int n = 0; n < rois_num; ++n) { + int roi_batch_id = rois_data[0]; + int roi_start_w = round(rois_data[1] * spatial_scale); + int roi_start_h = round(rois_data[2] * spatial_scale); + int roi_end_w = round(rois_data[3] * spatial_scale); + int roi_end_h = round(rois_data[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + + const float bin_size_h = + static_cast(roi_height) / static_cast(pooled_height); + const float bin_size_w = + static_cast(roi_width) / static_cast(pooled_width); + + const T* batch_data = input_data + roi_batch_id * in_stride[0]; + + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + // Compute pooling region for this output unit: + // start (included) = floor(ph * roi_height / pooled_height_) + // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) + int hstart = + static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = + static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + + const int pool_index = ph * pooled_width + pw; + + // Define an empty pooling region to be zero + bool is_empty = (hend <= hstart) || (wend <= wstart); + output_data[pool_index] = + is_empty ? 0 : -std::numeric_limits::max(); + argmax_data[pool_index] = -1; + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int index = h * width + w; + if (batch_data[index] > output_data[pool_index]) { + output_data[pool_index] = batch_data[index]; + argmax_data[pool_index] = index; + } + } + } + } + } + + batch_data += in_stride[1]; + output_data += out_stride[1]; + argmax_data += argmax_stride[1]; + } + // Increment ROI data pointer + rois_data += roi_stride[0]; + } + return; + } +}; + +template +class CPUROIPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* argmax = ctx.Input("Argmax"); + + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + + if (x_grad) { + int channels = in->dims()[1]; + auto in_stride = framework::stride(in->dims()); + auto roi_stride = framework::stride(rois->dims()); + + const int64_t* rois_data = rois->data(); + int rois_num = rois->dims()[0]; + + T* x_grad_data = x_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), x_grad, static_cast(0)); + + size_t roi_offset = roi_stride[0]; + size_t batch_offset = in_stride[0]; + size_t channel_offset = in_stride[1]; + + const T* out_grad_data = out_grad->data(); + size_t pool_channel_offset = pooled_height * pooled_width; + const int64_t* argmax_data = argmax->data(); + + for (size_t n = 0; n < rois_num; ++n) { + size_t roi_batch_idx = rois_data[0]; + T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx; + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + size_t pool_index = ph * pooled_width + pw; + + if (argmax_data[pool_index] >= 0) { + size_t index = static_cast(argmax_data[pool_index]); + batch_grad_data[index] += out_grad_data[pool_index]; + } + } + } + batch_grad_data += channel_offset; + out_grad_data += pool_channel_offset; + argmax_data += pool_channel_offset; + } + rois_data += roi_offset; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_roi_pool_op.py b/python/paddle/v2/fluid/tests/test_roi_pool_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7cedb930ca861aed95c355931d80cb4d265c8235 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_roi_pool_op.py @@ -0,0 +1,127 @@ +import unittest +import numpy as np +import math +import sys +from op_test import OpTest + +class TestROIPoolOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + self.calc_roi_pool() + + self.inputs = { + 'X': self.x, + 'ROIs': self.rois} + + self.attrs = { + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width} + + self.outputs = { + 'Out': self.outs, + 'Argmax': self.argmaxes} + + def init_test_case(self): + self.batch_size = 5 + self.channels = 3 + self.height = 6 + self.width = 4 + + # n, c, h, w + self.x_dim = (self.batch_size, self.channels, + self.height, self.width) + + self.spatial_scale = 1.0/4.0 + self.pooled_height = 2 + self.pooled_width = 2 + self.rois_num = 2 + + self.x = np.random.random(self.x_dim).astype('float32') + + def calc_roi_pool(self): + out_data = np.zeros( + (self.rois_num, self.channels, + self.pooled_height, self.pooled_width)) + argmax_data = np.zeros( + (self.rois_num, self.channels, + self.pooled_height, self.pooled_width)) + + for i in range(self.rois_num): + roi = self.rois[i] + roi_batch_id = roi[0] + roi_start_w = int(round(roi[1] * self.spatial_scale)) + roi_start_h = int(round(roi[2] * self.spatial_scale)) + roi_end_w = int(round(roi[3] * self.spatial_scale)) + roi_end_h = int(round(roi[4] * self.spatial_scale)) + + roi_height = int(max(roi_end_h - roi_start_h + 1, 1)); + roi_width = int(max(roi_end_w - roi_start_w + 1, 1)); + + x_i = self.x[roi_batch_id] + + bin_size_h = float(roi_height) / float(self.pooled_height) + bin_size_w = float(roi_width) / float(self.pooled_width) + + for c in range(self.channels): + for ph in range(self.pooled_height): + for pw in range(self.pooled_width): + hstart = int(math.floor(ph * bin_size_h)) + wstart = int(math.floor(pw * bin_size_w)) + hend = int(math.ceil((ph + 1) * bin_size_h)) + wend = int(math.ceil((pw + 1) * bin_size_w)) + + hstart = min(max(hstart + roi_start_h, 0), self.height) + hend = min(max(hend + roi_start_h, 0), self.height) + wstart = min(max(wstart + roi_start_w, 0), self.width) + wend = min(max(wend + roi_start_w, 0), self.width) + + is_empty = (hend <= hstart) or (wend <= wstart) + if is_empty: + out_data[i, c, ph, pw] = 0 + else: + out_data[i, c, ph, pw] = -sys.float_info.max + + argmax_data[i, c, ph, pw] = -1 + + for h in range(hstart, hend): + for w in range(wstart, wend): + if x_i[c, h, w] > out_data[i, c, ph, pw]: + out_data[i, c, ph, pw] = x_i[c, h, w] + argmax_data[i, c, ph, pw] = h * \ + self.width + w + + self.outs = out_data.astype('float32') + self.argmaxes = argmax_data.astype('int64') + + def make_rois(self): + rois = [] + batch_ids = np.random.randint(0, self.batch_size, size=self.rois_num) + for i in range(self.rois_num): + x1 = np.random.random_integers( + 0, self.width / self.spatial_scale - self.pooled_width) + y1 = np.random.random_integers( + 0, self.height / self.spatial_scale - self.pooled_height) + + x2 = np.random.random_integers( + x1 + self.pooled_width, self.width / self.spatial_scale) + y2 = np.random.random_integers( + y1 + self.pooled_height, self.height / self.spatial_scale) + + roi = [batch_ids[i], x1, y1, x2, y2] + rois.append(roi) + self.rois = np.array(rois).astype("int64") + + def setUp(self): + self.op_type = "roi_pool" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + +if __name__ == '__main__': + unittest.main()