未验证 提交 e45d64ec 编写于 作者: J JYChen 提交者: GitHub

[new api] add func/class API psroi_pool and UT (#35352)

* add func/class API psroi_pool and UT

* add UT in static mode

* Remove redundant type checks in static mode

* More detailed description for test_psroi_pool_op

* fix code format of UT

* fix en-doc
上级 991ae3b6
......@@ -25,22 +25,26 @@ class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"Tensor, "
"(Tensor), "
"the input of PSROIPoolOp. "
"The format of input tensor is NCHW. Where N is the batch size, "
"C is the number of input channels, "
"H is the height of the input feature map, and "
"W is the width. The data type can be float32 or float64");
AddInput("ROIs",
"LoDTensor, "
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]. "
"where (x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates. "
"The roi batch index can be calculated from LoD.");
AddInput("RoisNum",
"(Tensor), "
"The number of RoIs in each image.")
.AsDispensable();
AddOutput("Out",
"Tensor, "
"(Tensor), "
"the output of PSROIPoolOp is a 4-D Tensor with shape "
"(num_rois, output_channels, pooled_h, pooled_w). "
"The data type is the same as `x` ");
......@@ -65,8 +69,6 @@ class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"the pooled output width.")
.SetDefault(1);
AddComment(R"Doc(
**PSROIPool Operator,** `rois` **of this op should be a LoDTensor**
Position sensitive region of interest pooling (also known as PSROIPooling) is to perform
position-sensitive average pooling on regions of interest specified by input, takes as
input N position-sensitive score maps and a list of num_rois regions of interest.
......@@ -106,7 +108,14 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument(
"The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
......@@ -184,6 +193,7 @@ class PSROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType("psroi_pool_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("RoisNum", this->Input("RoisNum"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
......
......@@ -185,34 +185,67 @@ class GPUPSROIPoolOpKernel : public framework::OpKernel<T> {
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The number of rois from input(ROIs) and its LOD "
"must be the same. Received rois %d of input(ROIs) "
"but the number of rois %d from its LOD is %d",
rois_num, rois_num_with_lod));
// set rois batch id
int rois_batch_size;
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size, batch_size));
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(platform::CPUPlace(), rois_num_list.data(),
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()),
rois_num_data, sizeof(int) * rois_batch_size, 0);
int rois_num_count = 0;
for (int i = 0; i < rois_batch_size; ++i) {
rois_num_count += rois_num_list[i];
}
PADDLE_ENFORCE_EQ(
rois_num_count, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and RoisNum must be the same"));
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of input(ROIs) and input(X) must be "
"the same but received batch size of input(ROIs) and "
"input(X) is %d and %d respectively.",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The number of rois from input(ROIs) and its LOD "
"must be the same. Received rois %d of input(ROIs) "
"but the number of rois %d from its LOD is %d",
rois_num, rois_num_with_lod));
// set rois batch id
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
......@@ -257,14 +290,30 @@ class GPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(platform::CPUPlace(), rois_num_list.data(),
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()),
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
......
......@@ -40,6 +40,13 @@ class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];
int rois_num = rois->dims()[0];
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"the channels of input "
"X should equal the product of "
"output_channels x pooled_height x pooled_width"));
auto in_stride = framework::stride(in_dims);
auto out_stride = framework::stride(out->dims());
......@@ -49,32 +56,52 @@ class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument("the rois_batch_size and input(X) "
"batch_size should be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and lod must be the same"));
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
platform::errors::InvalidArgument(
"the channels of input "
"X should equal the product of "
"output_channels x pooled_height x pooled_width"));
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size, batch_size));
int rois_num_count = 0;
for (int i = 0; i < rois_batch_size; ++i) {
rois_num_count += rois_num_data[i];
}
PADDLE_ENFORCE_EQ(
rois_num_count, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and RoisNum must be the same"));
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument("the rois_batch_size and input(X) "
"batch_size should be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num_with_lod, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and lod must be the same"));
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* input_rois = rois->data<T>();
......@@ -93,7 +120,6 @@ class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
static_cast<T>(round(offset_input_rois[2]) + 1.) * spatial_scale;
T roi_end_h =
static_cast<T>(round(offset_input_rois[3]) + 1.) * spatial_scale;
// Force too small rois to be 1 x 1
T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0
T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1);
......@@ -172,15 +198,28 @@ class CPUPSROIPoolGradOpKernel : public framework::OpKernel<T> {
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
rois_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
rois_batch_size = rois_lod.size() - 1;
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n;
}
}
}
const T* input_rois = rois->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
......
......@@ -54,6 +54,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"gather", {"X", "Index", "Axis"}},
{"roi_pool", {"X", "ROIs", "RoisNum"}},
{"roi_align", {"X", "ROIs", "RoisNum"}},
{"psroi_pool", {"X", "ROIs", "RoisNum"}},
{"collect_fpn_proposals",
{"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}},
{"distribute_fpn_proposals", {"FpnRois", "RoisNum"}},
......
......@@ -14,18 +14,89 @@
from __future__ import print_function
import paddle
import math
import numpy as np
import unittest
from op_test import OpTest
def calc_psroi_pool(x, rois, rois_num_per_img, output_channels, spatial_scale,
pooled_height, pooled_width):
"""
Psroi_pool implemented by Numpy.
x: 4-D as (N, C, H, W),
rois: 2-D as [[x1, y1, x2, y2], ...],
rois_num_per_img: 1-D as [nums_of_batch_0, nums_of_batch_1, ...]
"""
output_shape = (len(rois), output_channels, pooled_height, pooled_width)
out_data = np.zeros(output_shape)
batch_id = 0
rois_num_id = 0
rois_num_left = rois_num_per_img[rois_num_id]
for i in range(len(rois)):
roi = rois[i]
roi_batch_id = batch_id
rois_num_left -= 1
if rois_num_left == 0:
rois_num_id += 1
if rois_num_id < len(rois_num_per_img):
rois_num_left = rois_num_per_img[rois_num_id]
batch_id += 1
roi_start_w = round(roi[0]) * spatial_scale
roi_start_h = round(roi[1]) * spatial_scale
roi_end_w = (round(roi[2]) + 1.) * spatial_scale
roi_end_h = (round(roi[3]) + 1.) * spatial_scale
roi_height = max(roi_end_h - roi_start_h, 0.1)
roi_width = max(roi_end_w - roi_start_w, 0.1)
bin_size_h = roi_height / float(pooled_height)
bin_size_w = roi_width / float(pooled_width)
x_i = x[roi_batch_id]
for c in range(output_channels):
for ph in range(pooled_height):
for pw in range(pooled_width):
hstart = int(
math.floor(float(ph) * bin_size_h + roi_start_h))
wstart = int(
math.floor(float(pw) * bin_size_w + roi_start_w))
hend = int(
math.ceil(float(ph + 1) * bin_size_h + roi_start_h))
wend = int(
math.ceil(float(pw + 1) * bin_size_w + roi_start_w))
hstart = min(max(hstart, 0), x.shape[2])
hend = min(max(hend, 0), x.shape[2])
wstart = min(max(wstart, 0), x.shape[3])
wend = min(max(wend, 0), x.shape[3])
c_in = (c * pooled_height + ph) * pooled_width + pw
is_empty = (hend <= hstart) or (wend <= wstart)
out_sum = 0.
for ih in range(hstart, hend):
for iw in range(wstart, wend):
out_sum += x_i[c_in, ih, iw]
bin_area = (hend - hstart) * (wend - wstart)
out_data[i, c, ph, pw] = 0. if is_empty else (
out_sum / float(bin_area))
return out_data
class TestPSROIPoolOp(OpTest):
def set_data(self):
paddle.enable_static()
self.init_test_case()
self.make_rois()
self.calc_psroi_pool()
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
self.outs = calc_psroi_pool(self.x, self.boxes, self.boxes_num,
self.output_channels, self.spatial_scale,
self.pooled_height,
self.pooled_width).astype('float64')
self.inputs = {
'X': self.x,
'ROIs': (self.rois_with_batch_id[:, 1:5], self.rois_lod)
}
self.attrs = {
'output_channels': self.output_channels,
'spatial_scale': self.spatial_scale,
......@@ -67,57 +138,10 @@ class TestPSROIPoolOp(OpTest):
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype('float64')
def calc_psroi_pool(self):
output_shape = (self.rois_num, self.output_channels, self.pooled_height,
self.pooled_width)
out_data = np.zeros(output_shape)
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = int(roi[0])
roi_start_w = round(roi[1]) * self.spatial_scale
roi_start_h = round(roi[2]) * self.spatial_scale
roi_end_w = (round(roi[3]) + 1.) * self.spatial_scale
roi_end_h = (round(roi[4]) + 1.) * self.spatial_scale
roi_height = max(roi_end_h - roi_start_h, 0.1)
roi_width = max(roi_end_w - roi_start_w, 0.1)
bin_size_h = roi_height / float(self.pooled_height)
bin_size_w = roi_width / float(self.pooled_width)
x_i = self.x[roi_batch_id]
for c in range(self.output_channels):
for ph in range(self.pooled_height):
for pw in range(self.pooled_width):
hstart = int(
math.floor(float(ph) * bin_size_h + roi_start_h))
wstart = int(
math.floor(float(pw) * bin_size_w + roi_start_w))
hend = int(
math.ceil(
float(ph + 1) * bin_size_h + roi_start_h))
wend = int(
math.ceil(
float(pw + 1) * bin_size_w + roi_start_w))
hstart = min(max(hstart, 0), self.height)
hend = min(max(hend, 0), self.height)
wstart = min(max(wstart, 0), self.width)
wend = min(max(wend, 0), self.width)
c_in = (c * self.pooled_height + ph
) * self.pooled_width + pw
is_empty = (hend <= hstart) or (wend <= wstart)
out_sum = 0.
for ih in range(hstart, hend):
for iw in range(wstart, wend):
out_sum += x_i[c_in, ih, iw]
bin_area = (hend - hstart) * (wend - wstart)
out_data[i, c, ph, pw] = 0. if is_empty else (
out_sum / float(bin_area))
self.outs = out_data.astype('float64')
self.rois_with_batch_id = np.array(rois).astype('float64')
self.boxes = self.rois_with_batch_id[:, 1:]
self.boxes_num = np.array(
[bno + 1 for bno in range(self.batch_size)]).astype('int32')
def setUp(self):
self.op_type = 'psroi_pool'
......@@ -130,5 +154,175 @@ class TestPSROIPoolOp(OpTest):
self.check_grad(['X'], 'Out')
class TestPSROIPoolDynamicFunctionAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.random([2, 490, 28, 28]).astype(np.float32)
self.boxes = np.array(
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]]).astype(np.float32)
self.boxes_num = np.array([1, 2]).astype(np.int32)
def test_output_size(self):
def test_output_size_is_int():
output_size = 7
out = paddle.vision.ops.psroi_pool(
paddle.to_tensor(self.x),
paddle.to_tensor(self.boxes),
paddle.to_tensor(self.boxes_num), output_size).numpy()
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 10,
1.0, 7, 7)
self.assertTrue(np.allclose(out, expect_out))
def test_output_size_is_tuple():
output_size = (7, 7)
out = paddle.vision.ops.psroi_pool(
paddle.to_tensor(self.x),
paddle.to_tensor(self.boxes),
paddle.to_tensor(self.boxes_num), output_size).numpy()
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 10,
1.0, 7, 7)
self.assertTrue(np.allclose(out, expect_out))
def test_dytype_is_float64():
output_size = (7, 7)
out = paddle.vision.ops.psroi_pool(
paddle.to_tensor(self.x, 'float64'),
paddle.to_tensor(self.boxes, 'float64'),
paddle.to_tensor(self.boxes_num, 'int32'), output_size).numpy()
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 10,
1.0, 7, 7)
self.assertTrue(np.allclose(out, expect_out))
places = ['cpu']
if paddle.fluid.core.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
test_output_size_is_int()
test_output_size_is_tuple()
test_dytype_is_float64()
class TestPSROIPoolDynamicClassAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.random([2, 128, 32, 32]).astype(np.float32)
self.boxes = np.array([[3, 5, 6, 13], [7, 4, 22, 18], [4, 5, 7, 10],
[5, 3, 25, 21]]).astype(np.float32)
self.boxes_num = np.array([2, 2]).astype(np.int32)
def test_output_size(self):
def test_output_size_is_int():
psroi_module = paddle.vision.ops.PSRoIPool(8, 1.1)
out = psroi_module(
paddle.to_tensor(self.x),
paddle.to_tensor(self.boxes),
paddle.to_tensor(self.boxes_num)).numpy()
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 2,
1.1, 8, 8)
self.assertTrue(np.allclose(out, expect_out))
def test_output_size_is_tuple():
psroi_pool_module = paddle.vision.ops.PSRoIPool(8, 1.1)
out = psroi_pool_module(
paddle.to_tensor(self.x),
paddle.to_tensor(self.boxes),
paddle.to_tensor(self.boxes_num)).numpy()
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 2,
1.1, 8, 8)
self.assertTrue(np.allclose(out, expect_out))
def test_dytype_is_float64():
psroi_pool_module = paddle.vision.ops.PSRoIPool(8, 1.1)
out = psroi_pool_module(
paddle.to_tensor(self.x, 'float64'),
paddle.to_tensor(self.boxes, 'float64'),
paddle.to_tensor(self.boxes_num, 'int32')).numpy()
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 2,
1.1, 8, 8)
self.assertTrue(np.allclose(out, expect_out))
paddle.disable_static()
places = ['cpu']
if paddle.fluid.core.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
test_output_size_is_int()
test_output_size_is_tuple()
test_dytype_is_float64()
class TestPSROIPoolBoxesNumError(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.uniform([2, 490, 28, 28], dtype='float32')
self.boxes = paddle.to_tensor(
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], 'float32')
def test_errors(self):
def test_boxes_num_nums_error():
boxes_num = paddle.to_tensor([1, 5], 'int32')
out = paddle.vision.ops.psroi_pool(
self.x, self.boxes, boxes_num, output_size=7)
self.assertRaises(ValueError, test_boxes_num_nums_error)
def test_boxes_num_length_error():
boxes_num = paddle.to_tensor([1, 1, 1], 'int32')
out = paddle.vision.ops.psroi_pool(
self.x, self.boxes, boxes_num, output_size=7)
self.assertRaises(ValueError, test_boxes_num_length_error)
class TestPSROIPoolChannelError(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.uniform([2, 490, 28, 28], dtype='float32')
self.boxes = paddle.to_tensor(
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], 'float32')
self.output_size = 4
def test_errors(self):
def test_channel_error():
boxes_num = paddle.to_tensor([2, 1], 'int32')
out = paddle.vision.ops.psroi_pool(self.x, self.boxes, boxes_num,
self.output_size)
self.assertRaises(ValueError, test_channel_error)
class TestPSROIPoolStaticAPI(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.x_placeholder = paddle.static.data(
name='x', shape=[2, 490, 28, 28])
self.x = np.random.random([2, 490, 28, 28]).astype(np.float32)
self.boxes_placeholder = paddle.static.data(
name='boxes', shape=[3, 4], lod_level=1)
self.boxes = np.array(
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]]).astype(np.float32)
self.boxes_num = np.array([1, 2]).astype(np.int32)
def test_function_in_static(self):
output_size = 7
out = paddle.vision.ops.psroi_pool(self.x_placeholder,
self.boxes_placeholder,
self.boxes_num, output_size)
expect_out = calc_psroi_pool(self.x, self.boxes, self.boxes_num, 10,
1.0, 7, 7)
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
exe = paddle.static.Executor(place)
boxes_lod_data = paddle.fluid.create_lod_tensor(self.boxes,
[[1, 2]], place)
out_res = exe.run(paddle.static.default_main_program(),
feed={'x': self.x,
'boxes': boxes_lod_data},
fetch_list=[out.name])
self.assertTrue(np.allclose(out_res, expect_out))
if __name__ == '__main__':
unittest.main()
......@@ -29,7 +29,9 @@ __all__ = [ #noqa
'deform_conv2d',
'DeformConv2D',
'read_file',
'decode_jpeg'
'decode_jpeg',
'psroi_pool',
'PSRoIPool',
]
......@@ -900,3 +902,114 @@ def decode_jpeg(x, mode='unchanged', name=None):
type="decode_jpeg", inputs=inputs, attrs=attrs, outputs={"Out": out})
return out
def psroi_pool(x, boxes, boxes_num, output_size, spatial_scale=1.0, name=None):
"""
Position sensitive region of interest pooling (also known as PSROIPooling) is to perform
position-sensitive average pooling on regions of interest specified by input. It performs
on inputs of nonuniform sizes to obtain fixed-size feature maps.
PSROIPooling is proposed by R-FCN. Please refer to https://arxiv.org/abs/1605.06409 for more details.
Args:
x (Tensor): Input features with shape (N, C, H, W). The data type can be float32 or float64.
boxes (Tensor): Box coordinates of ROIs (Regions of Interest) to pool over. It should be
a 2-D Tensor with shape (num_rois, 4). Given as [[x1, y1, x2, y2], ...],
(x1, y1) is the top left coordinates, and (x2, y2) is the bottom
right coordinates.
boxes_num (Tensor): The number of boxes contained in each picture in the batch.
output_size (int|Tuple(int, int)) The pooled output size(H, W), data type
is int32. If int, H and W are both equal to output_size.
spatial_scale (float): Multiplicative spatial scale factor to translate ROI coords from their
input scale to the scale used when pooling. Default: 1.0
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
4-D Tensor. The pooled ROIs with shape (num_rois, output_channels, pooled_h, pooled_w).
The output_channels equal to C / (pooled_h * pooled_w), where C is the channels of input.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([2, 490, 28, 28], dtype='float32')
boxes = paddle.to_tensor([[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], dtype='float32')
boxes_num = paddle.to_tensor([1, 2], dtype='int32')
pool_out = paddle.vision.ops.psroi_pool(x, boxes, boxes_num, 7, 1.0)
"""
check_type(output_size, 'output_size', (int, tuple, list), 'psroi_pool')
if isinstance(output_size, int):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
assert (len(x.shape) == 4,
"Input features with shape should be (N, C, H, W)")
output_channels = int(x.shape[1] / (pooled_height * pooled_width))
if in_dygraph_mode():
return core.ops.psroi_pool(x, boxes, boxes_num, "output_channels",
output_channels, "spatial_scale",
spatial_scale, "pooled_height",
pooled_height, "pooled_width", pooled_width)
helper = LayerHelper('psroi_pool', **locals())
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='psroi_pool',
inputs={'X': x,
'ROIs': boxes},
outputs={'Out': out},
attrs={
'output_channels': output_channels,
'spatial_scale': spatial_scale,
'pooled_height': pooled_height,
'pooled_width': pooled_width
})
return out
class PSRoIPool(Layer):
"""
This interface is used to construct a callable object of the ``PSRoIPool`` class. Please
refer to :ref:`api_paddle_vision_ops_psroi_pool`.
Args:
output_size (int|Tuple(int, int)) The pooled output size(H, W), data type
is int32. If int, H and W are both equal to output_size.
spatial_scale (float): Multiplicative spatial scale factor to translate ROI coords from their
input scale to the scale used when pooling. Default: 1.0.
Shape:
- x: 4-D Tensor with shape (N, C, H, W).
- boxes: 2-D Tensor with shape (num_rois, 4).
- boxes_num: 1-D Tensor.
- output: 4-D tensor with shape (num_rois, output_channels, pooled_h, pooled_w).
The output_channels equal to C / (pooled_h * pooled_w), where C is the channels of input.
Returns:
None
Examples:
.. code-block:: python
import paddle
psroi_module = paddle.vision.ops.PSRoIPool(7, 1.0)
x = paddle.uniform([2, 490, 28, 28], dtype='float32')
boxes = paddle.to_tensor([[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], dtype='float32')
boxes_num = paddle.to_tensor([1, 2], dtype='int32')
pool_out = psroi_module(x, boxes, boxes_num)
"""
def __init__(self, output_size, spatial_scale=1.0):
super(PSRoIPool, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, x, boxes, boxes_num):
return psroi_pool(x, boxes, boxes_num, self.output_size,
self.spatial_scale)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册