From e91141fb8543ae27bf6b5568d3a24ec297b95fbb Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 23 Nov 2021 16:00:42 +0800 Subject: [PATCH] fix problem of dcnv2 trt (#37345) * modify code about fp16 of dcnv2 trt --- .../tensorrt/convert/deformable_conv_op.cc | 8 +- .../plugin/deformable_conv_op_plugin.cu | 228 +++++++++++++++--- .../plugin/deformable_conv_op_plugin.h | 17 +- 3 files changed, 203 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc index 02d460ffa1c..d8534a4183b 100644 --- a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc @@ -70,7 +70,8 @@ class DeformableConvOpConverter : public OpConverter { nvinfer1::Weights weights; weights.count = filter_tensor->numel(); - if (engine_->WithFp16()) { + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + if (with_fp16) { auto half_filter_data = new half[filter_tensor->numel()]; for (int i = 0; i < filter_tensor->numel(); i++) { half_filter_data[i] = static_cast(filter_data[i]); @@ -82,10 +83,9 @@ class DeformableConvOpConverter : public OpConverter { weights.values = filter_data; } auto* deformable_conv_plugin = new plugin::DeformableConvPlugin( - engine_->WithFp16() ? nvinfer1::DataType::kHALF - : nvinfer1::DataType::kFLOAT, + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, weights, kernel_dims, strides, paddings, dilations, groups, - deformable_groups, im2col_step); + deformable_groups, im2col_step, with_fp16); std::vector deformable_conv_inputs; deformable_conv_inputs.push_back(input_tensor); diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu index 760f379eb07..0f32183c0fb 100644 --- a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu @@ -71,11 +71,13 @@ DeformableConvPlugin::DeformableConvPlugin( const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, const std::vector& kernel_dims, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, - const int groups, const int deformable_groups, const int im2col_step) + const int groups, const int deformable_groups, const int im2col_step, + const bool with_fp16) : data_type_(data_type), groups_(groups), deformable_groups_(deformable_groups), - im2col_step_(im2col_step) { + im2col_step_(im2col_step), + with_fp16_(with_fp16) { weights_ = copyToDevice(weights.values, weights.count); kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims.cend()); @@ -101,11 +103,13 @@ DeformableConvPlugin::DeformableConvPlugin( const std::vector& paddings, const std::vector& dilations, const int groups, const int deformable_groups, const int im2col_step, const std::vector& input_dim, const std::vector& offset_dim, - const std::vector& mask_dim, const std::vector& output_dim) + const std::vector& mask_dim, const std::vector& output_dim, + const bool with_fp16) : data_type_(data_type), groups_(groups), deformable_groups_(deformable_groups), - im2col_step_(im2col_step) { + im2col_step_(im2col_step), + with_fp16_(with_fp16) { weights_ = copyToDevice(weights.values, weights.count); kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims.cend()); @@ -145,6 +149,7 @@ DeformableConvPlugin::DeformableConvPlugin(const void* data, size_t length) { DeserializeValue(&data, &length, &offset_dim_); DeserializeValue(&data, &length, &mask_dim_); DeserializeValue(&data, &length, &output_dim_); + DeserializeValue(&data, &length, &with_fp16_); } DeformableConvPlugin::~DeformableConvPlugin() { @@ -182,8 +187,19 @@ nvinfer1::Dims DeformableConvPlugin::getOutputDimensions( bool DeformableConvPlugin::supportsFormat( nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT { - return ((type == data_type_ || type == nvinfer1::DataType::kINT32) && - format == nvinfer1::TensorFormat::kLINEAR); + if (with_fp16_) { +#ifdef TRT_PLUGIN_FP16_AVALIABLE + return (type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF) && + (format == nvinfer1::TensorFormat::kLINEAR); +#else + return (type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::TensorFormat::kLINEAR); +#endif + } else { + return (type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::TensorFormat::kLINEAR); + } } size_t DeformableConvPlugin::getWorkspaceSize(int max_batch_size) const @@ -207,7 +223,7 @@ int DeformableConvPlugin::enqueue(int batch_size, const void* const* inputs, if (data_type_ == nvinfer1::DataType::kFLOAT) { enqueue_impl(batch_size, inputs, outputs, workspace, stream); } else if (data_type_ == nvinfer1::DataType::kHALF) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if TRT_PLUGIN_FP16_AVALIABLE enqueue_impl(batch_size, inputs, outputs, workspace, stream); #else PADDLE_THROW(platform::errors::InvalidArgument( @@ -225,7 +241,9 @@ __device__ T kFloor(T x); template <> __device__ half kFloor(half x) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) return hfloor(x); +#endif } template <> @@ -235,35 +253,75 @@ __device__ float kFloor(float x) { template __device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, - const int height, const int width, T h, T w) { - int h_low = kFloor(h); - int w_low = kFloor(w); + const int height, const int width, T h, T w); + +template <> +__device__ float DmcnIm2colBilinear(const float* bottom_data, + const int data_width, + const int height, const int width, + float h, float w) { + int h_low = kFloor(h); + int w_low = kFloor(w); int h_high = h_low + 1; int w_high = w_low + 1; - T h_low_t = h_low, w_low_t = w_low, one = 1.0f; - T lh = h - h_low_t; - T lw = w - w_low_t; - T hh = one - lh, hw = one - lw; + float h_low_t = h_low, w_low_t = w_low, one = 1.0f; + float lh = h - h_low_t; + float lw = w - w_low_t; + float hh = one - lh, hw = one - lw; - T v1 = 0; + float v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; - T v2 = 0; + float v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; - T v3 = 0; + float v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; - T v4 = 0; + float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } +template <> +__device__ half DmcnIm2colBilinear(const half* bottom_data, + const int data_width, const int height, + const int width, half h, half w) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + int h_low = kFloor(h); + int w_low = kFloor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + half h_low_t = h_low, w_low_t = w_low, one = 1.0f; + half lh = h - h_low_t; + half lw = w - w_low_t; + half hh = one - lh, hw = one - lw; + + half v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; + half v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + half v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + half v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + half w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + half val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +#endif +} + template __global__ void ModulatedDeformableIm2colGpuKernel( const int nthreads, const T* data_im, const T* data_offset, @@ -272,11 +330,21 @@ __global__ void ModulatedDeformableIm2colGpuKernel( const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int num_channels, const int deformable_group, const int height_col, - const int width_col, T* data_col) { + const int width_col, T* data_col); + +template <> +__global__ void ModulatedDeformableIm2colGpuKernel( + const int nthreads, const float* data_im, const float* data_offset, + const float* data_mask, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, float* data_col) { int index = blockIdx.x * blockDim.x + threadIdx.x; int offset = blockDim.x * gridDim.x; - T minus_one = -1.0f, height_t = height, width_t = width; + float minus_one = -1.0f, height_t = height, width_t = width; for (size_t i = index; i < nthreads; i += offset) { const int w_col = i % width_col; const int h_col = (i / width_col) % height_col; @@ -289,16 +357,16 @@ __global__ void ModulatedDeformableIm2colGpuKernel( const int h_in = h_col * stride_h - pad_h; const int w_in = w_col * stride_w - pad_w; - T* data_col_ptr = + float* data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T* data_im_ptr = + const float* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; - const T* data_offset_ptr = + const float* data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const T* data_mask_ptr = + const float* data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; @@ -313,17 +381,17 @@ __global__ void ModulatedDeformableIm2colGpuKernel( const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - T val = 0; - T h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w; - const T h_im = h_im_t + offset_h; - const T w_im = w_im_t + offset_w; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float val = 0; + float h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w; + const float h_im = h_im_t + offset_h; + const float w_im = w_im_t + offset_w; if (h_im > minus_one && w_im > minus_one && h_im < height_t && w_im < width_t) { - val = DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, - w_im); + val = DmcnIm2colBilinear(data_im_ptr, width, height, width, + h_im, w_im); } *data_col_ptr = val * mask; data_col_ptr += batch_size * height_col * width_col; @@ -332,6 +400,76 @@ __global__ void ModulatedDeformableIm2colGpuKernel( } } +template <> +__global__ void ModulatedDeformableIm2colGpuKernel( + const int nthreads, const half* data_im, const half* data_offset, + const half* data_mask, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, half* data_col) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + + half minus_one = -1.0f, height_t = height, width_t = width; + for (size_t i = index; i < nthreads; i += offset) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + half* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const half* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const half* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const half* data_mask_ptr = + data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + + const half offset_h = data_offset_ptr[data_offset_h_ptr]; + const half offset_w = data_offset_ptr[data_offset_w_ptr]; + const half mask = data_mask_ptr[data_mask_hw_ptr]; + half val = 0; + half h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w; + const half h_im = h_im_t + offset_h; + const half w_im = w_im_t + offset_w; + if (h_im > minus_one && w_im > minus_one && h_im < height_t && + w_im < width_t) { + val = DmcnIm2colBilinear(data_im_ptr, width, height, width, + h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +#endif +} + template void gemm_impl(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const T* alpha, @@ -353,8 +491,13 @@ void gemm_impl(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const half* alpha, const half* A, int lda, const half* B, int ldb, const half* beta, half* C, int ldc) { +#if TRT_PLUGIN_FP16_AVALIABLE platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +#else + PADDLE_THROW(platform::errors::InvalidArgument( + "Current CUDA arch dose not support fp16. Please use fp32 instead.")); +#endif } template @@ -436,6 +579,7 @@ size_t DeformableConvPlugin::getSerializationSize() const TRT_NOEXCEPT { serialize_size += SerializedSize(offset_dim_); serialize_size += SerializedSize(mask_dim_); serialize_size += SerializedSize(output_dim_); + serialize_size += SerializedSize(with_fp16_); return serialize_size; } @@ -454,6 +598,7 @@ void DeformableConvPlugin::serialize(void* buffer) const TRT_NOEXCEPT { SerializeValue(&buffer, offset_dim_); SerializeValue(&buffer, mask_dim_); SerializeValue(&buffer, output_dim_); + SerializeValue(&buffer, with_fp16_); } void DeformableConvPlugin::destroy() TRT_NOEXCEPT {} @@ -521,10 +666,10 @@ void DeformableConvPlugin::configurePlugin( } nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT { - return new DeformableConvPlugin(data_type_, weights_, kernel_dims_, strides_, - paddings_, dilations_, groups_, - deformable_groups_, im2col_step_, input_dim_, - offset_dim_, mask_dim_, output_dim_); + return new DeformableConvPlugin( + data_type_, weights_, kernel_dims_, strides_, paddings_, dilations_, + groups_, deformable_groups_, im2col_step_, input_dim_, offset_dim_, + mask_dim_, output_dim_, with_fp16_); } void DeformableConvPluginCreator::setPluginNamespace(const char* lib_namespace) @@ -560,6 +705,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( int groups = -1; int deformable_groups = -1; int im2col_step = -1; + bool with_fp16 = false; for (int i = 0; i < fc->nbFields; ++i) { const std::string field_name(fc->fields[i].name); @@ -590,6 +736,8 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( } else if (field_name.compare("weights")) { weights.count = fc->fields[i].length; weights.values = fc->fields[i].data; + } else if (field_name.compare("with_fp16")) { + with_fp16 = *static_cast(fc->fields[i].data); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Unknown plugin field name [%s] in the DeformableConv TRT Plugin.", @@ -599,7 +747,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( weights.type = data_type; return new DeformableConvPlugin(data_type, weights, kernel_dims, strides, paddings, dilations, groups, - deformable_groups, im2col_step); + deformable_groups, im2col_step, with_fp16); } nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::deserializePlugin( diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h index 8ba19288ce5..46811b08d77 100644 --- a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h @@ -30,18 +30,22 @@ namespace plugin { class DeformableConvPlugin : public nvinfer1::IPluginV2Ext { public: - explicit DeformableConvPlugin( - const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, - const std::vector& kernel_dims, const std::vector& strides, - const std::vector& paddings, const std::vector& dilations, - const int groups, const int deformable_groups, const int im2col_step); + explicit DeformableConvPlugin(const nvinfer1::DataType data_type, + const nvinfer1::Weights& weights, + const std::vector& kernel_dims, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const int groups, const int deformable_groups, + const int im2col_step, const bool with_fp16); explicit DeformableConvPlugin( const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, const std::vector& kernel_dims, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, const int groups, const int deformable_groups, const int im2col_step, const std::vector& input_dim, const std::vector& offset_dim, - const std::vector& mask_dim, const std::vector& output_dim); + const std::vector& mask_dim, const std::vector& output_dim, + const bool with_fp16); DeformableConvPlugin(const void* data, size_t length); ~DeformableConvPlugin() override; @@ -98,6 +102,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext { const nvinfer1::Weights& deviceWeights) const; nvinfer1::Weights deserializeToDevice(const void** hostBuffer, size_t count); + bool with_fp16_; nvinfer1::DataType data_type_; nvinfer1::Weights weights_; std::vector kernel_dims_; -- GitLab