slice_op_plugin.cu 13.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh>  // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

template <typename T>
__global__ void SliceKernel(int num, int dims, const T *input,
                            const int *offsets_info, T *output) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  extern __shared__ int shared_data[];

34 35
  for (int i = threadIdx.x; i < dims * 3; i += blockDim.x) {
    shared_data[i] = offsets_info[i];
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  }
  __syncthreads();

  if (idx < num) {
    int t_idx = idx;
    int in_idx = 0;
    for (int i = dims - 1; i >= 0; i--) {
      // output_shape
      auto t = t_idx % shared_data[i * 3 + 1];
      // out offset
      auto s = t + shared_data[i * 3];
      // input_seg_offset
      in_idx = in_idx + shared_data[i * 3 + 2] * s;
      t_idx = t_idx / shared_data[i * 3 + 1];
    }
    output[idx] = input[in_idx];
  }
}

55
SlicePlugin::SlicePlugin(std::vector<int> starts, std::vector<int> ends,
56 57 58
                         std::vector<int> axes, bool with_fp16)
    : starts_(starts), ends_(ends), axes_(axes) {
  with_fp16_ = with_fp16;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  cudaEventCreate(&copy_event_);
  cudaStreamCreate(&copy_stream_);
}

SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
  deserializeBase(serial_data, serial_length);
  DeserializeValue(&serial_data, &serial_length, &starts_);
  DeserializeValue(&serial_data, &serial_length, &ends_);
  DeserializeValue(&serial_data, &serial_length, &axes_);
  cudaEventCreate(&copy_event_);
  cudaStreamCreate(&copy_stream_);
}

SlicePlugin::~SlicePlugin() {
  cudaStreamDestroy(copy_stream_);
  cudaEventDestroy(copy_event_);
  cudaFree(offset_temp_data_);
}

78
SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
79
  return new SlicePlugin(starts_, ends_, axes_, with_fp16_);
80 81
}

82 83
bool SlicePlugin::supportsFormat(
    nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
84 85 86
  if (with_fp16_) {
    return ((type == nvinfer1::DataType::kFLOAT ||
             type == nvinfer1::DataType::kHALF) &&
87
            (format == nvinfer1::PluginFormat::kLINEAR));
88 89
  } else {
    return ((type == nvinfer1::DataType::kFLOAT) &&
90
            (format == nvinfer1::PluginFormat::kLINEAR));
91
  }
92 93
}

94 95
nvinfer1::Dims SlicePlugin::getOutputDimensions(
    int index, const nvinfer1::Dims *inputs, int nb_input_dims) TRT_NOEXCEPT {
96 97 98 99 100 101 102 103 104 105 106
  auto in_dims = inputs[0];
  nvinfer1::Dims out_dims = in_dims;
  for (size_t i = 0; i < axes_.size(); i++) {
    int start = starts_[i];
    int end = ends_[i];
    out_dims.d[axes_[i] - 1] = end - start;
  }
  return out_dims;
}

int SlicePlugin::enqueue(int batch_size, const void *const *inputs,
107
#if IS_TRT_VERSION_LT(8000)
108
                         void **outputs, void *workspace, cudaStream_t stream) {
109 110
#else
                         void *const *outputs, void *workspace,
111
                         cudaStream_t stream) TRT_NOEXCEPT {
112
#endif
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  auto input_dims = getInputDims(0);

  // notice input dims is [C, H, W], add input batch dim here
  auto out_dims = getOutputDimensions(0, &input_dims, 1);
  input_dims.nbDims += 1;
  out_dims.nbDims += 1;
  for (auto i = input_dims.nbDims; i > 0; --i) {
    input_dims.d[i] = input_dims.d[i - 1];
    out_dims.d[i] = out_dims.d[i - 1];
  }
  input_dims.d[0] = batch_size;
  out_dims.d[0] = batch_size;

  auto num_dims = input_dims.nbDims;
  size_t out_num = ProductDim(out_dims);

  std::vector<int> seg_offsets;
  std::vector<int> offsets;
  std::vector<int> extends;

  offsets.resize(num_dims);
  extends.resize(num_dims);
  seg_offsets.resize(num_dims);

  seg_offsets[num_dims - 1] = 1;
  for (int i = num_dims - 2; i >= 0; i--) {
    seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1];
  }
  for (size_t i = 0; i < num_dims; ++i) {
    offsets[i] = 0;
    extends[i] = out_dims.d[i];
  }
  for (size_t i = 0; i < axes_.size(); ++i) {
    offsets[axes_[i]] = starts_[i];
  }

  std::vector<int> offset_info;
  for (size_t i = 0; i < num_dims; ++i) {
    offset_info.push_back(offsets[i]);
    offset_info.push_back(extends[i]);
    offset_info.push_back(seg_offsets[i]);
  }

  if (offset_temp_data_ == nullptr) {
    cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
  }

  cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
                  sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice,
                  copy_stream_);

  cudaEventRecord(copy_event_, copy_stream_);
  cudaStreamWaitEvent(stream, copy_event_, 0);

  int threads = 256;
  int blocks = (out_num + threads - 1) / threads;
  auto input_type = getDataType();
  if (input_type == nvinfer1::DataType::kFLOAT) {
171
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32";
172 173 174 175 176
    const float *input1 = static_cast<const float *>(inputs[0]);
    float *output = static_cast<float *>(outputs[0]);
    SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
        out_num, num_dims, input1, offset_temp_data_, output);
  } else if (input_type == nvinfer1::DataType::kHALF) {
177
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16";
178 179 180 181 182 183 184 185 186 187 188
    const half *input1 = static_cast<const half *>(inputs[0]);
    half *output = static_cast<half *>(outputs[0]);
    SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
        out_num, num_dims, input1, offset_temp_data_, output);
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "The Slice TRT Plugin's input type should be float or half."));
  }
  return cudaGetLastError() != cudaSuccess;
}

189
size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT {
190 191
  return getBaseSerializationSize() + SerializedSize(getPluginType()) +
         SerializedSize(starts_) + SerializedSize(ends_) +
192
         SerializedSize(axes_);
193 194
}

195
void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
196 197 198 199 200 201 202 203 204 205 206
  SerializeValue(&buffer, getPluginType());
  serializeBase(buffer);
  SerializeValue(&buffer, starts_);
  SerializeValue(&buffer, ends_);
  SerializeValue(&buffer, axes_);
}

// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
                                       std::vector<int> ends,
207 208 209
                                       std::vector<int> axes, bool with_fp16)
    : starts_(starts), ends_(ends), axes_(axes) {
  with_fp16_ = with_fp16;
210 211 212 213 214 215 216 217 218
  cudaEventCreate(&copy_event_);
  cudaStreamCreate(&copy_stream_);
}

SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
                                       size_t serialLength) {
  DeserializeValue(&serialData, &serialLength, &starts_);
  DeserializeValue(&serialData, &serialLength, &ends_);
  DeserializeValue(&serialData, &serialLength, &axes_);
219
  DeserializeValue(&serialData, &serialLength, &with_fp16_);
220 221 222 223
  cudaEventCreate(&copy_event_);
  cudaStreamCreate(&copy_stream_);
}

224
void SlicePluginDynamic::destroy() TRT_NOEXCEPT {
225 226 227 228 229 230
  cudaStreamDestroy(copy_stream_);
  cudaEventDestroy(copy_event_);
  cudaFree(offset_temp_data_);
  delete this;
}

231
int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
232

233
size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
234
  size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
235
                SerializedSize(axes_) + SerializedSize(with_fp16_);
236

237 238 239
  return size;
}

240
void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
241 242 243
  SerializeValue(&buffer, starts_);
  SerializeValue(&buffer, ends_);
  SerializeValue(&buffer, axes_);
244
  SerializeValue(&buffer, with_fp16_);
245
}
246 247 248

nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
249
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
250
  auto in_dims = inputs[0];
251
  nvinfer1::DimsExprs ret = in_dims;
252 253 254 255
  // start, ends should greater 0
  for (size_t i = 0; i < axes_.size(); i++) {
    int start = starts_[i];
    int end = ends_[i];
S
Shang Zhizhou 已提交
256 257 258 259 260 261 262 263
#if IS_TRT_VERSION_GE(7200)
    ret.d[axes_[i]] = expr_builder.operation(
        nvinfer1::DimensionOperation::kSUB,
        *expr_builder.operation(nvinfer1::DimensionOperation::kMIN,
                                *expr_builder.constant(ends_[i]),
                                *in_dims.d[axes_[i]]),
        *expr_builder.constant(start));
#else
264
    ret.d[axes_[i]] = expr_builder.constant(end - start);
S
Shang Zhizhou 已提交
265
#endif
266 267 268 269 270 271
  }
  return ret;
}

bool SlicePluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
272
    int nb_outputs) TRT_NOEXCEPT {
273 274 275 276 277 278 279 280 281 282 283 284
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc &in = in_out[pos];
  if (pos == 0) {
285
    if (with_fp16_) {
286 287 288
      return (in.type == nvinfer1::DataType::kFLOAT ||
              in.type == nvinfer1::DataType::kHALF) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
289 290 291
    } else {
      return (in.type == nvinfer1::DataType::kFLOAT) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
292 293 294 295 296 297 298 299
    }
  }
  const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType SlicePluginDynamic::getOutputDataType(
300 301
    int index, const nvinfer1::DataType *input_types,
    int nb_inputs) const TRT_NOEXCEPT {
302 303 304 305 306 307 308 309 310 311 312 313 314 315
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The Slice Plugin only has one input, so the "
                                  "index value should be 0, but get %d.",
                                  index));
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
                     input_types[0] == nvinfer1::DataType::kHALF),
                    true, platform::errors::InvalidArgument(
                              "The input type should be half or float"));
  return input_types[0];
}

int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
                                const void *const *inputs, void *const *outputs,
316 317
                                void *workspace,
                                cudaStream_t stream) TRT_NOEXCEPT {
318 319 320 321 322 323 324 325 326
  auto input_dims = input_desc[0].dims;
  auto out_dims = output_desc[0].dims;
  auto num_dims = input_dims.nbDims;
  size_t out_num = ProductDim(out_dims);

  std::vector<int> seg_offsets;
  std::vector<int> offsets;
  std::vector<int> extends;

327 328 329
  offsets.resize(num_dims);
  extends.resize(num_dims);
  seg_offsets.resize(num_dims);
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350

  seg_offsets[num_dims - 1] = 1;
  for (int i = num_dims - 2; i >= 0; i--) {
    seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1];
  }

  for (size_t i = 0; i < num_dims; ++i) {
    offsets[i] = 0;
    extends[i] = out_dims.d[i];
  }
  for (size_t i = 0; i < axes_.size(); ++i) {
    offsets[axes_[i]] = starts_[i];
  }

  std::vector<int> offset_info;
  for (size_t i = 0; i < num_dims; ++i) {
    offset_info.push_back(offsets[i]);
    offset_info.push_back(extends[i]);
    offset_info.push_back(seg_offsets[i]);
  }

351 352 353
  if (offset_temp_data_ == nullptr) {
    cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
  }
354

355 356 357
  cudaMemcpyAsync(offset_temp_data_, offset_info.data(),
                  sizeof(int) * 3 * num_dims, cudaMemcpyHostToDevice,
                  copy_stream_);
358

359 360
  cudaEventRecord(copy_event_, copy_stream_);
  cudaStreamWaitEvent(stream, copy_event_, 0);
361 362 363 364 365

  int threads = 256;
  int blocks = (out_num + threads - 1) / threads;
  auto input_type = input_desc[0].type;
  if (input_type == nvinfer1::DataType::kFLOAT) {
366
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32";
367 368 369
    const float *input1 = static_cast<const float *>(inputs[0]);
    float *output = static_cast<float *>(outputs[0]);
    SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
370
        out_num, num_dims, input1, offset_temp_data_, output);
371
  } else if (input_type == nvinfer1::DataType::kHALF) {
372
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16";
373 374 375
    const half *input1 = static_cast<const half *>(inputs[0]);
    half *output = static_cast<half *>(outputs[0]);
    SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
376
        out_num, num_dims, input1, offset_temp_data_, output);
377 378 379 380 381 382 383 384 385 386 387 388
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "The Slice TRT Plugin's input type should be float or half."));
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle