From d3ac561d65d77c06026f3b33a5cf5f41f065f0b5 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Thu, 25 Jul 2019 13:34:37 +0800 Subject: [PATCH] fix deformable_conv_op compile error, test=develop (#18793) --- paddle/fluid/operators/deformable_conv_op.cu | 60 ++++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/deformable_conv_op.cu b/paddle/fluid/operators/deformable_conv_op.cu index afcd418f056..cbb9bed90ce 100644 --- a/paddle/fluid/operators/deformable_conv_op.cu +++ b/paddle/fluid/operators/deformable_conv_op.cu @@ -200,6 +200,36 @@ __device__ T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height, return weight; } +template +__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, + const int height, const int width, T h, T w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + template __global__ void ModulatedDeformableCol2imCoordGpuKernel( const int nthreads, const T* data_col, const T* data_im, @@ -315,36 +345,6 @@ inline void ModulatedDeformableCol2imCoord( deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask); } -template -__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, - const int height, const int width, T h, T w) { - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh, hw = 1 - lw; - - T v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; - T v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = bottom_data[h_low * data_width + w_high]; - T v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = bottom_data[h_high * data_width + w_low]; - T v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = bottom_data[h_high * data_width + w_high]; - - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - template __global__ void ModulatedDeformableIm2colGpuKernel( const int nthreads, const T* data_im, const T* data_offset, -- GitLab