special_slice_plugin.cu 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}

SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
                                                     size_t serial_length) {}

SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}

nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const {
  return new SpecialSlicePluginDynamic();
}

const char* SpecialSlicePluginDynamic::getPluginType() const {
  return "special_slice_plugin";
}

int SpecialSlicePluginDynamic::getNbOutputs() const { return 1; }

int SpecialSlicePluginDynamic::initialize() { return 0; }

size_t SpecialSlicePluginDynamic::getSerializationSize() const {
  size_t serialize_size = 0;
  return serialize_size;
}

void SpecialSlicePluginDynamic::serialize(void* buffer) const {}

nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
    nvinfer1::IExprBuilder& expr_builder) {
  nvinfer1::DimsExprs output(inputs[0]);
56 57 58 59
  output.nbDims++;
  for (int i = output.nbDims - 1; i > 1; i--) {
    output.d[i] = inputs[0].d[i - 1];
  }
60
  auto one = expr_builder.constant(1);
61
  output.d[1] = one;
62 63
  output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
                                       *inputs[1].d[0], *one);
64 65
  // remove padding 1
  output.nbDims -= 2;
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183

  return output;
}

void SpecialSlicePluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
    const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}

size_t SpecialSlicePluginDynamic::getWorkspaceSize(
    const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
    const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
  return 0;
}

void SpecialSlicePluginDynamic::destroy() { delete this; }

void SpecialSlicePluginDynamic::terminate() {}

bool SpecialSlicePluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
    int nb_outputs) {
  if (pos == 0)  // slice tensor
    return (desc[pos].type == nvinfer1::DataType::kHALF &&
            desc[pos].format ==
                nvinfer1::TensorFormat::kLINEAR);  // || desc[pos].type ==
  // nvinfer1::DataType::kFLOAT);

  if (pos == 1)  // cu_seqlen
    return (desc[pos].type == nvinfer1::DataType::kINT32 &&
            desc[pos].format == nvinfer1::TensorFormat::kLINEAR);

  return (desc[pos].type == nvinfer1::DataType::kHALF &&
          desc[pos].format ==
              nvinfer1::TensorFormat::kLINEAR);  // || desc[pos].type ==
  // nvinfer1::DataType::kFLOAT);
}

nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
    int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The index should be equal to 0"));
  return input_types[0];
}

template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
                                   const int32_t* cu_seqlens, T* output) {
  const int hidden = blockDim.x;
  const int batch = blockIdx.x;

  output[batch * hidden + threadIdx.x] =
      slice_input[cu_seqlens[batch] * hidden + threadIdx.x];
}

int SpecialSlicePluginDynamic::enqueue(
    const nvinfer1::PluginTensorDesc* input_desc,
    const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
    void* const* outputs, void* workspace, cudaStream_t stream) {
  auto input_dims = input_desc[0].dims;  // (sum(S), 768, 1, 1)
  auto out_dims = output_desc[0].dims;   // (batch, 768, 1, 1)

  assert(input_desc[0].type == nvinfer1::DataType::kHALF);

  const int32_t hidden = input_dims.d[1];
  const int num_blocks = out_dims.d[0];  // batch size
  const int num_threads = hidden;

  const half* slice_input = static_cast<const half*>(inputs[0]);
  const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
  half* output = static_cast<half*>(outputs[0]);

  SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
      slice_input, cu_seqlens, output);

  return cudaGetLastError() != cudaSuccess;
}

SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}

const char* SpecialSlicePluginDynamicCreator::getPluginName() const {
  return "special_slice_plugin";
}

const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const {
  return "1";
}

const nvinfer1::PluginFieldCollection*
SpecialSlicePluginDynamicCreator::getFieldNames() {
  return &field_collection_;
}

nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
    const char* name, const nvinfer1::PluginFieldCollection* fc) {
  return new SpecialSlicePluginDynamic();
}

nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
    const char* name, const void* serial_data, size_t serial_length) {
  auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
  return plugin;
}

void SpecialSlicePluginDynamicCreator::setPluginNamespace(
    const char* lib_namespace) {
  plugin_namespace_ = lib_namespace;
}

const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const {
  return plugin_namespace_.c_str();
}

#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle