special_slice_plugin.cu 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}

SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
                                                     size_t serial_length) {}

SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}

33 34
nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const
    TRT_NOEXCEPT {
35 36 37
  return new SpecialSlicePluginDynamic();
}

38
const char* SpecialSlicePluginDynamic::getPluginType() const TRT_NOEXCEPT {
39 40 41
  return "special_slice_plugin";
}

42
int SpecialSlicePluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
43

44
int SpecialSlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
45

46
size_t SpecialSlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
47 48 49 50
  size_t serialize_size = 0;
  return serialize_size;
}

51
void SpecialSlicePluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {}
52 53 54

nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
55
    nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
56
  nvinfer1::DimsExprs output(inputs[0]);
57 58 59 60
  output.nbDims++;
  for (int i = output.nbDims - 1; i > 1; i--) {
    output.d[i] = inputs[0].d[i - 1];
  }
61
  auto one = expr_builder.constant(1);
62
  output.d[1] = one;
63 64
  output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
                                       *inputs[1].d[0], *one);
65 66
  // remove padding 1
  output.nbDims -= 2;
67 68 69 70 71 72

  return output;
}

void SpecialSlicePluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
73
    const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
74 75 76

size_t SpecialSlicePluginDynamic::getWorkspaceSize(
    const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
77 78
    const nvinfer1::PluginTensorDesc* outputs,
    int nbOutputs) const TRT_NOEXCEPT {
79 80 81
  return 0;
}

82
void SpecialSlicePluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
83

84
void SpecialSlicePluginDynamic::terminate() TRT_NOEXCEPT {}
85 86 87

bool SpecialSlicePluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
88
    int nb_outputs) TRT_NOEXCEPT {
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  if (pos == 0)  // slice tensor
    return (desc[pos].type == nvinfer1::DataType::kHALF &&
            desc[pos].format ==
                nvinfer1::TensorFormat::kLINEAR);  // || desc[pos].type ==
  // nvinfer1::DataType::kFLOAT);

  if (pos == 1)  // cu_seqlen
    return (desc[pos].type == nvinfer1::DataType::kINT32 &&
            desc[pos].format == nvinfer1::TensorFormat::kLINEAR);

  return (desc[pos].type == nvinfer1::DataType::kHALF &&
          desc[pos].format ==
              nvinfer1::TensorFormat::kLINEAR);  // || desc[pos].type ==
  // nvinfer1::DataType::kFLOAT);
}

nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
106 107
    int index, const nvinfer1::DataType* input_types,
    int nb_inputs) const TRT_NOEXCEPT {
108 109 110 111 112 113 114 115
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The index should be equal to 0"));
  return input_types[0];
}

template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
                                   const int32_t* cu_seqlens, T* output) {
116
  const int hidden = blockDim.x * gridDim.y;
117
  const int batch = blockIdx.x;
118
  const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;
119

120 121
  output[batch * hidden + local_idx] =
      slice_input[cu_seqlens[batch] * hidden + local_idx];
122 123 124 125 126
}

int SpecialSlicePluginDynamic::enqueue(
    const nvinfer1::PluginTensorDesc* input_desc,
    const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
127
    void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
128 129
  auto input_dims = input_desc[0].dims;  // (sum(S), hidden, 1, 1)
  auto out_dims = output_desc[0].dims;   // (batch, hidden, 1, 1)
130

W
wenbin 已提交
131 132 133
  PADDLE_ENFORCE_EQ(
      input_desc[0].type, nvinfer1::DataType::kHALF,
      platform::errors::InvalidArgument("Type of input should be half."));
134 135

  const int32_t hidden = input_dims.d[1];
W
wenbin 已提交
136 137 138
  PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument(
                                         "hidden should be multiple of 128."));

139 140
  constexpr int num_threads = 128;
  const dim3 blocks(out_dims.d[0], hidden / num_threads);
141 142 143 144 145

  const half* slice_input = static_cast<const half*>(inputs[0]);
  const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
  half* output = static_cast<half*>(outputs[0]);

146 147
  SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
                                                         cu_seqlens, output);
148 149 150 151 152 153

  return cudaGetLastError() != cudaSuccess;
}

SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}

154 155
const char* SpecialSlicePluginDynamicCreator::getPluginName() const
    TRT_NOEXCEPT {
156 157 158
  return "special_slice_plugin";
}

159 160
const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const
    TRT_NOEXCEPT {
161 162 163 164
  return "1";
}

const nvinfer1::PluginFieldCollection*
165
SpecialSlicePluginDynamicCreator::getFieldNames() TRT_NOEXCEPT {
166 167 168 169
  return &field_collection_;
}

nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
170
    const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
171 172 173 174
  return new SpecialSlicePluginDynamic();
}

nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
175 176
    const char* name, const void* serial_data,
    size_t serial_length) TRT_NOEXCEPT {
177 178 179 180 181
  auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
  return plugin;
}

void SpecialSlicePluginDynamicCreator::setPluginNamespace(
182
    const char* lib_namespace) TRT_NOEXCEPT {
183 184 185
  plugin_namespace_ = lib_namespace;
}

186 187
const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const
    TRT_NOEXCEPT {
188 189 190 191 192 193 194 195 196
  return plugin_namespace_.c_str();
}

#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle