gelu_op_plugin.h 5.7 KB
Newer Older
P
Pei Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

class GeluPlugin : public PluginTensorRT {
28
 public:
29
  explicit GeluPlugin(const bool with_fp16) { with_fp16_ = with_fp16; }
30 31 32

  // It was used for tensorrt deserialization.
  // It should not be called by users.
P
Pei Yang 已提交
33 34
  GeluPlugin(void const* serial_data, size_t serial_length) {
    deserializeBase(serial_data, serial_length);
35 36 37
  }

  ~GeluPlugin() {}
38
  GeluPlugin* clone() const override { return new GeluPlugin(with_fp16_); }
39 40 41 42 43 44 45

  const char* getPluginType() const override { return "gelu_plugin"; }
  int getNbOutputs() const override { return 1; }
  int initialize() override { return 0; }
  bool supportsFormat(nvinfer1::DataType type,
                      nvinfer1::PluginFormat format) const override;
  nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
P
Pei Yang 已提交
46
                                     int nb_input_dims) override;
47
#if IS_TRT_VERSION_LT(8000)
P
Pei Yang 已提交
48
  int enqueue(int batch_size, const void* const* inputs, void** outputs,
49 50 51
#else
  int enqueue(int batch_size, const void* const* inputs, void* const* outputs,
#endif
52 53
              void* workspace, cudaStream_t stream) override;

P
Pei Yang 已提交
54 55
 protected:
  size_t getSerializationSize() override {
56
    return getBaseSerializationSize() + SerializedSize(getPluginType());
P
Pei Yang 已提交
57 58 59 60
  }

  // TRT will call this func  to serialize the configuration of TRT
  // It should not be called by users.
61
  void serialize(void* buffer) override {
P
Pei Yang 已提交
62 63 64
    SerializeValue(&buffer, getPluginType());
    serializeBase(buffer);
  }
65
};
P
Pei Yang 已提交
66

67 68
#if IS_TRT_VERSION_GE(6000)
class GeluPluginDynamic : public DynamicPluginTensorRT {
P
Pei Yang 已提交
69
 public:
70 71 72 73
  explicit GeluPluginDynamic(const bool with_fp16) { with_fp16_ = with_fp16; }
  GeluPluginDynamic(void const* serial_data, size_t serial_length) {
    DeserializeValue(&serial_data, &serial_length, &with_fp16_);
  }
P
Pei Yang 已提交
74

75 76
  ~GeluPluginDynamic() {}
  nvinfer1::IPluginV2DynamicExt* clone() const override {
77
    return new GeluPluginDynamic(with_fp16_);
P
Pei Yang 已提交
78 79
  }

80 81
  const char* getPluginType() const override { return "gelu_plugin"; }
  int getNbOutputs() const override { return 1; }
P
Pei Yang 已提交
82 83
  int initialize() override { return 0; }

84 85 86 87 88 89
  size_t getSerializationSize() const override {
    return SerializedSize(with_fp16_);
  }
  void serialize(void* buffer) const override {
    SerializeValue(&buffer, with_fp16_);
  }
P
Pei Yang 已提交
90

91
  nvinfer1::DimsExprs getOutputDimensions(
P
Pei Yang 已提交
92 93
      int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
      nvinfer1::IExprBuilder& expr_builder) override;
94 95

  bool supportsFormatCombination(int pos,
P
Pei Yang 已提交
96 97
                                 const nvinfer1::PluginTensorDesc* in_out,
                                 int nb_inputs, int nb_outputs) override;
98 99

  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
P
Pei Yang 已提交
100
                       int nb_inputs,
101
                       const nvinfer1::DynamicPluginTensorDesc* out,
P
Pei Yang 已提交
102
                       int nb_outputs) override {}
103 104

  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
P
Pei Yang 已提交
105
                          int nb_inputs,
106
                          const nvinfer1::PluginTensorDesc* outputs,
P
Pei Yang 已提交
107
                          int nb_outputs) const override {
108 109 110
    return 0;
  }

P
Pei Yang 已提交
111 112
  int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
              const nvinfer1::PluginTensorDesc* output_desc,
113 114 115
              const void* const* inputs, void* const* outputs, void* workspace,
              cudaStream_t stream) override;
  nvinfer1::DataType getOutputDataType(int index,
P
Pei Yang 已提交
116 117
                                       const nvinfer1::DataType* input_types,
                                       int nb_inputs) const override;
P
Pei Yang 已提交
118

119
  void destroy() override { delete this; }
P
Pei Yang 已提交
120
};
P
Pei Yang 已提交
121

W
Wilber 已提交
122
class GeluPluginDynamicCreator : public nvinfer1::IPluginCreator {
P
Pei Yang 已提交
123
 public:
W
Wilber 已提交
124
  GeluPluginDynamicCreator() {}
P
Pei Yang 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  const char* getPluginName() const override { return "gelu_plugin"; }

  const char* getPluginVersion() const override { return "1"; }

  const nvinfer1::PluginFieldCollection* getFieldNames() override {
    return &field_collection_;
  }

  nvinfer1::IPluginV2* createPlugin(
      const char* name, const nvinfer1::PluginFieldCollection* fc) override {
    return nullptr;
  }

  nvinfer1::IPluginV2* deserializePlugin(const char* name,
                                         const void* serial_data,
                                         size_t serial_length) override {
    auto plugin = new GeluPluginDynamic(serial_data, serial_length);
    return plugin;
  }

  void setPluginNamespace(const char* lib_namespace) override {
    plugin_namespace_ = lib_namespace;
  }

  const char* getPluginNamespace() const override {
    return plugin_namespace_.c_str();
  }

 private:
  std::string plugin_namespace_;
  std::string plugin_name_;
  nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
  std::vector<nvinfer1::PluginField> plugin_attributes_;
};

W
Wilber 已提交
160
REGISTER_TRT_PLUGIN_V2(GeluPluginDynamicCreator);
161
#endif
P
Pei Yang 已提交
162 163 164 165 166

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle