diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 850ccbfb397cd9722d02ed8c4923d85dae3d943b..19ef23cdfa90912ff6fbd050a685d10861d909d2 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -116,6 +116,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None)) paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None)) paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) +paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c57a34c3a745e8fc03ca57dce478ecf60058a9a9 --- /dev/null +++ b/paddle/fluid/operators/roi_align_op.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/roi_align_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class ROIAlignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ROIAlignOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ROIs"), + "Input(ROIs) of ROIAlignOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ROIAlignOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + auto rois_dims = ctx->GetInputDim("ROIs"); + + PADDLE_ENFORCE(input_dims.size() == 4, + "The format of input tensor is NCHW."); + PADDLE_ENFORCE(rois_dims.size() == 2, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" + "given as [[x1, y1, x2, y2], …]."); + PADDLE_ENFORCE(rois_dims[1] == 4, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4)" + "given as [[x1, y1, x2, y2], …]."); + int pooled_height = ctx->Attrs().Get("pooled_height"); + int pooled_width = ctx->Attrs().Get("pooled_width"); + float spatial_scale = ctx->Attrs().Get("spatial_scale"); + + PADDLE_ENFORCE_GT(pooled_height, 0, + "The pooled output height must greater than 0"); + PADDLE_ENFORCE_GT(pooled_width, 0, + "The pooled output width must greater than 0"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0"); + + auto out_dims = input_dims; + out_dims[0] = rois_dims[0]; + out_dims[1] = input_dims[1]; + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIAlignGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The GRAD@Out of ROIAlignGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The GRAD@X of ROIAlignGradOp should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "The input of ROIAlignOp. " + "The format of input tensor is NCHW. Where N is batch size, " + "C is the number of input channels, " + "H is the height of the feature, and " + "W is the width of the feature."); + AddInput("ROIs", + "(LoDTensor), " + "ROIs (Regions of Interest) to pool over. " + "should be a 2-D LoDTensor of shape (num_rois, 4)" + "given as [[x1, y1, x2, y2], …]. " + "(x1, y1) is the top left coordinates, and " + "(x2, y2) is the bottom right coordinates."); + AddOutput("Out", + "(Tensor), " + "The output of ROIAlignOp is a 4-D tensor with shape " + "(num_rois, channels, pooled_h, pooled_w)."); + AddAttr("spatial_scale", + "(float, default 1.0), " + "Multiplicative spatial scale factor " + "to translate ROI coords from their input scale " + "to the scale used when pooling.") + .SetDefault(1.0); + AddAttr("pooled_height", + "(int, default 1), " + "The pooled output height.") + .SetDefault(1); + AddAttr("pooled_width", + "(int, default 1), " + "The pooled output width.") + .SetDefault(1); + AddAttr("sampling_ratio", + "(int,default -1)," + "number of sampling points in the interpolation grid" + "If <=0, then grid points are adaptive to roi_width " + "and pooled_w, likewise for height") + .SetDefault(-1); + AddComment(R"DOC( +**RoIAlign Operator** + +Region of interest align (also known as RoI align) is to perform +bilinear interpolation on inputs of nonuniform sizes to obtain +fixed-size feature maps (e.g. 7*7) + +Dividing each region proposal into equal-sized sections with +the pooled_width and pooled_height. Location remains the origin +result. + +In each ROI bin, the value of the four regularly sampled locations +are computed directly through bilinear interpolation. The output is +the mean of four locations. +Thus avoid the misaligned problem. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp); +REGISTER_OP_CPU_KERNEL( + roi_align, + ops::CPUROIAlignOpKernel, + ops::CPUROIAlignOpKernel); +REGISTER_OP_CPU_KERNEL( + roi_align_grad, + ops::CPUROIAlignGradOpKernel, + ops::CPUROIAlignGradOpKernel); diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bcec6f3563df7f4e1e48554cc891d596f9e56024 --- /dev/null +++ b/paddle/fluid/operators/roi_align_op.cu @@ -0,0 +1,353 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/roi_align_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ T BilinearInterpolate(const T* input_data, const int height, + const int width, T y, T x) { + if (y < -1.0 || y > height || x < -1.0 || x > width) { + return 0; + } + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = static_cast(y_low); + } else { + y_high = y_low + 1; + } + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = static_cast(x_low); + } else { + x_high = x_low + 1; + } + T ly = y - y_low, lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + T v1 = input_data[y_low * width + x_low]; + T v2 = input_data[y_low * width + x_high]; + T v3 = input_data[y_high * width + x_low]; + T v4 = input_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ void BilinearInterpolateGradient(const int height, const int width, + T y, T x, T* w1, T* w2, T* w3, + T* w4, int* x_low, int* x_high, + int* y_low, int* y_high) { + if (y < -1.0 || y > height || x < -1.0 || x > width) { + return; + } + + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + *y_low = static_cast(y); + *x_low = static_cast(x); + if (*y_low >= height - 1) { + *y_high = *y_low = height - 1; + y = static_cast(*y_low); + } else { + *y_high = *y_low + 1; + } + if (*x_low >= width - 1) { + *x_high = *x_low = width - 1; + x = static_cast(*x_low); + } else { + *x_high = *x_low + 1; + } + T ly = y - *y_low, lx = x - *x_low; + T hy = 1. - ly, hx = 1. - lx; + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; + + return; +} + +template +__global__ void GPUROIAlignForward( + const int nthreads, const T* input_data, const T* input_rois, + const float spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int sampling_ratio, int* roi_batch_id_data, T* output_data) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; + + const T* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = roi_batch_id_data[n]; + + T roi_xmin = offset_input_rois[0] * spatial_scale; + T roi_ymin = offset_input_rois[1] * spatial_scale; + T roi_xmax = offset_input_rois[2] * spatial_scale; + T roi_ymax = offset_input_rois[3] * spatial_scale; + + T roi_width = max(roi_xmax - roi_xmin, static_cast(1.)); + T roi_height = max(roi_ymax - roi_ymin, static_cast(1.)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input_data = + input_data + (roi_batch_ind * channels + c) * height * width; + + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + T output_val = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = BilinearInterpolate(offset_input_data, height, width, y, x); + output_val += val; + } + } + output_val /= count; + output_data[i] = output_val; + } +} + +template +__global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois, + const T* out_grad, const int num_rois, + const float spatial_scale, + const int channels, const int height, + const int width, const int pooled_height, + const int pooled_width, + const int sampling_ratio, + int* roi_batch_id_data, T* input_grad) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; + const T* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = roi_batch_id_data[n]; + + T roi_xmin = offset_input_rois[0] * spatial_scale; + T roi_ymin = offset_input_rois[1] * spatial_scale; + T roi_xmax = offset_input_rois[2] * spatial_scale; + T roi_ymax = offset_input_rois[3] * spatial_scale; + + T roi_width = max(roi_xmax - roi_xmin, static_cast(1.)); + T roi_height = max(roi_ymax - roi_ymin, static_cast(1.)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_input_grad = + input_grad + (roi_batch_ind * channels + c) * height * width; + + const T* offset_out_grad = + out_grad + (n * channels + c) * pooled_height * pooled_width; + const T out_grad_this_bin = offset_out_grad[ph * pooled_width + pw]; + + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + const T count = roi_bin_grid_h * roi_bin_grid_w; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T w1 = 0, w2 = 0, w3 = 0, w4 = 0; + int x_low = -1, x_high = -1, y_low = -1, y_high = -1; + BilinearInterpolateGradient(height, width, y, x, &w1, &w2, &w3, &w4, + &x_low, &x_high, &y_low, &y_high); + T diff1 = out_grad_this_bin * w1 / count; + T diff2 = out_grad_this_bin * w2 / count; + T diff3 = out_grad_this_bin * w3 / count; + T diff4 = out_grad_this_bin * w4 / count; + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_low, + diff1); + platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_high, + diff2); + platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_low, + diff3); + platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_high, + diff4); + } + } + } + } +} + +template +class GPUROIAlignOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + + int rois_num = rois->dims()[0]; + + if (rois_num == 0) return; + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; + + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(platform::CPUPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "The rois_batch_size and imgs batch_size must be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, + "The rois_num from input and lod must be the same."); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + Tensor roi_batch_id_list_gpu; + framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(), + &roi_batch_id_list_gpu); + GPUROIAlignForward< + T><<>>( + output_size, in->data(), rois->data(), spatial_scale, channels, + height, width, pooled_height, pooled_width, sampling_ratio, + roi_batch_id_list_gpu.data(), + out->mutable_data(ctx.GetPlace())); + } +}; + +template +class GPUROIAlignGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + + int rois_num = rois->dims()[0]; + int channels = in->dims()[1]; + int height = in->dims()[2]; + int width = in->dims()[3]; + + if (!in_grad) { + return; + } + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(platform::CPUPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + Tensor roi_batch_id_list_gpu; + framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(), + &roi_batch_id_list_gpu); + + in_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.cuda_device_context(), in_grad, static_cast(0)); + + int output_grad_size = out_grad->numel(); + int blocks = NumBlocks(output_grad_size); + int threads = kNumCUDAThreads; + + if (output_grad_size > 0) { + GPUROIAlignBackward< + T><<>>( + output_grad_size, rois->data(), out_grad->data(), rois_num, + spatial_scale, channels, height, width, pooled_height, pooled_width, + sampling_ratio, roi_batch_id_list_gpu.data(), + in_grad->mutable_data(ctx.GetPlace())); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + roi_align, + ops::GPUROIAlignOpKernel, + ops::GPUROIAlignOpKernel); +REGISTER_OP_CUDA_KERNEL( + roi_align_grad, + ops::GPUROIAlignGradOpKernel, + ops::GPUROIAlignGradOpKernel); diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a18aee1b86283cbb48f0b804ccfc476d7cd78f3b --- /dev/null +++ b/paddle/fluid/operators/roi_align_op.h @@ -0,0 +1,332 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kROISize = 4; + +template +void PreCalcForBilinearInterpolate( + const platform::DeviceContext& ctx, const int height, const int width, + const int pooled_height, const int pooled_width, const int iy_upper, + const int ix_upper, T roi_ymin, T roi_xmin, T bin_size_h, T bin_size_w, + int roi_bin_grid_h, int roi_bin_grid_w, Tensor* pre_pos, Tensor* pre_w) { + int pre_calc_index = 0; + int* pre_pos_data = pre_pos->mutable_data(ctx.GetPlace()); + T* pre_w_data = pre_w->mutable_data(ctx.GetPlace()); + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + // calculate y of sample points + T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + // calculate x of samle points + for (int ix = 0; ix < ix_upper; ix++) { + T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + // deal with elements out of map + if (y < -1.0 || y > height || x < -1.0 || x > width) { + for (int i = 0; i < kROISize; ++i) { + pre_pos_data[i + pre_calc_index * kROISize] = 0; + pre_w_data[i + pre_calc_index * kROISize] = 0; + } + pre_calc_index += 1; + continue; + } + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = static_cast(y_low); + } else { + y_high = y_low + 1; + } + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = static_cast(x_low); + } else { + x_high = x_low + 1; + } + T ly = y - y_low, lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + pre_pos_data[pre_calc_index * kROISize] = y_low * width + x_low; + pre_pos_data[pre_calc_index * kROISize + 1] = y_low * width + x_high; + pre_pos_data[pre_calc_index * kROISize + 2] = y_high * width + x_low; + pre_pos_data[pre_calc_index * kROISize + 3] = y_high * width + x_high; + pre_w_data[pre_calc_index * kROISize] = hy * hx; + pre_w_data[pre_calc_index * kROISize + 1] = hy * lx; + pre_w_data[pre_calc_index * kROISize + 2] = ly * hx; + pre_w_data[pre_calc_index * kROISize + 3] = ly * lx; + pre_calc_index += 1; + } + } + } + } +} + +template +void bilinear_interpolate_gradient(const int height, const int width, T y, T x, + const T out_grad_this_bin, const T count, + T* batch_grad_data) { + int x_low, y_low, x_high, y_high; + T w1, w2, w3, w4; + if (y < -1.0 || y > height || x < -1.0 || x > width) { + w1 = w2 = w3 = w4 = 0; + x_low = x_high = y_low = y_high = -1; + return; + } + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + y_low = static_cast(y); + x_low = static_cast(x); + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = static_cast(y_low); + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = static_cast(x_low); + } else { + x_high = x_low + 1; + } + + T ly = y - y_low, lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + T diff1 = out_grad_this_bin * w1 / count; + T diff2 = out_grad_this_bin * w2 / count; + T diff3 = out_grad_this_bin * w3 / count; + T diff4 = out_grad_this_bin * w4 / count; + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + *(batch_grad_data + y_low * width + x_low) += diff1; + *(batch_grad_data + y_low * width + x_high) += diff2; + *(batch_grad_data + y_high * width + x_low) += diff3; + *(batch_grad_data + y_high * width + x_high) += diff4; + } +} + +template +class CPUROIAlignOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + + auto& dev_ctx = ctx.template device_context(); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + + auto in_stride = framework::stride(in_dims); + auto roi_stride = framework::stride(rois->dims()); + auto out_stride = framework::stride(out->dims()); + + const T* input_data = in->data(); + framework::Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(ctx.GetPlace()); + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "The rois_batch_size and imgs batch_size must be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, + "The rois_num from input and lod must be the same."); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* rois_data = rois->data(); + for (int n = 0; n < rois_num; ++n) { + int roi_batch_id = roi_batch_id_data[n]; + T roi_xmin = rois_data[0] * spatial_scale; + T roi_ymin = rois_data[1] * spatial_scale; + T roi_xmax = rois_data[2] * spatial_scale; + T roi_ymax = rois_data[3] * spatial_scale; + + T roi_width = std::max(roi_xmax - roi_xmin, static_cast(1.)); + T roi_height = std::max(roi_ymax - roi_ymin, static_cast(1.)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + const T* batch_data = input_data + roi_batch_id * in_stride[0]; + + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + Tensor pre_pos; + Tensor pre_w; + int pre_size = count * out_stride[1]; + pre_pos.Resize({pre_size, kROISize}); + pre_w.Resize({pre_size, kROISize}); + + PreCalcForBilinearInterpolate( + dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h, + roi_bin_grid_w, roi_ymin, roi_xmin, bin_size_h, bin_size_w, + roi_bin_grid_h, roi_bin_grid_w, &pre_pos, &pre_w); + const int* pre_pos_data = pre_pos.data(); + const T* pre_w_data = pre_w.data(); + for (int c = 0; c < channels; c++) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + const int pool_index = ph * pooled_width + pw; + T output_val = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + for (int i = 0; i < kROISize; i++) { + int pos = pre_pos_data[pre_calc_index * kROISize + i]; + T w = pre_w_data[pre_calc_index * kROISize + i]; + output_val += w * batch_data[pos]; + } + pre_calc_index += 1; + } + } + output_val /= count; + output_data[pool_index] = output_val; + } + } + batch_data += in_stride[1]; + output_data += out_stride[1]; + } + rois_data += roi_stride[0]; + } + } +}; + +template +class CPUROIAlignGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + auto in_dims = in->dims(); + if (!in_grad) { + return; + } + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + Tensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = + roi_batch_id_list.mutable_data(ctx.GetPlace()); + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + + const T* rois_data = rois->data(); + const T* out_grad_data = out_grad->data(); + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + + auto in_stride = framework::stride(in->dims()); + auto roi_stride = framework::stride(rois->dims()); + auto out_stride = framework::stride(out_grad->dims()); + + for (int n = 0; n < rois_num; ++n) { + int roi_batch_idx = roi_batch_id_data[n]; + T roi_xmin = rois_data[0] * spatial_scale; + T roi_ymin = rois_data[1] * spatial_scale; + T roi_xmax = rois_data[2] * spatial_scale; + T roi_ymax = rois_data[3] * spatial_scale; + T roi_width = std::max(roi_xmax - roi_xmin, static_cast(1.)); + T roi_height = std::max(roi_ymax - roi_ymin, static_cast(1.)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + for (int c = 0; c < channels; ++c) { + T* batch_grad_data = + in_grad_data + roi_batch_idx * in_stride[0] + c * in_stride[1]; + const T* batch_out_grad_data = + out_grad_data + n * out_stride[0] + c * out_stride[1]; + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int pool_index = ph * pooled_width + pw; + T out_grad_this_bin = batch_out_grad_data[pool_index]; + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_width / pooled_width); + T count = roi_bin_grid_h * roi_bin_grid_w; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + bilinear_interpolate_gradient(height, width, y, x, + out_grad_this_bin, count, + batch_grad_data); + } + } + } + } + } + rois_data += roi_stride[0]; + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 58c9ce56bf6306a178727bff4b1fa958685948b1..538035de1a7a062a91cd48a8b4c7c11d5352d6c1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -96,6 +96,7 @@ __all__ = [ 'pad_constant_like', 'label_smooth', 'roi_pool', + 'roi_align', 'dice_loss', 'image_resize', 'image_resize_short', @@ -5430,6 +5431,54 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): return pool_out +@templatedoc() +def roi_align(input, + rois, + pooled_height=1, + pooled_width=1, + spatial_scale=1.0, + sampling_ratio=-1, + name=None): + """ + ${comment} + + Args: + input (Variable): ${x_comment} + rois (Variable): ROIs (Regions of Interest) to pool over. + pooled_height (integer): ${pooled_height_comment} Default: 1 + pooled_width (integer): ${pooled_width_comment} Default: 1 + spatial_scale (float): ${spatial_scale_comment} Default: 1.0 + sampling_ratio(intger): ${sampling_ratio_comment} Default: -1 + + Returns: + Variable: ${out_comment}. + Examples: + .. code-block:: python + + align_out = fluid.layers.roi_align(input=x, + rois=rois, + pooled_height=7, + pooled_width=7, + spatial_scale=0.5, + sampling_ratio=-1) + """ + helper = LayerHelper('roi_align', **locals()) + dtype = helper.input_dtype() + align_out = helper.create_tmp_variable(dtype) + helper.append_op( + type="roi_align", + inputs={"X": input, + "ROIs": rois}, + outputs={"Out": align_out}, + attrs={ + "pooled_height": pooled_height, + "pooled_width": pooled_width, + "spatial_scale": spatial_scale, + "sampling_ratio": sampling_ratio + }) + return align_out + + def dice_loss(input, label, epsilon=0.00001): """ Dice loss for comparing the similarity of two batch of data, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index dc70477ebe1cfbffd207ebb4bbf9d9f39893d79e..50de468dba803d0a2a0c129ad04aac8a3822cdbc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -465,6 +465,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_roi_align(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + output = layers.roi_align(x, rois, 14, 14, 0.5, 2) + self.assertIsNotNone(output) + print(str(program)) + def test_resize_bilinear(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_roi_align_op.py b/python/paddle/fluid/tests/unittests/test_roi_align_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1a252ea547e4d93d83f64fa9cdb3605eeef0a3cf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_roi_align_op.py @@ -0,0 +1,170 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys +from op_test import OpTest + + +class TestROIAlignOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + self.calc_roi_align() + self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)} + self.attrs = { + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width, + 'sampling_ratio': self.sampling_ratio + } + + self.outputs = {'Out': self.out_data} + + def init_test_case(self): + self.batch_size = 3 + self.channels = 3 + self.height = 8 + self.width = 6 + + # n, c, h, w + self.x_dim = (self.batch_size, self.channels, self.height, self.width) + + self.spatial_scale = 1.0 / 2.0 + self.pooled_height = 2 + self.pooled_width = 2 + self.sampling_ratio = -1 + + self.x = np.random.random(self.x_dim).astype('float32') + + def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w, + bin_size_h, bin_size_w): + count = roi_bin_grid_h * roi_bin_grid_w + bilinear_pos = np.zeros( + [self.channels, self.pooled_height, self.pooled_width, count, 4], + np.float32) + bilinear_w = np.zeros( + [self.pooled_height, self.pooled_width, count, 4], np.float32) + for ph in range(self.pooled_width): + for pw in range(self.pooled_height): + c = 0 + for iy in range(roi_bin_grid_h): + y = roi_ymin + ph * bin_size_h + (iy + 0.5) * \ + bin_size_h / roi_bin_grid_h + for ix in range(roi_bin_grid_w): + x = roi_xmin + pw * bin_size_w + (ix + 0.5) * \ + bin_size_w / roi_bin_grid_w + if y < -1.0 or y > self.height or \ + x < -1.0 or x > self.width: + continue + if y <= 0: + y = 0 + if x <= 0: + x = 0 + y_low = int(y) + x_low = int(x) + if y_low >= self.height - 1: + y = y_high = y_low = self.height - 1 + else: + y_high = y_low + 1 + if x_low >= self.width - 1: + x = x_high = x_low = self.width - 1 + else: + x_high = x_low + 1 + ly = y - y_low + lx = x - x_low + hy = 1 - ly + hx = 1 - lx + for ch in range(self.channels): + bilinear_pos[ch, ph, pw, c, 0] = x_i[ch, y_low, + x_low] + bilinear_pos[ch, ph, pw, c, 1] = x_i[ch, y_low, + x_high] + bilinear_pos[ch, ph, pw, c, 2] = x_i[ch, y_high, + x_low] + bilinear_pos[ch, ph, pw, c, 3] = x_i[ch, y_high, + x_high] + bilinear_w[ph, pw, c, 0] = hy * hx + bilinear_w[ph, pw, c, 1] = hy * lx + bilinear_w[ph, pw, c, 2] = ly * hx + bilinear_w[ph, pw, c, 3] = ly * lx + c = c + 1 + return bilinear_pos, bilinear_w + + def calc_roi_align(self): + self.out_data = np.zeros( + (self.rois_num, self.channels, self.pooled_height, + self.pooled_width)).astype('float32') + + for i in range(self.rois_num): + roi = self.rois[i] + roi_batch_id = int(roi[0]) + x_i = self.x[roi_batch_id] + roi_xmin = roi[1] * self.spatial_scale + roi_ymin = roi[2] * self.spatial_scale + roi_xmax = roi[3] * self.spatial_scale + roi_ymax = roi[4] * self.spatial_scale + roi_width = max(roi_xmax - roi_xmin, 1) + roi_height = max(roi_ymax - roi_ymin, 1) + bin_size_h = float(roi_height) / float(self.pooled_height) + bin_size_w = float(roi_width) / float(self.pooled_width) + roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \ + math.ceil(roi_height / self.pooled_height) + roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \ + math.ceil(roi_width / self.pooled_width) + count = int(roi_bin_grid_h * roi_bin_grid_w) + pre_size = count * self.pooled_width * self.pooled_height + bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin, + int(roi_bin_grid_h), + int(roi_bin_grid_w), + bin_size_h, bin_size_w) + for ch in range(self.channels): + align_per_bin = (bilinear_pos[ch] * bilinear_w).sum(axis=-1) + output_val = align_per_bin.mean(axis=-1) + self.out_data[i, ch, :, :] = output_val + + def make_rois(self): + rois = [] + self.rois_lod = [[]] + for bno in range(self.batch_size): + self.rois_lod[0].append(bno + 1) + for i in range(bno + 1): + x1 = np.random.random_integers( + 0, self.width // self.spatial_scale - self.pooled_width) + y1 = np.random.random_integers( + 0, self.height // self.spatial_scale - self.pooled_height) + + x2 = np.random.random_integers(x1 + self.pooled_width, + self.width // self.spatial_scale) + y2 = np.random.random_integers( + y1 + self.pooled_height, self.height // self.spatial_scale) + + roi = [bno, x1, y1, x2, y2] + rois.append(roi) + self.rois_num = len(rois) + self.rois = np.array(rois).astype("float32") + + def setUp(self): + self.op_type = "roi_align" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out')