提交 98c3294b 编写于 作者: J jerrywgz

Merge branch 'roialign' of https://github.com/jerrywgz/Paddle into roialign

此差异已折叠。
......@@ -10,7 +10,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
......
......@@ -33,16 +33,9 @@ static inline int NumBlocks(const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
/*
template <class T>
inline __device__ T gpu_atomic_add(const T val, T* address) {
return atomicAdd(address, val);
}
*/
template <class T>
__device__ T bilinear_interpolate(const T* input_data, const int height,
const int width, T y, T x, ) {
const int width, T y, T x) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
return 0;
}
......@@ -82,15 +75,11 @@ __device__ T bilinear_interpolate(const T* input_data, const int height,
}
template <class T>
__device__ T bilinear_interpolate_gradient(const int height, const int width,
T y, T x, const T& w1, const T& w2,
const T& w3, const T& w4,
const int& x_low, const int& x_high,
const int& y_low,
const int& y_high) {
__device__ void bilinear_interpolate_gradient(const int height, const int width,
T y, T x, T* w1, T* w2, T* w3,
T* w4, int* x_low, int* x_high,
int* y_low, int* y_high) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
......@@ -100,23 +89,23 @@ __device__ T bilinear_interpolate_gradient(const int height, const int width,
if (x <= 0) {
x = 0;
}
y_low = static_cast<int>(y);
x_low = static_cast<int>(x);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
*y_low = static_cast<int>(y);
*x_low = static_cast<int>(x);
if (*y_low >= height - 1) {
*y_high = *y_low = height - 1;
y = static_cast<T>(*y_low);
} else {
y_high = y_low + 1;
*y_high = *y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
if (*x_low >= width - 1) {
*x_high = *x_low = width - 1;
x = static_cast<T>(*x_low);
} else {
x_high = x_low + 1;
*x_high = *x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T ly = y - *y_low, lx = x - *x_low;
T hy = 1. - ly, hx = 1. - lx;
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
......@@ -126,7 +115,7 @@ __global__ void GPUROIAlignForward(
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int sampling_ratio int* roi_batch_id_data, T* output_data) {
const int sampling_ratio, int* roi_batch_id_data, T* output_data) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
......@@ -141,8 +130,8 @@ __global__ void GPUROIAlignForward(
T roi_xmax = offset_input_rois[2] * spatial_scale;
T roi_ymax = offset_input_rois[3] * spatial_scale;
T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
T roi_width = max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
......@@ -175,7 +164,7 @@ __global__ void GPUROIAlignForward(
template <typename T>
__global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
const T* output_grad, const int num_rois,
const T* out_grad, const int num_rois,
const float spatial_scale,
const int channels, const int height,
const int width, const int pooled_height,
......@@ -185,7 +174,7 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (ic / pooled_width / pooled_height) % channels;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
......@@ -195,12 +184,12 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
T roi_xmax = offset_input_rois[2] * spatial_scale;
T roi_ymax = offset_input_rois[3] * spatial_scale;
T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
T roi_width = max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input_grad =
T* offset_input_grad =
input_grad + (roi_batch_ind * channels + c) * height * width;
const T* offset_out_grad =
......@@ -215,17 +204,17 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
const T count = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
const T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high);
T w1 = 0, w2 = 0, w3 = 0, w4 = 0;
int x_low = -1, x_high = -1, y_low = -1, y_high = -1;
bilinear_interpolate_gradient(height, width, y, x, &w1, &w2, &w3, &w4,
&x_low, &x_high, &y_low, &y_high);
T diff1 = out_grad_this_bin * w1 / count;
T diff2 = out_grad_this_bin * w2 / count;
T diff3 = out_grad_this_bin * w3 / count;
......@@ -238,7 +227,7 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_low,
diff3);
platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_high,
diff3);
diff4);
}
}
}
......@@ -249,7 +238,7 @@ template <typename Place, typename T>
class GPUROIAlignOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
i auto* in = ctx.Input<Tensor>("X");
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
......@@ -337,9 +326,9 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu);
x_grad->mutable_data<T>(ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
......@@ -351,7 +340,7 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
output_grad_size, rois->data<T>(), out_grad->data<T>(), rois_num,
spatial_scale, channels, height, width, pooled_height, pooled_width,
sampling_ratio, roi_batch_id_list_gpu.data<int>(),
x_grad->mutable_data<T>(ctx.GetPlace()));
in_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册