special_slice_plugin.cu 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}

SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
                                                     size_t serial_length) {}

SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}

33 34
nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const
    TRT_NOEXCEPT {
35 36 37
  return new SpecialSlicePluginDynamic();
}

38
const char* SpecialSlicePluginDynamic::getPluginType() const TRT_NOEXCEPT {
39 40 41
  return "special_slice_plugin";
}

42
int SpecialSlicePluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
43

44
int SpecialSlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
45

46
size_t SpecialSlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
47 48 49 50
  size_t serialize_size = 0;
  return serialize_size;
}

51
void SpecialSlicePluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {}
52 53 54

nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
55
    nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
56
  nvinfer1::DimsExprs output(inputs[0]);
57 58 59 60
  output.nbDims++;
  for (int i = output.nbDims - 1; i > 1; i--) {
    output.d[i] = inputs[0].d[i - 1];
  }
61
  auto one = expr_builder.constant(1);
62
  output.d[1] = one;
63 64
  output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
                                       *inputs[1].d[0], *one);
65 66
  // remove padding 1
  output.nbDims -= 2;
67 68 69 70 71 72

  return output;
}

void SpecialSlicePluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
73
    const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
74 75 76

size_t SpecialSlicePluginDynamic::getWorkspaceSize(
    const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
77 78
    const nvinfer1::PluginTensorDesc* outputs,
    int nbOutputs) const TRT_NOEXCEPT {
79 80 81
  return 0;
}

82
void SpecialSlicePluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
83

84
void SpecialSlicePluginDynamic::terminate() TRT_NOEXCEPT {}
85 86 87

bool SpecialSlicePluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
88
    int nb_outputs) TRT_NOEXCEPT {
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  if (pos == 0)  // slice tensor
    return (desc[pos].type == nvinfer1::DataType::kHALF &&
            desc[pos].format ==
                nvinfer1::TensorFormat::kLINEAR);  // || desc[pos].type ==
  // nvinfer1::DataType::kFLOAT);

  if (pos == 1)  // cu_seqlen
    return (desc[pos].type == nvinfer1::DataType::kINT32 &&
            desc[pos].format == nvinfer1::TensorFormat::kLINEAR);

  return (desc[pos].type == nvinfer1::DataType::kHALF &&
          desc[pos].format ==
              nvinfer1::TensorFormat::kLINEAR);  // || desc[pos].type ==
  // nvinfer1::DataType::kFLOAT);
}

nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
106 107
    int index, const nvinfer1::DataType* input_types,
    int nb_inputs) const TRT_NOEXCEPT {
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The index should be equal to 0"));
  return input_types[0];
}

template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
                                   const int32_t* cu_seqlens, T* output) {
  const int hidden = blockDim.x;
  const int batch = blockIdx.x;

  output[batch * hidden + threadIdx.x] =
      slice_input[cu_seqlens[batch] * hidden + threadIdx.x];
}

int SpecialSlicePluginDynamic::enqueue(
    const nvinfer1::PluginTensorDesc* input_desc,
    const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
126
    void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  auto input_dims = input_desc[0].dims;  // (sum(S), 768, 1, 1)
  auto out_dims = output_desc[0].dims;   // (batch, 768, 1, 1)

  assert(input_desc[0].type == nvinfer1::DataType::kHALF);

  const int32_t hidden = input_dims.d[1];
  const int num_blocks = out_dims.d[0];  // batch size
  const int num_threads = hidden;

  const half* slice_input = static_cast<const half*>(inputs[0]);
  const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
  half* output = static_cast<half*>(outputs[0]);

  SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
      slice_input, cu_seqlens, output);

  return cudaGetLastError() != cudaSuccess;
}

SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}

148 149
const char* SpecialSlicePluginDynamicCreator::getPluginName() const
    TRT_NOEXCEPT {
150 151 152
  return "special_slice_plugin";
}

153 154
const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const
    TRT_NOEXCEPT {
155 156 157 158
  return "1";
}

const nvinfer1::PluginFieldCollection*
159
SpecialSlicePluginDynamicCreator::getFieldNames() TRT_NOEXCEPT {
160 161 162 163
  return &field_collection_;
}

nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
164
    const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
165 166 167 168
  return new SpecialSlicePluginDynamic();
}

nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
169 170
    const char* name, const void* serial_data,
    size_t serial_length) TRT_NOEXCEPT {
171 172 173 174 175
  auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
  return plugin;
}

void SpecialSlicePluginDynamicCreator::setPluginNamespace(
176
    const char* lib_namespace) TRT_NOEXCEPT {
177 178 179
  plugin_namespace_ = lib_namespace;
}

180 181
const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const
    TRT_NOEXCEPT {
182 183 184 185 186 187 188 189 190
  return plugin_namespace_.c_str();
}

#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle