未验证 提交 e91141fb 编写于 作者: W wangxinxin08 提交者: GitHub

fix problem of dcnv2 trt (#37345)

* modify code about fp16 of dcnv2 trt
上级 586bafbd
...@@ -70,7 +70,8 @@ class DeformableConvOpConverter : public OpConverter { ...@@ -70,7 +70,8 @@ class DeformableConvOpConverter : public OpConverter {
nvinfer1::Weights weights; nvinfer1::Weights weights;
weights.count = filter_tensor->numel(); weights.count = filter_tensor->numel();
if (engine_->WithFp16()) { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (with_fp16) {
auto half_filter_data = new half[filter_tensor->numel()]; auto half_filter_data = new half[filter_tensor->numel()];
for (int i = 0; i < filter_tensor->numel(); i++) { for (int i = 0; i < filter_tensor->numel(); i++) {
half_filter_data[i] = static_cast<half>(filter_data[i]); half_filter_data[i] = static_cast<half>(filter_data[i]);
...@@ -82,10 +83,9 @@ class DeformableConvOpConverter : public OpConverter { ...@@ -82,10 +83,9 @@ class DeformableConvOpConverter : public OpConverter {
weights.values = filter_data; weights.values = filter_data;
} }
auto* deformable_conv_plugin = new plugin::DeformableConvPlugin( auto* deformable_conv_plugin = new plugin::DeformableConvPlugin(
engine_->WithFp16() ? nvinfer1::DataType::kHALF with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
: nvinfer1::DataType::kFLOAT,
weights, kernel_dims, strides, paddings, dilations, groups, weights, kernel_dims, strides, paddings, dilations, groups,
deformable_groups, im2col_step); deformable_groups, im2col_step, with_fp16);
std::vector<nvinfer1::ITensor*> deformable_conv_inputs; std::vector<nvinfer1::ITensor*> deformable_conv_inputs;
deformable_conv_inputs.push_back(input_tensor); deformable_conv_inputs.push_back(input_tensor);
......
...@@ -71,11 +71,13 @@ DeformableConvPlugin::DeformableConvPlugin( ...@@ -71,11 +71,13 @@ DeformableConvPlugin::DeformableConvPlugin(
const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, const nvinfer1::DataType data_type, const nvinfer1::Weights& weights,
const std::vector<int>& kernel_dims, const std::vector<int>& strides, const std::vector<int>& kernel_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& paddings, const std::vector<int>& dilations,
const int groups, const int deformable_groups, const int im2col_step) const int groups, const int deformable_groups, const int im2col_step,
const bool with_fp16)
: data_type_(data_type), : data_type_(data_type),
groups_(groups), groups_(groups),
deformable_groups_(deformable_groups), deformable_groups_(deformable_groups),
im2col_step_(im2col_step) { im2col_step_(im2col_step),
with_fp16_(with_fp16) {
weights_ = copyToDevice(weights.values, weights.count); weights_ = copyToDevice(weights.values, weights.count);
kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(),
kernel_dims.cend()); kernel_dims.cend());
...@@ -101,11 +103,13 @@ DeformableConvPlugin::DeformableConvPlugin( ...@@ -101,11 +103,13 @@ DeformableConvPlugin::DeformableConvPlugin(
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& paddings, const std::vector<int>& dilations,
const int groups, const int deformable_groups, const int im2col_step, const int groups, const int deformable_groups, const int im2col_step,
const std::vector<int>& input_dim, const std::vector<int>& offset_dim, const std::vector<int>& input_dim, const std::vector<int>& offset_dim,
const std::vector<int>& mask_dim, const std::vector<int>& output_dim) const std::vector<int>& mask_dim, const std::vector<int>& output_dim,
const bool with_fp16)
: data_type_(data_type), : data_type_(data_type),
groups_(groups), groups_(groups),
deformable_groups_(deformable_groups), deformable_groups_(deformable_groups),
im2col_step_(im2col_step) { im2col_step_(im2col_step),
with_fp16_(with_fp16) {
weights_ = copyToDevice(weights.values, weights.count); weights_ = copyToDevice(weights.values, weights.count);
kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims_.insert(kernel_dims_.end(), kernel_dims.cbegin(),
kernel_dims.cend()); kernel_dims.cend());
...@@ -145,6 +149,7 @@ DeformableConvPlugin::DeformableConvPlugin(const void* data, size_t length) { ...@@ -145,6 +149,7 @@ DeformableConvPlugin::DeformableConvPlugin(const void* data, size_t length) {
DeserializeValue(&data, &length, &offset_dim_); DeserializeValue(&data, &length, &offset_dim_);
DeserializeValue(&data, &length, &mask_dim_); DeserializeValue(&data, &length, &mask_dim_);
DeserializeValue(&data, &length, &output_dim_); DeserializeValue(&data, &length, &output_dim_);
DeserializeValue(&data, &length, &with_fp16_);
} }
DeformableConvPlugin::~DeformableConvPlugin() { DeformableConvPlugin::~DeformableConvPlugin() {
...@@ -182,8 +187,19 @@ nvinfer1::Dims DeformableConvPlugin::getOutputDimensions( ...@@ -182,8 +187,19 @@ nvinfer1::Dims DeformableConvPlugin::getOutputDimensions(
bool DeformableConvPlugin::supportsFormat( bool DeformableConvPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT { nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT {
return ((type == data_type_ || type == nvinfer1::DataType::kINT32) && if (with_fp16_) {
format == nvinfer1::TensorFormat::kLINEAR); #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::TensorFormat::kLINEAR);
#else
return (type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else {
return (type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::TensorFormat::kLINEAR);
}
} }
size_t DeformableConvPlugin::getWorkspaceSize(int max_batch_size) const size_t DeformableConvPlugin::getWorkspaceSize(int max_batch_size) const
...@@ -207,7 +223,7 @@ int DeformableConvPlugin::enqueue(int batch_size, const void* const* inputs, ...@@ -207,7 +223,7 @@ int DeformableConvPlugin::enqueue(int batch_size, const void* const* inputs,
if (data_type_ == nvinfer1::DataType::kFLOAT) { if (data_type_ == nvinfer1::DataType::kFLOAT) {
enqueue_impl<float>(batch_size, inputs, outputs, workspace, stream); enqueue_impl<float>(batch_size, inputs, outputs, workspace, stream);
} else if (data_type_ == nvinfer1::DataType::kHALF) { } else if (data_type_ == nvinfer1::DataType::kHALF) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if TRT_PLUGIN_FP16_AVALIABLE
enqueue_impl<half>(batch_size, inputs, outputs, workspace, stream); enqueue_impl<half>(batch_size, inputs, outputs, workspace, stream);
#else #else
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -225,7 +241,9 @@ __device__ T kFloor(T x); ...@@ -225,7 +241,9 @@ __device__ T kFloor(T x);
template <> template <>
__device__ half kFloor<half>(half x) { __device__ half kFloor<half>(half x) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
return hfloor(x); return hfloor(x);
#endif
} }
template <> template <>
...@@ -235,35 +253,75 @@ __device__ float kFloor<float>(float x) { ...@@ -235,35 +253,75 @@ __device__ float kFloor<float>(float x) {
template <typename T> template <typename T>
__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, __device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
const int height, const int width, T h, T w) { const int height, const int width, T h, T w);
int h_low = kFloor<T>(h);
int w_low = kFloor<T>(w); template <>
__device__ float DmcnIm2colBilinear<float>(const float* bottom_data,
const int data_width,
const int height, const int width,
float h, float w) {
int h_low = kFloor<float>(h);
int w_low = kFloor<float>(w);
int h_high = h_low + 1; int h_high = h_low + 1;
int w_high = w_low + 1; int w_high = w_low + 1;
T h_low_t = h_low, w_low_t = w_low, one = 1.0f; float h_low_t = h_low, w_low_t = w_low, one = 1.0f;
T lh = h - h_low_t; float lh = h - h_low_t;
T lw = w - w_low_t; float lw = w - w_low_t;
T hh = one - lh, hw = one - lw; float hh = one - lh, hw = one - lw;
T v1 = 0; float v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];
T v2 = 0; float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high]; v2 = bottom_data[h_low * data_width + w_high];
T v3 = 0; float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low]; v3 = bottom_data[h_high * data_width + w_low];
T v4 = 0; float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high]; v4 = bottom_data[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val; return val;
} }
template <>
__device__ half DmcnIm2colBilinear<half>(const half* bottom_data,
const int data_width, const int height,
const int width, half h, half w) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int h_low = kFloor<half>(h);
int w_low = kFloor<half>(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
half h_low_t = h_low, w_low_t = w_low, one = 1.0f;
half lh = h - h_low_t;
half lw = w - w_low_t;
half hh = one - lh, hw = one - lw;
half v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];
half v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
half v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
half v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
half w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
half val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
#endif
}
template <typename T> template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel( __global__ void ModulatedDeformableIm2colGpuKernel(
const int nthreads, const T* data_im, const T* data_offset, const int nthreads, const T* data_im, const T* data_offset,
...@@ -272,11 +330,21 @@ __global__ void ModulatedDeformableIm2colGpuKernel( ...@@ -272,11 +330,21 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
const int stride_w, const int dilation_h, const int dilation_w, const int stride_w, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size, const int channel_per_deformable_group, const int batch_size,
const int num_channels, const int deformable_group, const int height_col, const int num_channels, const int deformable_group, const int height_col,
const int width_col, T* data_col) { const int width_col, T* data_col);
template <>
__global__ void ModulatedDeformableIm2colGpuKernel<float>(
const int nthreads, const float* data_im, const float* data_offset,
const float* data_mask, const int height, const int width,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col, float* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x; int offset = blockDim.x * gridDim.x;
T minus_one = -1.0f, height_t = height, width_t = width; float minus_one = -1.0f, height_t = height, width_t = width;
for (size_t i = index; i < nthreads; i += offset) { for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col; const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col; const int h_col = (i / width_col) % height_col;
...@@ -289,16 +357,16 @@ __global__ void ModulatedDeformableIm2colGpuKernel( ...@@ -289,16 +357,16 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
const int h_in = h_col * stride_h - pad_h; const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w; const int w_in = w_col * stride_w - pad_w;
T* data_col_ptr = float* data_col_ptr =
data_col + data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T* data_im_ptr = const float* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width; data_im + (b_col * num_channels + c_im) * height * width;
const T* data_offset_ptr = const float* data_offset_ptr =
data_offset + data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h * (b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col; kernel_w * height_col * width_col;
const T* data_mask_ptr = const float* data_mask_ptr =
data_mask + data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h * (b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col; kernel_w * height_col * width_col;
...@@ -313,17 +381,17 @@ __global__ void ModulatedDeformableIm2colGpuKernel( ...@@ -313,17 +381,17 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
const int data_mask_hw_ptr = const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col; ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr];
T val = 0; float val = 0;
T h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w; float h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w;
const T h_im = h_im_t + offset_h; const float h_im = h_im_t + offset_h;
const T w_im = w_im_t + offset_w; const float w_im = w_im_t + offset_w;
if (h_im > minus_one && w_im > minus_one && h_im < height_t && if (h_im > minus_one && w_im > minus_one && h_im < height_t &&
w_im < width_t) { w_im < width_t) {
val = DmcnIm2colBilinear<T>(data_im_ptr, width, height, width, h_im, val = DmcnIm2colBilinear<float>(data_im_ptr, width, height, width,
w_im); h_im, w_im);
} }
*data_col_ptr = val * mask; *data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col; data_col_ptr += batch_size * height_col * width_col;
...@@ -332,6 +400,76 @@ __global__ void ModulatedDeformableIm2colGpuKernel( ...@@ -332,6 +400,76 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
} }
} }
template <>
__global__ void ModulatedDeformableIm2colGpuKernel<half>(
const int nthreads, const half* data_im, const half* data_offset,
const half* data_mask, const int height, const int width,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col, half* data_col) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
half minus_one = -1.0f, height_t = height, width_t = width;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
const int b_col = (i / width_col) / height_col % batch_size;
const int c_im = (i / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
half* data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const half* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const half* data_offset_ptr =
data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const half* data_mask_ptr =
data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const half offset_h = data_offset_ptr[data_offset_h_ptr];
const half offset_w = data_offset_ptr[data_offset_w_ptr];
const half mask = data_mask_ptr[data_mask_hw_ptr];
half val = 0;
half h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w;
const half h_im = h_im_t + offset_h;
const half w_im = w_im_t + offset_w;
if (h_im > minus_one && w_im > minus_one && h_im < height_t &&
w_im < width_t) {
val = DmcnIm2colBilinear<half>(data_im_ptr, width, height, width,
h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
#endif
}
template <typename T> template <typename T>
void gemm_impl(cublasHandle_t handle, cublasOperation_t transa, void gemm_impl(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, const T* alpha, cublasOperation_t transb, int m, int n, int k, const T* alpha,
...@@ -353,8 +491,13 @@ void gemm_impl<half>(cublasHandle_t handle, cublasOperation_t transa, ...@@ -353,8 +491,13 @@ void gemm_impl<half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transb, int m, int n, int k,
const half* alpha, const half* A, int lda, const half* B, const half* alpha, const half* A, int lda, const half* B,
int ldb, const half* beta, half* C, int ldc) { int ldb, const half* beta, half* C, int ldc) {
#if TRT_PLUGIN_FP16_AVALIABLE
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc); B, ldb, beta, C, ldc);
#else
PADDLE_THROW(platform::errors::InvalidArgument(
"Current CUDA arch dose not support fp16. Please use fp32 instead."));
#endif
} }
template <typename T> template <typename T>
...@@ -436,6 +579,7 @@ size_t DeformableConvPlugin::getSerializationSize() const TRT_NOEXCEPT { ...@@ -436,6 +579,7 @@ size_t DeformableConvPlugin::getSerializationSize() const TRT_NOEXCEPT {
serialize_size += SerializedSize(offset_dim_); serialize_size += SerializedSize(offset_dim_);
serialize_size += SerializedSize(mask_dim_); serialize_size += SerializedSize(mask_dim_);
serialize_size += SerializedSize(output_dim_); serialize_size += SerializedSize(output_dim_);
serialize_size += SerializedSize(with_fp16_);
return serialize_size; return serialize_size;
} }
...@@ -454,6 +598,7 @@ void DeformableConvPlugin::serialize(void* buffer) const TRT_NOEXCEPT { ...@@ -454,6 +598,7 @@ void DeformableConvPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, offset_dim_); SerializeValue(&buffer, offset_dim_);
SerializeValue(&buffer, mask_dim_); SerializeValue(&buffer, mask_dim_);
SerializeValue(&buffer, output_dim_); SerializeValue(&buffer, output_dim_);
SerializeValue(&buffer, with_fp16_);
} }
void DeformableConvPlugin::destroy() TRT_NOEXCEPT {} void DeformableConvPlugin::destroy() TRT_NOEXCEPT {}
...@@ -521,10 +666,10 @@ void DeformableConvPlugin::configurePlugin( ...@@ -521,10 +666,10 @@ void DeformableConvPlugin::configurePlugin(
} }
nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT { nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT {
return new DeformableConvPlugin(data_type_, weights_, kernel_dims_, strides_, return new DeformableConvPlugin(
paddings_, dilations_, groups_, data_type_, weights_, kernel_dims_, strides_, paddings_, dilations_,
deformable_groups_, im2col_step_, input_dim_, groups_, deformable_groups_, im2col_step_, input_dim_, offset_dim_,
offset_dim_, mask_dim_, output_dim_); mask_dim_, output_dim_, with_fp16_);
} }
void DeformableConvPluginCreator::setPluginNamespace(const char* lib_namespace) void DeformableConvPluginCreator::setPluginNamespace(const char* lib_namespace)
...@@ -560,6 +705,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( ...@@ -560,6 +705,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
int groups = -1; int groups = -1;
int deformable_groups = -1; int deformable_groups = -1;
int im2col_step = -1; int im2col_step = -1;
bool with_fp16 = false;
for (int i = 0; i < fc->nbFields; ++i) { for (int i = 0; i < fc->nbFields; ++i) {
const std::string field_name(fc->fields[i].name); const std::string field_name(fc->fields[i].name);
...@@ -590,6 +736,8 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( ...@@ -590,6 +736,8 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
} else if (field_name.compare("weights")) { } else if (field_name.compare("weights")) {
weights.count = fc->fields[i].length; weights.count = fc->fields[i].length;
weights.values = fc->fields[i].data; weights.values = fc->fields[i].data;
} else if (field_name.compare("with_fp16")) {
with_fp16 = *static_cast<const bool*>(fc->fields[i].data);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown plugin field name [%s] in the DeformableConv TRT Plugin.", "Unknown plugin field name [%s] in the DeformableConv TRT Plugin.",
...@@ -599,7 +747,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin( ...@@ -599,7 +747,7 @@ nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
weights.type = data_type; weights.type = data_type;
return new DeformableConvPlugin(data_type, weights, kernel_dims, strides, return new DeformableConvPlugin(data_type, weights, kernel_dims, strides,
paddings, dilations, groups, paddings, dilations, groups,
deformable_groups, im2col_step); deformable_groups, im2col_step, with_fp16);
} }
nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::deserializePlugin( nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::deserializePlugin(
......
...@@ -30,18 +30,22 @@ namespace plugin { ...@@ -30,18 +30,22 @@ namespace plugin {
class DeformableConvPlugin : public nvinfer1::IPluginV2Ext { class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {
public: public:
explicit DeformableConvPlugin( explicit DeformableConvPlugin(const nvinfer1::DataType data_type,
const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, const nvinfer1::Weights& weights,
const std::vector<int>& kernel_dims, const std::vector<int>& strides, const std::vector<int>& kernel_dims,
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& strides,
const int groups, const int deformable_groups, const int im2col_step); const std::vector<int>& paddings,
const std::vector<int>& dilations,
const int groups, const int deformable_groups,
const int im2col_step, const bool with_fp16);
explicit DeformableConvPlugin( explicit DeformableConvPlugin(
const nvinfer1::DataType data_type, const nvinfer1::Weights& weights, const nvinfer1::DataType data_type, const nvinfer1::Weights& weights,
const std::vector<int>& kernel_dims, const std::vector<int>& strides, const std::vector<int>& kernel_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& paddings, const std::vector<int>& dilations,
const int groups, const int deformable_groups, const int im2col_step, const int groups, const int deformable_groups, const int im2col_step,
const std::vector<int>& input_dim, const std::vector<int>& offset_dim, const std::vector<int>& input_dim, const std::vector<int>& offset_dim,
const std::vector<int>& mask_dim, const std::vector<int>& output_dim); const std::vector<int>& mask_dim, const std::vector<int>& output_dim,
const bool with_fp16);
DeformableConvPlugin(const void* data, size_t length); DeformableConvPlugin(const void* data, size_t length);
~DeformableConvPlugin() override; ~DeformableConvPlugin() override;
...@@ -98,6 +102,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext { ...@@ -98,6 +102,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {
const nvinfer1::Weights& deviceWeights) const; const nvinfer1::Weights& deviceWeights) const;
nvinfer1::Weights deserializeToDevice(const void** hostBuffer, size_t count); nvinfer1::Weights deserializeToDevice(const void** hostBuffer, size_t count);
bool with_fp16_;
nvinfer1::DataType data_type_; nvinfer1::DataType data_type_;
nvinfer1::Weights weights_; nvinfer1::Weights weights_;
std::vector<int> kernel_dims_; std::vector<int> kernel_dims_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册