未验证 提交 39de9b8a 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Move forward kernel of roi_align into phi (#40382)

* move roi_align kernel to phi

* fix bug of roi_align xpu
上级 573ca984
...@@ -226,11 +226,7 @@ REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, ...@@ -226,11 +226,7 @@ REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
ops::ROIAlignGradMaker<paddle::imperative::OpBase>); ops::ROIAlignGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp, REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp,
ops::RoiAlignGradNoNeedBufVarsInferer); ops::RoiAlignGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
roi_align,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roi_align_grad, roi_align_grad,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -33,43 +33,6 @@ static inline int NumBlocks(const int N) { ...@@ -33,43 +33,6 @@ static inline int NumBlocks(const int N) {
kNumMaxinumNumBlocks); kNumMaxinumNumBlocks);
} }
template <class T>
__device__ T BilinearInterpolate(const T* input_data, const int height,
const int width, T y, T x) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
return 0;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T v1 = input_data[y_low * width + x_low];
T v2 = input_data[y_low * width + x_high];
T v3 = input_data[y_high * width + x_low];
T v4 = input_data[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <class T> template <class T>
__device__ void BilinearInterpolateGradient(const int height, const int width, __device__ void BilinearInterpolateGradient(const int height, const int width,
T y, T x, T* w1, T* w2, T* w3, T y, T x, T* w1, T* w2, T* w3,
...@@ -102,65 +65,6 @@ __device__ void BilinearInterpolateGradient(const int height, const int width, ...@@ -102,65 +65,6 @@ __device__ void BilinearInterpolateGradient(const int height, const int width,
return; return;
} }
template <class T>
__global__ void GPUROIAlignForward(
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int sampling_ratio, int* roi_batch_id_data, T* output_data,
const bool continuous_coordinate) {
CUDA_KERNEL_LOOP(i, nthreads) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
T roi_offset = continuous_coordinate ? static_cast<T>(0.5) : 0;
T roi_xmin = offset_input_rois[0] * spatial_scale - roi_offset;
T roi_ymin = offset_input_rois[1] * spatial_scale - roi_offset;
T roi_xmax = offset_input_rois[2] * spatial_scale - roi_offset;
T roi_ymax = offset_input_rois[3] * spatial_scale - roi_offset;
T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin;
if (!continuous_coordinate) {
roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = BilinearInterpolate(offset_input_data, height, width, y, x);
output_val += val;
}
}
output_val /= count;
output_data[i] = output_val;
}
}
template <typename T> template <typename T>
__global__ void GPUROIAlignBackward( __global__ void GPUROIAlignBackward(
const int nthreads, const T* input_rois, const T* out_grad, const int nthreads, const T* input_rois, const T* out_grad,
...@@ -236,105 +140,6 @@ __global__ void GPUROIAlignBackward( ...@@ -236,105 +140,6 @@ __global__ void GPUROIAlignBackward(
} }
} }
template <typename Place, typename T>
class GPUROIAlignOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto aligned = ctx.Attr<bool>("aligned");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(ctx.cuda_device_context(), &threads, 256);
#endif
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = ctx.GetPlace();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs "
"batch_size must be the same. But received rois_batch_size = %d, "
"batch_size = %d",
rois_batch_size, batch_size));
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto lod = rois->lod();
PADDLE_ENFORCE_EQ(
lod.empty(), false,
platform::errors::InvalidArgument("Input(ROIs) in ROIAlignOp does "
"not contain LoD information."));
auto rois_lod = lod.back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and batch size "
"of images must be the same. But received rois batch size = %d, "
"and images batch size = %d",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The actual number of rois and the number of rois "
"provided from Input(RoIsLoD) in RoIAlign must be the same."
" But received actual number of rois is %d, and the number "
"of rois from RoIsLoD is %d",
rois_num, rois_num_with_lod));
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
GPUROIAlignForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
height, width, pooled_height, pooled_width, sampling_ratio, roi_id_data,
out->mutable_data<T>(ctx.GetPlace()), aligned);
}
};
template <typename Place, typename T> template <typename Place, typename T>
class GPUROIAlignGradOpKernel : public framework::OpKernel<T> { class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -416,10 +221,6 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -416,10 +221,6 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
roi_align,
ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
roi_align_grad, roi_align_grad,
ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, float>, ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, float>,
......
...@@ -23,152 +23,6 @@ namespace operators { ...@@ -23,152 +23,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
namespace { // NOLINT
constexpr size_t get_offset(size_t x, size_t y, size_t width) {
return y * width + x;
}
template <class T>
struct offsets_and_ratios {
offsets_and_ratios() = default;
offsets_and_ratios(std::size_t xy, std::size_t xY, std::size_t Xy,
std::size_t XY, T xy_ratio, T xY_ratio, T Xy_ratio,
T XY_ratio)
: xy(xy),
xY(xY),
Xy(Xy),
XY(XY),
xy_ratio(xy_ratio),
xY_ratio(xY_ratio),
Xy_ratio(Xy_ratio),
XY_ratio(XY_ratio) {}
std::size_t xy = 0;
std::size_t xY = 0;
std::size_t Xy = 0;
std::size_t XY = 0;
T xy_ratio = 0.0f;
T xY_ratio = 0.0f;
T Xy_ratio = 0.0f;
T XY_ratio = 0.0f;
};
template <typename T>
std::vector<offsets_and_ratios<T>> get_indexes_and_ratios(
std::size_t width, std::size_t height, const T roi_width,
const T roi_height, const T roi_xmin, const T roi_ymin,
std::size_t pooled_width, std::size_t roi_bin_grid_w,
std::size_t pooled_height, std::size_t roi_bin_grid_h) {
const auto ind_num =
pooled_width * roi_bin_grid_w * pooled_height * roi_bin_grid_h;
std::vector<offsets_and_ratios<T>> interpolation_cords;
interpolation_cords.reserve(ind_num);
const auto bin_w = roi_width / pooled_width;
const auto bin_h = roi_height / pooled_height;
for (std::size_t py = 0; py < pooled_height; py++) {
for (std::size_t px = 0; px < pooled_width; px++) {
for (std::size_t iy = 0; iy < roi_bin_grid_h; iy++) {
// calculate x of sample points
auto y =
roi_ymin +
bin_h * (py +
static_cast<T>(iy + .5f) / static_cast<T>(roi_bin_grid_h));
for (std::size_t ix = 0; ix < roi_bin_grid_w; ix++) {
// calculate x of sample points
auto x = roi_xmin +
bin_w * (px +
static_cast<T>(ix + .5f) /
static_cast<T>(roi_bin_grid_w));
// deal with elements out of map
if (y < -1.0 || y > height || x < -1.0 || x > width) {
interpolation_cords.emplace_back();
continue;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
std::size_t x_low_index = static_cast<std::size_t>(x);
std::size_t x_high_index;
if (x_low_index >= width - 1) {
x_high_index = x_low_index = width - 1;
x = static_cast<T>(x_low_index);
} else {
x_high_index = x_low_index + 1;
}
T x_ratio = x_high_index - x;
std::size_t y_low_index = static_cast<std::size_t>(y);
std::size_t y_high_index;
if (y_low_index >= height - 1) {
y_high_index = y_low_index = height - 1;
y = static_cast<T>(y_low_index);
} else {
y_high_index = y_low_index + 1;
}
T y_ratio = y_high_index - y;
auto xy = get_offset(x_low_index, y_low_index, width);
auto xY = get_offset(x_low_index, y_high_index, width);
auto Xy = get_offset(x_high_index, y_low_index, width);
auto XY = get_offset(x_high_index, y_high_index, width);
auto xy_ratio = x_ratio * y_ratio;
auto xY_ratio = x_ratio * (1 - y_ratio);
auto Xy_ratio = (1 - x_ratio) * y_ratio;
auto XY_ratio = (1 - x_ratio) * (1 - y_ratio);
interpolation_cords.emplace_back(xy, xY, Xy, XY, xy_ratio, xY_ratio,
Xy_ratio, XY_ratio);
}
}
}
}
return interpolation_cords;
} // namespace
template <typename T>
void interpolate(std::vector<T>& interpolated_values, // NOLINT
const std::vector<offsets_and_ratios<T>>& interpolation_cords,
const T* data) {
for (auto& ic : interpolation_cords) {
auto xlyl_offset = ic.xy;
auto xhyl_offset = ic.Xy;
auto xlyh_offset = ic.xY;
auto xhyh_offset = ic.XY;
auto xlyl_ratio = ic.xy_ratio;
auto xhyl_ratio = ic.Xy_ratio;
auto xlyh_ratio = ic.xY_ratio;
auto xhyh_ratio = ic.XY_ratio;
interpolated_values.emplace_back(
xlyl_ratio * data[xlyl_offset] + xhyl_ratio * data[xhyl_offset] +
xlyh_ratio * data[xlyh_offset] + xhyh_ratio * data[xhyh_offset]);
}
}
template <typename T>
void avg_pool(const std::vector<T>& interpolated_values, T* output_data,
int roi_bin_grid_w, int roi_bin_grid_h, int pooled_width,
int pooled_height) {
const auto data_amount = pooled_width * pooled_height;
const auto grid_points = roi_bin_grid_w * roi_bin_grid_h;
const T count = 1.0 / grid_points;
auto val_begin = interpolated_values.cbegin();
for (auto i = 0; i < data_amount; ++i) {
T sum = 0.0;
auto val_end = val_begin + grid_points;
sum = std::accumulate(val_begin, val_end, sum);
val_begin = val_end;
output_data[i] = sum * count;
}
}
} // NOLINT
template <class T> template <class T>
void bilinear_interpolate_gradient(const int height, const int width, T y, T x, void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
const T out_grad_this_bin, const T count, const T out_grad_this_bin, const T count,
...@@ -213,129 +67,6 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x, ...@@ -213,129 +67,6 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
} }
} }
template <typename DeviceContext, typename T>
class CPUROIAlignOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto aligned = ctx.Attr<bool>("aligned");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto in_stride = phi::stride(in_dims);
auto roi_stride = phi::stride(rois->dims());
auto out_stride = phi::stride(out->dims());
const T* input_data = in->data<T>();
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size, batch_size));
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto lod = rois->lod();
PADDLE_ENFORCE_EQ(lod.empty(), false,
platform::errors::InvalidArgument(
"Input(ROIs) Tensor of ROIAlignOp "
"does not contain LoD information."));
auto rois_lod = lod.back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs "
"batch_size must be the same. But received rois_batch_size = %d, "
"batch_size = %d",
rois_batch_size, batch_size));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(
rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The actual number of rois and the number of rois "
"provided from Input(RoIsLoD) in RoIAlign must be the same."
" But received actual number of rois is %d, and the number "
"of rois from RoIsLoD is %d",
rois_num, rois_num_with_lod));
for (int n = 0; n < rois_batch_size; ++n) {
for (std::size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* rois_data = rois->data<T>();
T roi_offset = aligned ? T(0.5) : 0;
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n];
T roi_xmin = rois_data[0] * spatial_scale - roi_offset;
T roi_ymin = rois_data[1] * spatial_scale - roi_offset;
T roi_xmax = rois_data[2] * spatial_scale - roi_offset;
T roi_ymax = rois_data[3] * spatial_scale - roi_offset;
T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin;
if (!aligned) {
roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
}
const T* batch_data = input_data + roi_batch_id * in_stride[0];
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_width / pooled_width);
auto interpolation_cords = get_indexes_and_ratios(
width, height, roi_width, roi_height, roi_xmin, roi_ymin,
pooled_width, roi_bin_grid_w, pooled_height, roi_bin_grid_h);
std::vector<T> interpolated_values;
interpolated_values.reserve(interpolation_cords.size());
for (auto channel = 0; channel < channels; ++channel) {
interpolate(interpolated_values, interpolation_cords, batch_data);
avg_pool(interpolated_values, output_data, roi_bin_grid_w,
roi_bin_grid_h, pooled_width, pooled_height);
batch_data += in_stride[1];
output_data += out_stride[1];
interpolated_values.clear();
}
rois_data += roi_stride[0];
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/roi_align_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class XPUROIAlignOpKernel : public framework::OpKernel<T> { class XPUROIAlignOpKernel : public framework::OpKernel<T> {
public: public:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roi_align_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
constexpr size_t GetOffset(size_t x, size_t y, size_t width) {
return y * width + x;
}
template <class T>
struct OffsetsAndRatios {
OffsetsAndRatios() = default;
OffsetsAndRatios(std::size_t xy,
std::size_t xY,
std::size_t Xy,
std::size_t XY,
T xy_ratio,
T xY_ratio,
T Xy_ratio,
T XY_ratio)
: xy(xy),
xY(xY),
Xy(Xy),
XY(XY),
xy_ratio(xy_ratio),
xY_ratio(xY_ratio),
Xy_ratio(Xy_ratio),
XY_ratio(XY_ratio) {}
std::size_t xy = 0;
std::size_t xY = 0;
std::size_t Xy = 0;
std::size_t XY = 0;
T xy_ratio = 0.0f;
T xY_ratio = 0.0f;
T Xy_ratio = 0.0f;
T XY_ratio = 0.0f;
};
template <typename T>
std::vector<OffsetsAndRatios<T>> GetIndexesAndRatios(
std::size_t width,
std::size_t height,
const T roi_width,
const T roi_height,
const T roi_xmin,
const T roi_ymin,
std::size_t pooled_width,
std::size_t roi_bin_grid_w,
std::size_t pooled_height,
std::size_t roi_bin_grid_h) {
const auto ind_num =
pooled_width * roi_bin_grid_w * pooled_height * roi_bin_grid_h;
std::vector<OffsetsAndRatios<T>> interpolation_cords;
interpolation_cords.reserve(ind_num);
const auto bin_w = roi_width / pooled_width;
const auto bin_h = roi_height / pooled_height;
for (std::size_t py = 0; py < pooled_height; py++) {
for (std::size_t px = 0; px < pooled_width; px++) {
for (std::size_t iy = 0; iy < roi_bin_grid_h; iy++) {
// calculate x of sample points
auto y =
roi_ymin +
bin_h * (py +
static_cast<T>(iy + .5f) / static_cast<T>(roi_bin_grid_h));
for (std::size_t ix = 0; ix < roi_bin_grid_w; ix++) {
// calculate x of sample points
auto x = roi_xmin +
bin_w * (px +
static_cast<T>(ix + .5f) /
static_cast<T>(roi_bin_grid_w));
// deal with elements out of map
if (y < -1.0 || y > height || x < -1.0 || x > width) {
interpolation_cords.emplace_back();
continue;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
std::size_t x_low_index = static_cast<std::size_t>(x);
std::size_t x_high_index;
if (x_low_index >= width - 1) {
x_high_index = x_low_index = width - 1;
x = static_cast<T>(x_low_index);
} else {
x_high_index = x_low_index + 1;
}
T x_ratio = x_high_index - x;
std::size_t y_low_index = static_cast<std::size_t>(y);
std::size_t y_high_index;
if (y_low_index >= height - 1) {
y_high_index = y_low_index = height - 1;
y = static_cast<T>(y_low_index);
} else {
y_high_index = y_low_index + 1;
}
T y_ratio = y_high_index - y;
auto xy = GetOffset(x_low_index, y_low_index, width);
auto xY = GetOffset(x_low_index, y_high_index, width);
auto Xy = GetOffset(x_high_index, y_low_index, width);
auto XY = GetOffset(x_high_index, y_high_index, width);
auto xy_ratio = x_ratio * y_ratio;
auto xY_ratio = x_ratio * (1 - y_ratio);
auto Xy_ratio = (1 - x_ratio) * y_ratio;
auto XY_ratio = (1 - x_ratio) * (1 - y_ratio);
interpolation_cords.emplace_back(
xy, xY, Xy, XY, xy_ratio, xY_ratio, Xy_ratio, XY_ratio);
}
}
}
}
return interpolation_cords;
}
template <typename T>
void Interpolate(std::vector<T>& interpolated_values, // NOLINT
const std::vector<OffsetsAndRatios<T>>& interpolation_cords,
const T* data) {
for (auto& ic : interpolation_cords) {
auto xlyl_offset = ic.xy;
auto xhyl_offset = ic.Xy;
auto xlyh_offset = ic.xY;
auto xhyh_offset = ic.XY;
auto xlyl_ratio = ic.xy_ratio;
auto xhyl_ratio = ic.Xy_ratio;
auto xlyh_ratio = ic.xY_ratio;
auto xhyh_ratio = ic.XY_ratio;
interpolated_values.emplace_back(
xlyl_ratio * data[xlyl_offset] + xhyl_ratio * data[xhyl_offset] +
xlyh_ratio * data[xlyh_offset] + xhyh_ratio * data[xhyh_offset]);
}
}
template <typename T>
void AvgPool(const std::vector<T>& interpolated_values,
T* output_data,
int roi_bin_grid_w,
int roi_bin_grid_h,
int pooled_width,
int pooled_height) {
const auto data_amount = pooled_width * pooled_height;
const auto grid_points = roi_bin_grid_w * roi_bin_grid_h;
const T count = 1.0 / grid_points;
auto val_begin = interpolated_values.cbegin();
for (auto i = 0; i < data_amount; ++i) {
T sum = 0.0;
auto val_end = val_begin + grid_points;
sum = std::accumulate(val_begin, val_end, sum);
val_begin = val_end;
output_data[i] = sum * count;
}
}
template <typename T, typename Context>
void ROIAlignKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
int sampling_ratio,
bool aligned,
DenseTensor* out) {
auto in_dims = x.dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = boxes.dims()[0];
auto in_stride = phi::stride(in_dims);
auto roi_stride = phi::stride(boxes.dims());
auto out_stride = phi::stride(out->dims());
const T* input_data = x.data<T>();
DenseTensor roi_batch_id_list = Empty<int>(dev_ctx, {rois_num});
int* roi_batch_id_data = roi_batch_id_list.data<int>();
int boxes_batch_size;
if (boxes_num) {
boxes_batch_size = boxes_num->numel();
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
boxes_batch_size,
batch_size));
auto* boxes_num_data = boxes_num->data<int>();
int start = 0;
for (int n = 0; n < boxes_batch_size; ++n) {
for (int i = start; i < start + boxes_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += boxes_num_data[n];
}
} else {
auto lod = boxes.lod();
PADDLE_ENFORCE_EQ(
lod.empty(),
false,
errors::InvalidArgument("Input(ROIs) Tensor of ROIAlignOp "
"does not contain LoD information."));
auto boxes_lod = lod.back();
int boxes_batch_size = boxes_lod.size() - 1;
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
errors::InvalidArgument(
"The boxes_batch_size and imgs "
"batch_size must be the same. But received boxes_batch_size = %d, "
"batch_size = %d",
boxes_batch_size,
batch_size));
int boxes_num_with_lod = boxes_lod[boxes_batch_size];
PADDLE_ENFORCE_EQ(
rois_num,
boxes_num_with_lod,
errors::InvalidArgument(
"The actual number of rois and the number of rois "
"provided from Input(RoIsLoD) in RoIAlign must be the same."
" But received actual number of rois is %d, and the number "
"of rois from RoIsLoD is %d",
rois_num,
boxes_num_with_lod));
for (int n = 0; n < boxes_batch_size; ++n) {
for (std::size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
T* output_data = dev_ctx.template Alloc<T>(out);
const T* boxes_data = boxes.data<T>();
T roi_offset = aligned ? T(0.5) : 0;
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n];
T roi_xmin = boxes_data[0] * spatial_scale - roi_offset;
T roi_ymin = boxes_data[1] * spatial_scale - roi_offset;
T roi_xmax = boxes_data[2] * spatial_scale - roi_offset;
T roi_ymax = boxes_data[3] * spatial_scale - roi_offset;
T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin;
if (!aligned) {
roi_width = std::max(roi_width, static_cast<T>(1.));
roi_height = std::max(roi_height, static_cast<T>(1.));
}
const T* batch_data = input_data + roi_batch_id * in_stride[0];
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
auto interpolation_cords = GetIndexesAndRatios(width,
height,
roi_width,
roi_height,
roi_xmin,
roi_ymin,
pooled_width,
roi_bin_grid_w,
pooled_height,
roi_bin_grid_h);
std::vector<T> interpolated_values;
interpolated_values.reserve(interpolation_cords.size());
for (auto channel = 0; channel < channels; ++channel) {
Interpolate(interpolated_values, interpolation_cords, batch_data);
AvgPool(interpolated_values,
output_data,
roi_bin_grid_w,
roi_bin_grid_h,
pooled_width,
pooled_height);
batch_data += in_stride[1];
output_data += out_stride[1];
interpolated_values.clear();
}
boxes_data += roi_stride[0];
}
}
} // namespace phi
PD_REGISTER_KERNEL(
roi_align, CPU, ALL_LAYOUT, phi::ROIAlignKernel, float, double, int) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roi_align_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/fluid/memory/memory.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int kROISize = 4;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <class T>
__device__ T BilinearInterpolate(
const T* input_data, const int height, const int width, T y, T x) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
return 0;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T v1 = input_data[y_low * width + x_low];
T v2 = input_data[y_low * width + x_high];
T v3 = input_data[y_high * width + x_low];
T v4 = input_data[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <class T>
__global__ void GPUROIAlignForward(const int nthreads,
const T* input_data,
const T* input_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
int* roi_batch_id_data,
T* output_data,
const bool continuous_coordinate) {
CUDA_KERNEL_LOOP(i, nthreads) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
T roi_offset = continuous_coordinate ? static_cast<T>(0.5) : 0;
T roi_xmin = offset_input_rois[0] * spatial_scale - roi_offset;
T roi_ymin = offset_input_rois[1] * spatial_scale - roi_offset;
T roi_xmax = offset_input_rois[2] * spatial_scale - roi_offset;
T roi_ymax = offset_input_rois[3] * spatial_scale - roi_offset;
T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin;
if (!continuous_coordinate) {
roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = BilinearInterpolate(offset_input_data, height, width, y, x);
output_val += val;
}
}
output_val /= count;
output_data[i] = output_val;
}
}
template <typename T, typename Context>
void ROIAlignKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
int sampling_ratio,
bool aligned,
DenseTensor* out) {
auto in_dims = x.dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = boxes.dims()[0];
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(dev_ctx, &threads, 256);
#endif
DenseTensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data = dev_ctx.template HostAlloc<int>(&roi_batch_id_list);
auto cplace = phi::CPUPlace();
auto gplace = dev_ctx.GetPlace();
if (boxes_num) {
int boxes_batch_size = boxes_num->numel();
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
errors::InvalidArgument(
"The boxes_batch_size and imgs "
"batch_size must be the same. But received boxes_batch_size = %d, "
"batch_size = %d",
boxes_batch_size,
batch_size));
std::vector<int> boxes_num_list(boxes_batch_size);
paddle::memory::Copy(cplace,
boxes_num_list.data(),
gplace,
boxes_num->data<int>(),
sizeof(int) * boxes_batch_size,
0);
int start = 0;
for (int n = 0; n < boxes_batch_size; ++n) {
for (int i = start; i < start + boxes_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += boxes_num_list[n];
}
} else {
auto lod = boxes.lod();
PADDLE_ENFORCE_EQ(lod.empty(),
false,
errors::InvalidArgument("Input(ROIs) in ROIAlignOp does "
"not contain LoD information."));
auto boxes_lod = lod.back();
int boxes_batch_size = boxes_lod.size() - 1;
PADDLE_ENFORCE_EQ(
boxes_batch_size,
batch_size,
errors::InvalidArgument(
"The batch size of rois and batch size "
"of images must be the same. But received rois batch size = %d, "
"and images batch size = %d",
boxes_batch_size,
batch_size));
int boxes_num_with_lod = boxes_lod[boxes_batch_size];
PADDLE_ENFORCE_EQ(
rois_num,
boxes_num_with_lod,
errors::InvalidArgument(
"The actual number of rois and the number of rois "
"provided from Input(RoIsLoD) in RoIAlign must be the same."
" But received actual number of rois is %d, and the number "
"of rois from RoIsLoD is %d",
rois_num,
boxes_num_with_lod));
for (int n = 0; n < boxes_batch_size; ++n) {
for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = paddle::memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
paddle::memory::Copy(
gplace, roi_id_data, cplace, roi_batch_id_data, bytes, dev_ctx.stream());
GPUROIAlignForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_size,
x.data<T>(),
boxes.data<T>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
roi_id_data,
dev_ctx.template Alloc<T>(out),
aligned);
}
} // namespace phi
PD_REGISTER_KERNEL(
roi_align, GPU, ALL_LAYOUT, phi::ROIAlignKernel, float, double) {}
...@@ -15,10 +15,9 @@ limitations under the License. */ ...@@ -15,10 +15,9 @@ limitations under the License. */
#include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/scale_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/phi/common/float16.h"
namespace phi { namespace phi {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void ROIAlignKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& boxes,
paddle::optional<const DenseTensor&> boxes_num,
int pooled_height,
int pooled_width,
float spatial_scale,
int sampling_ratio,
bool aligned,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ROIAlignOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("roi_align",
{"X", "ROIs", "RoisNum"},
{"pooled_height",
"pooled_width",
"spatial_scale",
"sampling_ratio",
"aligned"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(roi_align, phi::ROIAlignOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册