未验证 提交 cf9ca61d 编写于 作者: Z Zhang Ting 提交者: GitHub

Revert #46111 (#46961)

* Revert "【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 (#46111)"
上级 b971ec5e
...@@ -97,14 +97,14 @@ inline void ModulatedDeformableCol2imCPUKernel( ...@@ -97,14 +97,14 @@ inline void ModulatedDeformableCol2imCPUKernel(
width); width);
*(grad_im + cur_bottom_grad_pos) = *(grad_im + cur_bottom_grad_pos) =
*(grad_im + cur_bottom_grad_pos) + (weight * cur_top_grad); *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad;
} }
} }
} }
} }
} }
template <typename T, typename MT, typename Context> template <typename T, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx, void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col, const T* data_col,
const T* data_offset, const T* data_offset,
...@@ -116,7 +116,7 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, ...@@ -116,7 +116,7 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const int deformable_group, const int deformable_group,
MT* grad_im) { T* grad_im) {
int channel_per_deformable_group = im_shape[0] / deformable_group; int channel_per_deformable_group = im_shape[0] / deformable_group;
int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
...@@ -222,8 +222,8 @@ void ModulatedDeformableCol2imCoordCPUKernel( ...@@ -222,8 +222,8 @@ void ModulatedDeformableCol2imCoordCPUKernel(
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2; inv_h = inv_w = -2;
} else { } else {
mval += data_col_ptr[col_pos] * funcs::DmcnIm2colBilinear<T, T>( mval += data_col_ptr[col_pos] *
data_im_ptr + cnt * height * width, funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width,
width, width,
height, height,
width, width,
...@@ -231,7 +231,7 @@ void ModulatedDeformableCol2imCoordCPUKernel( ...@@ -231,7 +231,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
inv_w); inv_w);
} }
const T weight = const T weight =
DmcnGetCoordinateWeight<T, T>(inv_h, DmcnGetCoordinateWeight(inv_h,
inv_w, inv_w,
height, height,
width, width,
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -82,8 +82,8 @@ inline void ModulatedDeformableIm2colCPUKernel( ...@@ -82,8 +82,8 @@ inline void ModulatedDeformableIm2colCPUKernel(
const T h_im = h_in + i * dilation_h + offset_h; const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w; const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val = DmcnIm2colBilinear<T, T>( val =
data_im_ptr, width, height, width, h_im, w_im); DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
} }
*data_col_ptr = val; *data_col_ptr = val;
if (data_mask_ptr) { if (data_mask_ptr) {
......
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -54,8 +51,6 @@ __global__ void ModulatedDeformableIm2colGpuKernel( ...@@ -54,8 +51,6 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
T* data_col) { T* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (size_t i = index; i < nthreads; i += offset) { for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col; const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col; const int h_col = (i / width_col) % height_col;
...@@ -90,22 +85,22 @@ __global__ void ModulatedDeformableIm2colGpuKernel( ...@@ -90,22 +85,22 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col; w_col;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]); const T offset_h = data_offset_ptr[data_offset_h_ptr];
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]); const T offset_w = data_offset_ptr[data_offset_w_ptr];
MT val = static_cast<MT>(0); T val = static_cast<T>(0);
const MT h_im = h_in + i * dilation_h + offset_h; const T h_im = h_in + i * dilation_h + offset_h;
const MT w_im = w_in + j * dilation_w + offset_w; const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val = DmcnIm2colBilinear<T, MT>( val =
data_im_ptr, width, height, width, h_im, w_im); DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
} }
*data_col_ptr = val;
if (data_mask_ptr) { if (data_mask_ptr) {
const int data_mask_hw_ptr = const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col; ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]); const T mask = data_mask_ptr[data_mask_hw_ptr];
val *= mask; *data_col_ptr *= mask;
} }
*data_col_ptr = static_cast<T>(val);
data_col_ptr += batch_size * height_col * width_col; data_col_ptr += batch_size * height_col * width_col;
} }
} }
...@@ -169,20 +164,6 @@ template void ModulatedDeformableIm2col( ...@@ -169,20 +164,6 @@ template void ModulatedDeformableIm2col(
const int deformable_groups, const int deformable_groups,
float* data_col); float* data_col);
template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const phi::dtype::float16* data_im,
const phi::dtype::float16* data_offset,
const phi::dtype::float16* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
phi::dtype::float16* data_col);
template void ModulatedDeformableIm2col( template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx, const phi::GPUContext& dev_ctx,
const double* data_im, const double* data_im,
......
...@@ -14,47 +14,44 @@ ...@@ -14,47 +14,44 @@
#pragma once #pragma once
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
template <typename T, typename MT> template <typename T>
HOSTDEVICE MT DmcnIm2colBilinear(const T* bottom_data, HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data,
const int data_width, const int data_width,
const int height, const int height,
const int width, const int width,
MT h, T h,
MT w) { T w) {
int h_low = floor(h); int h_low = floor(h);
int w_low = floor(w); int w_low = floor(w);
int h_high = h_low + 1; int h_high = h_low + 1;
int w_high = w_low + 1; int w_high = w_low + 1;
MT lh = h - h_low; T lh = h - h_low;
MT lw = w - w_low; T lw = w - w_low;
MT hh = 1 - lh; T hh = 1 - lh;
MT hw = 1 - lw; T hw = 1 - lw;
MT v1 = (h_low >= 0 && w_low >= 0) T v1 =
? static_cast<MT>(bottom_data[h_low * data_width + w_low]) (h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0; : 0;
MT v2 = (h_low >= 0 && w_high <= width - 1) T v3 = (h_high <= height - 1 && w_low >= 0)
? static_cast<MT>(bottom_data[h_low * data_width + w_high]) ? bottom_data[h_high * data_width + w_low]
: 0; : 0;
MT v3 = (h_high <= height - 1 && w_low >= 0) T v4 = (h_high <= height - 1 && w_high <= width - 1)
? static_cast<MT>(bottom_data[h_high * data_width + w_low]) ? bottom_data[h_high * data_width + w_high]
: 0;
MT v4 = (h_high <= height - 1 && w_high <= width - 1)
? static_cast<MT>(bottom_data[h_high * data_width + w_high])
: 0; : 0;
MT w1 = hh * hw; T w1 = hh * hw;
MT w2 = hh * lw; T w2 = hh * lw;
MT w3 = lh * hw; T w3 = lh * hw;
MT w4 = lh * lw; T w4 = lh * lw;
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
} }
......
...@@ -29,7 +29,7 @@ static inline int NumBlocks(const int N) { ...@@ -29,7 +29,7 @@ static inline int NumBlocks(const int N) {
kNumMaximumNumBlocks); kNumMaximumNumBlocks);
} }
template <typename T, typename MT> template <typename T>
__global__ void ModulatedDeformableCol2imGpuKernel( __global__ void ModulatedDeformableCol2imGpuKernel(
const int nthreads, const int nthreads,
const T* data_col, const T* data_col,
...@@ -51,10 +51,9 @@ __global__ void ModulatedDeformableCol2imGpuKernel( ...@@ -51,10 +51,9 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
const int deformable_group, const int deformable_group,
const int height_col, const int height_col,
const int width_col, const int width_col,
MT* grad_im) { T* grad_im) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
// using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (size_t thread = index; thread < nthreads; thread += offset) { for (size_t thread = index; thread < nthreads; thread += offset) {
const int j = (thread / width_col / height_col / batch_size) % kernel_w; const int j = (thread / width_col / height_col / batch_size) % kernel_w;
const int i = const int i =
...@@ -79,17 +78,17 @@ __global__ void ModulatedDeformableCol2imGpuKernel( ...@@ -79,17 +78,17 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out; ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]); const T offset_h = data_offset_ptr[data_offset_h_ptr];
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]); const T offset_w = data_offset_ptr[data_offset_w_ptr];
const MT cur_inv_h_data = h_in + i * dilation_h + offset_h; const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const MT cur_inv_w_data = w_in + j * dilation_w + offset_w; const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
MT cur_top_grad = static_cast<MT>(data_col[thread]); T cur_top_grad = data_col[thread];
if (data_mask) { if (data_mask) {
const T* data_mask_ptr = const T* data_mask_ptr =
data_mask + (b * deformable_group + deformable_group_index) * data_mask + (b * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col; kernel_h * kernel_w * height_col * width_col;
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]); const T mask = data_mask_ptr[data_mask_hw_ptr];
cur_top_grad *= mask; cur_top_grad *= mask;
} }
const int cur_h = static_cast<int>(cur_inv_h_data); const int cur_h = static_cast<int>(cur_inv_h_data);
...@@ -101,12 +100,13 @@ __global__ void ModulatedDeformableCol2imGpuKernel( ...@@ -101,12 +100,13 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
abs(cur_inv_w_data - (cur_w + dx)) < 1) { abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos = int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
MT weight = DmcnGetGradientWeight(cur_inv_h_data, T weight = DmcnGetGradientWeight(cur_inv_h_data,
cur_inv_w_data, cur_inv_w_data,
cur_h + dy, cur_h + dy,
cur_w + dx, cur_w + dx,
height, height,
width); width);
paddle::platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos, paddle::platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad); weight * cur_top_grad);
} }
...@@ -115,7 +115,7 @@ __global__ void ModulatedDeformableCol2imGpuKernel( ...@@ -115,7 +115,7 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
} }
} }
template <typename T, typename MT, typename Context> template <typename T, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx, void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col, const T* data_col,
const T* data_offset, const T* data_offset,
...@@ -127,13 +127,13 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, ...@@ -127,13 +127,13 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const int deformable_group, const int deformable_group,
MT* grad_im) { T* grad_im) {
int channel_per_deformable_group = im_shape[0] / deformable_group; int channel_per_deformable_group = im_shape[0] / deformable_group;
int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];
int blocks = NumBlocks(num_kernels); int blocks = NumBlocks(num_kernels);
int threads = kNumCUDAThreads; int threads = kNumCUDAThreads;
ModulatedDeformableCol2imGpuKernel<T, MT> ModulatedDeformableCol2imGpuKernel<T>
<<<blocks, threads, 0, dev_ctx.stream()>>>(num_kernels, <<<blocks, threads, 0, dev_ctx.stream()>>>(num_kernels,
data_col, data_col,
data_offset, data_offset,
...@@ -185,9 +185,8 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( ...@@ -185,9 +185,8 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel(
T* grad_mask) { T* grad_mask) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (size_t i = index; i < nthreads; i += offset) { for (size_t i = index; i < nthreads; i += offset) {
MT val = 0, mval = 0; T val = 0, mval = 0;
const int w = i % width_col; const int w = i % width_col;
const int h = (i / width_col) % height_col; const int h = (i / width_col) % height_col;
const int c = (i / width_col / height_col) % offset_channels; const int c = (i / width_col / height_col) % offset_channels;
...@@ -232,25 +231,23 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( ...@@ -232,25 +231,23 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel(
const int data_offset_w_ptr = const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out); w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]); const T offset_w = data_offset_ptr[data_offset_w_ptr];
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]); T inv_h = h_in + i * dilation_h + offset_h;
MT inv_h = h_in + i * dilation_h + offset_h; T inv_w = w_in + j * dilation_w + offset_w;
MT inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2; inv_h = inv_w = -2;
} else { } else {
mval += mval += data_col_ptr[col_pos] *
static_cast<MT>(data_col_ptr[col_pos]) * funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width,
funcs::DmcnIm2colBilinear<T, MT>(data_im_ptr + cnt * height * width,
width, width,
height, height,
width, width,
inv_h, inv_h,
inv_w); inv_w);
} }
const MT weight = const T weight =
DmcnGetCoordinateWeight<T, MT>(inv_h, DmcnGetCoordinateWeight(inv_h,
inv_w, inv_w,
height, height,
width, width,
...@@ -260,14 +257,14 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( ...@@ -260,14 +257,14 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel(
if (data_mask_ptr) { if (data_mask_ptr) {
const int data_mask_hw_ptr = const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out); (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]); const T mask = data_mask_ptr[data_mask_hw_ptr];
val += weight * static_cast<MT>(data_col_ptr[col_pos]) * mask; val += weight * data_col_ptr[col_pos] * mask;
} else { } else {
val += weight * static_cast<MT>(data_col_ptr[col_pos]); val += weight * data_col_ptr[col_pos];
} }
cnt += 1; cnt += 1;
} }
grad_offset[i] = static_cast<T>(val); grad_offset[i] = val;
if (grad_mask && offset_c % 2 == 0) if (grad_mask && offset_c % 2 == 0)
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w + kernel_w +
...@@ -362,5 +359,4 @@ PD_REGISTER_KERNEL(deformable_conv_grad, ...@@ -362,5 +359,4 @@ PD_REGISTER_KERNEL(deformable_conv_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::DeformableConvGradKernel, phi::DeformableConvGradKernel,
float, float,
double, double) {}
paddle::platform::float16) {}
...@@ -23,5 +23,4 @@ PD_REGISTER_KERNEL(deformable_conv, ...@@ -23,5 +23,4 @@ PD_REGISTER_KERNEL(deformable_conv,
ALL_LAYOUT, ALL_LAYOUT,
phi::DeformableConvKernel, phi::DeformableConvKernel,
float, float,
double, double) {}
phi::dtype::float16) {}
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
#pragma once #pragma once
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
...@@ -60,9 +58,9 @@ HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, ...@@ -60,9 +58,9 @@ HOSTDEVICE T DmcnGetGradientWeight(T argmax_h,
return weight; return weight;
} }
template <typename T, typename MT> template <typename T>
HOSTDEVICE MT DmcnGetCoordinateWeight(MT argmax_h, HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h,
MT argmax_w, T argmax_w,
const int height, const int height,
const int width, const int width,
const T* im_data, const T* im_data,
...@@ -78,51 +76,43 @@ HOSTDEVICE MT DmcnGetCoordinateWeight(MT argmax_h, ...@@ -78,51 +76,43 @@ HOSTDEVICE MT DmcnGetCoordinateWeight(MT argmax_h,
int argmax_h_high = argmax_h_low + 1; int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1; int argmax_w_high = argmax_w_low + 1;
MT weight = 0; T weight = 0;
if (bp_dir == 0) { if (bp_dir == 0) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0) weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_w_low + 1 - argmax_w) * ? -1 * (argmax_w_low + 1 - argmax_w) *
static_cast<MT>( im_data[argmax_h_low * data_width + argmax_w_low]
im_data[argmax_h_low * data_width + argmax_w_low])
: 0; : 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? -1 * (argmax_w - argmax_w_low) * ? -1 * (argmax_w - argmax_w_low) *
static_cast<MT>( im_data[argmax_h_low * data_width + argmax_w_high]
im_data[argmax_h_low * data_width + argmax_w_high])
: 0; : 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? (argmax_w_low + 1 - argmax_w) * ? (argmax_w_low + 1 - argmax_w) *
static_cast<MT>( im_data[argmax_h_high * data_width + argmax_w_low]
im_data[argmax_h_high * data_width + argmax_w_low])
: 0; : 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_w - argmax_w_low) * ? (argmax_w - argmax_w_low) *
static_cast<MT>( im_data[argmax_h_high * data_width + argmax_w_high]
im_data[argmax_h_high * data_width + argmax_w_high])
: 0; : 0;
} else if (bp_dir == 1) { } else if (bp_dir == 1) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0) weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_h_low + 1 - argmax_h) * ? -1 * (argmax_h_low + 1 - argmax_h) *
static_cast<MT>( im_data[argmax_h_low * data_width + argmax_w_low]
im_data[argmax_h_low * data_width + argmax_w_low])
: 0; : 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? (argmax_h_low + 1 - argmax_h) * ? (argmax_h_low + 1 - argmax_h) *
static_cast<MT>( im_data[argmax_h_low * data_width + argmax_w_high]
im_data[argmax_h_low * data_width + argmax_w_high])
: 0; : 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? -1 * (argmax_h - argmax_h_low) * ? -1 * (argmax_h - argmax_h_low) *
static_cast<MT>( im_data[argmax_h_high * data_width + argmax_w_low]
im_data[argmax_h_high * data_width + argmax_w_low])
: 0; : 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_h - argmax_h_low) * ? (argmax_h - argmax_h_low) *
static_cast<MT>( im_data[argmax_h_high * data_width + argmax_w_high]
im_data[argmax_h_high * data_width + argmax_w_high])
: 0; : 0;
} }
...@@ -145,7 +135,7 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, ...@@ -145,7 +135,7 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
T* grad_offset, T* grad_offset,
T* grad_mask); T* grad_mask);
template <typename T, typename MT, typename Context> template <typename T, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx, void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col, const T* data_col,
const T* data_offset, const T* data_offset,
...@@ -157,7 +147,7 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, ...@@ -157,7 +147,7 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const int deformable_group, const int deformable_group,
MT* grad_im); T* grad_im);
template <typename T, typename Context> template <typename T, typename Context>
void FilterGradAddup(const Context& dev_ctx, void FilterGradAddup(const Context& dev_ctx,
...@@ -186,7 +176,7 @@ void DeformableConvGradKernel(const Context& dev_ctx, ...@@ -186,7 +176,7 @@ void DeformableConvGradKernel(const Context& dev_ctx,
DenseTensor* filter_grad, DenseTensor* filter_grad,
DenseTensor* mask_grad) { DenseTensor* mask_grad) {
const int batch_size = static_cast<int>(x.dims()[0]); const int batch_size = static_cast<int>(x.dims()[0]);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DDim input_shape = phi::slice_ddim(x.dims(), 1, x.dims().size()); DDim input_shape = phi::slice_ddim(x.dims(), 1, x.dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape); std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
...@@ -302,8 +292,8 @@ void DeformableConvGradKernel(const Context& dev_ctx, ...@@ -302,8 +292,8 @@ void DeformableConvGradKernel(const Context& dev_ctx,
mask_grad_data_ptr); mask_grad_data_ptr);
} }
if (dx) { if (dx) {
MT* mt_dx_ptr = dev_ctx.template Alloc<MT>(dx); T* dx_ptr = dx->data<T>();
// get grad of input
ModulatedDeformableCol2im(dev_ctx, ModulatedDeformableCol2im(dev_ctx,
col_buffer_ptr, col_buffer_ptr,
offset_ptr + i * im2col_step * input_offset_dim, offset_ptr + i * im2col_step * input_offset_dim,
...@@ -315,7 +305,7 @@ void DeformableConvGradKernel(const Context& dev_ctx, ...@@ -315,7 +305,7 @@ void DeformableConvGradKernel(const Context& dev_ctx,
strides, strides,
dilations, dilations,
deformable_groups, deformable_groups,
mt_dx_ptr + i * im2col_step * input_dim); dx_ptr + i * im2col_step * input_dim);
dx->Resize(x.dims()); dx->Resize(x.dims());
} }
......
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
#pragma once #pragma once
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
namespace phi { namespace phi {
...@@ -40,12 +38,6 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -40,12 +38,6 @@ void DeformableConvKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
const int batch_size = static_cast<int>(x.dims()[0]); const int batch_size = static_cast<int>(x.dims()[0]);
int temp_step = std::min(64, batch_size);
if (batch_size % temp_step == 0) {
im2col_step = temp_step;
}
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(out->dims())); std::vector<int64_t> output_shape_vec(phi::vectorize(out->dims()));
...@@ -109,11 +101,8 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -109,11 +101,8 @@ void DeformableConvKernel(const Context& dev_ctx,
dilations, dilations,
deformable_groups, deformable_groups,
col_buffer_ptr); col_buffer_ptr);
DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize(phi::slice_ddim( DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize(
output_4d.dims(), phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size()));
1,
output_4d.dims().size())); // group * C/group * (im2step * H * W)
// get the product of pixel and weight // get the product of pixel and weight
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
...@@ -121,11 +110,8 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -121,11 +110,8 @@ void DeformableConvKernel(const Context& dev_ctx,
DenseTensor col_buffer_3d_slice = DenseTensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
DenseTensor output_3d_slice = DenseTensor output_3d_slice = output_3d.Slice(g, g + 1).Resize(
output_3d.Slice(g, g + 1).Resize(phi::slice_ddim( phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size()));
output_3d.dims(),
1,
output_3d.dims().size())); // C * ((im2col_step)*H*W))
blas.MatMul(weight_3d_slice, blas.MatMul(weight_3d_slice,
false, false,
col_buffer_3d_slice, col_buffer_3d_slice,
...@@ -135,29 +121,7 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -135,29 +121,7 @@ void DeformableConvKernel(const Context& dev_ctx,
T(0.0)); T(0.0));
} }
} }
// swap axis to get the right result when im2col_step is greater than 1
if (im2col_step > 1) {
std::vector<int> axis(4);
axis[0] = 0;
axis[1] = 2;
axis[2] = 1;
axis[3] = 3;
DenseTensor real_output_buffer = phi::Transpose<T, Context>(
dev_ctx,
output_4d.Resize(
phi::make_ddim({batch_size / im2col_step,
output_shape_vec[1],
im2col_step,
output_shape_vec[2] * output_shape_vec[3]})),
axis);
out->ShareDataWith(real_output_buffer)
.Resize(phi::make_ddim(output_shape_vec));
} else {
out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec)); out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec));
}
} }
} // namespace phi } // namespace phi
...@@ -19,8 +19,6 @@ import paddle.fluid as fluid ...@@ -19,8 +19,6 @@ import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard
paddle.enable_static()
def dmc_bilinear(data_im, height, width, h, w): def dmc_bilinear(data_im, height, width, h, w):
h_low = int(np.floor(h)) h_low = int(np.floor(h))
...@@ -60,7 +58,7 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param): ...@@ -60,7 +58,7 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param):
assert f_c * group == in_c assert f_c * group == in_c
assert np.mod(out_c, group) == 0 assert np.mod(out_c, group) == 0
stride, pad, dilation = conv_param['stride'], conv_param['pad'], \ stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
conv_param['dilation'] conv_param['dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
...@@ -285,69 +283,6 @@ class TestWithDouble(TestModulatedDeformableConvOp): ...@@ -285,69 +283,6 @@ class TestWithDouble(TestModulatedDeformableConvOp):
self.dtype = np.float64 self.dtype = np.float64
class TestFP16(unittest.TestCase):
def check_main(self, input_np, offset_np, filter_np, dtype):
paddle.disable_static()
input_np = input_np.astype(dtype)
offset_np = offset_np.astype(dtype)
filter_np = filter_np.astype(dtype)
input = paddle.to_tensor(input_np)
offset = paddle.to_tensor(offset_np)
filter = paddle.to_tensor(filter_np)
input.stop_gradient = False
offset.stop_gradient = False
filter.stop_gradient = False
y = paddle.vision.ops.deform_conv2d(input, offset, filter)
input_grad, offset_grad, filter_grad = paddle.grad(
y, [input, offset, filter])
y_np = y.numpy().astype('float32')
input_grad_np = input_grad.numpy().astype('float32')
offset_grad_np = offset_grad.numpy().astype('float32')
filter_grad_np = filter_grad.numpy().astype('float32')
paddle.enable_static()
return y_np, input_grad_np, offset_grad_np, filter_grad_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [40, f_c, 1, 1]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
input = np.random.random(self.input_size)
offset = 10 * np.random.random(self.offset_size)
filter = np.random.random(self.filter_size)
y_np_1, input_g_np_1, offset_g_np_1, filter_g_np_1 = self.check_main(
input, offset, filter, 'float16')
y_np_2, input_g_np_2, offset_g_np_2, filter_g_np_2 = self.check_main(
input, offset, filter, 'float32')
def assert_equal(x, y):
np.testing.assert_allclose(x, y, atol=3e-2)
assert_equal(y_np_1, y_np_2)
assert_equal(input_g_np_1, input_g_np_2)
assert_equal(offset_g_np_1, offset_g_np_2)
assert_equal(filter_g_np_1, filter_g_np_2)
class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase): class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase):
def test_error(self): def test_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册