未验证 提交 d3ac561d 编写于 作者: B Bai Yifan 提交者: GitHub

fix deformable_conv_op compile error, test=develop (#18793)

上级 9ecd8ee7
......@@ -200,6 +200,36 @@ __device__ T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height,
return weight;
}
template <typename T>
__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__global__ void ModulatedDeformableCol2imCoordGpuKernel(
const int nthreads, const T* data_col, const T* data_im,
......@@ -315,36 +345,6 @@ inline void ModulatedDeformableCol2imCoord(
deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask);
}
template <typename T>
__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel(
const int nthreads, const T* data_im, const T* data_offset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册