split_op_plugin.cu 4.8 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

H
hjchen2 已提交
15 16
#include <cuda_fp16.h>
#include <algorithm>
N
nhzlx 已提交
17
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
N
nhzlx 已提交
18
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
N
nhzlx 已提交
19 20 21 22

namespace paddle {
namespace inference {
namespace tensorrt {
23
namespace plugin {
N
nhzlx 已提交
24

N
nhzlx 已提交
25 26 27 28 29
SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) {
  return new SplitPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize);

H
hjchen2 已提交
30
template <typename T>
31 32 33 34 35 36 37 38 39 40
__device__ int upper_bound(T const* vals, int n, T const& key) {
  int i = 0;
  while (n > 0) {
    int m = n / 2;
    int j = i + m;
    if (!(key < vals[j])) {
      i = j + 1;
      n -= m + 1;
    } else {
      n = m;
H
hjchen2 已提交
41 42
    }
  }
43
  return i;
H
hjchen2 已提交
44 45
}

46 47 48 49 50 51
nvinfer1::Dims SplitPlugin::getOutputDimensions(
    int index, const nvinfer1::Dims* input_dims, int num_inputs) {
  PADDLE_ENFORCE_EQ(num_inputs, 1);
  PADDLE_ENFORCE_LT(index, this->getNbOutputs());

  nvinfer1::Dims output_dims = input_dims[0];
52
  output_dims.d[axis_] = output_length_.at(index);
N
nhzlx 已提交
53 54 55 56
  return output_dims;
}

int SplitPlugin::initialize() {
57
  PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS);
H
hjchen2 已提交
58 59 60 61 62 63 64 65 66 67 68
  // notice input dims is [C, H, W]
  nvinfer1::Dims dims = this->getInputDims(0);
  outer_rows_ = 1;
  inner_cols_ = 1;
  for (int i = 0; i < axis_; ++i) {
    outer_rows_ *= dims.d[i];
  }
  for (int i = axis_ + 1; i < dims.nbDims; ++i) {
    inner_cols_ *= dims.d[i];
  }
  same_shape_ = true;
N
nhzlx 已提交
69 70
  std::vector<int> segment_offsets(1, 0);
  for (int i = 0; i < this->getNbOutputs(); ++i) {
H
hjchen2 已提交
71 72 73
    if (output_length_[i] != output_length_[0]) {
      same_shape_ = false;
    }
74
    segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
N
nhzlx 已提交
75
  }
76
  axis_shape_ = dims.d[axis_];
H
hjchen2 已提交
77 78 79 80 81 82
  d_segment_offsets_ = segment_offsets;
  segment_offsets_ = std::move(segment_offsets);
  d_output_ptrs_.resize(this->getNbOutputs(), nullptr);
  return 0;
}

83 84
// The following part of the code refers to onnx-tensorrt
// https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu
H
hjchen2 已提交
85
template <typename T>
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
__global__ void split_kernel(int nsegment,
                             int const* __restrict__ segment_offsets,
                             T const* __restrict__ idata, T* const* odatas,
                             int inner_cols, int axis_shape, int outer_rows) {
  int x0 = threadIdx.x + blockIdx.x * blockDim.x;
  int src_y0 = threadIdx.y + blockIdx.y * blockDim.y;
  int z0 = threadIdx.z + blockIdx.z * blockDim.z;
  for (int z = z0; z < outer_rows; z += blockDim.z * gridDim.z) {
    for (int src_y = src_y0; src_y < axis_shape;
         src_y += blockDim.y * gridDim.y) {
      for (int x = x0; x < inner_cols; x += blockDim.x * gridDim.x) {
        int segment = upper_bound(segment_offsets, nsegment, src_y) - 1;
        int dst_y = src_y - segment_offsets[segment];
        int dst_ny = segment_offsets[segment + 1] - segment_offsets[segment];
        odatas[segment][x + inner_cols * (dst_y + dst_ny * z)] =
            idata[x + inner_cols * (src_y + axis_shape * z)];
      }
    }
N
nhzlx 已提交
104 105 106 107 108
  }
}

int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
                         void** outputs, void* workspace, cudaStream_t stream) {
109 110
  const int* d_segment_offsets_ptr =
      thrust::raw_pointer_cast(&d_segment_offsets_[0]);
H
hjchen2 已提交
111
  float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
  float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
  PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, h_odatas,
                                 d_output_ptrs_.size() * sizeof(float*),
                                 cudaMemcpyHostToDevice,
                                 stream) == cudaSuccess);

  int outer_rows = outer_rows_ * batchSize;

  dim3 block(32, 16);
  dim3 grid(std::min((inner_cols_ - 1) / block.x + 1, 65535u),
            std::min((axis_shape_ - 1) / block.y + 1, 65535u),
            std::min((outer_rows_ - 1) / block.z + 1, 65535u));

  split_kernel<<<grid, block, 0, stream>>>(
      d_segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs,
      inner_cols_, axis_shape_, outer_rows);
N
nhzlx 已提交
129 130 131
  return cudaGetLastError() != cudaSuccess;
}

132 133 134 135
}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle