deformable_conv_op_plugin.cu 52.6 KB
Newer Older
W
wangxinxin08 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
#include <cuda_runtime.h>
17

W
wangxinxin08 已提交
18 19 20
#include <algorithm>
#include <cstdio>

21
#include "paddle/fluid/framework/eigen.h"
W
wangxinxin08 已提交
22
#include "paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h"
23 24 25 26 27
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
W
wangxinxin08 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaximumNumBlocks);
}

42 43
static inline int ConvOutputSize(
    int input_size, int filter_size, int dilation, int padding, int stride) {
W
wangxinxin08 已提交
44 45 46 47 48 49 50 51 52
  const int dkernel = dilation * (filter_size - 1) + 1;
  int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
  return output_size;
}

nvinfer1::Weights DeformableConvPlugin::copyToDevice(const void* hostData,
                                                     size_t count) {
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  void* deviceData;
53
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&deviceData, count * num_bytes));
54 55
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(
      deviceData, hostData, count * num_bytes, cudaMemcpyHostToDevice));
W
wangxinxin08 已提交
56 57 58 59 60 61
  return nvinfer1::Weights{data_type_, deviceData, int64_t(count)};
}

void DeformableConvPlugin::serializeFromDevice(
    void** hostBuffer, const nvinfer1::Weights& deviceWeights) const {
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
62 63 64 65
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(static_cast<char*>(*hostBuffer),
                                        deviceWeights.values,
                                        deviceWeights.count * num_bytes,
                                        cudaMemcpyDeviceToHost));
66 67
  *hostBuffer =
      reinterpret_cast<char*>(*hostBuffer) + deviceWeights.count * num_bytes;
W
wangxinxin08 已提交
68 69 70 71 72 73 74
}

nvinfer1::Weights DeformableConvPlugin::deserializeToDevice(
    const void** hostBuffer, size_t count) {
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  nvinfer1::Weights w =
      copyToDevice(static_cast<const char*>(*hostBuffer), count);
75
  *hostBuffer = reinterpret_cast<const char*>(*hostBuffer) + count * num_bytes;
W
wangxinxin08 已提交
76 77 78
  return w;
}

79 80 81 82 83 84 85 86 87 88
DeformableConvPlugin::DeformableConvPlugin(const nvinfer1::DataType data_type,
                                           const nvinfer1::Weights& weights,
                                           const std::vector<int>& kernel_dims,
                                           const std::vector<int>& strides,
                                           const std::vector<int>& paddings,
                                           const std::vector<int>& dilations,
                                           const int groups,
                                           const int deformable_groups,
                                           const int im2col_step,
                                           const bool with_fp16)
W
wangxinxin08 已提交
89 90 91
    : data_type_(data_type),
      groups_(groups),
      deformable_groups_(deformable_groups),
W
wangxinxin08 已提交
92 93
      im2col_step_(im2col_step),
      with_fp16_(with_fp16) {
W
wangxinxin08 已提交
94
  weights_ = copyToDevice(weights.values, weights.count);
95 96
  kernel_dims_.insert(
      kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims.cend());
W
wangxinxin08 已提交
97 98 99 100 101 102

  strides_.insert(strides_.end(), strides.cbegin(), strides.cend());
  paddings_.insert(paddings_.end(), paddings.cbegin(), paddings.cend());
  dilations_.insert(dilations_.end(), dilations.cbegin(), dilations.cend());
  PADDLE_ENFORCE_EQ(data_type_ == nvinfer1::DataType::kFLOAT ||
                        data_type_ == nvinfer1::DataType::kHALF,
103 104 105 106
                    true,
                    platform::errors::InvalidArgument(
                        "The DeformableConv TRT Plugin's input type "
                        "should be float or half."));
W
wangxinxin08 已提交
107
  PADDLE_ENFORCE_EQ(
108 109
      paddings_.size(),
      strides_.size(),
W
wangxinxin08 已提交
110 111
      platform::errors::InvalidArgument(
          "The size of paddings (%d) is not equal to the size of strides (%d).",
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
          paddings_.size(),
          strides_.size()));
}

DeformableConvPlugin::DeformableConvPlugin(const nvinfer1::DataType data_type,
                                           const nvinfer1::Weights& weights,
                                           const std::vector<int>& kernel_dims,
                                           const std::vector<int>& strides,
                                           const std::vector<int>& paddings,
                                           const std::vector<int>& dilations,
                                           const int groups,
                                           const int deformable_groups,
                                           const int im2col_step,
                                           const std::vector<int>& input_dim,
                                           const std::vector<int>& offset_dim,
                                           const std::vector<int>& mask_dim,
                                           const std::vector<int>& output_dim,
                                           const bool with_fp16)
W
wangxinxin08 已提交
130 131 132
    : data_type_(data_type),
      groups_(groups),
      deformable_groups_(deformable_groups),
W
wangxinxin08 已提交
133 134
      im2col_step_(im2col_step),
      with_fp16_(with_fp16) {
W
wangxinxin08 已提交
135
  weights_ = copyToDevice(weights.values, weights.count);
136 137
  kernel_dims_.insert(
      kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims.cend());
W
wangxinxin08 已提交
138 139 140 141 142 143 144 145 146 147

  strides_.insert(strides_.end(), strides.cbegin(), strides.cend());
  paddings_.insert(paddings_.end(), paddings.cbegin(), paddings.cend());
  dilations_.insert(dilations_.end(), dilations.cbegin(), dilations.cend());
  input_dim_.insert(input_dim_.end(), input_dim.cbegin(), input_dim.cend());
  offset_dim_.insert(offset_dim_.end(), offset_dim.cbegin(), offset_dim.cend());
  mask_dim_.insert(mask_dim_.end(), mask_dim.cbegin(), mask_dim.cend());
  output_dim_.insert(output_dim_.end(), output_dim.cbegin(), output_dim.cend());
  PADDLE_ENFORCE_EQ(data_type_ == nvinfer1::DataType::kFLOAT ||
                        data_type_ == nvinfer1::DataType::kHALF,
148 149 150 151
                    true,
                    platform::errors::InvalidArgument(
                        "The DeformableConv TRT Plugin's input type "
                        "should be float or half."));
W
wangxinxin08 已提交
152
  PADDLE_ENFORCE_EQ(
153 154
      paddings_.size(),
      strides_.size(),
W
wangxinxin08 已提交
155 156
      platform::errors::InvalidArgument(
          "The size of paddings (%d) is not equal to the size of strides (%d).",
157 158
          paddings_.size(),
          strides_.size()));
W
wangxinxin08 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
}

DeformableConvPlugin::DeformableConvPlugin(const void* data, size_t length) {
  DeserializeValue(&data, &length, &data_type_);
  DeserializeValue(&data, &length, &strides_);
  DeserializeValue(&data, &length, &paddings_);
  DeserializeValue(&data, &length, &dilations_);
  DeserializeValue(&data, &length, &groups_);
  DeserializeValue(&data, &length, &deformable_groups_);
  DeserializeValue(&data, &length, &im2col_step_);
  DeserializeValue(&data, &length, &kernel_dims_);
  int64_t count;
  DeserializeValue(&data, &length, &count);
  weights_ = deserializeToDevice(&data, count);
  DeserializeValue(&data, &length, &input_dim_);
  DeserializeValue(&data, &length, &offset_dim_);
  DeserializeValue(&data, &length, &mask_dim_);
  DeserializeValue(&data, &length, &output_dim_);
W
wangxinxin08 已提交
177
  DeserializeValue(&data, &length, &with_fp16_);
W
wangxinxin08 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
}

DeformableConvPlugin::~DeformableConvPlugin() {
  if (weights_.values) {
    cudaFree(const_cast<void*>(weights_.values));
    weights_.values = nullptr;
  }
}

const char* DeformableConvPlugin::getPluginType() const TRT_NOEXCEPT {
  return "deformable_conv_plugin";
}

const char* DeformableConvPlugin::getPluginVersion() const TRT_NOEXCEPT {
  return "1";
}

int DeformableConvPlugin::getNbOutputs() const TRT_NOEXCEPT { return 1; }

nvinfer1::Dims DeformableConvPlugin::getOutputDimensions(
    int index, const nvinfer1::Dims* inputs, int nb_input_dims) TRT_NOEXCEPT {
199 200
  PADDLE_ENFORCE_EQ(nb_input_dims,
                    3,
W
wangxinxin08 已提交
201 202 203 204 205 206
                    platform::errors::InvalidArgument(
                        "The number of inputs should be equal to 3, but got %d",
                        nb_input_dims));
  nvinfer1::Dims ret;
  ret.nbDims = inputs[0].nbDims;
  ret.d[0] = kernel_dims_[0];
207 208 209 210 211 212 213 214 215 216
  ret.d[1] = ConvOutputSize(inputs[0].d[1],
                            kernel_dims_[2],
                            dilations_[0],
                            paddings_[0],
                            strides_[0]);
  ret.d[2] = ConvOutputSize(inputs[0].d[2],
                            kernel_dims_[3],
                            dilations_[1],
                            paddings_[1],
                            strides_[1]);
W
wangxinxin08 已提交
217 218 219 220 221
  return ret;
}

bool DeformableConvPlugin::supportsFormat(
    nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT {
W
wangxinxin08 已提交
222 223
  if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
224
    return (type == nvinfer1::DataType::kHALF) &&
W
wangxinxin08 已提交
225 226 227 228 229 230 231 232 233
           (format == nvinfer1::TensorFormat::kLINEAR);
#else
    return (type == nvinfer1::DataType::kFLOAT) &&
           (format == nvinfer1::TensorFormat::kLINEAR);
#endif
  } else {
    return (type == nvinfer1::DataType::kFLOAT) &&
           (format == nvinfer1::TensorFormat::kLINEAR);
  }
W
wangxinxin08 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246
}

size_t DeformableConvPlugin::getWorkspaceSize(int max_batch_size) const
    TRT_NOEXCEPT {
  int c_i = input_dim_[0], h_i = input_dim_[1], w_i = input_dim_[2];
  int k_h = kernel_dims_[2], k_w = kernel_dims_[3];
  int c_o = output_dim_[0], h_o = output_dim_[1], w_o = output_dim_[2];
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  size_t data_col_size = static_cast<size_t>(c_i * k_h * k_w * im2col_step_ *
                                             h_o * w_o * num_bytes);
  return data_col_size;
}

247 248
int DeformableConvPlugin::enqueue(int batch_size,
                                  const void* const* inputs,
W
wangxinxin08 已提交
249
#if IS_TRT_VERSION_LT(8000)
250 251
                                  void** outputs,
                                  void* workspace,
W
wangxinxin08 已提交
252
#else
253 254
                                  void* const* outputs,
                                  void* workspace,
W
wangxinxin08 已提交
255 256 257 258 259
#endif
                                  cudaStream_t stream) TRT_NOEXCEPT {
  if (data_type_ == nvinfer1::DataType::kFLOAT) {
    enqueue_impl<float>(batch_size, inputs, outputs, workspace, stream);
  } else if (data_type_ == nvinfer1::DataType::kHALF) {
W
wangxinxin08 已提交
260
#if TRT_PLUGIN_FP16_AVALIABLE
W
wangxinxin08 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    enqueue_impl<half>(batch_size, inputs, outputs, workspace, stream);
#else
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Current CUDA arch dose not support fp16. Please use fp32 instead."));
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "The DeformableConv TRT Plugin's input type should be float or half."));
  }
  return cudaGetLastError() != cudaSuccess;
}

template <typename T>
__device__ T kFloor(T x);

template <>
__device__ half kFloor<half>(half x) {
W
wangxinxin08 已提交
278
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
W
wangxinxin08 已提交
279
  return hfloor(x);
W
wangxinxin08 已提交
280
#endif
W
wangxinxin08 已提交
281 282 283 284 285 286 287 288
}

template <>
__device__ float kFloor<float>(float x) {
  return floor(x);
}

template <typename T>
289 290 291 292 293 294
__device__ T DmcnIm2colBilinear(const T* bottom_data,
                                const int data_width,
                                const int height,
                                const int width,
                                T h,
                                T w);
W
wangxinxin08 已提交
295 296 297 298

template <>
__device__ float DmcnIm2colBilinear<float>(const float* bottom_data,
                                           const int data_width,
299 300 301 302
                                           const int height,
                                           const int width,
                                           float h,
                                           float w) {
W
wangxinxin08 已提交
303 304
  int h_low = kFloor<float>(h);
  int w_low = kFloor<float>(w);
W
wangxinxin08 已提交
305 306 307
  int h_high = h_low + 1;
  int w_high = w_low + 1;

W
wangxinxin08 已提交
308 309 310 311
  float h_low_t = h_low, w_low_t = w_low, one = 1.0f;
  float lh = h - h_low_t;
  float lw = w - w_low_t;
  float hh = one - lh, hw = one - lw;
W
wangxinxin08 已提交
312

W
wangxinxin08 已提交
313
  float v1 = 0;
W
wangxinxin08 已提交
314
  if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];
W
wangxinxin08 已提交
315
  float v2 = 0;
W
wangxinxin08 已提交
316 317
  if (h_low >= 0 && w_high <= width - 1)
    v2 = bottom_data[h_low * data_width + w_high];
W
wangxinxin08 已提交
318
  float v3 = 0;
W
wangxinxin08 已提交
319 320
  if (h_high <= height - 1 && w_low >= 0)
    v3 = bottom_data[h_high * data_width + w_low];
W
wangxinxin08 已提交
321
  float v4 = 0;
W
wangxinxin08 已提交
322 323 324
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = bottom_data[h_high * data_width + w_high];

W
wangxinxin08 已提交
325
  float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
W
wangxinxin08 已提交
326

W
wangxinxin08 已提交
327
  float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
W
wangxinxin08 已提交
328 329 330
  return val;
}

W
wangxinxin08 已提交
331 332
template <>
__device__ half DmcnIm2colBilinear<half>(const half* bottom_data,
333 334 335 336 337
                                         const int data_width,
                                         const int height,
                                         const int width,
                                         half h,
                                         half w) {
W
wangxinxin08 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  int h_low = kFloor<half>(h);
  int w_low = kFloor<half>(w);
  int h_high = h_low + 1;
  int w_high = w_low + 1;

  half h_low_t = h_low, w_low_t = w_low, one = 1.0f;
  half lh = h - h_low_t;
  half lw = w - w_low_t;
  half hh = one - lh, hw = one - lw;

  half v1 = 0;
  if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low];
  half v2 = 0;
  if (h_low >= 0 && w_high <= width - 1)
    v2 = bottom_data[h_low * data_width + w_high];
  half v3 = 0;
  if (h_high <= height - 1 && w_low >= 0)
    v3 = bottom_data[h_high * data_width + w_low];
  half v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = bottom_data[h_high * data_width + w_high];

  half w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  half val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
#endif
}

W
wangxinxin08 已提交
368 369
template <typename T>
__global__ void ModulatedDeformableIm2colGpuKernel(
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    const int nthreads,
    const T* data_im,
    const T* data_offset,
    const T* data_mask,
    const int height,
    const int width,
    const int kernel_h,
    const int kernel_w,
    const int pad_h,
    const int pad_w,
    const int stride_h,
    const int stride_w,
    const int dilation_h,
    const int dilation_w,
    const int channel_per_deformable_group,
    const int batch_size,
    const int num_channels,
    const int deformable_group,
    const int height_col,
    const int width_col,
    T* data_col);
W
wangxinxin08 已提交
391 392 393

template <>
__global__ void ModulatedDeformableIm2colGpuKernel<float>(
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
    const int nthreads,
    const float* data_im,
    const float* data_offset,
    const float* data_mask,
    const int height,
    const int width,
    const int kernel_h,
    const int kernel_w,
    const int pad_h,
    const int pad_w,
    const int stride_h,
    const int stride_w,
    const int dilation_h,
    const int dilation_w,
    const int channel_per_deformable_group,
    const int batch_size,
    const int num_channels,
    const int deformable_group,
    const int height_col,
    const int width_col,
    float* data_col) {
W
wangxinxin08 已提交
415 416 417
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;

W
wangxinxin08 已提交
418
  float minus_one = -1.0f, height_t = height, width_t = width;
W
wangxinxin08 已提交
419 420 421 422 423 424 425 426 427 428 429 430
  for (size_t i = index; i < nthreads; i += offset) {
    const int w_col = i % width_col;
    const int h_col = (i / width_col) % height_col;
    const int b_col = (i / width_col) / height_col % batch_size;
    const int c_im = (i / width_col / height_col) / batch_size;
    const int c_col = c_im * kernel_h * kernel_w;

    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;
    const int w_in = w_col * stride_w - pad_w;

W
wangxinxin08 已提交
431
    float* data_col_ptr =
W
wangxinxin08 已提交
432 433
        data_col +
        ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
W
wangxinxin08 已提交
434
    const float* data_im_ptr =
W
wangxinxin08 已提交
435
        data_im + (b_col * num_channels + c_im) * height * width;
W
wangxinxin08 已提交
436
    const float* data_offset_ptr =
437 438
        data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
                          kernel_h * kernel_w * height_col * width_col;
W
wangxinxin08 已提交
439
    const float* data_mask_ptr =
440 441
        data_mask + (b_col * deformable_group + deformable_group_index) *
                        kernel_h * kernel_w * height_col * width_col;
W
wangxinxin08 已提交
442 443 444 445 446 447 448 449 450 451 452

    for (int i = 0; i < kernel_h; ++i) {
      for (int j = 0; j < kernel_w; ++j) {
        const int data_offset_h_ptr =
            ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
        const int data_offset_w_ptr =
            ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
            w_col;
        const int data_mask_hw_ptr =
            ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;

W
wangxinxin08 已提交
453 454 455 456 457 458 459
        const float offset_h = data_offset_ptr[data_offset_h_ptr];
        const float offset_w = data_offset_ptr[data_offset_w_ptr];
        const float mask = data_mask_ptr[data_mask_hw_ptr];
        float val = 0;
        float h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w;
        const float h_im = h_im_t + offset_h;
        const float w_im = w_im_t + offset_w;
W
wangxinxin08 已提交
460 461
        if (h_im > minus_one && w_im > minus_one && h_im < height_t &&
            w_im < width_t) {
462 463
          val = DmcnIm2colBilinear<float>(
              data_im_ptr, width, height, width, h_im, w_im);
W
wangxinxin08 已提交
464 465 466 467 468 469 470 471
        }
        *data_col_ptr = val * mask;
        data_col_ptr += batch_size * height_col * width_col;
      }
    }
  }
}

W
wangxinxin08 已提交
472 473
template <>
__global__ void ModulatedDeformableIm2colGpuKernel<half>(
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
    const int nthreads,
    const half* data_im,
    const half* data_offset,
    const half* data_mask,
    const int height,
    const int width,
    const int kernel_h,
    const int kernel_w,
    const int pad_h,
    const int pad_w,
    const int stride_h,
    const int stride_w,
    const int dilation_h,
    const int dilation_w,
    const int channel_per_deformable_group,
    const int batch_size,
    const int num_channels,
    const int deformable_group,
    const int height_col,
    const int width_col,
    half* data_col) {
W
wangxinxin08 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;

  half minus_one = -1.0f, height_t = height, width_t = width;
  for (size_t i = index; i < nthreads; i += offset) {
    const int w_col = i % width_col;
    const int h_col = (i / width_col) % height_col;
    const int b_col = (i / width_col) / height_col % batch_size;
    const int c_im = (i / width_col / height_col) / batch_size;
    const int c_col = c_im * kernel_h * kernel_w;

    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;
    const int w_in = w_col * stride_w - pad_w;

    half* data_col_ptr =
        data_col +
        ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
    const half* data_im_ptr =
        data_im + (b_col * num_channels + c_im) * height * width;
    const half* data_offset_ptr =
518 519
        data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
                          kernel_h * kernel_w * height_col * width_col;
W
wangxinxin08 已提交
520
    const half* data_mask_ptr =
521 522
        data_mask + (b_col * deformable_group + deformable_group_index) *
                        kernel_h * kernel_w * height_col * width_col;
W
wangxinxin08 已提交
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542

    for (int i = 0; i < kernel_h; ++i) {
      for (int j = 0; j < kernel_w; ++j) {
        const int data_offset_h_ptr =
            ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
        const int data_offset_w_ptr =
            ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
            w_col;
        const int data_mask_hw_ptr =
            ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;

        const half offset_h = data_offset_ptr[data_offset_h_ptr];
        const half offset_w = data_offset_ptr[data_offset_w_ptr];
        const half mask = data_mask_ptr[data_mask_hw_ptr];
        half val = 0;
        half h_im_t = h_in + i * dilation_h, w_im_t = w_in + j * dilation_w;
        const half h_im = h_im_t + offset_h;
        const half w_im = w_im_t + offset_w;
        if (h_im > minus_one && w_im > minus_one && h_im < height_t &&
            w_im < width_t) {
543 544
          val = DmcnIm2colBilinear<half>(
              data_im_ptr, width, height, width, h_im, w_im);
W
wangxinxin08 已提交
545 546 547 548 549 550 551 552 553
        }
        *data_col_ptr = val * mask;
        data_col_ptr += batch_size * height_col * width_col;
      }
    }
  }
#endif
}

554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
template <typename T>
struct CUDATypeTraits;

template <>
struct CUDATypeTraits<half> {
  typedef platform::float16 TYPE;
};

template <>
struct CUDATypeTraits<float> {
  typedef float TYPE;
};

template <typename T>
void gemm_impl_new(int m,
                   int n,
                   int k,
                   const T* alpha,
                   const T* A,
                   const T* B,
                   const T* beta,
                   T* C) {
  auto* device_ctx = static_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
  const phi::GPUContext& dev_ctx = *device_ctx;

  typedef typename CUDATypeTraits<T>::TYPE run_type;
  auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(dev_ctx);
  // note: here calls GEMM like cblas, so do not use like cblas
  blas.GEMM(CblasNoTrans,
            CblasNoTrans,
            n,
            m,
            k,
            static_cast<run_type>(*alpha),
            reinterpret_cast<run_type*>(const_cast<T*>(B)),
            reinterpret_cast<run_type*>(const_cast<T*>(A)),
            static_cast<run_type>(*beta),
            reinterpret_cast<run_type*>(C));
}

W
wangxinxin08 已提交
595
template <typename T>
596 597 598 599 600 601 602 603 604 605 606 607 608
void gemm_impl(cublasHandle_t handle,
               cublasOperation_t transa,
               cublasOperation_t transb,
               int m,
               int n,
               int k,
               const T* alpha,
               const T* A,
               int lda,
               const T* B,
               int ldb,
               const T* beta,
               T* C,
W
wangxinxin08 已提交
609 610 611
               int ldc);

