stack_op_plugin.cu 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

#if IS_TRT_VERSION_GE(6000)
26 27 28 29
StackPluginDynamic::StackPluginDynamic(int axis, int num_stack, bool with_fp16)
    : axis_(axis), num_stack_(num_stack) {
  with_fp16_ = with_fp16;
}
30 31 32 33 34

StackPluginDynamic::StackPluginDynamic(void const* serial_data,
                                       size_t serial_length) {
  DeserializeValue(&serial_data, &serial_length, &axis_);
  DeserializeValue(&serial_data, &serial_length, &num_stack_);
35
  DeserializeValue(&serial_data, &serial_length, &with_fp16_);
36 37 38 39 40
}

StackPluginDynamic::~StackPluginDynamic() {}

nvinfer1::IPluginV2DynamicExt* StackPluginDynamic::clone() const {
41
  return new StackPluginDynamic(axis_, num_stack_, with_fp16_);
42 43 44 45 46 47 48 49 50 51 52 53
}

const char* StackPluginDynamic::getPluginType() const { return "stack_plugin"; }

int StackPluginDynamic::getNbOutputs() const { return 1; }

int StackPluginDynamic::initialize() { return 0; }

size_t StackPluginDynamic::getSerializationSize() const {
  size_t serialize_size = 0;
  serialize_size += SerializedSize(axis_);
  serialize_size += SerializedSize(num_stack_);
54
  serialize_size += SerializedSize(with_fp16_);
55 56 57 58 59 60
  return serialize_size;
}

void StackPluginDynamic::serialize(void* buffer) const {
  SerializeValue(&buffer, axis_);
  SerializeValue(&buffer, num_stack_);
61
  SerializeValue(&buffer, with_fp16_);
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
}

nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
    nvinfer1::IExprBuilder& expr_builder) {
  nvinfer1::DimsExprs output(inputs[0]);
  output.nbDims = inputs[0].nbDims + 1;

  for (int i = inputs[0].nbDims; i > axis_; --i) {
    output.d[i] = inputs[0].d[i - 1];
  }
  output.d[axis_] = expr_builder.constant(nb_inputs);
  return output;
}

void StackPluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
    const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}

size_t StackPluginDynamic::getWorkspaceSize(
    const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
    const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
  return num_stack_ * sizeof(uintptr_t);
}

void StackPluginDynamic::destroy() { delete this; }

void StackPluginDynamic::terminate() {}

bool StackPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
    int nb_outputs) {
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of stack plugin should not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc& in = in_out[pos];
  if (pos == 0) {
106 107 108 109 110 111 112 113
    if (with_fp16_) {
      return (in.type == nvinfer1::DataType::kFLOAT ||
              in.type == nvinfer1::DataType::kHALF) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    } else {
      return (in.type == nvinfer1::DataType::kFLOAT) &&
             (in.format == nvinfer1::TensorFormat::kLINEAR);
    }
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
  }
  const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType StackPluginDynamic::getOutputDataType(
    int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The index should be equal to 0"));
  return input_types[0];
}

template <typename T>
__global__ void StackKernel(const T* const* input, T* output, int num_stack,
                            int base_unit) {
  int stack_id = blockIdx.x;
  int lead_id = blockIdx.y;

  for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
    output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
        input[stack_id][lead_id * base_unit + i];
  }
}

int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
                                const nvinfer1::PluginTensorDesc* output_desc,
                                const void* const* inputs, void* const* outputs,
                                void* workspace, cudaStream_t stream) {
  auto input_dims = input_desc[0].dims;  // (batch, seq, seq)
  auto out_dims = output_desc[0].dims;   // (batch, num_head, seq, seq)
  auto out_num_dims = out_dims.nbDims;

  int base_unit = 1;
  for (int i = axis_ + 1; i < out_num_dims; ++i) {
    PADDLE_ENFORCE_GT(out_dims.d[i], 0,
                      platform::errors::InvalidArgument(
                          "Input dimensions should be greater than 0"));
    base_unit *= out_dims.d[i];
  }

  int lead_unit = 1;
  for (int i = 0; i < axis_; ++i) {
    PADDLE_ENFORCE_GT(out_dims.d[i], 0,
                      platform::errors::InvalidArgument(
                          "Input dimensions should be greater than 0"));
    lead_unit *= out_dims.d[i];
  }

  PADDLE_ENFORCE_EQ(
      out_dims.d[axis_], num_stack_,
      platform::errors::InvalidArgument("number of stack axis should be same"));

  cudaMemcpyAsync(workspace, reinterpret_cast<const void* const>(inputs),
                  sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
                  stream);

  const int num_stacks = out_dims.d[axis_];
  dim3 num_blocks(num_stacks, lead_unit);
  const int num_threads = 256;
  auto infer_type = input_desc[0].type;

  if (infer_type == nvinfer1::DataType::kFLOAT) {
177
    VLOG(1) << "TRT Plugin DataType selected. Stack-->fp32";
178 179 180 181 182
    float* output = static_cast<float*>(outputs[0]);
    StackKernel<float><<<num_blocks, num_threads, 0, stream>>>(
        reinterpret_cast<const float* const*>(workspace), output, num_stacks,
        base_unit);
  } else if (infer_type == nvinfer1::DataType::kHALF) {
183
    VLOG(1) << "TRT Plugin DataType selected. Stack-->fp16";
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    __half* output = static_cast<__half*>(outputs[0]);
    StackKernel<__half><<<num_blocks, num_threads, 0, stream>>>(
        reinterpret_cast<const __half* const*>(workspace), output, num_stacks,
        base_unit);
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("The Stack TRT Plugin's input type only "
                                "support float or half currently."));
  }
  return cudaGetLastError() != cudaSuccess;
}

StackPluginDynamicCreator::StackPluginDynamicCreator() {}

const char* StackPluginDynamicCreator::getPluginName() const {
  return "stack_plugin";
}

const char* StackPluginDynamicCreator::getPluginVersion() const { return "1"; }

const nvinfer1::PluginFieldCollection*
StackPluginDynamicCreator::getFieldNames() {
  return &field_collection_;
}

nvinfer1::IPluginV2* StackPluginDynamicCreator::createPlugin(
    const char* name, const nvinfer1::PluginFieldCollection* fc) {
  int axis = -1;
  int num_stack = -1;
213
  bool with_fp16 = false;
214 215 216 217 218 219 220

  for (int i = 0; i < fc->nbFields; ++i) {
    const std::string name(fc->fields[i].name);
    if (name == "axis") {
      axis = static_cast<const int*>(fc->fields[i].data)[0];
    } else if (name == "num_stack") {
      num_stack = static_cast<const int*>(fc->fields[i].data)[0];
221 222
    } else if (name == "with_fp16") {
      with_fp16 = static_cast<const bool*>(fc->fields[i].data)[0];
223 224 225 226 227 228
    } else {
      PADDLE_THROW(platform::errors::Fatal("Meet an unknown plugin field '" +
                                           name +
                                           "' when creating stack op plugin."));
    }
  }
229
  return new StackPluginDynamic(axis, num_stack, with_fp16);
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
}

nvinfer1::IPluginV2* StackPluginDynamicCreator::deserializePlugin(
    const char* name, const void* serial_data, size_t serial_length) {
  auto plugin = new StackPluginDynamic(serial_data, serial_length);
  return plugin;
}

void StackPluginDynamicCreator::setPluginNamespace(const char* lib_namespace) {
  plugin_namespace_ = lib_namespace;
}

const char* StackPluginDynamicCreator::getPluginNamespace() const {
  return plugin_namespace_.c_str();
}

#endif

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle