split_op_plugin.cu 11.3 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

H
hjchen2 已提交
15
#include <cuda_fp16.h>
16
#include <thrust/device_vector.h>
H
hjchen2 已提交
17
#include <algorithm>
18

N
nhzlx 已提交
19 20 21 22 23
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
24
namespace plugin {
N
nhzlx 已提交
25

H
hjchen2 已提交
26
template <typename T>
27 28 29 30 31 32 33 34 35 36
__device__ int upper_bound(T const* vals, int n, T const& key) {
  int i = 0;
  while (n > 0) {
    int m = n / 2;
    int j = i + m;
    if (!(key < vals[j])) {
      i = j + 1;
      n -= m + 1;
    } else {
      n = m;
H
hjchen2 已提交
37 38
    }
  }
39
  return i;
H
hjchen2 已提交
40 41
}

42
nvinfer1::Dims SplitPlugin::getOutputDimensions(
43
    int index, const nvinfer1::Dims* input_dims, int num_inputs) TRT_NOEXCEPT {
44 45 46 47 48 49 50 51 52 53 54
  PADDLE_ENFORCE_EQ(num_inputs, 1,
                    platform::errors::InvalidArgument(
                        "Invalid number of inputs of split TRT plugin. "
                        "Expected 1, received %d.",
                        num_inputs));
  PADDLE_ENFORCE_LT(
      index, this->getNbOutputs(),
      platform::errors::InvalidArgument(
          "Index of output should be less than the total number of outputs in "
          "split TensorRT plugin. Received index = %d >= total outputs = %d",
          index, this->getNbOutputs()));
55 56

  nvinfer1::Dims output_dims = input_dims[0];
57
  output_dims.d[axis_] = output_length_.at(index);
N
nhzlx 已提交
58 59 60
  return output_dims;
}

61 62 63 64 65 66 67 68
void SplitPlugin::shareData(const SplitPlugin* another) {
  outer_rows_ = another->outer_rows_;
  inner_cols_ = another->inner_cols_;
  same_shape_ = another->same_shape_;
  axis_shape_ = another->axis_shape_;
  segment_offsets_ = another->segment_offsets_;
}

69
int SplitPlugin::initialize() TRT_NOEXCEPT {
70 71 72 73 74
  PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS,
                    platform::errors::InvalidArgument(
                        "Axis dimension exceeds max dimension in TensorRT. "
                        "Received axis = %d > MAX_DIMS = %d",
                        axis_, nvinfer1::Dims::MAX_DIMS));
H
hjchen2 已提交
75 76 77 78 79 80 81 82 83 84 85
  // notice input dims is [C, H, W]
  nvinfer1::Dims dims = this->getInputDims(0);
  outer_rows_ = 1;
  inner_cols_ = 1;
  for (int i = 0; i < axis_; ++i) {
    outer_rows_ *= dims.d[i];
  }
  for (int i = axis_ + 1; i < dims.nbDims; ++i) {
    inner_cols_ *= dims.d[i];
  }
  same_shape_ = true;
N
nhzlx 已提交
86 87
  std::vector<int> segment_offsets(1, 0);
  for (int i = 0; i < this->getNbOutputs(); ++i) {
H
hjchen2 已提交
88 89 90
    if (output_length_[i] != output_length_[0]) {
      same_shape_ = false;
    }
91
    segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
N
nhzlx 已提交
92
  }
93
  axis_shape_ = dims.d[axis_];
H
hjchen2 已提交
94 95 96 97
  segment_offsets_ = std::move(segment_offsets);
  return 0;
}

98
// nothing to release according to initialize
99
void SplitPlugin::terminate() TRT_NOEXCEPT {}
100

101 102
// The following part of the code refers to onnx-tensorrt
// https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu
H
hjchen2 已提交
103
template <typename T>
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
__global__ void split_kernel(int nsegment,
                             int const* __restrict__ segment_offsets,
                             T const* __restrict__ idata, T* const* odatas,
                             int inner_cols, int axis_shape, int outer_rows) {
  int x0 = threadIdx.x + blockIdx.x * blockDim.x;
  int src_y0 = threadIdx.y + blockIdx.y * blockDim.y;
  int z0 = threadIdx.z + blockIdx.z * blockDim.z;
  for (int z = z0; z < outer_rows; z += blockDim.z * gridDim.z) {
    for (int src_y = src_y0; src_y < axis_shape;
         src_y += blockDim.y * gridDim.y) {
      for (int x = x0; x < inner_cols; x += blockDim.x * gridDim.x) {
        int segment = upper_bound(segment_offsets, nsegment, src_y) - 1;
        int dst_y = src_y - segment_offsets[segment];
        int dst_ny = segment_offsets[segment + 1] - segment_offsets[segment];
        odatas[segment][x + inner_cols * (dst_y + dst_ny * z)] =
            idata[x + inner_cols * (src_y + axis_shape * z)];
      }
    }
N
nhzlx 已提交
122 123 124 125
  }
}

int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
126
#if IS_TRT_VERSION_LT(8000)
N
nhzlx 已提交
127
                         void** outputs, void* workspace, cudaStream_t stream) {
128 129
#else
                         void* const* outputs, void* workspace,
130
                         cudaStream_t stream) TRT_NOEXCEPT {
131
#endif
132 133 134 135 136
  // this two thrust variables decalared here , not with in .h
  // to avoid compiling error in cuda 11.6
  thrust::device_vector<int> d_segment_offsets = segment_offsets_;
  thrust::device_vector<float*> d_output_ptrs;
  d_output_ptrs.resize(segment_offsets_.size(), nullptr);
137
  const int* d_segment_offsets_ptr =
138
      thrust::raw_pointer_cast(&d_segment_offsets[0]);
H
hjchen2 已提交
139
  float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
140
  float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
141
  float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
142
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
143
      output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*),
144
      cudaMemcpyHostToDevice, stream));
145 146 147 148 149 150 151 152 153

  int outer_rows = outer_rows_ * batchSize;

  dim3 block(32, 16);
  dim3 grid(std::min((inner_cols_ - 1) / block.x + 1, 65535u),
            std::min((axis_shape_ - 1) / block.y + 1, 65535u),
            std::min((outer_rows_ - 1) / block.z + 1, 65535u));

  split_kernel<<<grid, block, 0, stream>>>(
154
      segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
155
      inner_cols_, axis_shape_, outer_rows);
N
nhzlx 已提交
156 157 158
  return cudaGetLastError() != cudaSuccess;
}

159 160
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
161
int SplitPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
162

163
size_t SplitPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
164 165 166
  return SerializedSize(axis_) + SerializedSize(output_length_) +
         SerializedSize(with_fp16_);
}
167

168
void SplitPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
169 170 171 172
  SerializeValue(&buffer, axis_);
  SerializeValue(&buffer, output_length_);
  SerializeValue(&buffer, with_fp16_);
}
173 174 175

nvinfer1::DimsExprs SplitPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
176
    nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
  PADDLE_ENFORCE_EQ(nb_inputs, 1,
                    platform::errors::InvalidArgument(
                        "The Split plugin should be only one input."));
  PADDLE_ENFORCE_LT(output_index, output_length_.size(),
                    platform::errors::InvalidArgument(
                        "When GetOutputDimensions, the index(%d) should not "
                        "greater the num(%d) of the outpus.",
                        output_index, output_length_.size()));

  nvinfer1::DimsExprs output_dims = inputs[0];
  output_dims.d[axis_] = expr_builder.constant(output_length_.at(output_index));

  return output_dims;
}

bool SplitPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
194
    int nb_outputs) TRT_NOEXCEPT {
195 196
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
197
                  "The input of split plugin should not be nullptr."));
198 199 200 201 202 203 204 205 206 207

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc& in = in_out[pos];
  if (pos == 0) {
208 209 210 211 212 213 214 215
    if (with_fp16_) {
      return (in.type == nvinfer1::DataType::kFLOAT ||
              in.type == nvinfer1::DataType::kHALF) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    } else {
      return (in.type == nvinfer1::DataType::kFLOAT) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    }
216 217 218 219 220 221 222
  }
  const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType SplitPluginDynamic::getOutputDataType(
223 224
    int index, const nvinfer1::DataType* input_types,
    int nb_inputs) const TRT_NOEXCEPT {
225 226 227 228 229 230
  return input_types[0];
}

int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
                                const nvinfer1::PluginTensorDesc* output_desc,
                                const void* const* inputs, void* const* outputs,
231 232
                                void* workspace,
                                cudaStream_t stream) TRT_NOEXCEPT {
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
  auto input_dims = input_desc[0].dims;
  int outer_rows = 1;
  int inner_cols = 1;
  // with batch
  for (int i = 0; i < axis_; i++) {
    outer_rows *= input_dims.d[i];
  }

  for (int i = axis_ + 1; i < input_dims.nbDims; i++) {
    inner_cols *= input_dims.d[i];
  }

  std::vector<int> segment_offsets(1, 0);
  for (int i = 0; i < this->getNbOutputs(); i++) {
    segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
  }
  int axis_shape = input_dims.d[axis_];
  thrust::device_vector<int> d_segment_offsets = segment_offsets;
  const int* d_segment_offsets_ptr =
      thrust::raw_pointer_cast(&d_segment_offsets[0]);

  dim3 block(32, 16);
  dim3 grid(std::min((inner_cols - 1) / block.x + 1, 65535u),
            std::min((axis_shape - 1) / block.y + 1, 65535u),
            std::min((outer_rows - 1) / block.z + 1, 65535u));

  auto input_type = input_desc[0].type;
  if (input_type == nvinfer1::DataType::kFLOAT) {
261
    VLOG(1) << "TRT Plugin DataType selected. Split-->fp32";
262 263 264 265 266 267 268
    thrust::device_vector<float*> d_output_ptrs;
    d_output_ptrs.resize(this->getNbOutputs(), nullptr);

    const float* input_ptr = static_cast<const float*>(inputs[0]);
    float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
    float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);

269
    PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
270 271
        output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*),
        cudaMemcpyHostToDevice, stream));
272 273 274 275 276

    split_kernel<<<grid, block, 0, stream>>>(
        d_segment_offsets.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
        inner_cols, axis_shape, outer_rows);
  } else if (input_type == nvinfer1::DataType::kHALF) {
277
    VLOG(1) << "TRT Plugin DataType selected. Split-->fp16";
278 279 280 281 282 283 284
    thrust::device_vector<half*> d_output_ptrs;
    d_output_ptrs.resize(this->getNbOutputs(), nullptr);

    const half* input_ptr = static_cast<const half*>(inputs[0]);
    half* const* h_odatas = reinterpret_cast<half* const*>(outputs);
    half** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);

285
    PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
286 287
        output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(half*),
        cudaMemcpyHostToDevice, stream));
288 289 290 291 292 293 294 295 296

    split_kernel<<<grid, block, 0, stream>>>(
        d_segment_offsets.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
        inner_cols, axis_shape, outer_rows);
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

297 298 299 300
}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle