stack_op_plugin.cu 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

#if IS_TRT_VERSION_GE(6000)
26 27 28 29
StackPluginDynamic::StackPluginDynamic(int axis, int num_stack, bool with_fp16)
    : axis_(axis), num_stack_(num_stack) {
  with_fp16_ = with_fp16;
}
30 31 32 33 34

StackPluginDynamic::StackPluginDynamic(void const* serial_data,
                                       size_t serial_length) {
  DeserializeValue(&serial_data, &serial_length, &axis_);
  DeserializeValue(&serial_data, &serial_length, &num_stack_);
35
  DeserializeValue(&serial_data, &serial_length, &with_fp16_);
36 37 38 39
}

StackPluginDynamic::~StackPluginDynamic() {}

40
nvinfer1::IPluginV2DynamicExt* StackPluginDynamic::clone() const TRT_NOEXCEPT {
41
  return new StackPluginDynamic(axis_, num_stack_, with_fp16_);
42 43
}

44 45 46
const char* StackPluginDynamic::getPluginType() const TRT_NOEXCEPT {
  return "stack_plugin";
}
47

48
int StackPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
49

50
int StackPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
51

52
size_t StackPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
53 54 55
  size_t serialize_size = 0;
  serialize_size += SerializedSize(axis_);
  serialize_size += SerializedSize(num_stack_);
56
  serialize_size += SerializedSize(with_fp16_);
57 58 59
  return serialize_size;
}

60
void StackPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
61 62
  SerializeValue(&buffer, axis_);
  SerializeValue(&buffer, num_stack_);
63
  SerializeValue(&buffer, with_fp16_);
64 65 66 67
}

nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
68
    nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
69 70 71 72 73 74 75 76 77 78 79 80
  nvinfer1::DimsExprs output(inputs[0]);
  output.nbDims = inputs[0].nbDims + 1;

  for (int i = inputs[0].nbDims; i > axis_; --i) {
    output.d[i] = inputs[0].d[i - 1];
  }
  output.d[axis_] = expr_builder.constant(nb_inputs);
  return output;
}

void StackPluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
81
    const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
82 83 84

size_t StackPluginDynamic::getWorkspaceSize(
    const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
85 86
    const nvinfer1::PluginTensorDesc* outputs,
    int nbOutputs) const TRT_NOEXCEPT {
87 88 89
  return num_stack_ * sizeof(uintptr_t);
}

90
void StackPluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
91

92
void StackPluginDynamic::terminate() TRT_NOEXCEPT {}
93 94 95

bool StackPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
96
    int nb_outputs) TRT_NOEXCEPT {
97 98 99 100 101 102 103 104 105 106 107 108
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of stack plugin should not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc& in = in_out[pos];
  if (pos == 0) {
109 110 111 112 113 114 115 116
    if (with_fp16_) {
      return (in.type == nvinfer1::DataType::kFLOAT ||
              in.type == nvinfer1::DataType::kHALF) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    } else {
      return (in.type == nvinfer1::DataType::kFLOAT) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    }
117 118 119 120 121 122 123
  }
  const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType StackPluginDynamic::getOutputDataType(
124 125
    int index, const nvinfer1::DataType* input_types,
    int nb_inputs) const TRT_NOEXCEPT {
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The index should be equal to 0"));
  return input_types[0];
}

template <typename T>
__global__ void StackKernel(const T* const* input, T* output, int num_stack,
                            int base_unit) {
  int stack_id = blockIdx.x;
  int lead_id = blockIdx.y;

  for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
    output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
        input[stack_id][lead_id * base_unit + i];
  }
}

int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
                                const nvinfer1::PluginTensorDesc* output_desc,
                                const void* const* inputs, void* const* outputs,
146 147
                                void* workspace,
                                cudaStream_t stream) TRT_NOEXCEPT {
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
  auto input_dims = input_desc[0].dims;  // (batch, seq, seq)
  auto out_dims = output_desc[0].dims;   // (batch, num_head, seq, seq)
  auto out_num_dims = out_dims.nbDims;

  int base_unit = 1;
  for (int i = axis_ + 1; i < out_num_dims; ++i) {
    PADDLE_ENFORCE_GT(out_dims.d[i], 0,
                      platform::errors::InvalidArgument(
                          "Input dimensions should be greater than 0"));
    base_unit *= out_dims.d[i];
  }

  int lead_unit = 1;
  for (int i = 0; i < axis_; ++i) {
    PADDLE_ENFORCE_GT(out_dims.d[i], 0,
                      platform::errors::InvalidArgument(
                          "Input dimensions should be greater than 0"));
    lead_unit *= out_dims.d[i];
  }

  PADDLE_ENFORCE_EQ(
      out_dims.d[axis_], num_stack_,
      platform::errors::InvalidArgument("number of stack axis should be same"));

  cudaMemcpyAsync(workspace, reinterpret_cast<const void* const>(inputs),
                  sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
                  stream);

  const int num_stacks = out_dims.d[axis_];
  dim3 num_blocks(num_stacks, lead_unit);
  const int num_threads = 256;
  auto infer_type = input_desc[0].type;

  if (infer_type == nvinfer1::DataType::kFLOAT) {
182
    VLOG(1) << "TRT Plugin DataType selected. Stack-->fp32";
183 184 185 186 187
    float* output = static_cast<float*>(outputs[0]);
    StackKernel<float><<<num_blocks, num_threads, 0, stream>>>(
        reinterpret_cast<const float* const*>(workspace), output, num_stacks,
        base_unit);
  } else if (infer_type == nvinfer1::DataType::kHALF) {
188
    VLOG(1) << "TRT Plugin DataType selected. Stack-->fp16";
189 190 191 192 193 194 195 196 197 198 199 200 201 202
    __half* output = static_cast<__half*>(outputs[0]);
    StackKernel<__half><<<num_blocks, num_threads, 0, stream>>>(
        reinterpret_cast<const __half* const*>(workspace), output, num_stacks,
        base_unit);
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("The Stack TRT Plugin's input type only "
                                "support float or half currently."));
  }
  return cudaGetLastError() != cudaSuccess;
}

StackPluginDynamicCreator::StackPluginDynamicCreator() {}

203
const char* StackPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
204 205 206
  return "stack_plugin";
}

207 208 209
const char* StackPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT {
  return "1";
}
210 211

const nvinfer1::PluginFieldCollection*
212
StackPluginDynamicCreator::getFieldNames() TRT_NOEXCEPT {
213 214 215 216
  return &field_collection_;
}

nvinfer1::IPluginV2* StackPluginDynamicCreator::createPlugin(
217
    const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
218 219
  int axis = -1;
  int num_stack = -1;
220
  bool with_fp16 = false;
221 222 223 224 225 226 227

  for (int i = 0; i < fc->nbFields; ++i) {
    const std::string name(fc->fields[i].name);
    if (name == "axis") {
      axis = static_cast<const int*>(fc->fields[i].data)[0];
    } else if (name == "num_stack") {
      num_stack = static_cast<const int*>(fc->fields[i].data)[0];
228 229
    } else if (name == "with_fp16") {
      with_fp16 = static_cast<const bool*>(fc->fields[i].data)[0];
230 231 232 233 234 235
    } else {
      PADDLE_THROW(platform::errors::Fatal("Meet an unknown plugin field '" +
                                           name +
                                           "' when creating stack op plugin."));
    }
  }
236
  return new StackPluginDynamic(axis, num_stack, with_fp16);
237 238 239
}

nvinfer1::IPluginV2* StackPluginDynamicCreator::deserializePlugin(
240 241
    const char* name, const void* serial_data,
    size_t serial_length) TRT_NOEXCEPT {
242 243 244 245
  auto plugin = new StackPluginDynamic(serial_data, serial_length);
  return plugin;
}

246 247
void StackPluginDynamicCreator::setPluginNamespace(const char* lib_namespace)
    TRT_NOEXCEPT {
248 249 250
  plugin_namespace_ = lib_namespace;
}

251
const char* StackPluginDynamicCreator::getPluginNamespace() const TRT_NOEXCEPT {
252 253 254 255 256 257 258 259 260
  return plugin_namespace_.c_str();
}

#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle