slice_op_plugin.cu 15.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_runtime.h>
#include <stdio.h>
17

18 19 20
#include <cassert>
#include <cub/cub.cuh>  // NOLINT
#include <vector>
21

22 23 24 25 26 27 28 29 30
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

template <typename T>
31 32
__global__ void SliceKernel(
    int num, int dims, const T *input, const int *offsets_info, T *output) {
33 34 35
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  extern __shared__ int shared_data[];

36 37
  for (int i = threadIdx.x; i < dims * 3; i += blockDim.x) {
    shared_data[i] = offsets_info[i];
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
  }
  __syncthreads();

  if (idx < num) {
    int t_idx = idx;
    int in_idx = 0;
    for (int i = dims - 1; i >= 0; i--) {
      // output_shape
      auto t = t_idx % shared_data[i * 3 + 1];
      // out offset
      auto s = t + shared_data[i * 3];
      // input_seg_offset
      in_idx = in_idx + shared_data[i * 3 + 2] * s;
      t_idx = t_idx / shared_data[i * 3 + 1];
    }
    output[idx] = input[in_idx];
  }
}

57 58 59 60
SlicePlugin::SlicePlugin(std::vector<int> starts,
                         std::vector<int> ends,
                         std::vector<int> axes,
                         bool with_fp16)
61 62
    : starts_(starts), ends_(ends), axes_(axes) {
  with_fp16_ = with_fp16;
63 64 65 66 67 68 69
}

SlicePlugin::SlicePlugin(void const *serial_data, size_t serial_length) {
  deserializeBase(serial_data, serial_length);
  DeserializeValue(&serial_data, &serial_length, &starts_);
  DeserializeValue(&serial_data, &serial_length, &ends_);
  DeserializeValue(&serial_data, &serial_length, &axes_);
W
wenbin 已提交
70
  DeserializeValue(&serial_data, &serial_length, &with_fp16_);
W
Wilber 已提交
71
  DeserializeValue(&serial_data, &serial_length, &offset_info_);
72 73
}

W
Wilber 已提交
74
SlicePlugin::~SlicePlugin() { cudaFree(offset_temp_data_); }
75

76
SlicePlugin *SlicePlugin::clone() const TRT_NOEXCEPT {
77
  return new SlicePlugin(starts_, ends_, axes_, with_fp16_);
78 79
}

80 81
bool SlicePlugin::supportsFormat(
    nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
82
  if (with_fp16_) {
83
    return ((type == nvinfer1::DataType::kFLOAT ||
84 85
             type == nvinfer1::DataType::kHALF ||
             type == nvinfer1::DataType::kINT32) &&
86
            (format == nvinfer1::PluginFormat::kLINEAR));
87
  } else {
88 89
    return ((type == nvinfer1::DataType::kFLOAT ||
             type == nvinfer1::DataType::kINT32) &&
90
            (format == nvinfer1::PluginFormat::kLINEAR));
91
  }
92 93
}

94 95
nvinfer1::Dims SlicePlugin::getOutputDimensions(
    int index, const nvinfer1::Dims *inputs, int nb_input_dims) TRT_NOEXCEPT {
96 97 98 99 100 101 102 103 104 105
  auto in_dims = inputs[0];
  nvinfer1::Dims out_dims = in_dims;
  for (size_t i = 0; i < axes_.size(); i++) {
    int start = starts_[i];
    int end = ends_[i];
    out_dims.d[axes_[i] - 1] = end - start;
  }
  return out_dims;
}

106 107
int SlicePlugin::enqueue(int batch_size,
                         const void *const *inputs,
108
#if IS_TRT_VERSION_LT(8000)
109 110 111
                         void **outputs,
                         void *workspace,
                         cudaStream_t stream) {
112
#else
113 114
                         void *const *outputs,
                         void *workspace,
115
                         cudaStream_t stream) TRT_NOEXCEPT {
116
#endif
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
  auto input_dims = getInputDims(0);

  // notice input dims is [C, H, W], add input batch dim here
  auto out_dims = getOutputDimensions(0, &input_dims, 1);
  input_dims.nbDims += 1;
  out_dims.nbDims += 1;
  for (auto i = input_dims.nbDims; i > 0; --i) {
    input_dims.d[i] = input_dims.d[i - 1];
    out_dims.d[i] = out_dims.d[i - 1];
  }
  input_dims.d[0] = batch_size;
  out_dims.d[0] = batch_size;

  auto num_dims = input_dims.nbDims;
  size_t out_num = ProductDim(out_dims);

  std::vector<int> seg_offsets;
  std::vector<int> offsets;
  std::vector<int> extends;

  offsets.resize(num_dims);
  extends.resize(num_dims);
  seg_offsets.resize(num_dims);

  seg_offsets[num_dims - 1] = 1;
  for (int i = num_dims - 2; i >= 0; i--) {
    seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1];
  }
  for (size_t i = 0; i < num_dims; ++i) {
    offsets[i] = 0;
    extends[i] = out_dims.d[i];
  }
  for (size_t i = 0; i < axes_.size(); ++i) {
    offsets[axes_[i]] = starts_[i];
  }

  std::vector<int> offset_info;
  for (size_t i = 0; i < num_dims; ++i) {
    offset_info.push_back(offsets[i]);
    offset_info.push_back(extends[i]);
    offset_info.push_back(seg_offsets[i]);
  }

  if (offset_temp_data_ == nullptr) {
    cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
  }

164 165 166 167 168
  cudaMemcpyAsync(offset_temp_data_,
                  offset_info.data(),
                  sizeof(int) * 3 * num_dims,
                  cudaMemcpyHostToDevice,
                  stream);
169 170 171 172 173

  int threads = 256;
  int blocks = (out_num + threads - 1) / threads;
  auto input_type = getDataType();
  if (input_type == nvinfer1::DataType::kFLOAT) {
174
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32";
175 176 177 178 179
    const float *input1 = static_cast<const float *>(inputs[0]);
    float *output = static_cast<float *>(outputs[0]);
    SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
        out_num, num_dims, input1, offset_temp_data_, output);
  } else if (input_type == nvinfer1::DataType::kHALF) {
180
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16";
181 182 183 184
    const half *input1 = static_cast<const half *>(inputs[0]);
    half *output = static_cast<half *>(outputs[0]);
    SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
        out_num, num_dims, input1, offset_temp_data_, output);
185 186 187 188 189 190
  } else if (input_type == nvinfer1::DataType::kINT32) {
    VLOG(1) << "TRT Plugin DataType selected. Slice-->int32";
    const int *input1 = static_cast<const int *>(inputs[0]);
    int *output = static_cast<int *>(outputs[0]);
    SliceKernel<int><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
        out_num, num_dims, input1, offset_temp_data_, output);
191 192
  } else {
    PADDLE_THROW(platform::errors::Fatal(
193
        "The Slice TRT Plugin's input type should be float, half or int."));
194 195 196 197
  }
  return cudaGetLastError() != cudaSuccess;
}

198
size_t SlicePlugin::getSerializationSize() const TRT_NOEXCEPT {
W
wenbin 已提交
199 200
  return getBaseSerializationSize() + SerializedSize(starts_) +
         SerializedSize(ends_) + SerializedSize(axes_) +
W
Wilber 已提交
201
         SerializedSize(with_fp16_) + SerializedSize(offset_info_);
202 203
}

204
void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
205 206 207 208
  serializeBase(buffer);
  SerializeValue(&buffer, starts_);
  SerializeValue(&buffer, ends_);
  SerializeValue(&buffer, axes_);
W
wenbin 已提交
209
  SerializeValue(&buffer, with_fp16_);
W
Wilber 已提交
210
  SerializeValue(&buffer, offset_info_);
211 212 213 214 215 216
}

// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
                                       std::vector<int> ends,
217 218
                                       std::vector<int> axes,
                                       int decrease_axis,
219 220
                                       bool with_fp16)
    : starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
221
  with_fp16_ = with_fp16;
222 223 224 225 226 227 228
}

SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
                                       size_t serialLength) {
  DeserializeValue(&serialData, &serialLength, &starts_);
  DeserializeValue(&serialData, &serialLength, &ends_);
  DeserializeValue(&serialData, &serialLength, &axes_);
229
  DeserializeValue(&serialData, &serialLength, &decrease_axis_);
230
  DeserializeValue(&serialData, &serialLength, &with_fp16_);
W
Wilber 已提交
231
  DeserializeValue(&serialData, &serialLength, &offset_info_);
232 233
}

234
void SlicePluginDynamic::destroy() TRT_NOEXCEPT {
235 236 237 238
  cudaFree(offset_temp_data_);
  delete this;
}

239
int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
240

241
size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
242
  size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
243
                SerializedSize(axes_) + SerializedSize(decrease_axis_) +
W
Wilber 已提交
244
                SerializedSize(with_fp16_) + SerializedSize(offset_info_);
245

246 247 248
  return size;
}

249
void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
250 251 252
  SerializeValue(&buffer, starts_);
  SerializeValue(&buffer, ends_);
  SerializeValue(&buffer, axes_);
253
  SerializeValue(&buffer, decrease_axis_);
254
  SerializeValue(&buffer, with_fp16_);
W
Wilber 已提交
255
  SerializeValue(&buffer, offset_info_);
256
}
257 258

nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
259 260 261
    int output_index,
    const nvinfer1::DimsExprs *inputs,
    int nb_inputs,
262
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
263
  auto in_dims = inputs[0];
264
  nvinfer1::DimsExprs ret = in_dims;
265 266 267 268
  // start, ends should greater 0
  for (size_t i = 0; i < axes_.size(); i++) {
    int start = starts_[i];
    int end = ends_[i];
S
Shang Zhizhou 已提交
269 270 271 272 273 274 275 276
#if IS_TRT_VERSION_GE(7200)
    ret.d[axes_[i]] = expr_builder.operation(
        nvinfer1::DimensionOperation::kSUB,
        *expr_builder.operation(nvinfer1::DimensionOperation::kMIN,
                                *expr_builder.constant(ends_[i]),
                                *in_dims.d[axes_[i]]),
        *expr_builder.constant(start));
#else
277
    ret.d[axes_[i]] = expr_builder.constant(end - start);
S
Shang Zhizhou 已提交
278
#endif
279
  }
280 281 282 283 284 285 286
  if (decrease_axis_ != -1) {
    nvinfer1::DimsExprs res;
    res.nbDims = ret.nbDims - 1;
    int j = 0;
    for (size_t i = 0; i < in_dims.nbDims; i++) {
      if (decrease_axis_ == i) continue;
      res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX,
287 288
                                          *expr_builder.constant(0),
                                          *ret.d[i]);
289 290 291
    }
    return res;
  }
292 293 294 295
  return ret;
}

bool SlicePluginDynamic::supportsFormatCombination(
296 297 298
    int pos,
    const nvinfer1::PluginTensorDesc *in_out,
    int nb_inputs,
299
    int nb_outputs) TRT_NOEXCEPT {
300
  PADDLE_ENFORCE_NOT_NULL(
301 302 303
      in_out,
      platform::errors::InvalidArgument(
          "The input of swish plugin shoule not be nullptr."));
304 305

  PADDLE_ENFORCE_LT(
306 307
      pos,
      nb_inputs + nb_outputs,
308 309
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
310 311
                                        pos,
                                        nb_inputs + nb_outputs));
312 313 314

  const nvinfer1::PluginTensorDesc &in = in_out[pos];
  if (pos == 0) {
315
    if (with_fp16_) {
316
      return (in.type == nvinfer1::DataType::kFLOAT ||
317 318
              in.type == nvinfer1::DataType::kHALF ||
              in.type == nvinfer1::DataType::kINT32) &&
319
             (in.format == nvinfer1::TensorFormat::kLINEAR);
320
    } else {
321 322
      return (in.type == nvinfer1::DataType::kFLOAT ||
              in.type == nvinfer1::DataType::kINT32) &&
323
             (in.format == nvinfer1::TensorFormat::kLINEAR);
324 325 326 327 328 329 330 331
    }
  }
  const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType SlicePluginDynamic::getOutputDataType(
332 333
    int index,
    const nvinfer1::DataType *input_types,
334
    int nb_inputs) const TRT_NOEXCEPT {
335 336
  PADDLE_ENFORCE_EQ(index,
                    0,
337 338 339 340
                    platform::errors::InvalidArgument(
                        "The Slice Plugin only has one input, so the "
                        "index value should be 0, but get %d.",
                        index));
341
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
342 343
                     input_types[0] == nvinfer1::DataType::kHALF ||
                     input_types[0] == nvinfer1::DataType::kINT32),
344 345
                    true,
                    platform::errors::InvalidArgument(
346
                        "The input type should be half, float or int"));
347 348 349 350 351
  return input_types[0];
}

int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
352 353
                                const void *const *inputs,
                                void *const *outputs,
354 355
                                void *workspace,
                                cudaStream_t stream) TRT_NOEXCEPT {
356 357
  auto input_dims = input_desc[0].dims;
  auto out_dims = output_desc[0].dims;
358 359 360 361
  if (decrease_axis_ != -1) {
    out_dims = input_dims;
    out_dims.d[decrease_axis_] = 1;
  }
362 363 364 365 366 367 368
  auto num_dims = input_dims.nbDims;
  size_t out_num = ProductDim(out_dims);

  std::vector<int> seg_offsets;
  std::vector<int> offsets;
  std::vector<int> extends;

369 370 371
  offsets.resize(num_dims);
  extends.resize(num_dims);
  seg_offsets.resize(num_dims);
372 373 374 375 376 377 378 379 380 381 382 383 384 385

  seg_offsets[num_dims - 1] = 1;
  for (int i = num_dims - 2; i >= 0; i--) {
    seg_offsets[i] = input_dims.d[i + 1] * seg_offsets[i + 1];
  }

  for (size_t i = 0; i < num_dims; ++i) {
    offsets[i] = 0;
    extends[i] = out_dims.d[i];
  }
  for (size_t i = 0; i < axes_.size(); ++i) {
    offsets[axes_[i]] = starts_[i];
  }

W
Wilber 已提交
386
  offset_info_.resize(num_dims * 3);
387
  for (size_t i = 0; i < num_dims; ++i) {
W
Wilber 已提交
388 389 390
    offset_info_[i * 3 + 0] = offsets[i];
    offset_info_[i * 3 + 1] = extends[i];
    offset_info_[i * 3 + 2] = seg_offsets[i];
391 392
  }

393 394 395
  if (offset_temp_data_ == nullptr) {
    cudaMalloc(&offset_temp_data_, 3 * num_dims * sizeof(int));
  }
396

397 398 399 400 401
  cudaMemcpyAsync(offset_temp_data_,
                  offset_info_.data(),
                  sizeof(int) * 3 * num_dims,
                  cudaMemcpyHostToDevice,
                  stream);
402 403 404 405 406

  int threads = 256;
  int blocks = (out_num + threads - 1) / threads;
  auto input_type = input_desc[0].type;
  if (input_type == nvinfer1::DataType::kFLOAT) {
407
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp32";
408 409 410
    const float *input1 = static_cast<const float *>(inputs[0]);
    float *output = static_cast<float *>(outputs[0]);
    SliceKernel<float><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
411
        out_num, num_dims, input1, offset_temp_data_, output);
412
  } else if (input_type == nvinfer1::DataType::kHALF) {
413
    VLOG(1) << "TRT Plugin DataType selected. Slice-->fp16";
414 415 416
    const half *input1 = static_cast<const half *>(inputs[0]);
    half *output = static_cast<half *>(outputs[0]);
    SliceKernel<half><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
417
        out_num, num_dims, input1, offset_temp_data_, output);
418 419 420 421 422 423
  } else if (input_type == nvinfer1::DataType::kINT32) {
    VLOG(1) << "TRT Plugin DataType selected. Slice-->int32";
    const int *input1 = static_cast<const int *>(inputs[0]);
    int *output = static_cast<int *>(outputs[0]);
    SliceKernel<int><<<blocks, threads, 3 * num_dims * sizeof(int), stream>>>(
        out_num, num_dims, input1, offset_temp_data_, output);
424 425
  } else {
    PADDLE_THROW(platform::errors::Fatal(
426
        "The Slice TRT Plugin's input type should be float, half or int."));
427 428 429 430 431 432 433 434 435
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle