diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 5627b4f229e100d9979663e8688b8694188bab0f..ac0cd75237baf5e8b860f197d42cd27bae65270e 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -226,11 +226,7 @@ REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, ops::ROIAlignGradMaker); REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp, ops::RoiAlignGradNoNeedBufVarsInferer); -REGISTER_OP_CPU_KERNEL( - roi_align, - ops::CPUROIAlignOpKernel, - ops::CPUROIAlignOpKernel, - ops::CPUROIAlignOpKernel); + REGISTER_OP_CPU_KERNEL( roi_align_grad, ops::CPUROIAlignGradOpKernel, diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu index 18941d10e937d3c28e5793384f00d9d97225a128..1a2e64cd45ca401f5fb8ca6b6975a029ba735280 100644 --- a/paddle/fluid/operators/roi_align_op.cu +++ b/paddle/fluid/operators/roi_align_op.cu @@ -33,43 +33,6 @@ static inline int NumBlocks(const int N) { kNumMaxinumNumBlocks); } -template -__device__ T BilinearInterpolate(const T* input_data, const int height, - const int width, T y, T x) { - if (y < -1.0 || y > height || x < -1.0 || x > width) { - return 0; - } - y = y <= 0 ? 0 : y; - x = x <= 0 ? 0 : x; - int y_low = static_cast(y); - int x_low = static_cast(x); - int y_high; - int x_high; - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = static_cast(y_low); - } else { - y_high = y_low + 1; - } - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = static_cast(x_low); - } else { - x_high = x_low + 1; - } - T ly = y - y_low, lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - T v1 = input_data[y_low * width + x_low]; - T v2 = input_data[y_low * width + x_high]; - T v3 = input_data[y_high * width + x_low]; - T v4 = input_data[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - template __device__ void BilinearInterpolateGradient(const int height, const int width, T y, T x, T* w1, T* w2, T* w3, @@ -102,65 +65,6 @@ __device__ void BilinearInterpolateGradient(const int height, const int width, return; } -template -__global__ void GPUROIAlignForward( - const int nthreads, const T* input_data, const T* input_rois, - const float spatial_scale, const int channels, const int height, - const int width, const int pooled_height, const int pooled_width, - const int sampling_ratio, int* roi_batch_id_data, T* output_data, - const bool continuous_coordinate) { - CUDA_KERNEL_LOOP(i, nthreads) { - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % channels; - int n = i / pooled_width / pooled_height / channels; - - const T* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = roi_batch_id_data[n]; - - T roi_offset = continuous_coordinate ? static_cast(0.5) : 0; - T roi_xmin = offset_input_rois[0] * spatial_scale - roi_offset; - T roi_ymin = offset_input_rois[1] * spatial_scale - roi_offset; - T roi_xmax = offset_input_rois[2] * spatial_scale - roi_offset; - T roi_ymax = offset_input_rois[3] * spatial_scale - roi_offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - if (!continuous_coordinate) { - roi_width = max(roi_width, static_cast(1.)); - roi_height = max(roi_height, static_cast(1.)); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - const T* offset_input_data = - input_data + (roi_batch_ind * channels + c) * height * width; - - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); - T output_val = 0; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_ymin + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_xmin + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - T val = BilinearInterpolate(offset_input_data, height, width, y, x); - output_val += val; - } - } - output_val /= count; - output_data[i] = output_val; - } -} - template __global__ void GPUROIAlignBackward( const int nthreads, const T* input_rois, const T* out_grad, @@ -236,105 +140,6 @@ __global__ void GPUROIAlignBackward( } } -template -class GPUROIAlignOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - auto sampling_ratio = ctx.Attr("sampling_ratio"); - auto aligned = ctx.Attr("aligned"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - int channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - - int rois_num = rois->dims()[0]; - - if (rois_num == 0) return; - - int output_size = out->numel(); - int blocks = NumBlocks(output_size); - int threads = kNumCUDAThreads; -#ifdef WITH_NV_JETSON - platform::ChangeThreadNum(ctx.cuda_device_context(), &threads, 256); -#endif - Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - auto cplace = platform::CPUPlace(); - int* roi_batch_id_data = roi_batch_id_list.mutable_data(cplace); - auto& dev_ctx = ctx.cuda_device_context(); - auto gplace = ctx.GetPlace(); - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - int rois_batch_size = rois_num_t->numel(); - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The rois_batch_size and imgs " - "batch_size must be the same. But received rois_batch_size = %d, " - "batch_size = %d", - rois_batch_size, batch_size)); - - std::vector rois_num_list(rois_batch_size); - memory::Copy(cplace, rois_num_list.data(), gplace, - rois_num_t->data(), sizeof(int) * rois_batch_size, 0); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_list[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_list[n]; - } - } else { - auto lod = rois->lod(); - PADDLE_ENFORCE_EQ( - lod.empty(), false, - platform::errors::InvalidArgument("Input(ROIs) in ROIAlignOp does " - "not contain LoD information.")); - auto rois_lod = lod.back(); - int rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of rois and batch size " - "of images must be the same. But received rois batch size = %d, " - "and images batch size = %d", - rois_batch_size, batch_size)); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ( - rois_num, rois_num_with_lod, - platform::errors::InvalidArgument( - "The actual number of rois and the number of rois " - "provided from Input(RoIsLoD) in RoIAlign must be the same." - " But received actual number of rois is %d, and the number " - "of rois from RoIsLoD is %d", - rois_num, rois_num_with_lod)); - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = memory::Alloc(dev_ctx, bytes); - int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); - memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, - dev_ctx.stream()); - GPUROIAlignForward<<>>( - output_size, in->data(), rois->data(), spatial_scale, channels, - height, width, pooled_height, pooled_width, sampling_ratio, roi_id_data, - out->mutable_data(ctx.GetPlace()), aligned); - } -}; - template class GPUROIAlignGradOpKernel : public framework::OpKernel { public: @@ -416,10 +221,6 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - roi_align, - ops::GPUROIAlignOpKernel, - ops::GPUROIAlignOpKernel); REGISTER_OP_CUDA_KERNEL( roi_align_grad, ops::GPUROIAlignGradOpKernel, diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index e71099ed99f00f5846e6e23d5d39b3b2f8997531..589e35e4ab7ae4caf5efd3fb4d93a26b2ca86b26 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -23,152 +23,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -namespace { // NOLINT -constexpr size_t get_offset(size_t x, size_t y, size_t width) { - return y * width + x; -} - -template -struct offsets_and_ratios { - offsets_and_ratios() = default; - offsets_and_ratios(std::size_t xy, std::size_t xY, std::size_t Xy, - std::size_t XY, T xy_ratio, T xY_ratio, T Xy_ratio, - T XY_ratio) - : xy(xy), - xY(xY), - Xy(Xy), - XY(XY), - xy_ratio(xy_ratio), - xY_ratio(xY_ratio), - Xy_ratio(Xy_ratio), - XY_ratio(XY_ratio) {} - - std::size_t xy = 0; - std::size_t xY = 0; - std::size_t Xy = 0; - std::size_t XY = 0; - T xy_ratio = 0.0f; - T xY_ratio = 0.0f; - T Xy_ratio = 0.0f; - T XY_ratio = 0.0f; -}; - -template -std::vector> get_indexes_and_ratios( - std::size_t width, std::size_t height, const T roi_width, - const T roi_height, const T roi_xmin, const T roi_ymin, - std::size_t pooled_width, std::size_t roi_bin_grid_w, - std::size_t pooled_height, std::size_t roi_bin_grid_h) { - const auto ind_num = - pooled_width * roi_bin_grid_w * pooled_height * roi_bin_grid_h; - - std::vector> interpolation_cords; - interpolation_cords.reserve(ind_num); - - const auto bin_w = roi_width / pooled_width; - const auto bin_h = roi_height / pooled_height; - - for (std::size_t py = 0; py < pooled_height; py++) { - for (std::size_t px = 0; px < pooled_width; px++) { - for (std::size_t iy = 0; iy < roi_bin_grid_h; iy++) { - // calculate x of sample points - auto y = - roi_ymin + - bin_h * (py + - static_cast(iy + .5f) / static_cast(roi_bin_grid_h)); - for (std::size_t ix = 0; ix < roi_bin_grid_w; ix++) { - // calculate x of sample points - auto x = roi_xmin + - bin_w * (px + - static_cast(ix + .5f) / - static_cast(roi_bin_grid_w)); - - // deal with elements out of map - if (y < -1.0 || y > height || x < -1.0 || x > width) { - interpolation_cords.emplace_back(); - continue; - } - y = y <= 0 ? 0 : y; - x = x <= 0 ? 0 : x; - - std::size_t x_low_index = static_cast(x); - std::size_t x_high_index; - if (x_low_index >= width - 1) { - x_high_index = x_low_index = width - 1; - x = static_cast(x_low_index); - } else { - x_high_index = x_low_index + 1; - } - T x_ratio = x_high_index - x; - - std::size_t y_low_index = static_cast(y); - std::size_t y_high_index; - if (y_low_index >= height - 1) { - y_high_index = y_low_index = height - 1; - y = static_cast(y_low_index); - } else { - y_high_index = y_low_index + 1; - } - T y_ratio = y_high_index - y; - - auto xy = get_offset(x_low_index, y_low_index, width); - auto xY = get_offset(x_low_index, y_high_index, width); - auto Xy = get_offset(x_high_index, y_low_index, width); - auto XY = get_offset(x_high_index, y_high_index, width); - - auto xy_ratio = x_ratio * y_ratio; - auto xY_ratio = x_ratio * (1 - y_ratio); - auto Xy_ratio = (1 - x_ratio) * y_ratio; - auto XY_ratio = (1 - x_ratio) * (1 - y_ratio); - - interpolation_cords.emplace_back(xy, xY, Xy, XY, xy_ratio, xY_ratio, - Xy_ratio, XY_ratio); - } - } - } - } - return interpolation_cords; -} // namespace - -template -void interpolate(std::vector& interpolated_values, // NOLINT - const std::vector>& interpolation_cords, - const T* data) { - for (auto& ic : interpolation_cords) { - auto xlyl_offset = ic.xy; - auto xhyl_offset = ic.Xy; - auto xlyh_offset = ic.xY; - auto xhyh_offset = ic.XY; - - auto xlyl_ratio = ic.xy_ratio; - auto xhyl_ratio = ic.Xy_ratio; - auto xlyh_ratio = ic.xY_ratio; - auto xhyh_ratio = ic.XY_ratio; - - interpolated_values.emplace_back( - xlyl_ratio * data[xlyl_offset] + xhyl_ratio * data[xhyl_offset] + - xlyh_ratio * data[xlyh_offset] + xhyh_ratio * data[xhyh_offset]); - } -} - -template -void avg_pool(const std::vector& interpolated_values, T* output_data, - int roi_bin_grid_w, int roi_bin_grid_h, int pooled_width, - int pooled_height) { - const auto data_amount = pooled_width * pooled_height; - const auto grid_points = roi_bin_grid_w * roi_bin_grid_h; - const T count = 1.0 / grid_points; - auto val_begin = interpolated_values.cbegin(); - for (auto i = 0; i < data_amount; ++i) { - T sum = 0.0; - auto val_end = val_begin + grid_points; - sum = std::accumulate(val_begin, val_end, sum); - val_begin = val_end; - output_data[i] = sum * count; - } -} -} // NOLINT - template void bilinear_interpolate_gradient(const int height, const int width, T y, T x, const T out_grad_this_bin, const T count, @@ -213,129 +67,6 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x, } } -template -class CPUROIAlignOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - auto sampling_ratio = ctx.Attr("sampling_ratio"); - auto aligned = ctx.Attr("aligned"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - int channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int rois_num = rois->dims()[0]; - - auto in_stride = phi::stride(in_dims); - auto roi_stride = phi::stride(rois->dims()); - auto out_stride = phi::stride(out->dims()); - - const T* input_data = in->data(); - framework::Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - int* roi_batch_id_data = - roi_batch_id_list.mutable_data(ctx.GetPlace()); - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of rois and the batch size of images " - " must be the same. But received the batch size of rois is %d, " - "and the batch size of images is %d", - rois_batch_size, batch_size)); - auto* rois_num_data = rois_num_t->data(); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_data[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_data[n]; - } - } else { - auto lod = rois->lod(); - PADDLE_ENFORCE_EQ(lod.empty(), false, - platform::errors::InvalidArgument( - "Input(ROIs) Tensor of ROIAlignOp " - "does not contain LoD information.")); - auto rois_lod = lod.back(); - int rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The rois_batch_size and imgs " - "batch_size must be the same. But received rois_batch_size = %d, " - "batch_size = %d", - rois_batch_size, batch_size)); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ( - rois_num, rois_num_with_lod, - platform::errors::InvalidArgument( - "The actual number of rois and the number of rois " - "provided from Input(RoIsLoD) in RoIAlign must be the same." - " But received actual number of rois is %d, and the number " - "of rois from RoIsLoD is %d", - rois_num, rois_num_with_lod)); - for (int n = 0; n < rois_batch_size; ++n) { - for (std::size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - T* output_data = out->mutable_data(ctx.GetPlace()); - const T* rois_data = rois->data(); - T roi_offset = aligned ? T(0.5) : 0; - for (int n = 0; n < rois_num; ++n) { - int roi_batch_id = roi_batch_id_data[n]; - T roi_xmin = rois_data[0] * spatial_scale - roi_offset; - T roi_ymin = rois_data[1] * spatial_scale - roi_offset; - T roi_xmax = rois_data[2] * spatial_scale - roi_offset; - T roi_ymax = rois_data[3] * spatial_scale - roi_offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - if (!aligned) { - roi_width = std::max(roi_width, static_cast(1.)); - roi_height = std::max(roi_height, static_cast(1.)); - } - - const T* batch_data = input_data + roi_batch_id * in_stride[0]; - - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_width / pooled_width); - - auto interpolation_cords = get_indexes_and_ratios( - width, height, roi_width, roi_height, roi_xmin, roi_ymin, - pooled_width, roi_bin_grid_w, pooled_height, roi_bin_grid_h); - - std::vector interpolated_values; - interpolated_values.reserve(interpolation_cords.size()); - for (auto channel = 0; channel < channels; ++channel) { - interpolate(interpolated_values, interpolation_cords, batch_data); - avg_pool(interpolated_values, output_data, roi_bin_grid_w, - roi_bin_grid_h, pooled_width, pooled_height); - batch_data += in_stride[1]; - output_data += out_stride[1]; - interpolated_values.clear(); - } - rois_data += roi_stride[0]; - } - } -}; - template class CPUROIAlignGradOpKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/roi_align_op_npu.cc b/paddle/fluid/operators/roi_align_op_npu.cc index d5b63854d99053ac0620a32cfaba267c7262d515..78509e4299b80ee44610ce3d10f9c57afa0cde18 100644 --- a/paddle/fluid/operators/roi_align_op_npu.cc +++ b/paddle/fluid/operators/roi_align_op_npu.cc @@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/roi_align_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/roi_align_op_xpu.cc b/paddle/fluid/operators/roi_align_op_xpu.cc index 09d2d906653e8c71ddeca7fa606cf5adac8cc596..13490d6fcde3a22e7299db21969d7de6f9a6582c 100644 --- a/paddle/fluid/operators/roi_align_op_xpu.cc +++ b/paddle/fluid/operators/roi_align_op_xpu.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/roi_align_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + template class XPUROIAlignOpKernel : public framework::OpKernel { public: diff --git a/paddle/phi/kernels/cpu/roi_align_kernel.cc b/paddle/phi/kernels/cpu/roi_align_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..35ab99a98eba7e59853fb311d5b2307b69ae31b2 --- /dev/null +++ b/paddle/phi/kernels/cpu/roi_align_kernel.cc @@ -0,0 +1,318 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_align_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { + +constexpr size_t GetOffset(size_t x, size_t y, size_t width) { + return y * width + x; +} + +template +struct OffsetsAndRatios { + OffsetsAndRatios() = default; + OffsetsAndRatios(std::size_t xy, + std::size_t xY, + std::size_t Xy, + std::size_t XY, + T xy_ratio, + T xY_ratio, + T Xy_ratio, + T XY_ratio) + : xy(xy), + xY(xY), + Xy(Xy), + XY(XY), + xy_ratio(xy_ratio), + xY_ratio(xY_ratio), + Xy_ratio(Xy_ratio), + XY_ratio(XY_ratio) {} + + std::size_t xy = 0; + std::size_t xY = 0; + std::size_t Xy = 0; + std::size_t XY = 0; + T xy_ratio = 0.0f; + T xY_ratio = 0.0f; + T Xy_ratio = 0.0f; + T XY_ratio = 0.0f; +}; + +template +std::vector> GetIndexesAndRatios( + std::size_t width, + std::size_t height, + const T roi_width, + const T roi_height, + const T roi_xmin, + const T roi_ymin, + std::size_t pooled_width, + std::size_t roi_bin_grid_w, + std::size_t pooled_height, + std::size_t roi_bin_grid_h) { + const auto ind_num = + pooled_width * roi_bin_grid_w * pooled_height * roi_bin_grid_h; + + std::vector> interpolation_cords; + interpolation_cords.reserve(ind_num); + + const auto bin_w = roi_width / pooled_width; + const auto bin_h = roi_height / pooled_height; + + for (std::size_t py = 0; py < pooled_height; py++) { + for (std::size_t px = 0; px < pooled_width; px++) { + for (std::size_t iy = 0; iy < roi_bin_grid_h; iy++) { + // calculate x of sample points + auto y = + roi_ymin + + bin_h * (py + + static_cast(iy + .5f) / static_cast(roi_bin_grid_h)); + for (std::size_t ix = 0; ix < roi_bin_grid_w; ix++) { + // calculate x of sample points + auto x = roi_xmin + + bin_w * (px + + static_cast(ix + .5f) / + static_cast(roi_bin_grid_w)); + + // deal with elements out of map + if (y < -1.0 || y > height || x < -1.0 || x > width) { + interpolation_cords.emplace_back(); + continue; + } + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + + std::size_t x_low_index = static_cast(x); + std::size_t x_high_index; + if (x_low_index >= width - 1) { + x_high_index = x_low_index = width - 1; + x = static_cast(x_low_index); + } else { + x_high_index = x_low_index + 1; + } + T x_ratio = x_high_index - x; + + std::size_t y_low_index = static_cast(y); + std::size_t y_high_index; + if (y_low_index >= height - 1) { + y_high_index = y_low_index = height - 1; + y = static_cast(y_low_index); + } else { + y_high_index = y_low_index + 1; + } + T y_ratio = y_high_index - y; + + auto xy = GetOffset(x_low_index, y_low_index, width); + auto xY = GetOffset(x_low_index, y_high_index, width); + auto Xy = GetOffset(x_high_index, y_low_index, width); + auto XY = GetOffset(x_high_index, y_high_index, width); + + auto xy_ratio = x_ratio * y_ratio; + auto xY_ratio = x_ratio * (1 - y_ratio); + auto Xy_ratio = (1 - x_ratio) * y_ratio; + auto XY_ratio = (1 - x_ratio) * (1 - y_ratio); + + interpolation_cords.emplace_back( + xy, xY, Xy, XY, xy_ratio, xY_ratio, Xy_ratio, XY_ratio); + } + } + } + } + return interpolation_cords; +} + +template +void Interpolate(std::vector& interpolated_values, // NOLINT + const std::vector>& interpolation_cords, + const T* data) { + for (auto& ic : interpolation_cords) { + auto xlyl_offset = ic.xy; + auto xhyl_offset = ic.Xy; + auto xlyh_offset = ic.xY; + auto xhyh_offset = ic.XY; + + auto xlyl_ratio = ic.xy_ratio; + auto xhyl_ratio = ic.Xy_ratio; + auto xlyh_ratio = ic.xY_ratio; + auto xhyh_ratio = ic.XY_ratio; + + interpolated_values.emplace_back( + xlyl_ratio * data[xlyl_offset] + xhyl_ratio * data[xhyl_offset] + + xlyh_ratio * data[xlyh_offset] + xhyh_ratio * data[xhyh_offset]); + } +} + +template +void AvgPool(const std::vector& interpolated_values, + T* output_data, + int roi_bin_grid_w, + int roi_bin_grid_h, + int pooled_width, + int pooled_height) { + const auto data_amount = pooled_width * pooled_height; + const auto grid_points = roi_bin_grid_w * roi_bin_grid_h; + const T count = 1.0 / grid_points; + auto val_begin = interpolated_values.cbegin(); + for (auto i = 0; i < data_amount; ++i) { + T sum = 0.0; + auto val_end = val_begin + grid_points; + sum = std::accumulate(val_begin, val_end, sum); + val_begin = val_end; + output_data[i] = sum * count; + } +} + +template +void ROIAlignKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + DenseTensor* out) { + auto in_dims = x.dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = boxes.dims()[0]; + + auto in_stride = phi::stride(in_dims); + auto roi_stride = phi::stride(boxes.dims()); + auto out_stride = phi::stride(out->dims()); + + const T* input_data = x.data(); + DenseTensor roi_batch_id_list = Empty(dev_ctx, {rois_num}); + int* roi_batch_id_data = roi_batch_id_list.data(); + int boxes_batch_size; + if (boxes_num) { + boxes_batch_size = boxes_num->numel(); + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + errors::InvalidArgument( + "The batch size of rois and the batch size of images " + " must be the same. But received the batch size of rois is %d, " + "and the batch size of images is %d", + boxes_batch_size, + batch_size)); + auto* boxes_num_data = boxes_num->data(); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_data[n]; ++i) { + roi_batch_id_data[i] = n; + } + start += boxes_num_data[n]; + } + } else { + auto lod = boxes.lod(); + PADDLE_ENFORCE_EQ( + lod.empty(), + false, + errors::InvalidArgument("Input(ROIs) Tensor of ROIAlignOp " + "does not contain LoD information.")); + auto boxes_lod = lod.back(); + int boxes_batch_size = boxes_lod.size() - 1; + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + errors::InvalidArgument( + "The boxes_batch_size and imgs " + "batch_size must be the same. But received boxes_batch_size = %d, " + "batch_size = %d", + boxes_batch_size, + batch_size)); + int boxes_num_with_lod = boxes_lod[boxes_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, + boxes_num_with_lod, + errors::InvalidArgument( + "The actual number of rois and the number of rois " + "provided from Input(RoIsLoD) in RoIAlign must be the same." + " But received actual number of rois is %d, and the number " + "of rois from RoIsLoD is %d", + rois_num, + boxes_num_with_lod)); + for (int n = 0; n < boxes_batch_size; ++n) { + for (std::size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + } + T* output_data = dev_ctx.template Alloc(out); + const T* boxes_data = boxes.data(); + T roi_offset = aligned ? T(0.5) : 0; + for (int n = 0; n < rois_num; ++n) { + int roi_batch_id = roi_batch_id_data[n]; + T roi_xmin = boxes_data[0] * spatial_scale - roi_offset; + T roi_ymin = boxes_data[1] * spatial_scale - roi_offset; + T roi_xmax = boxes_data[2] * spatial_scale - roi_offset; + T roi_ymax = boxes_data[3] * spatial_scale - roi_offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + if (!aligned) { + roi_width = std::max(roi_width, static_cast(1.)); + roi_height = std::max(roi_height, static_cast(1.)); + } + + const T* batch_data = input_data + roi_batch_id * in_stride[0]; + + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + auto interpolation_cords = GetIndexesAndRatios(width, + height, + roi_width, + roi_height, + roi_xmin, + roi_ymin, + pooled_width, + roi_bin_grid_w, + pooled_height, + roi_bin_grid_h); + + std::vector interpolated_values; + interpolated_values.reserve(interpolation_cords.size()); + for (auto channel = 0; channel < channels; ++channel) { + Interpolate(interpolated_values, interpolation_cords, batch_data); + AvgPool(interpolated_values, + output_data, + roi_bin_grid_w, + roi_bin_grid_h, + pooled_width, + pooled_height); + batch_data += in_stride[1]; + output_data += out_stride[1]; + interpolated_values.clear(); + } + boxes_data += roi_stride[0]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + roi_align, CPU, ALL_LAYOUT, phi::ROIAlignKernel, float, double, int) {} diff --git a/paddle/phi/kernels/gpu/roi_align_kernel.cu b/paddle/phi/kernels/gpu/roi_align_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2f906fa4f663b6da65a3e986af2214dfb49f2ec0 --- /dev/null +++ b/paddle/phi/kernels/gpu/roi_align_kernel.cu @@ -0,0 +1,255 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_align_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" + +#include "paddle/fluid/memory/memory.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; +static constexpr int kROISize = 4; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__device__ T BilinearInterpolate( + const T* input_data, const int height, const int width, T y, T x) { + if (y < -1.0 || y > height || x < -1.0 || x > width) { + return 0; + } + y = y <= 0 ? 0 : y; + x = x <= 0 ? 0 : x; + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = static_cast(y_low); + } else { + y_high = y_low + 1; + } + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = static_cast(x_low); + } else { + x_high = x_low + 1; + } + T ly = y - y_low, lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + T v1 = input_data[y_low * width + x_low]; + T v2 = input_data[y_low * width + x_high]; + T v3 = input_data[y_high * width + x_low]; + T v4 = input_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__global__ void GPUROIAlignForward(const int nthreads, + const T* input_data, + const T* input_rois, + const float spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + int* roi_batch_id_data, + T* output_data, + const bool continuous_coordinate) { + CUDA_KERNEL_LOOP(i, nthreads) { + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; + + const T* offset_input_rois = input_rois + n * kROISize; + int roi_batch_ind = roi_batch_id_data[n]; + + T roi_offset = continuous_coordinate ? static_cast(0.5) : 0; + T roi_xmin = offset_input_rois[0] * spatial_scale - roi_offset; + T roi_ymin = offset_input_rois[1] * spatial_scale - roi_offset; + T roi_xmax = offset_input_rois[2] * spatial_scale - roi_offset; + T roi_ymax = offset_input_rois[3] * spatial_scale - roi_offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + if (!continuous_coordinate) { + roi_width = max(roi_width, static_cast(1.)); + roi_height = max(roi_height, static_cast(1.)); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input_data = + input_data + (roi_batch_ind * channels + c) * height * width; + + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_ymin + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_xmin + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = BilinearInterpolate(offset_input_data, height, width, y, x); + output_val += val; + } + } + output_val /= count; + output_data[i] = output_val; + } +} + +template +void ROIAlignKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + DenseTensor* out) { + auto in_dims = x.dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + + int rois_num = boxes.dims()[0]; + + if (rois_num == 0) return; + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; +#ifdef WITH_NV_JETSON + backends::gpu::ChangeThreadNum(dev_ctx, &threads, 256); +#endif + DenseTensor roi_batch_id_list; + roi_batch_id_list.Resize({rois_num}); + int* roi_batch_id_data = dev_ctx.template HostAlloc(&roi_batch_id_list); + auto cplace = phi::CPUPlace(); + auto gplace = dev_ctx.GetPlace(); + if (boxes_num) { + int boxes_batch_size = boxes_num->numel(); + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + errors::InvalidArgument( + "The boxes_batch_size and imgs " + "batch_size must be the same. But received boxes_batch_size = %d, " + "batch_size = %d", + boxes_batch_size, + batch_size)); + + std::vector boxes_num_list(boxes_batch_size); + paddle::memory::Copy(cplace, + boxes_num_list.data(), + gplace, + boxes_num->data(), + sizeof(int) * boxes_batch_size, + 0); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_list[n]; ++i) { + roi_batch_id_data[i] = n; + } + start += boxes_num_list[n]; + } + } else { + auto lod = boxes.lod(); + PADDLE_ENFORCE_EQ(lod.empty(), + false, + errors::InvalidArgument("Input(ROIs) in ROIAlignOp does " + "not contain LoD information.")); + auto boxes_lod = lod.back(); + int boxes_batch_size = boxes_lod.size() - 1; + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + errors::InvalidArgument( + "The batch size of rois and batch size " + "of images must be the same. But received rois batch size = %d, " + "and images batch size = %d", + boxes_batch_size, + batch_size)); + int boxes_num_with_lod = boxes_lod[boxes_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, + boxes_num_with_lod, + errors::InvalidArgument( + "The actual number of rois and the number of rois " + "provided from Input(RoIsLoD) in RoIAlign must be the same." + " But received actual number of rois is %d, and the number " + "of rois from RoIsLoD is %d", + rois_num, + boxes_num_with_lod)); + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + } + int bytes = roi_batch_id_list.numel() * sizeof(int); + auto roi_ptr = paddle::memory::Alloc(dev_ctx, bytes); + int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); + paddle::memory::Copy( + gplace, roi_id_data, cplace, roi_batch_id_data, bytes, dev_ctx.stream()); + GPUROIAlignForward<<>>( + output_size, + x.data(), + boxes.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + roi_id_data, + dev_ctx.template Alloc(out), + aligned); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + roi_align, GPU, ALL_LAYOUT, phi::ROIAlignKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/scale_kernel.cu b/paddle/phi/kernels/gpu/scale_kernel.cu index 930c50a24be8fae40535c2d5e6dbbe85e7ced990..6f96a697b2f2db6c2097640f34c30142939f80e0 100644 --- a/paddle/phi/kernels/gpu/scale_kernel.cu +++ b/paddle/phi/kernels/gpu/scale_kernel.cu @@ -15,10 +15,9 @@ limitations under the License. */ #include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/phi/common/float16.h" namespace phi { diff --git a/paddle/phi/kernels/roi_align_kernel.h b/paddle/phi/kernels/roi_align_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..16b52c563a592f0cc23ddca94f554f5dc49e8ccf --- /dev/null +++ b/paddle/phi/kernels/roi_align_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void ROIAlignKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + int sampling_ratio, + bool aligned, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/roi_align_sig.cc b/paddle/phi/ops/compat/roi_align_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..0549103b6fbcb8b2367c34c8a44fb3b52f318859 --- /dev/null +++ b/paddle/phi/ops/compat/roi_align_sig.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ROIAlignOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("roi_align", + {"X", "ROIs", "RoisNum"}, + {"pooled_height", + "pooled_width", + "spatial_scale", + "sampling_ratio", + "aligned"}, + {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(roi_align, phi::ROIAlignOpArgumentMapping);