提交 6ea38091 编写于 作者: D Double_V 提交者: lanxianghit

Support prroi_pool_op with Tensor and LoDTensor rois (#20649)

1. Add a new input named batch_roi_nums for prroi_pool_op. batch_roi_nums includes the number of roi for each image in batch when rois is Tensor. This information is saved in rois's lod when rois is LoDTensor.
2. add grad check to prroi_pool_op and solve unnormal X grad diff in CPU.
上级 d8a9b134
...@@ -39,6 +39,11 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -39,6 +39,11 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"where (x1, y1) is the top left coordinates, and " "where (x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates. " "(x2, y2) is the bottom right coordinates. "
"The roi batch index can be calculated from LoD."); "The roi batch index can be calculated from LoD.");
AddInput("BatchRoINums",
"(Tensor), "
"1-D tensor with shape [N], the number of"
" rois for each image in batch, where N is the batch size")
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"(Tensor), " "(Tensor), "
"the output of PRROIPoolOp is a 4-D Tensor with shape " "the output of PRROIPoolOp is a 4-D Tensor with shape "
...@@ -75,39 +80,57 @@ class PRROIPoolOp : public framework::OperatorWithKernel { ...@@ -75,39 +80,57 @@ class PRROIPoolOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of op(PRROIPool) should not be null."); platform::errors::NotFound(
"Input(X) of op(PRROIPool) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
"Input(ROIs) of op(PRROIPool) should not be null."); platform::errors::NotFound(
"Input(ROIs) of op(PRROIPool) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of op(PRROIPool) should not be null."); platform::errors::NotFound(
"Output(Out) of op(PRROIPool) should not be null."));
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs"); auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE_EQ(input_dims.size(), 4, PADDLE_ENFORCE_EQ(input_dims.size(), 4,
"The format of input tensor is NCHW"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(rois_dims.size(), 2, "The format of input tensor is NCHW"));
PADDLE_ENFORCE_EQ(
rois_dims.size(), 2,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"); "given as [(x1, y1, x2, y2), ...]"));
PADDLE_ENFORCE_EQ(rois_dims[1], 4, PADDLE_ENFORCE_EQ(
rois_dims[1], 4,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"); "given as [(x1, y1, x2, y2), ...]"));
int pooled_height = ctx->Attrs().Get<int>("pooled_height"); int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width"); int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale"); float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_GT(pooled_height, 0, PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must be greater than 0"); platform::errors::InvalidArgument(
"The pooled output height must be greater than 0"));
PADDLE_ENFORCE_GT(pooled_width, 0, PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must be greater than 0"); platform::errors::InvalidArgument(
"The pooled output width must be greater than 0"));
PADDLE_ENFORCE_GT(spatial_scale, 0.0f, PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0."); platform::errors::InvalidArgument(
"The spatial scale must greater than 0."));
auto out_dims = input_dims; auto out_dims = input_dims;
out_dims[0] = rois_dims[0]; out_dims[0] = rois_dims[0];
out_dims[1] = input_dims[1]; out_dims[1] = input_dims[1];
out_dims[2] = pooled_height; out_dims[2] = pooled_height;
out_dims[3] = pooled_width; out_dims[3] = pooled_width;
if (ctx->HasInput("BatchRoINums")) {
auto rois_batch_index = ctx->GetInputDim("BatchRoINums");
PADDLE_ENFORCE_EQ(rois_batch_index[0], input_dims[0],
platform::errors::InvalidArgument(
"The length of BatchRoINums should equal to "
"first dim of inputs(X)"));
}
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
} }
...@@ -154,6 +177,7 @@ class PRROIPoolGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -154,6 +177,7 @@ class PRROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Out", this->Output("Out")); op->SetInput("Out", this->Output("Out"));
op->SetInput("ROIs", this->Input("ROIs")); op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("BatchRoINums", this->Input("BatchRoINums"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("ROIs"), this->InputGrad("ROIs")); op->SetOutput(framework::GradVarName("ROIs"), this->InputGrad("ROIs"));
...@@ -172,8 +196,12 @@ REGISTER_OPERATOR(prroi_pool_grad, ops::PRROIPoolGradOp); ...@@ -172,8 +196,12 @@ REGISTER_OPERATOR(prroi_pool_grad, ops::PRROIPoolGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
prroi_pool, prroi_pool,
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>); ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
prroi_pool_grad, prroi_pool_grad,
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>); ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -185,8 +185,8 @@ __global__ void GPUPRROIPoolBackward( ...@@ -185,8 +185,8 @@ __global__ void GPUPRROIPoolBackward(
PrRoIPoolingCoorBackward( PrRoIPoolingCoorBackward(
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w, s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w,
win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale, win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale,
offset_in_data, offset_out_data, offset_input_grad_data, offset_in_data, offset_out_data, offset_input_roi_grad_data,
offset_input_roi_grad_data, GPUAccumulateRois<T>, offset_output_grad_data, GPUAccumulateRois<T>,
[](const T x, const T y) { return max(x, y); }, [](const T x, const T y) { return max(x, y); },
[](const T x, const T y) { return min(x, y); }); [](const T x, const T y) { return min(x, y); });
} }
...@@ -214,41 +214,66 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -214,41 +214,66 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel<T> {
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (rois_num == 0) return; if (rois_num == 0) return;
// set rois batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) {
auto* batchroinum = ctx.Input<Tensor>("BatchRoINums");
framework::Tensor batch_index_cpu;
framework::TensorCopySync(*batchroinum, platform::CPUPlace(),
&batch_index_cpu);
int rois_batch_size = batchroinum->dims()[0];
auto* batch_index = batch_index_cpu.data<int64_t>();
size_t c = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int64_t k = 0; k < batch_index[n]; ++k) {
rois_batch_id_data[c] = n;
c = c + 1;
}
}
} else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size, rois_batch_size, batch_size,
"The rois_batch_size and input(X) batch_size must be the same."); platform::errors::InvalidArgument(
"The rois_batch_size and input(X) batch_size must be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, PADDLE_ENFORCE_EQ(
"The rois_num from input and lod must be the same."); rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The rois_num from input and lod must be the same."));
// set rois batch id
framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
for (int n = 0; n < rois_batch_size; ++n) { for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
rois_batch_id_data[i] = n; rois_batch_id_data[i] = n;
} }
} }
}
framework::Tensor rois_batch_id_list_gpu;
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &rois_batch_id_list_gpu);
int output_size = out->numel(); int output_size = out->numel();
int blocks = NumBlocks(output_size); int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads; int threads = kNumCUDAThreads;
auto cplace = platform::CPUPlace();
auto& dev_ctx = ctx.cuda_device_context();
int bytes = rois_batch_id_list.numel() * sizeof(int);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, rois_batch_id_data, bytes,
dev_ctx.stream());
// call cuda kernel function // call cuda kernel function
GPUPRROIPoolForward< GPUPRROIPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, output_size, in->data<T>(), rois->data<T>(), spatial_scale,
input_channels, height, width, output_channels, pooled_height, input_channels, height, width, output_channels, pooled_height,
pooled_width, rois_batch_id_list_gpu.data<int>(), pooled_width, roi_id_data, out->mutable_data<T>(ctx.GetPlace()));
out->mutable_data<T>(ctx.GetPlace()));
} }
}; };
...@@ -275,12 +300,33 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -275,12 +300,33 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
int height = in->dims()[2]; int height = in->dims()[2];
int width = in->dims()[3]; int width = in->dims()[3];
if (input_grad) { if (input_grad || input_roi_grad) {
// set roi batch id // set roi batch id
framework::Tensor rois_batch_id_list; framework::Tensor rois_batch_id_list;
rois_batch_id_list.Resize({rois_num}); rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data = int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(platform::CPUPlace()); rois_batch_id_list.mutable_data<int>(platform::CPUPlace());
if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) {
auto* batchroinum = ctx.Input<Tensor>("BatchRoINums");
framework::Tensor batch_index_cpu;
framework::TensorCopySync(*batchroinum, platform::CPUPlace(),
&batch_index_cpu);
int rois_batch_size = batchroinum->dims()[0];
auto* batch_index = batch_index_cpu.data<int64_t>();
size_t c = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int64_t k = 0; k < batch_index[n]; ++k) {
rois_batch_id_data[c] = n;
c = c + 1;
}
}
} else {
PADDLE_ENFORCE_EQ(rois->lod().empty(), false,
platform::errors::InvalidArgument(
"the lod of Input ROIs shoule not be empty when "
"BatchRoINums is None!"));
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) { for (int n = 0; n < rois_batch_size; ++n) {
...@@ -288,10 +334,16 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -288,10 +334,16 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
rois_batch_id_data[i] = n; rois_batch_id_data[i] = n;
} }
} }
}
framework::Tensor rois_batch_id_list_gpu; auto cplace = platform::CPUPlace();
framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), auto& dev_ctx = ctx.cuda_device_context();
ctx.device_context(), &rois_batch_id_list_gpu); int bytes = rois_batch_id_list.numel() * sizeof(int);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, rois_batch_id_data, bytes,
dev_ctx.stream());
input_grad->mutable_data<T>(ctx.GetPlace()); input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
...@@ -304,12 +356,10 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -304,12 +356,10 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
int threads = kNumCUDAThreads; int threads = kNumCUDAThreads;
if (output_grad_size > 0) { if (output_grad_size > 0) {
GPUPRROIPoolBackward< GPUPRROIPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, in->data<T>(), rois->data<T>(), output_grad_size, in->data<T>(), rois->data<T>(),
output_grad->data<T>(), spatial_scale, input_channels, height, output_grad->data<T>(), spatial_scale, input_channels, height,
width, output_channels, pooled_height, pooled_width, width, output_channels, pooled_height, pooled_width, roi_id_data,
rois_batch_id_list_gpu.data<int>(),
input_grad->mutable_data<T>(ctx.GetPlace()), out->data<T>(), input_grad->mutable_data<T>(ctx.GetPlace()), out->data<T>(),
input_roi_grad->mutable_data<T>(ctx.GetPlace())); input_roi_grad->mutable_data<T>(ctx.GetPlace()));
} }
......
...@@ -81,7 +81,7 @@ inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff, ...@@ -81,7 +81,7 @@ inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff,
const T coeff) { const T coeff) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow) { if (!overflow) {
*(diff + h * width + w) = top_diff * coeff; *(diff + h * width + w) += top_diff * coeff;
} }
} }
...@@ -179,7 +179,7 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( ...@@ -179,7 +179,7 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward(
T win_start_h, T win_end_w, T win_end_h, int pw, int ph, T win_start_h, T win_end_w, T win_end_h, int pw, int ph,
const int pooled_width, const int pooled_height, T win_size, const int pooled_width, const int pooled_height, T win_size,
const float spatial_scale, const T* this_bottom_data, const float spatial_scale, const T* this_bottom_data,
const T* this_top_data, T* this_data_grad, T* this_out_grad, const T* this_top_data, T* this_data_grad, const T* this_out_grad,
Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) { Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) {
T g_x1_y = 0.f; T g_x1_y = 0.f;
T g_x2_y = 0.f; T g_x2_y = 0.f;
...@@ -232,20 +232,19 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward( ...@@ -232,20 +232,19 @@ inline HOSTDEVICE void PrRoIPoolingCoorBackward(
partial_y1 = partial_y1 / win_size * spatial_scale; partial_y1 = partial_y1 / win_size * spatial_scale;
partial_y2 = partial_y2 / win_size * spatial_scale; partial_y2 = partial_y2 / win_size * spatial_scale;
this_data_grad[0] = 0; functor(this_data_grad + 0,
functor(this_data_grad + 1,
(partial_x1 * (1.0 - static_cast<T>(pw) / pooled_width) + (partial_x1 * (1.0 - static_cast<T>(pw) / pooled_width) +
partial_x2 * (1.0 - static_cast<T>(pw + 1) / pooled_width)) * partial_x2 * (1.0 - static_cast<T>(pw + 1) / pooled_width)) *
(*this_out_grad)); (*this_out_grad));
functor(this_data_grad + 2, functor(this_data_grad + 1,
(partial_y1 * (1.0 - static_cast<T>(ph) / pooled_height) + (partial_y1 * (1.0 - static_cast<T>(ph) / pooled_height) +
partial_y2 * (1.0 - static_cast<T>(ph + 1) / pooled_height)) * partial_y2 * (1.0 - static_cast<T>(ph + 1) / pooled_height)) *
(*this_out_grad)); (*this_out_grad));
functor(this_data_grad + 3, functor(this_data_grad + 2,
(partial_x2 * static_cast<T>(pw + 1) / pooled_width + (partial_x2 * static_cast<T>(pw + 1) / pooled_width +
partial_x1 * static_cast<T>(pw) / pooled_width) * partial_x1 * static_cast<T>(pw) / pooled_width) *
(*this_out_grad)); (*this_out_grad));
functor(this_data_grad + 4, functor(this_data_grad + 3,
(partial_y2 * static_cast<T>(ph + 1) / pooled_height + (partial_y2 * static_cast<T>(ph + 1) / pooled_height +
partial_y1 * static_cast<T>(ph) / pooled_height) * partial_y1 * static_cast<T>(ph) / pooled_height) *
(*this_out_grad)); (*this_out_grad));
...@@ -262,7 +261,6 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -262,7 +261,6 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
auto pooled_height = ctx.Attr<int>("pooled_height"); auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width"); auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale"); auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims(); auto in_dims = in->dims();
int batch_size = in_dims[0]; int batch_size = in_dims[0];
int input_channels = in_dims[1]; int input_channels = in_dims[1];
...@@ -270,6 +268,7 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -270,6 +268,7 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
int height = in_dims[2]; int height = in_dims[2];
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto in_stride = framework::stride(in_dims); auto in_stride = framework::stride(in_dims);
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
...@@ -280,15 +279,33 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -280,15 +279,33 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
rois_batch_id_list.Resize({rois_num}); rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data = int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace()); rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) {
auto* batchroinum = ctx.Input<framework::Tensor>("BatchRoINums");
auto* batch_index = batchroinum->data<int64_t>();
int rois_batch_size = batchroinum->dims()[0];
size_t c = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int64_t k = 0; k < batch_index[n]; ++k) {
rois_batch_id_data[c] = n;
c = c + 1;
}
}
} else {
PADDLE_ENFORCE_EQ(rois->lod().empty(), false,
platform::errors::InvalidArgument(
"the lod of Input ROIs shoule not be empty when "
"BatchRoINums is None!"));
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size, rois_batch_size, batch_size,
"the rois_batch_size and input(X) batch_size should be the same."); platform::errors::InvalidArgument("the rois_batch_size and input(X) "
"batch_size should be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num, PADDLE_ENFORCE_EQ(
"the rois_num from input and lod must be the same"); rois_num_with_lod, rois_num,
platform::errors::InvalidArgument(
"the rois_num from input and lod must be the same"));
// calculate batch id index for each roi according to LoD // calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) { for (int n = 0; n < rois_batch_size; ++n) {
...@@ -296,10 +313,10 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -296,10 +313,10 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
rois_batch_id_data[i] = n; rois_batch_id_data[i] = n;
} }
} }
}
T* output_data = out->mutable_data<T>(ctx.GetPlace()); T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* input_rois = rois->data<T>(); const T* input_rois = rois->data<T>();
// calculate prroipooling, parallel processing can be implemented per ROI // calculate prroipooling, parallel processing can be implemented per ROI
for (int n = 0; n < rois_num; ++n) { for (int n = 0; n < rois_num; ++n) {
// set roi batch id // set roi batch id
...@@ -390,7 +407,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -390,7 +407,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
auto pooled_width = ctx.Attr<int>("pooled_width"); auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale"); auto spatial_scale = ctx.Attr<float>("spatial_scale");
if (input_grad && input_roi_grad) { if (input_grad || input_roi_grad) {
auto in_dims = in->dims(); auto in_dims = in->dims();
auto* in_data = in->data<T>(); auto* in_data = in->data<T>();
auto* out_data = out->data<T>(); auto* out_data = out->data<T>();
...@@ -406,6 +423,18 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -406,6 +423,18 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
rois_batch_id_list.Resize({rois_num}); rois_batch_id_list.Resize({rois_num});
int* rois_batch_id_data = int* rois_batch_id_data =
rois_batch_id_list.mutable_data<int>(ctx.GetPlace()); rois_batch_id_list.mutable_data<int>(ctx.GetPlace());
if (ctx.HasInput("BatchRoINums") || rois->lod().empty()) {
auto* batchroinum = ctx.Input<framework::Tensor>("BatchRoINums");
auto* batch_index = batchroinum->data<int64_t>();
int rois_batch_size = batchroinum->dims()[0];
size_t c = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int64_t k = 0; k < batch_index[n]; ++k) {
rois_batch_id_data[c] = n;
c = c + 1;
}
}
} else {
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
// calculate batch id index for each roi according to LoD // calculate batch id index for each roi according to LoD
...@@ -414,16 +443,22 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -414,16 +443,22 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
rois_batch_id_data[i] = n; rois_batch_id_data[i] = n;
} }
} }
}
const T* input_rois = rois->data<T>(); const T* input_rois = rois->data<T>();
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
T* input_roi_grad_data = input_roi_grad->mutable_data<T>(ctx.GetPlace());
input_grad->mutable_data<T>(ctx.GetPlace());
input_roi_grad->mutable_data<T>(ctx.GetPlace());
// set gradient of X to be 0. before backpropagate. // set gradient of X to be 0. before backpropagate.
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), input_grad, set_zero(ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0)); static_cast<T>(0));
set_zero(ctx.template device_context<DeviceContext>(), input_roi_grad,
static_cast<T>(0));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
T* input_roi_grad_data = input_roi_grad->mutable_data<T>(ctx.GetPlace());
// backpropagate gradient per output pixel // backpropagate gradient per output pixel
int output_grad_size = output_grad->numel(); int output_grad_size = output_grad->numel();
...@@ -493,7 +528,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -493,7 +528,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h,
win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size, win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size,
spatial_scale, offset_in_data, offset_out_data, spatial_scale, offset_in_data, offset_out_data,
offset_input_grad_data, offset_input_roi_grad_data, offset_input_roi_grad_data, offset_output_grad_data,
CPUAccumulateRois<T>, CPUAccumulateRois<T>,
[](const T x, const T y) { return std::max(x, y); }, [](const T x, const T y) { return std::max(x, y); },
[](const T x, const T y) { return std::min(x, y); }); [](const T x, const T y) { return std::min(x, y); });
......
...@@ -12655,35 +12655,53 @@ def prroi_pool(input, ...@@ -12655,35 +12655,53 @@ def prroi_pool(input,
spatial_scale=1.0, spatial_scale=1.0,
pooled_height=1, pooled_height=1,
pooled_width=1, pooled_width=1,
batch_roi_nums=None,
name=None): name=None):
""" """
The precise roi pooling implementation for paddle?https://arxiv.org/pdf/1807.11590.pdf The precise roi pooling implementation for paddle. Reference: https://arxiv.org/pdf/1807.11590.pdf
Args: Args:
input (Variable):The input of Deformable PSROIPooling.The shape of input tensor is input (Variable):The input of precise roi pooliing.The shape of input tensor is
[N,C,H,W]. Where N is batch size,C is number of input channels,H [N,C,H,W]. Where N is batch size,C is number of input channels,H
is height of the feature, and W is the width of the feature. is height of the feature, and W is the width of the feature.
rois (Variable): ROIs (Regions of Interest) to pool over.It should be rois (Variable): ROIs (Regions of Interest) to pool over.It should be
a 2-D LoDTensor of shape (num_rois, 4), the lod level a 2-D LoDTensor or Tensor of shape (num_rois, 4), the lod level
is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is is 1 when it is LoDTensor. The LoD include the rois's batch index
information. If rois is Tensor, its batch index information should
be provided by batch_index.
Given as [[x1, y1, x2, y2], ...], (x1, y1) is
the top left coordinates, and (x2, y2) is the bottom the top left coordinates, and (x2, y2) is the bottom
right coordinates. right coordinates.
spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width). spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width).
Equals the reciprocal of total stride in convolutional layers, Default: 1.0. Equals the reciprocal of total stride in convolutional layers, Default: 1.0.
pooled_height (integer): The pooled output height. Default: 1. pooled_height (integer): The pooled output height. Default: 1.
pooled_width (integer): The pooled output width. Default: 1. pooled_width (integer): The pooled output width. Default: 1.
batch_roi_nums (Variable): The number of roi for each image in batch. It
shoule be 1-D Tensor, with shape [N] and dtype int64,
where N is the batch size. Default: None. Be note: The lod of input should be
empty when batch_roi_nums has values;
name (str, default None): The name of this operation. name (str, default None): The name of this operation.
Returns: Returns:
Variable(Tensor): The shape of the returned Tensor is (num_rois, output_channels, pooled_h, pooled_w), with value type float32,float16.. Variable(Tensor):The shape of the returned Tensor is (N, C, pooled_height, pooled_width), with value type float32,float16. N, C denote batch_size and channels of input respectively.
Examples: Examples:
.. code-block:: python .. code-block:: python
## prroi_pool without batch_roi_num
import paddle.fluid as fluid import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[490, 28, 28], dtype='float32') x = fluid.data(name='x', shape=[None, 490, 28, 28], dtype='float32')
rois = fluid.layers.data(name='rois', shape=[4], lod_level=1, dtype='float32') rois = fluid.data(name='rois', shape=[None, 4], lod_level=1, dtype='float32')
pool_out = fluid.layers.prroi_pool(x, rois, 1.0, 7, 7) pool_out = fluid.layers.prroi_pool(x, rois, 1.0, 7, 7)
## prroi_pool with batch_roi_num
batchsize=4
x2 = fluid.data(name='x2', shape=[batchsize, 490, 28, 28], dtype='float32')
rois2 = fluid.data(name='rois2', shape=[batchsize, 4], dtype='float32')
batch_rois_num = fluid.data(name='rois_nums', shape=[batchsize], dtype='int64')
pool_out2 = fluid.layers.prroi_pool(x2, rois2, 1.0, 7, 7, batch_roi_nums=batch_rois_num)
""" """
helper = LayerHelper('prroi_pool', **locals()) helper = LayerHelper('prroi_pool', **locals())
# check attrs # check attrs
...@@ -12695,10 +12713,12 @@ def prroi_pool(input, ...@@ -12695,10 +12713,12 @@ def prroi_pool(input,
raise TypeError("pooled_width must be int type") raise TypeError("pooled_width must be int type")
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
inputs_op = {'X': input, 'ROIs': rois}
if batch_roi_nums is not None:
inputs_op['BatchRoINums'] = batch_roi_nums
helper.append_op( helper.append_op(
type='prroi_pool', type='prroi_pool',
inputs={'X': input, inputs=inputs_op,
'ROIs': rois},
outputs={'Out': out}, outputs={'Out': out},
attrs={ attrs={
'spatial_scale': spatial_scale, 'spatial_scale': spatial_scale,
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from py_precise_roi_pool import PyPrRoIPool from py_precise_roi_pool import PyPrRoIPool
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
...@@ -29,7 +30,7 @@ class TestPRROIPoolOp(OpTest): ...@@ -29,7 +30,7 @@ class TestPRROIPoolOp(OpTest):
self.prRoIPool = PyPrRoIPool() self.prRoIPool = PyPrRoIPool()
self.outs = self.prRoIPool.compute( self.outs = self.prRoIPool.compute(
self.x, self.rois, self.output_channels, self.spatial_scale, self.x, self.rois, self.output_channels, self.spatial_scale,
self.pooled_height, self.pooled_width).astype('float32') self.pooled_height, self.pooled_width).astype('float64')
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)} self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = { self.attrs = {
'output_channels': self.output_channels, 'output_channels': self.output_channels,
...@@ -42,17 +43,17 @@ class TestPRROIPoolOp(OpTest): ...@@ -42,17 +43,17 @@ class TestPRROIPoolOp(OpTest):
def init_test_case(self): def init_test_case(self):
self.batch_size = 3 self.batch_size = 3
self.channels = 3 * 2 * 2 self.channels = 3 * 2 * 2
self.height = 6 self.height = 12
self.width = 4 self.width = 16
self.x_dim = [self.batch_size, self.channels, self.height, self.width] self.x_dim = [self.batch_size, self.channels, self.height, self.width]
self.spatial_scale = 1.0 / 4.0 self.spatial_scale = 1.0 / 2.0
self.output_channels = self.channels self.output_channels = self.channels
self.pooled_height = 2 self.pooled_height = 4
self.pooled_width = 2 self.pooled_width = 4
self.x = np.random.random(self.x_dim).astype('float32') self.x = np.random.random(self.x_dim).astype('float64')
def make_rois(self): def make_rois(self):
rois = [] rois = []
...@@ -72,7 +73,7 @@ class TestPRROIPoolOp(OpTest): ...@@ -72,7 +73,7 @@ class TestPRROIPoolOp(OpTest):
roi = [bno, x1, y1, x2, y2] roi = [bno, x1, y1, x2, y2]
rois.append(roi) rois.append(roi)
self.rois_num = len(rois) self.rois_num = len(rois)
self.rois = np.array(rois).astype('float32') self.rois = np.array(rois).astype('float64')
def setUp(self): def setUp(self):
self.op_type = 'prroi_pool' self.op_type = 'prroi_pool'
...@@ -82,17 +83,20 @@ class TestPRROIPoolOp(OpTest): ...@@ -82,17 +83,20 @@ class TestPRROIPoolOp(OpTest):
self.check_output() self.check_output()
def test_backward(self): def test_backward(self):
for place in self._get_places(): places = [fluid.CPUPlace()]
self._get_gradient(['X'], place, ["Out"], None) if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_grad_with_place(place, ['X'], 'Out')
def run_net(self, place): def run_net(self, place):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = fluid.layers.data( x = fluid.layers.data(
name="X", name="X",
shape=[self.channels, self.height, self.width], shape=[self.channels, self.height, self.width],
dtype="float32") dtype="float64")
rois = fluid.layers.data( rois = fluid.layers.data(
name="ROIs", shape=[4], dtype="float32", lod_level=1) name="ROIs", shape=[4], dtype="float64", lod_level=1)
output = fluid.layers.prroi_pool(x, rois, 0.25, 2, 2) output = fluid.layers.prroi_pool(x, rois, 0.25, 2, 2)
loss = fluid.layers.mean(output) loss = fluid.layers.mean(output)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3) optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
...@@ -116,9 +120,127 @@ class TestPRROIPoolOp(OpTest): ...@@ -116,9 +120,127 @@ class TestPRROIPoolOp(OpTest):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = fluid.layers.data( x = fluid.layers.data(
name="x", shape=[245, 30, 30], dtype="float32") name="x", shape=[245, 30, 30], dtype="float64")
rois = fluid.layers.data(
name="rois", shape=[4], dtype="float64", lod_level=1)
# spatial_scale must be float type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7,
7)
# pooled_height must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25,
0.7, 7)
# pooled_width must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25,
7, 0.7)
class TestPRROIPoolOpTensorRoIs(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.prRoIPool = PyPrRoIPool()
self.outs = self.prRoIPool.compute(
self.x, self.rois, self.output_channels, self.spatial_scale,
self.pooled_height, self.pooled_width).astype('float64')
self.rois_index = np.array(self.rois_lod).reshape([-1]).astype(np.int64)
self.inputs = {
'X': self.x,
'ROIs': self.rois[:, 1:5],
'BatchRoINums': self.rois_index
}
self.attrs = {
'output_channels': self.output_channels,
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width
}
self.outputs = {'Out': self.outs}
def init_test_case(self):
self.batch_size = 3
self.channels = 3 * 2 * 2
self.height = 12
self.width = 16
self.x_dim = [self.batch_size, self.channels, self.height, self.width]
self.spatial_scale = 1.0 / 2.0
self.output_channels = self.channels
self.pooled_height = 4
self.pooled_width = 4
self.x = np.random.random(self.x_dim).astype('float64')
def make_rois(self):
rois = []
self.rois_lod = []
for bno in range(self.batch_size):
self.rois_lod.append(bno + 1)
for i in range(bno + 1):
x1 = np.random.uniform(
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.uniform(
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.uniform(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.uniform(y1 + self.pooled_height,
self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype('float64')
def setUp(self):
self.op_type = 'prroi_pool'
self.set_data()
def test_check_output(self):
self.check_output()
def test_backward(self):
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_grad_with_place(place, ['X'], 'Out')
def run_net(self, place):
with program_guard(Program(), Program()):
x = fluid.layers.data(
name="X",
shape=[self.channels, self.height, self.width],
dtype="float64")
rois = fluid.layers.data(name="ROIs", shape=[4], dtype="float64")
rois_index = fluid.layers.data(
name='rois_idx', shape=[], dtype="int64")
output = fluid.layers.prroi_pool(
x, rois, 0.25, 2, 2, batch_roi_nums=rois_index)
loss = fluid.layers.mean(output)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run(fluid.default_main_program(), {
'X': self.x,
"ROIs": self.rois[:, 1:5],
"rois_idx": self.rois_index
})
def test_net(self):
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.run_net(place)
def test_errors(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(
name="x", shape=[245, 30, 30], dtype="float64")
rois = fluid.layers.data( rois = fluid.layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1) name="rois", shape=[4], dtype="float64", lod_level=1)
# spatial_scale must be float type # spatial_scale must be float type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7, self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7,
7) 7)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册