提交 4a5bfa89 编写于 作者: D dyning 提交者: qingqing01

Modify RoI pooling op to use LoDTensor and expose it into Python API (#10208)

* modify roi pool with lod and expose ROI Pooling into Python API

* make lod code brief

* make doc more clearly

* make doc more clearly
上级 87d86ee3
......@@ -479,6 +479,13 @@ label_smooth
.. autofunction:: paddle.fluid.layers.label_smooth
:noindex:
roi_pool
---------
.. autofunction:: paddle.fluid.layers.roi_pool
:noindex:
ops
===
......@@ -820,3 +827,5 @@ topk
.. autofunction:: paddle.fluid.layers.topk
:noindex:
......@@ -18,8 +18,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static constexpr int kROISize = 5;
using LoDTensor = framework::LoDTensor;
class ROIPoolOp : public framework::OperatorWithKernel {
public:
......@@ -40,11 +39,11 @@ class ROIPoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(input_dims.size() == 4,
"The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …].");
PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
......@@ -109,10 +108,10 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"H is the height of the feature, and "
"W is the width of the feature.");
AddInput("ROIs",
"(Tensor), "
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …]. "
"should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …]. "
"Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
......
......@@ -19,10 +19,10 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int kROISize = 5;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
......@@ -30,13 +30,11 @@ static inline int NumBlocks(const int N) {
}
template <typename T>
__global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
const int64_t* input_rois,
const float spatial_scale, const int channels,
const int height, const int width,
const int pooled_height,
const int pooled_width, T* output_data,
int64_t* argmax_data) {
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
......@@ -46,11 +44,11 @@ __global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
int roi_batch_ind = roi_batch_id_data[n];
int roi_start_w = round(offset_input_rois[0] * spatial_scale);
int roi_start_h = round(offset_input_rois[1] * spatial_scale);
int roi_end_w = round(offset_input_rois[2] * spatial_scale);
int roi_end_h = round(offset_input_rois[3] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
......@@ -93,7 +91,8 @@ __global__ void GPUROIPoolBackward(
const int nthreads, const int64_t* input_rois, const T* output_grad,
const int64_t* argmax_data, const int num_rois, const float spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, T* input_grad) {
const int pooled_height, const int pooled_width, int* roi_batch_id_data,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
......@@ -102,8 +101,7 @@ __global__ void GPUROIPoolBackward(
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_batch_ind = roi_batch_id_data[n];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
......@@ -124,7 +122,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<Tensor>("ROIs");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto* argmax = ctx.Output<Tensor>("Argmax");
......@@ -133,23 +131,46 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
auto in_stride = framework::stride(in_dims);
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
size_t rois_num = rois->dims()[0];
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same.");
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
framework::Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu);
GPUROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
channels, height, width, pooled_height, pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};
......@@ -159,7 +180,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<Tensor>("ROIs");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -169,12 +190,27 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
size_t rois_num = rois->dims()[0];
int rois_num = rois->dims()[0];
int channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];
if (x_grad) {
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
framework::Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu);
x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
......@@ -189,6 +225,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(),
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
......
......@@ -21,12 +21,14 @@ limitations under the License. */
namespace paddle {
namespace operators {
static constexpr int kROISize = 4;
template <typename DeviceContext, typename T>
class CPUROIPoolOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::Tensor>("ROIs");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* argmax = ctx.Output<framework::Tensor>("Argmax");
......@@ -47,24 +49,36 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
auto out_stride = framework::stride(out->dims());
const T* input_data = in->data<T>();
const int64_t* rois_data = rois->data<int64_t>();
T* output_data = out->mutable_data<T>(ctx.GetPlace());
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = rois_data[0];
PADDLE_ENFORCE_GE(roi_batch_id, 0);
PADDLE_ENFORCE_LT(roi_batch_id, batch_size);
rois_data += roi_stride[0];
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same.");
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
rois_data = rois->data<int64_t>();
T* output_data = out->mutable_data<T>(ctx.GetPlace());
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
const int64_t* rois_data = rois->data<int64_t>();
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = rois_data[0];
int roi_start_w = round(rois_data[1] * spatial_scale);
int roi_start_h = round(rois_data[2] * spatial_scale);
int roi_end_w = round(rois_data[3] * spatial_scale);
int roi_end_h = round(rois_data[4] * spatial_scale);
int roi_batch_id = roi_batch_id_data[n];
int roi_start_w = round(rois_data[0] * spatial_scale);
int roi_start_h = round(rois_data[1] * spatial_scale);
int roi_end_w = round(rois_data[2] * spatial_scale);
int roi_end_h = round(rois_data[3] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
......@@ -133,7 +147,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::Tensor>("ROIs");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......@@ -143,6 +157,20 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto pooled_width = ctx.Attr<int>("pooled_width");
if (in_grad) {
int rois_num = rois->dims()[0];
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
const int64_t* rois_data = rois->data<int64_t>();
const T* out_grad_data = out_grad->data<T>();
const int64_t* argmax_data = argmax->data<int64_t>();
......@@ -156,11 +184,10 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims());
int rois_num = rois->dims()[0];
int channels = in->dims()[1];
for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = rois_data[0];
int roi_batch_idx = roi_batch_id_data[n];
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
......
......@@ -79,6 +79,7 @@ __all__ = [
'lrn',
'pad',
'label_smooth',
'roi_pool',
]
......@@ -3759,3 +3760,53 @@ def label_smooth(label,
outputs={"Out": smooth_label},
attrs={"epsilon": float(epsilon)})
return smooth_label
def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
"""
Region of interest pooling (also known as RoI pooling) is to perform
is to perform max pooling on inputs of nonuniform sizes to obtain
fixed-size feature maps (e.g. 7*7).
The operator has three steps:
1. Dividing each region proposal into equal-sized sections with
the pooled_width and pooled_height
2. Finding the largest value in each section
3. Copying these max values to the output buffer
Args:
input (Variable): The input for ROI pooling.
rois (Variable): ROIs (Regions of Interest) to pool over. It should
be a 2-D one level LoTensor of shape [num_rois, 4].
The layout is [x1, y1, x2, y2], where (x1, y1)
is the top left coordinates, and (x2, y2) is the
bottom right coordinates. The num_rois is the
total number of ROIs in this batch data.
pooled_height (integer): The pooled output height. Default: 1
pooled_width (integer): The pooled output width. Default: 1
spatial_scale (float): Multiplicative spatial scale factor. To
translate ROI coords from their input scale
to the scale used when pooling. Default: 1.0
Returns:
pool_out (Variable): The output is a 4-D tensor of the shape
(num_rois, channels, pooled_h, pooled_w).
Examples:
pool_out = fluid.layers.roi_pool(input=x, rois=rois, 7, 7, 1.0)
"""
helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)
argmaxes = helper.create_tmp_variable(dtype='int32')
helper.append_op(
type="roi_pool",
inputs={"X": input,
"ROIs": rois},
outputs={"Out": pool_out,
"Argmax": argmaxes},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale
})
return pool_out
......@@ -359,6 +359,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(indices)
print(str(program))
def test_roi_pool(self):
program = Program()
with program_guard(program):
x = layers.data(name="x", shape=[256, 30, 30], dtype="float32")
rois = layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
output = layers.roi_pool(x, rois, 7, 7, 0.6)
self.assertIsNotNone(output)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -25,7 +25,7 @@ class TestROIPoolOp(OpTest):
self.make_rois()
self.calc_roi_pool()
self.inputs = {'X': self.x, 'ROIs': self.rois}
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'spatial_scale': self.spatial_scale,
......@@ -36,7 +36,7 @@ class TestROIPoolOp(OpTest):
self.outputs = {'Out': self.outs, 'Argmax': self.argmaxes}
def init_test_case(self):
self.batch_size = 5
self.batch_size = 3
self.channels = 3
self.height = 6
self.width = 4
......@@ -47,7 +47,6 @@ class TestROIPoolOp(OpTest):
self.spatial_scale = 1.0 / 4.0
self.pooled_height = 2
self.pooled_width = 2
self.rois_num = 2
self.x = np.random.random(self.x_dim).astype('float32')
......@@ -106,20 +105,24 @@ class TestROIPoolOp(OpTest):
def make_rois(self):
rois = []
batch_ids = np.random.randint(0, self.batch_size, size=self.rois_num)
for i in range(self.rois_num):
x1 = np.random.random_integers(
0, self.width / self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height / self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width / self.spatial_scale)
y2 = np.random.random_integers(y1 + self.pooled_height,
self.height / self.spatial_scale)
roi = [batch_ids[i], x1, y1, x2, y2]
rois.append(roi)
self.rois_lod = [[]]
for bno in range(self.batch_size):
self.rois_lod[0].append(len(rois))
for i in range(bno + 1):
x1 = np.random.random_integers(
0, self.width / self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height / self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width / self.spatial_scale)
y2 = np.random.random_integers(y1 + self.pooled_height,
self.height / self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_lod[0].append(len(rois))
self.rois_num = len(rois)
self.rois = np.array(rois).astype("int64")
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册