template <>
612 613 614 615 616 617 618 619 620 621 622 623 624
void gemm_impl<float>(cublasHandle_t handle,
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const float* alpha,
                      const float* A,
                      int lda,
                      const float* B,
                      int ldb,
                      const float* beta,
                      float* C,
W
wangxinxin08 已提交
625
                      int ldc) {
626 627
  platform::dynload::cublasSgemm(
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
W
wangxinxin08 已提交
628 629 630
}

template <>
631 632 633 634 635 636 637 638 639 640 641 642 643 644
void gemm_impl<half>(cublasHandle_t handle,
                     cublasOperation_t transa,
                     cublasOperation_t transb,
                     int m,
                     int n,
                     int k,
                     const half* alpha,
                     const half* A,
                     int lda,
                     const half* B,
                     int ldb,
                     const half* beta,
                     half* C,
                     int ldc) {
W
wangxinxin08 已提交
645
#if TRT_PLUGIN_FP16_AVALIABLE
646 647
  platform::dynload::cublasHgemm(
      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
W
wangxinxin08 已提交
648 649 650 651
#else
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Current CUDA arch dose not support fp16. Please use fp32 instead."));
#endif
W
wangxinxin08 已提交
652 653 654 655 656
}

template <typename T>
int DeformableConvPlugin::enqueue_impl(int batch_size,
                                       const void* const* inputs,
657 658
                                       void* const* outputs,
                                       void* workspace,
W
wangxinxin08 已提交
659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
                                       cudaStream_t stream) {
  const T* input = reinterpret_cast<const T*>(inputs[0]);
  const T* offset = reinterpret_cast<const T*>(inputs[1]);
  const T* mask = reinterpret_cast<const T*>(inputs[2]);
  const T* filter = reinterpret_cast<const T*>(weights_.values);
  T* output = reinterpret_cast<T*>(outputs[0]);

  int c_i = input_dim_[0], h_i = input_dim_[1], w_i = input_dim_[2];
  int k_h = kernel_dims_[2], k_w = kernel_dims_[3];
  int c_o = output_dim_[0], h_o = output_dim_[1], w_o = output_dim_[2];

  int input_stride = c_i * h_i * w_i;
  int offset_stride = offset_dim_[0] * offset_dim_[1] * offset_dim_[2];
  int mask_stride = mask_dim_[0] * mask_dim_[1] * mask_dim_[2];
  int output_stride = c_o * h_o * w_o;

  int M = c_o / groups_;
  int N = im2col_step_ * h_o * w_o;
  int K = c_i * k_h * k_w / groups_;

  // c_i / deformable_groups
  int channel_per_deformable_group = c_i / deformable_groups_;
  // c_i * im2col_step * h_o * w_o
  int num_kernels = c_i * im2col_step_ * h_o * w_o;

  int blocks = NumBlocks(num_kernels);
  int threads = kNumCUDAThreads;

  T alpha = static_cast<T>(1.0f);
  T beta = static_cast<T>(0.0f);

  for (int i = 0; i < batch_size / im2col_step_; ++i) {
    const T* data_im = input + i * im2col_step_ * input_stride;
    const T* data_offset = offset + i * im2col_step_ * offset_stride;
    const T* data_mask = mask + i * im2col_step_ * mask_stride;
    T* data_col = reinterpret_cast<T*>(workspace);

696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
    ModulatedDeformableIm2colGpuKernel<T>
        <<<blocks, threads, 0, stream>>>(num_kernels,
                                         data_im,
                                         data_offset,
                                         data_mask,
                                         h_i,
                                         w_i,
                                         k_h,
                                         k_w,
                                         paddings_[0],
                                         paddings_[1],
                                         strides_[0],
                                         strides_[1],
                                         dilations_[0],
                                         dilations_[1],
                                         channel_per_deformable_group,
                                         im2col_step_,
                                         c_i,
                                         deformable_groups_,
                                         h_o,
                                         w_o,
                                         data_col);
W
wangxinxin08 已提交
718 719 720 721 722

    for (int g = 0; g < groups_; ++g) {
      const T* weight = filter + g * M * K;
      const T* col = data_col + g * K * N;
      T* out = output + i * im2col_step_ * output_stride + g * M * N;
723 724 725 726 727 728 729 730 731 732 733 734 735 736
      gemm_impl<T>(cublasHandle_,
                   CUBLAS_OP_N,
                   CUBLAS_OP_N,
                   N,
                   M,
                   K,
                   &alpha,
                   col,
                   N,
                   weight,
                   K,
                   &beta,
                   out,
                   N);
W
wangxinxin08 已提交
737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
    }
  }
  return 0;
}

int DeformableConvPlugin::initialize() TRT_NOEXCEPT { return 0; }

void DeformableConvPlugin::terminate() TRT_NOEXCEPT {}

size_t DeformableConvPlugin::getSerializationSize() const TRT_NOEXCEPT {
  size_t serialize_size = 0;
  serialize_size += SerializedSize(data_type_);
  serialize_size += SerializedSize(strides_);
  serialize_size += SerializedSize(paddings_);
  serialize_size += SerializedSize(dilations_);
  serialize_size += SerializedSize(groups_);
  serialize_size += SerializedSize(deformable_groups_);
  serialize_size += SerializedSize(im2col_step_);
  serialize_size += SerializedSize(kernel_dims_);
  serialize_size += SerializedSize(weights_.count);
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  serialize_size += weights_.count * num_bytes;
  serialize_size += SerializedSize(input_dim_);
  serialize_size += SerializedSize(offset_dim_);
  serialize_size += SerializedSize(mask_dim_);
  serialize_size += SerializedSize(output_dim_);
W
wangxinxin08 已提交
763
  serialize_size += SerializedSize(with_fp16_);
W
wangxinxin08 已提交
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
  return serialize_size;
}

void DeformableConvPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
  SerializeValue(&buffer, data_type_);
  SerializeValue(&buffer, strides_);
  SerializeValue(&buffer, paddings_);
  SerializeValue(&buffer, dilations_);
  SerializeValue(&buffer, groups_);
  SerializeValue(&buffer, deformable_groups_);
  SerializeValue(&buffer, im2col_step_);
  SerializeValue(&buffer, kernel_dims_);
  SerializeValue(&buffer, weights_.count);
  serializeFromDevice(&buffer, weights_);
  SerializeValue(&buffer, input_dim_);
  SerializeValue(&buffer, offset_dim_);
  SerializeValue(&buffer, mask_dim_);
  SerializeValue(&buffer, output_dim_);
W
wangxinxin08 已提交
782
  SerializeValue(&buffer, with_fp16_);
W
wangxinxin08 已提交
783 784 785 786 787 788 789 790 791 792 793 794 795 796
}

void DeformableConvPlugin::destroy() TRT_NOEXCEPT {}

void DeformableConvPlugin::setPluginNamespace(const char* lib_namespace)
    TRT_NOEXCEPT {
  namespace_ = std::string(lib_namespace);
}

const char* DeformableConvPlugin::getPluginNamespace() const TRT_NOEXCEPT {
  return namespace_.c_str();
}

nvinfer1::DataType DeformableConvPlugin::getOutputDataType(
797 798
    int index,
    const nvinfer1::DataType* input_type,
W
wangxinxin08 已提交
799
    int nb_inputs) const TRT_NOEXCEPT {
800
  return input_type[0];
W
wangxinxin08 已提交
801 802 803
}

bool DeformableConvPlugin::isOutputBroadcastAcrossBatch(
804 805
    int output_index,
    const bool* input_is_broadcast,
W
wangxinxin08 已提交
806 807 808 809 810 811 812 813 814 815
    int nb_inputs) const TRT_NOEXCEPT {
  return false;
}

bool DeformableConvPlugin::canBroadcastInputAcrossBatch(int input_index) const
    TRT_NOEXCEPT {
  return false;
}

void DeformableConvPlugin::attachToContext(
816 817
    cudnnContext* cudnnContext,
    cublasContext* cublasContext,
W
wangxinxin08 已提交
818 819 820 821 822
    nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {
  cublasHandle_ = cublasContext;
}

void DeformableConvPlugin::configurePlugin(
823 824 825 826
    const nvinfer1::Dims* input_dims,
    int nb_inputs,
    const nvinfer1::Dims* output_dims,
    int nb_outputs,
W
wangxinxin08 已提交
827
    const nvinfer1::DataType* input_types,
828 829 830 831
    const nvinfer1::DataType* output_types,
    const bool* input_is_broadcast,
    const bool* output_is_broadcast,
    nvinfer1::PluginFormat float_format,
W
wangxinxin08 已提交
832 833
    int max_batct_size) TRT_NOEXCEPT {
  PADDLE_ENFORCE_EQ(
834 835
      nb_inputs,
      3,
W
wangxinxin08 已提交
836 837 838
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 3, but got %d", nb_inputs));
  PADDLE_ENFORCE_EQ(
839 840
      nb_outputs,
      1,
W
wangxinxin08 已提交
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 1, but got %d", nb_outputs));

  for (int i = 0; i < input_dims[0].nbDims; i++) {
    input_dim_.push_back(input_dims[0].d[i]);
  }
  for (int i = 0; i < input_dims[1].nbDims; i++) {
    offset_dim_.push_back(input_dims[1].d[i]);
  }
  for (int i = 0; i < input_dims[2].nbDims; i++) {
    mask_dim_.push_back(input_dims[2].d[i]);
  }
  for (int i = 0; i < output_dims[0].nbDims; i++) {
    output_dim_.push_back(output_dims[0].d[i]);
  }
}

nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT {
859 860 861 862 863 864 865 866 867 868 869 870 871 872
  return new DeformableConvPlugin(data_type_,
                                  weights_,
                                  kernel_dims_,
                                  strides_,
                                  paddings_,
                                  dilations_,
                                  groups_,
                                  deformable_groups_,
                                  im2col_step_,
                                  input_dim_,
                                  offset_dim_,
                                  mask_dim_,
                                  output_dim_,
                                  with_fp16_);
W
wangxinxin08 已提交
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907
}

void DeformableConvPluginCreator::setPluginNamespace(const char* lib_namespace)
    TRT_NOEXCEPT {
  namespace_ = std::string(lib_namespace);
}

const char* DeformableConvPluginCreator::getPluginNamespace() const
    TRT_NOEXCEPT {
  return namespace_.c_str();
}

const char* DeformableConvPluginCreator::getPluginName() const TRT_NOEXCEPT {
  return "deformable_conv_plugin";
}

const char* DeformableConvPluginCreator::getPluginVersion() const TRT_NOEXCEPT {
  return "1";
}

const nvinfer1::PluginFieldCollection*
DeformableConvPluginCreator::getFieldNames() TRT_NOEXCEPT {
  return &field_collection_;
}

nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::createPlugin(
    const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
  const nvinfer1::PluginField* fields = fc->fields;

  nvinfer1::DataType data_type;
  std::vector<int> strides, paddings, dilations, kernel_dims;
  nvinfer1::Weights weights;
  int groups = -1;
  int deformable_groups = -1;
  int im2col_step = -1;
W
wangxinxin08 已提交
908
  bool with_fp16 = false;
W
wangxinxin08 已提交
909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938

  for (int i = 0; i < fc->nbFields; ++i) {
    const std::string field_name(fc->fields[i].name);
    if (field_name.compare("data_type") == 0) {
      data_type = *static_cast<const nvinfer1::DataType*>(fc->fields[i].data);
    } else if (field_name.compare("strides")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      strides.insert(strides.end(), data, data + length);
    } else if (field_name.compare("paddings")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      paddings.insert(paddings.end(), data, data + length);
    } else if (field_name.compare("dilations")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      dilations.insert(dilations.end(), data, data + length);
    } else if (field_name.compare("groups")) {
      groups = *static_cast<const int*>(fc->fields[i].data);
    } else if (field_name.compare("deformable_groups")) {
      deformable_groups = *static_cast<const int*>(fc->fields[i].data);
    } else if (field_name.compare("im2col_step")) {
      im2col_step = *static_cast<const int*>(fc->fields[i].data);
    } else if (field_name.compare("kernel_dims")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      kernel_dims.insert(kernel_dims.end(), data, data + length);
    } else if (field_name.compare("weights")) {
      weights.count = fc->fields[i].length;
      weights.values = fc->fields[i].data;
W
wangxinxin08 已提交
939 940
    } else if (field_name.compare("with_fp16")) {
      with_fp16 = *static_cast<const bool*>(fc->fields[i].data);
W
wangxinxin08 已提交
941 942 943 944 945 946 947
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unknown plugin field name [%s] in the DeformableConv TRT Plugin.",
          field_name));
    }
  }
  weights.type = data_type;
948 949 950 951 952 953 954 955 956 957
  return new DeformableConvPlugin(data_type,
                                  weights,
                                  kernel_dims,
                                  strides,
                                  paddings,
                                  dilations,
                                  groups,
                                  deformable_groups,
                                  im2col_step,
                                  with_fp16);
W
wangxinxin08 已提交
958 959 960
}

nvinfer1::IPluginV2Ext* DeformableConvPluginCreator::deserializePlugin(
961 962
    const char* name,
    const void* serial_data,
W
wangxinxin08 已提交
963 964 965 966 967 968
    size_t serial_length) TRT_NOEXCEPT {
  auto plugin = new DeformableConvPlugin(serial_data, serial_length);
  plugin->setPluginNamespace(namespace_.c_str());
  return plugin;
}

969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
#if IS_TRT_VERSION_GE(6000)

DeformableConvPluginDynamic::DeformableConvPluginDynamic(
    const nvinfer1::DataType data_type,
    const nvinfer1::Weights& weights,
    const std::vector<int>& kernel_dims,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    const std::vector<int>& dilations,
    const int groups,
    const int deformable_groups,
    const int im2col_step,
    const bool with_fp16)
    : data_type_(data_type),
      groups_(groups),
      deformable_groups_(deformable_groups),
      im2col_step_(im2col_step),
      with_fp16_(with_fp16) {
  weights_ = copyToDevice(weights.values, weights.count);
  kernel_dims_.insert(
      kernel_dims_.end(), kernel_dims.cbegin(), kernel_dims.cend());

  strides_.insert(strides_.end(), strides.cbegin(), strides.cend());
  paddings_.insert(paddings_.end(), paddings.cbegin(), paddings.cend());
  dilations_.insert(dilations_.end(), dilations.cbegin(), dilations.cend());
  PADDLE_ENFORCE_EQ(data_type_ == nvinfer1::DataType::kFLOAT ||
                        data_type_ == nvinfer1::DataType::kHALF,
                    true,
                    platform::errors::InvalidArgument(
                        "The DeformableConv TRT Plugin's input type "
                        "should be float or half."));
  PADDLE_ENFORCE_EQ(
      paddings_.size(),
      strides_.size(),
      platform::errors::InvalidArgument(
          "The size of paddings (%d) is not equal to the size of strides (%d).",
          paddings_.size(),
          strides_.size()));
}
DeformableConvPluginDynamic::DeformableConvPluginDynamic(const void* data,
                                                         size_t length) {
  DeserializeValue(&data, &length, &data_type_);
  DeserializeValue(&data, &length, &strides_);
  DeserializeValue(&data, &length, &paddings_);
  DeserializeValue(&data, &length, &dilations_);
  DeserializeValue(&data, &length, &groups_);
  DeserializeValue(&data, &length, &deformable_groups_);
  DeserializeValue(&data, &length, &im2col_step_);
  DeserializeValue(&data, &length, &kernel_dims_);
  int64_t count;
  DeserializeValue(&data, &length, &count);
  weights_ = deserializeToDevice(&data, count);
  DeserializeValue(&data, &length, &with_fp16_);
}

DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {
  if (weights_.values) {
    cudaFree(const_cast<void*>(weights_.values));
    weights_.values = nullptr;
  }
}

nvinfer1::Weights DeformableConvPluginDynamic::copyToDevice(
    const void* hostData, size_t count) {
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  void* deviceData;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&deviceData, count * num_bytes));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(
      deviceData, hostData, count * num_bytes, cudaMemcpyHostToDevice));
  return nvinfer1::Weights{data_type_, deviceData, int64_t(count)};
}

void DeformableConvPluginDynamic::serializeFromDevice(
    void** hostBuffer, const nvinfer1::Weights& deviceWeights) const {
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(static_cast<char*>(*hostBuffer),
                                        deviceWeights.values,
                                        deviceWeights.count * num_bytes,
                                        cudaMemcpyDeviceToHost));
  *hostBuffer =
      reinterpret_cast<char*>(*hostBuffer) + deviceWeights.count * num_bytes;
}

nvinfer1::Weights DeformableConvPluginDynamic::deserializeToDevice(
    const void** hostBuffer, size_t count) {
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  nvinfer1::Weights w =
      copyToDevice(static_cast<const char*>(*hostBuffer), count);
  *hostBuffer = reinterpret_cast<const char*>(*hostBuffer) + count * num_bytes;
  return w;
}

int DeformableConvPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }

size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
  size_t serialize_size = 0;
  serialize_size += SerializedSize(data_type_);
  serialize_size += SerializedSize(strides_);
  serialize_size += SerializedSize(paddings_);
  serialize_size += SerializedSize(dilations_);
  serialize_size += SerializedSize(groups_);
  serialize_size += SerializedSize(deformable_groups_);
  serialize_size += SerializedSize(im2col_step_);
  serialize_size += SerializedSize(kernel_dims_);
  serialize_size += SerializedSize(weights_.count);
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  serialize_size += weights_.count * num_bytes;
  serialize_size += SerializedSize(with_fp16_);
  return serialize_size;
}

void DeformableConvPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
  SerializeValue(&buffer, data_type_);
  SerializeValue(&buffer, strides_);
  SerializeValue(&buffer, paddings_);
  SerializeValue(&buffer, dilations_);
  SerializeValue(&buffer, groups_);
  SerializeValue(&buffer, deformable_groups_);
  SerializeValue(&buffer, im2col_step_);
  SerializeValue(&buffer, kernel_dims_);
  SerializeValue(&buffer, weights_.count);
  serializeFromDevice(&buffer, weights_);
  SerializeValue(&buffer, with_fp16_);
}

size_t DeformableConvPluginDynamic::getWorkspaceSize(
    const nvinfer1::PluginTensorDesc* inputs,
    int nbInputs,
    const nvinfer1::PluginTensorDesc* outputs,
    int nbOutputs) const TRT_NOEXCEPT {
  PADDLE_ENFORCE_EQ(
      nbInputs,
      3,
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 3, but got %d", nbInputs));
  PADDLE_ENFORCE_EQ(
      nbOutputs,
      1,
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 1, but got %d", nbOutputs));
  int c_i = inputs[0].dims.d[1], h_i = inputs[0].dims.d[2],
      w_i = inputs[0].dims.d[3];
  int k_h = kernel_dims_[2], k_w = kernel_dims_[3];
  int c_o = outputs[0].dims.d[1], h_o = outputs[0].dims.d[2],
      w_o = outputs[0].dims.d[3];
  int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
  size_t data_col_size = static_cast<size_t>(c_i * k_h * k_w * im2col_step_ *
                                             h_o * w_o * num_bytes);
  return data_col_size;
}

nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions(
    int output_index,
    const nvinfer1::DimsExprs* inputDims,
    int nb_inputs,
    nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
  PADDLE_ENFORCE_EQ(
      nb_inputs,
      3,
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 3, but got %d", nb_inputs));
  nvinfer1::DimsExprs ret;
  ret.nbDims = inputDims[0].nbDims;
  ret.d[0] = inputDims[0].d[0];
  auto ConvOutputSizeDynamic =
      [&](const nvinfer1::IDimensionExpr* input_size,
          int filter_size,
          int dilation,
          int padding,
          int stride) -> const nvinfer1::IDimensionExpr* {
    auto dkernel = dilation * (filter_size - 1) + 1;
    return expr_builder.operation(
        nvinfer1::DimensionOperation::kSUM,
        *expr_builder.operation(
            nvinfer1::DimensionOperation::kFLOOR_DIV,
            *expr_builder.operation(
                nvinfer1::DimensionOperation::kSUM,
                *input_size,
                *expr_builder.constant(2 * padding - dkernel)),
            *expr_builder.constant(stride)),
        *expr_builder.constant(1));
  };

  ret.d[1] = expr_builder.constant(kernel_dims_[0]);

  ret.d[2] = ConvOutputSizeDynamic(inputDims[0].d[2],
                                   kernel_dims_[2],
                                   dilations_[0],
                                   paddings_[0],
                                   strides_[0]);
  ret.d[3] = ConvOutputSizeDynamic(inputDims[0].d[3],
                                   kernel_dims_[3],
                                   dilations_[1],
                                   paddings_[1],
                                   strides_[1]);
  return ret;
}

bool DeformableConvPluginDynamic::supportsFormatCombination(
    int pos,
    const nvinfer1::PluginTensorDesc* in_out,
    int nb_inputs,
    int nb_outputs) TRT_NOEXCEPT {
  PADDLE_ENFORCE_NOT_NULL(
      in_out,
      platform::errors::InvalidArgument(
          "The input of groupnorm plugin shoule not be nullptr."));
  PADDLE_ENFORCE_LT(
      pos,
      nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos,
                                        nb_inputs + nb_outputs));
  const nvinfer1::PluginTensorDesc& in = in_out[pos];
  if (pos == 0) {
    if (with_fp16_) {
      return ((in.type == nvinfer1::DataType::kHALF) &&
              ((in.format == nvinfer1::PluginFormat::kLINEAR) ||
               in.format == nvinfer1::PluginFormat::kHWC8));
    } else {
      return (in.type == nvinfer1::DataType::kFLOAT) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    }
  }
  const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType(
    int index,
    const nvinfer1::DataType* input_types,
    int nb_inputs) const TRT_NOEXCEPT {
  PADDLE_ENFORCE_EQ(index,
                    0,
                    platform::errors::InvalidArgument(
                        "The Elementwise Plugin only has one input, so the "
                        "index value should be 0, but get %d.",
                        index));
  return input_types[0];
}

void DeformableConvPluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* in,
    int nbInputs,
    const nvinfer1::DynamicPluginTensorDesc* out,
    int nbOutputs) TRT_NOEXCEPT {
  PADDLE_ENFORCE_EQ(
      nbInputs,
      3,
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 3, but got %d", nbInputs));
  PADDLE_ENFORCE_EQ(
      nbOutputs,
      1,
      platform::errors::InvalidArgument(
          "The number of inputs should be equal to 1, but got %d", nbOutputs));
}

int DeformableConvPluginDynamic::enqueue(
    const nvinfer1::PluginTensorDesc* input_desc,
    const nvinfer1::PluginTensorDesc* output_desc,
    const void* const* inputs,
    void* const* outputs,
    void* workspace,
    cudaStream_t stream) TRT_NOEXCEPT {
  if (data_type_ == nvinfer1::DataType::kFLOAT) {
    enqueue_impl<float>(
        input_desc, output_desc, inputs, outputs, workspace, stream);
  } else if (data_type_ == nvinfer1::DataType::kHALF) {
#if TRT_PLUGIN_FP16_AVALIABLE
    enqueue_impl<half>(
        input_desc, output_desc, inputs, outputs, workspace, stream);
#else
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Current CUDA arch dose not support fp16. Please use fp32 instead."));
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "The DeformableConv TRT Plugin's input type should be float or half."));
  }
  return cudaGetLastError() != cudaSuccess;
}

template <typename T>
int DeformableConvPluginDynamic::enqueue_impl(
    const nvinfer1::PluginTensorDesc* input_desc,
    const nvinfer1::PluginTensorDesc* output_desc,
    const void* const* inputs,
    void* const* outputs,
    void* workspace,
    cudaStream_t stream) {
  const auto& input_dims = input_desc[0].dims;
  const auto& offset_dims = input_desc[1].dims;
  const auto& mask_dims = input_desc[2].dims;
  const auto& output_dims = output_desc[0].dims;

  int batch_size = input_dims.d[0];
  const T* input = reinterpret_cast<const T*>(inputs[0]);
  const T* offset = reinterpret_cast<const T*>(inputs[1]);
  const T* mask = reinterpret_cast<const T*>(inputs[2]);
  const T* filter = reinterpret_cast<const T*>(weights_.values);
  T* output = reinterpret_cast<T*>(outputs[0]);

  int c_i = input_dims.d[1], h_i = input_dims.d[2], w_i = input_dims.d[3];
  int k_h = kernel_dims_[2], k_w = kernel_dims_[3];
  int c_o = output_dims.d[1], h_o = output_dims.d[2], w_o = output_dims.d[3];

  int input_stride = c_i * h_i * w_i;
  int offset_stride = offset_dims.d[1] * offset_dims.d[2] * offset_dims.d[3];
  int mask_stride = mask_dims.d[1] * mask_dims.d[2] * mask_dims.d[3];
  int output_stride = c_o * h_o * w_o;

  int M = c_o / groups_;
  int N = im2col_step_ * h_o * w_o;
  int K = c_i * k_h * k_w / groups_;

  // c_i / deformable_groups
  int channel_per_deformable_group = c_i / deformable_groups_;
  // c_i * im2col_step * h_o * w_o
  int num_kernels = c_i * im2col_step_ * h_o * w_o;

  int blocks = NumBlocks(num_kernels);
  int threads = kNumCUDAThreads;

  const T alpha = static_cast<T>(1.0f);
  const T beta = static_cast<T>(0.0f);

  for (int i = 0; i < batch_size / im2col_step_; ++i) {
    const T* data_im = input + i * im2col_step_ * input_stride;
    const T* data_offset = offset + i * im2col_step_ * offset_stride;
    const T* data_mask = mask + i * im2col_step_ * mask_stride;
    T* data_col = reinterpret_cast<T*>(workspace);

    ModulatedDeformableIm2colGpuKernel<T>
        <<<blocks, threads, 0, stream>>>(num_kernels,
                                         data_im,
                                         data_offset,
                                         data_mask,
                                         h_i,
                                         w_i,
                                         k_h,
                                         k_w,
                                         paddings_[0],
                                         paddings_[1],
                                         strides_[0],
                                         strides_[1],
                                         dilations_[0],
                                         dilations_[1],
                                         channel_per_deformable_group,
                                         im2col_step_,
                                         c_i,
                                         deformable_groups_,
                                         h_o,
                                         w_o,
                                         data_col);

    for (int g = 0; g < groups_; ++g) {
      const T* weight = filter + g * M * K;
      const T* col = data_col + g * K * N;
      T* out = output + i * im2col_step_ * output_stride + g * M * N;

      gemm_impl_new<T>(N, M, K, &alpha, col, weight, &beta, out);
    }
  }
  return 0;
}

nvinfer1::IPluginV2Ext* DeformableConvPluginDynamicCreator::deserializePlugin(
    const char* name,
    const void* serial_data,
    size_t serial_length) TRT_NOEXCEPT {
  return new DeformableConvPluginDynamic(serial_data, serial_length);
}

nvinfer1::IPluginV2Ext* DeformableConvPluginDynamicCreator::createPlugin(
    const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
  const nvinfer1::PluginField* fields = fc->fields;

  nvinfer1::DataType data_type;
  std::vector<int> strides, paddings, dilations, kernel_dims;
  nvinfer1::Weights weights;
  int groups = -1;
  int deformable_groups = -1;
  int im2col_step = -1;
  bool with_fp16 = false;

  for (int i = 0; i < fc->nbFields; ++i) {
    const std::string field_name(fc->fields[i].name);
    if (field_name.compare("data_type") == 0) {
      data_type = *static_cast<const nvinfer1::DataType*>(fc->fields[i].data);
    } else if (field_name.compare("strides")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      strides.insert(strides.end(), data, data + length);
    } else if (field_name.compare("paddings")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      paddings.insert(paddings.end(), data, data + length);
    } else if (field_name.compare("dilations")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      dilations.insert(dilations.end(), data, data + length);
    } else if (field_name.compare("groups")) {
      groups = *static_cast<const int*>(fc->fields[i].data);
    } else if (field_name.compare("deformable_groups")) {
      deformable_groups = *static_cast<const int*>(fc->fields[i].data);
    } else if (field_name.compare("im2col_step")) {
      im2col_step = *static_cast<const int*>(fc->fields[i].data);
    } else if (field_name.compare("kernel_dims")) {
      const int length = fc->fields[i].length;
      const int* data = static_cast<const int*>(fc->fields[i].data);
      kernel_dims.insert(kernel_dims.end(), data, data + length);
    } else if (field_name.compare("weights")) {
      weights.count = fc->fields[i].length;
      weights.values = fc->fields[i].data;
    } else if (field_name.compare("with_fp16")) {
      with_fp16 = *static_cast<const bool*>(fc->fields[i].data);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unknown plugin field name [%s] in the DeformableConv TRT Plugin.",
          field_name));
    }
  }
  weights.type = data_type;
  return new DeformableConvPlugin(data_type,
                                  weights,
                                  kernel_dims,
                                  strides,
                                  paddings,
                                  dilations,
                                  groups,
                                  deformable_groups,
                                  im2col_step,
                                  with_fp16);
}

#endif

W
wangxinxin08 已提交
1409 1410 1411 1412
}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle