/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { // CUDA: index helpers #define idx4_4(index, d1, d2, d3, d4) (index % d4) #define idx4_3(index, d1, d2, d3, d4) ((index / d4) % d3) #define idx4_2(index, d1, d2, d3, d4) ((index / d4 / d3) % d2) #define idx4_1(index, d1, d2, d3, d4) ((index / d4 / d3 / d2) % d1) #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) template __device__ bool GT_E(T a, T b) { return (a > b) || fabs(a - b) < 1e-4; } template __device__ bool LT_E(T a, T b) { return (a < b) || fabs(a - b) < 1e-4; } template __device__ bool GT(T a, T b) { return (a - b) > 1e-4; } template __device__ T max(T a, T b) { return a > b ? a : b; } template __device__ T min(T a, T b) { return a < b ? a : b; } /* * check if (x, y) is in the boundary of roi */ template __device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) { for (int i = 0; i < 4; i++) { T start_w = roi_x[i]; T start_h = roi_y[i]; T end_w = roi_x[(i + 1) % 4]; T end_h = roi_y[(i + 1) % 4]; if (fabs(start_h - end_h) < 1e-4) { if (fabs(y - start_h) < 1e-4 && fabs(y - end_h) < 1e-4 && GT_E(x, min(start_w, end_w)) && LT_E(x, max(start_w, end_w))) { return true; } } else { T intersec_x = (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w; if (fabs(intersec_x - x) < 1e-4 && GT_E(y, min(start_h, end_h)) && LT_E(y, max(start_h, end_h))) { return true; } } } int n_cross = 0; for (int i = 0; i < 4; i++) { T start_w = roi_x[i]; T start_h = roi_y[i]; T end_w = roi_x[(i + 1) % 4]; T end_h = roi_y[(i + 1) % 4]; if (fabs(start_h - end_h) < 1e-4) { continue; } if (LT_E(y, min(start_h, end_h)) || GT(y, max(start_h, end_h))) { continue; } T intersec_x = (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w; if (fabs(intersec_x - x) < 1e-4) { return true; } if (GT(intersec_x, x)) { n_cross++; } } return (n_cross % 2 == 1); } /** * Perform bilinear interpolation in the input feature map. */ template __device__ void bilinear_interpolate(const T* in_data, const int channels, const int width, const int height, int in_n, int in_c, T in_w, T in_h, T* val) { // Deal with cases that source coords are out of feature map boundary if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || GT(in_h, height - 0.5)) { val[0] = 0.0; return; } if (GT(0, in_w)) { in_w = 0; } if (GT(0, in_h)) { in_h = 0; } int in_w_floor = floor(in_w); int in_h_floor = floor(in_h); int in_w_ceil; int in_h_ceil; if (GT_E(in_w_floor, width - 1)) { in_w_ceil = in_w_floor = width - 1; in_w = static_cast(in_w_floor); } else { in_w_ceil = in_w_floor + 1; } if (GT_E(in_h_floor, height - 1)) { in_h_ceil = in_h_floor = height - 1; in_h = static_cast(in_h_floor); } else { in_h_ceil = in_h_floor + 1; } T w_floor = in_w - in_w_floor; T h_floor = in_h - in_h_floor; T w_ceil = 1 - w_floor; T h_ceil = 1 - h_floor; const T* data = in_data + (in_n * channels + in_c) * height * width; // Do bilinear interpolation T v1 = data[in_h_floor * width + in_w_floor]; T v2 = data[in_h_ceil * width + in_w_floor]; T v3 = data[in_h_ceil * width + in_w_ceil]; T v4 = data[in_h_floor * width + in_w_ceil]; T w1 = w_ceil * h_ceil; T w2 = w_ceil * h_floor; T w3 = w_floor * h_floor; T w4 = w_floor * h_ceil; val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; } /** * Get the source coordinates in the input feature map. * * (u, v, w)^matrix = T * (out_w, out_h, 1)^matrix * * in_w = u / w * in_h = v / w * */ template __device__ void get_source_coords(T matrix[], int out_w, int out_h, T* in_w, T* in_h) { T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2]; T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5]; T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8]; in_w[0] = u / w; in_h[0] = v / w; } /** * Get the matrix of perspective transform. * * dx1 = x1 - x2 * dx2 = x3 - x2 * dx3 = x0 - x1 + x2 - x3 * dy1 = y1 - y2 * dy2 = y3 - y2 * dy3 = y0 - y1 + y2 - y3 * * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1) * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1) * a13 = x0 * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1) * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1) * a23 = y0 * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1) * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1) * a33 = 1 * */ template __device__ void get_transform_matrix(const int transformed_width, const int transformed_height, T roi_x[], T roi_y[], T matrix[]) { T x0 = roi_x[0]; T x1 = roi_x[1]; T x2 = roi_x[2]; T x3 = roi_x[3]; T y0 = roi_y[0]; T y1 = roi_y[1]; T y2 = roi_y[2]; T y3 = roi_y[3]; // Estimate the height and width of RoI T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)); T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)); T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)); T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)); T estimated_height = (len2 + len4) / 2.0; T estimated_width = (len1 + len3) / 2.0; // Get the normalized height and normalized width int normalized_height = transformed_height; int normalized_width = round(estimated_width * (normalized_height - 1) / estimated_height) + 1; normalized_width = min(normalized_width, transformed_width); T dx1 = x1 - x2; T dx2 = x3 - x2; T dx3 = x0 - x1 + x2 - x3; T dy1 = y1 - y2; T dy2 = y3 - y2; T dy3 = y0 - y1 + y2 - y3; matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (normalized_width - 1); matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (normalized_height - 1); matrix[8] = 1; matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) / (normalized_width - 1); matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) / (normalized_height - 1); matrix[5] = y0; matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) / (normalized_width - 1); matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) / (normalized_height - 1); matrix[2] = x0; } template __global__ void RoiTransformKernel(const float* input_data, const float* rois_data, const int* roi2image_data, int num_rois, int in_height, int in_width, int channels, int transformed_height, int transformed_width, float spatial_scale, T* output_data) { int output_size = num_rois * transformed_height * transformed_width * channels; CUDA_1D_KERNEL_LOOP(index, output_size) { // (n, c, out_h, out_w) is an element in the transformed output int out_w = idx4_4(index, num_rois, channels, transformed_height, transformed_width); int out_h = idx4_3(index, num_rois, channels, transformed_height, transformed_width); int c = idx4_2(index, num_rois, channels, transformed_height, transformed_width); int n = idx4_1(index, num_rois, channels, transformed_height, transformed_width); auto bottom_rois = rois_data + n * 8; int roi_batch_ind = bottom_rois[0]; T roi_x[4]; T roi_y[4]; for (int k = 0; k < 4; ++k) { roi_x[k] = bottom_rois[2 * k] * spatial_scale; roi_y[k] = bottom_rois[2 * k + 1] * spatial_scale; } // Get transform matrix T matrix[9]; get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y, matrix); // Get source coords T in_w; T in_h; get_source_coords(matrix, out_w, out_h, &in_w, &in_h); if (in_quad(in_w, in_h, roi_x, roi_y)) { if (GT(-0.5, in_w) || GT(in_w, static_cast(in_width - 0.5)) || GT(-0.5, in_h) || GT(in_h, static_cast(in_height - 0.5))) { // Skip if source coords is not in input image output_data[index] = 0.0; } else { // Perform bilinear interpolation int in_n = roi2image_data[n]; bilinear_interpolate(input_data, channels, in_width, in_height, in_n, c, in_w, in_h, output_data + index); } } else { // Skip if source coords is not in quad output_data[index] = 0.0; } } } template class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); auto* out = ctx.Output("Out"); auto transformed_height = ctx.Attr("transformed_height"); auto transformed_width = ctx.Attr("transformed_width"); auto spatial_scale = ctx.Attr("spatial_scale"); auto in_dims = in->dims(); int batch_size = in_dims[0]; int channels = in_dims[1]; int in_height = in_dims[2]; int in_width = in_dims[3]; int rois_num = rois->dims()[0]; const T* input_data = in->data(); T* output_data = out->mutable_data(ctx.GetPlace()); const T* rois_data = rois->data(); framework::Tensor roi2image; framework::Tensor roi2image_dev; roi2image.Resize({rois_num}); int* roi2image_data = roi2image.mutable_data(platform::CPUPlace()); auto lod = rois->lod().back(); for (int i = 0; i < lod.size() - 1; ++i) { for (int j = lod[i]; j < lod[i + 1]; ++j) { roi2image_data[j] = i; } } TensorCopySync(roi2image, ctx.GetPlace(), &roi2image_dev); int out_size = rois_num * transformed_height * transformed_width * channels; auto stream = ctx.cuda_device_context().stream(); int block = 512; int grid = (out_size + block - 1) / block; RoiTransformKernel<<>>( input_data, rois_data, roi2image_dev.data(), rois_num, in_height, in_width, channels, transformed_height, transformed_width, spatial_scale, output_data); } }; template __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, const int height) { if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || GT(ys, height - 0.5)) { return 0; } if (GT(0, xs)) { xs = 0; } if (GT(0, ys)) { ys = 0; } int xs_floor = floor(xs); int ys_floor = floor(ys); int xs_ceil; int ys_ceil; if (GT_E(xs_floor, width - 1)) { xs_ceil = xs_floor = width - 1; xs = static_cast(xs_floor); } else { xs_ceil = xs_floor + 1; } if (GT_E(ys_floor, height - 1)) { ys_ceil = ys_floor = height - 1; ys = static_cast(ys_floor); } else { ys_ceil = ys_floor + 1; } T weight = 0; if (w == xs_floor) { if (h == ys_floor) { weight = (w + 1 - xs) * (h + 1 - ys); } else if (h == ys_ceil) { weight = (w + 1 - xs) * (ys + 1 - h); } } else if (w == xs_ceil) { if (h == ys_floor) { weight = (xs + 1 - w) * (h + 1 - ys); } else if (h == ys_ceil) { weight = (xs + 1 - w) * (ys + 1 - h); } } return weight; } template __global__ void RoiTransformGradKernel( const size_t* lod, const T* rois_data, int batch_size, int num_rois, int in_height, int in_width, int channels, int transformed_height, int transformed_width, float spatial_scale, const T* out_grad_data, T* in_grad_data) { int input_size = batch_size * in_height * in_width * channels; CUDA_1D_KERNEL_LOOP(index, input_size) { // (n, c, h, w) coords in input int in_w = idx4_4(index, batch_size, channels, in_height, in_width); int in_h = idx4_3(index, batch_size, channels, in_height, in_width); int c = idx4_2(index, batch_size, channels, in_height, in_width); int n = idx4_1(index, batch_size, channels, in_height, in_width); T gradient = 0.0; // Accumulate gradient over all RoIs that interpolated this element for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { const T* rois = rois_data + roi_idx * 8; T roi_x[4]; T roi_y[4]; for (int k = 0; k < 4; ++k) { roi_x[k] = rois[2 * k] * spatial_scale; roi_y[k] = rois[2 * k + 1] * spatial_scale; } // Get transform matrix T matrix[9]; get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y, matrix); const T* out_grad_ptr = out_grad_data + (roi_idx * channels + c) * transformed_height * transformed_width; for (int out_h = 0; out_h < transformed_height; ++out_h) { for (int out_w = 0; out_w < transformed_width; ++out_w) { T src_w; T src_h; get_source_coords(matrix, out_w, out_h, &src_w, &src_h); if (in_quad(src_w, src_h, roi_x, roi_y)) { if (GT(-0.5, src_w) || GT(src_w, static_cast(in_width - 0.5)) || GT(-0.5, src_h) || GT(src_h, static_cast(in_height - 0.5))) { continue; } T weight = get_feature_gradient(src_w, src_h, in_w, in_h, in_width, in_height); gradient += out_grad_ptr[out_h * transformed_width + out_w] * weight; } } } } in_grad_data[index] = gradient; } } template class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); auto* in_grad = ctx.Output(framework::GradVarName("X")); auto transformed_height = ctx.Attr("transformed_height"); auto transformed_width = ctx.Attr("transformed_width"); auto spatial_scale = ctx.Attr("spatial_scale"); auto in_dims = in->dims(); int batch_size = in_dims[0]; int channels = in_dims[1]; int in_height = in_dims[2]; int in_width = in_dims[3]; int rois_num = rois->dims()[0]; T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); const T* out_grad_data = out_grad->data(); const T* rois_data = rois->data(); auto lod = rois->lod().back(); auto lod_data = lod.CUDAData(ctx.GetPlace()); int in_size = in->numel(); auto stream = ctx.cuda_device_context().stream(); int block = 512; int grid = (in_size + block - 1) / block; RoiTransformGradKernel<<>>( lod_data, rois_data, batch_size, rois_num, in_height, in_width, channels, transformed_height, transformed_width, spatial_scale, out_grad_data, in_grad_data); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(roi_perspective_transform, ops::CUDAROIPerspectiveTransformOpKernel); REGISTER_OP_CUDA_KERNEL(roi_perspective_transform_grad, ops::CUDAROIPerspectiveTransformGradOpKernel);