未验证 提交 f4d9212d 编写于 作者: W Wilber 提交者: GitHub

trt plugin upgrade to pluginv2ext (#31670)

上级 372ac08a
......@@ -101,7 +101,7 @@ class SplitOpConverter : public OpConverter {
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SplitPlugin* plugin =
new plugin::SplitPlugin(axis, output_lengths, with_fp16);
layer = engine_->AddPlugin(&input, input_num, plugin);
layer = engine_->AddPluginV2Ext(&input, input_num, plugin);
}
std::string layer_name = "split (Output: ";
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <string>
#include "cuda_runtime_api.h"
#include "cuda_runtime_api.h" // NOLINT
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
......@@ -353,6 +353,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
return network()->addPluginExt(inputs, num_inputs, *plugin);
}
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext(
nvinfer1::ITensor *const *inputs, int num_inputs,
plugin::PluginTensorRTV2Ext *plugin) {
owned_plugin_v2ext_.emplace_back(plugin);
return network()->addPluginV2(inputs, num_inputs, *plugin);
}
void TensorRTEngine::freshDeviceId() {
int count;
cudaGetDeviceCount(&count);
......
......@@ -305,8 +305,14 @@ class TensorRTEngine {
}
int GetDeviceId() { return device_id_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
nvinfer1::IPluginV2Layer* AddPluginV2Ext(nvinfer1::ITensor* const* inputs,
int num_inputs,
plugin::PluginTensorRTV2Ext* plugin);
void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) {
quant_dynamic_range_[tensor] = range;
}
......@@ -414,6 +420,7 @@ class TensorRTEngine {
itensor_map_;
std::vector<std::unique_ptr<plugin::PluginTensorRT>> owned_plugin_;
std::vector<std::unique_ptr<plugin::PluginTensorRTV2Ext>> owned_plugin_v2ext_;
// TensorRT related internal members
template <typename T>
......
......@@ -6,3 +6,6 @@ nv_library(tensorrt_plugin
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin)
......@@ -22,11 +22,6 @@ namespace inference {
namespace tensorrt {
namespace plugin {
SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) {
return new SplitPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize);
template <typename T>
__device__ int upper_bound(T const* vals, int n, T const& key) {
int i = 0;
......
......@@ -25,7 +25,7 @@ namespace inference {
namespace tensorrt {
namespace plugin {
class SplitPlugin : public PluginTensorRT {
class SplitPlugin : public PluginTensorRTV2Ext {
public:
SplitPlugin() {}
SplitPlugin(int axis, std::vector<int> const& output_lengths, bool with_fp16)
......@@ -39,13 +39,20 @@ class SplitPlugin : public PluginTensorRT {
DeserializeValue(&serial_data, &serial_length, &output_length_);
}
SplitPlugin* clone() const override {
auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
nvinfer1::IPluginV2Ext* clone() const override {
SplitPlugin* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
ptr->setPluginNamespace(this->getPluginNamespace());
ptr->shareData(this);
return ptr;
}
const char* getPluginType() const override { return "split_plugin"; }
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const override {
return input_types[0];
}
const char* getPluginType() const override { return "split_plugin_v2ext"; }
int getNbOutputs() const override { return output_length_.size(); }
nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims* input_dims,
......@@ -53,17 +60,18 @@ class SplitPlugin : public PluginTensorRT {
int initialize() override;
void terminate() override;
int enqueue(int batchSize, const void* const* inputs, void** outputs,
int enqueue(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
void destroy() override { delete this; }
protected:
size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + SerializedSize(axis_) +
SerializedSize(output_length_) + getBaseSerializationSize();
size_t getSerializationSize() const override {
return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize();
}
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_);
......@@ -83,6 +91,47 @@ class SplitPlugin : public PluginTensorRT {
void shareData(const SplitPlugin* another);
};
class SplitPluginCreator : public nvinfer1::IPluginCreator {
public:
SplitPluginCreator() {}
const char* getPluginName() const override { return "split_plugin_v2ext"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
// not implemented
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new SplitPlugin(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(SplitPluginCreator);
#if IS_TRT_VERSION_GE(6000)
class SplitPluginDynamic : public DynamicPluginTensorRT {
public:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
TEST(split_op_plugin, test_plugin) {
int axis = 1;
std::vector<int> output_lengths{1, 1};
bool with_fp16 = false;
std::vector<nvinfer1::DataType> input_types{nvinfer1::DataType::kFLOAT};
std::vector<nvinfer1::Dims> input_dims;
SplitPlugin sp_plugin(axis, output_lengths, with_fp16);
nvinfer1::Dims in_dims;
in_dims.nbDims = 4;
input_dims.push_back(in_dims);
sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2,
input_types.data(), nullptr, nullptr, nullptr,
nvinfer1::PluginFormat::kNCHW, 4);
sp_plugin.initialize();
sp_plugin.getPluginType();
sp_plugin.canBroadcastInputAcrossBatch(0);
sp_plugin.getNbOutputs();
auto clone_plugin = sp_plugin.clone();
clone_plugin->setPluginNamespace("test");
clone_plugin->destroy();
sp_plugin.getOutputDataType(0, input_types.data(), 1);
sp_plugin.terminate();
}
TEST(split_op_plugin, test_plugin_creater) {
SplitPluginCreator creator;
creator.getFieldNames();
creator.createPlugin("test", nullptr);
creator.setPluginNamespace("test");
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -19,27 +19,50 @@ namespace inference {
namespace tensorrt {
namespace plugin {
inline void Seria(void*& buffer, // NOLINT
const std::vector<nvinfer1::Dims>& input_dims,
size_t max_batch_size, nvinfer1::DataType data_type,
nvinfer1::PluginFormat data_format, bool with_fp16) {
SerializeValue(&buffer, input_dims);
SerializeValue(&buffer, max_batch_size);
SerializeValue(&buffer, data_type);
SerializeValue(&buffer, data_format);
SerializeValue(&buffer, with_fp16);
}
inline void Deseria(void const*& serial_data, size_t& serial_length, // NOLINT
std::vector<nvinfer1::Dims>* input_dims,
size_t* max_batch_size, nvinfer1::DataType* data_type,
nvinfer1::PluginFormat* data_format, bool* with_fp16) {
DeserializeValue(&serial_data, &serial_length, input_dims);
DeserializeValue(&serial_data, &serial_length, max_batch_size);
DeserializeValue(&serial_data, &serial_length, data_type);
DeserializeValue(&serial_data, &serial_length, data_format);
DeserializeValue(&serial_data, &serial_length, with_fp16);
}
inline size_t SeriaSize(const std::vector<nvinfer1::Dims>& input_dims,
size_t max_batch_size, nvinfer1::DataType data_type,
nvinfer1::PluginFormat data_format, bool with_fp16) {
return (SerializedSize(input_dims) + SerializedSize(max_batch_size) +
SerializedSize(data_type) + SerializedSize(data_format) +
SerializedSize(with_fp16));
}
void PluginTensorRT::serializeBase(void*& buffer) {
SerializeValue(&buffer, input_dims_);
SerializeValue(&buffer, max_batch_size_);
SerializeValue(&buffer, data_type_);
SerializeValue(&buffer, data_format_);
SerializeValue(&buffer, with_fp16_);
Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
}
void PluginTensorRT::deserializeBase(void const*& serial_data,
size_t& serial_length) {
DeserializeValue(&serial_data, &serial_length, &input_dims_);
DeserializeValue(&serial_data, &serial_length, &max_batch_size_);
DeserializeValue(&serial_data, &serial_length, &data_type_);
DeserializeValue(&serial_data, &serial_length, &data_format_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_,
&data_type_, &data_format_, &with_fp16_);
}
size_t PluginTensorRT::getBaseSerializationSize() {
return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
SerializedSize(data_type_) + SerializedSize(data_format_) +
SerializedSize(with_fp16_));
return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
}
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
......@@ -58,6 +81,35 @@ void PluginTensorRT::configureWithFormat(
max_batch_size_ = max_batch_size;
}
void PluginTensorRTV2Ext::serializeBase(void*& buffer) const {
Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
}
void PluginTensorRTV2Ext::deserializeBase(void const*& serial_data,
size_t& serial_length) {
Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_,
&data_type_, &data_format_, &with_fp16_);
}
size_t PluginTensorRTV2Ext::getBaseSerializationSize() const {
return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_,
with_fp16_);
}
void PluginTensorRTV2Ext::configurePlugin(
const nvinfer1::Dims* input_dims, int32_t nb_inputs,
const nvinfer1::Dims* output_dims, int32_t nb_outputs,
const nvinfer1::DataType* input_types,
const nvinfer1::DataType* output_types, const bool* input_is_broadcast,
const bool* output_is_broadcast, nvinfer1::PluginFormat float_format,
int32_t max_batch_size) {
input_dims_.assign(input_dims, input_dims + nb_inputs);
max_batch_size_ = max_batch_size;
data_format_ = float_format;
data_type_ = input_types[0];
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -44,6 +44,7 @@ typedef std::function<PluginTensorRT*(const void*, size_t)>
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
// Deprecated. Do not inherit this class, please refer to PluginTensorRTV2Ext
class PluginTensorRT : public nvinfer1::IPluginExt {
public:
PluginTensorRT() : with_fp16_(false) {}
......@@ -119,6 +120,114 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
bool with_fp16_;
};
// TensorRT introduced IPluginV2Ext after 5.1, Paddle no longer supports
// versions before 5.1
class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
public:
PluginTensorRTV2Ext() : with_fp16_(false) {}
PluginTensorRTV2Ext(const void* serialized_data, size_t length) {}
nvinfer1::Dims const& getInputDims(int index) const {
return input_dims_.at(index);
}
size_t getMaxBatchSize() const { return max_batch_size_; }
nvinfer1::DataType getDataType() const { return data_type_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
// The Func in IPluginV2Ext
virtual nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int32_t output_index,
const bool* input_is_broadcasted,
int32_t nb_inputs) const {
return false;
}
virtual bool canBroadcastInputAcrossBatch(int32_t input_index) const {
return false;
}
void configurePlugin(const nvinfer1::Dims* input_dims, int32_t nb_inputs,
const nvinfer1::Dims* output_dims, int32_t nb_outputs,
const nvinfer1::DataType* input_types,
const nvinfer1::DataType* output_types,
const bool* input_is_broadcast,
const bool* output_is_broadcast,
nvinfer1::PluginFormat float_format,
int32_t max_batch_size) override;
virtual IPluginV2Ext* clone() const = 0;
void attachToContext(cudnnContext*, cublasContext*,
nvinfer1::IGpuAllocator*) override {}
void detachFromContext() override {}
// The Func in IPluginV2
virtual const char* getPluginType() const = 0;
const char* getPluginVersion() const override { return "1"; }
virtual int32_t getNbOutputs() const { return 1; }
virtual nvinfer1::Dims getOutputDimensions(int32_t index,
const nvinfer1::Dims* inputs,
int32_t nb_input) = 0;
// Check format support. The default is FLOAT32 and NCHW.
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
}
// Initialize the layer for execution.
// This is called when the engine is created.
int initialize() override { return 0; }
// Shutdown the layer. This is called when the engine is destroyed
void terminate() override {}
// Find the workspace size required by the layer
size_t getWorkspaceSize(int) const override { return 0; }
// Execute the layer
virtual int enqueue(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
// Find the size of the serialization buffer required
virtual size_t getSerializationSize() const = 0;
// Serialize the layer config to buffer.
// TensorRT will call this func to serialize the configuration of TensorRT
// engine. It should not be called by users.
virtual void serialize(void* buffer) const = 0;
virtual void destroy() = 0;
void setPluginNamespace(const char* plugin_namespace) override {
name_space_ = plugin_namespace;
}
const char* getPluginNamespace() const override {
return name_space_.c_str();
}
protected:
void deserializeBase(void const*& serial_data, // NOLINT
size_t& serial_length); // NOLINT
size_t getBaseSerializationSize() const;
void serializeBase(void*& buffer) const; // NOLINT
protected:
std::vector<nvinfer1::Dims> input_dims_;
size_t max_batch_size_;
nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_;
std::vector<nvinfer1::ITensor*> inputs_;
bool with_fp16_;
private:
std::string name_space_;
};
#if IS_TRT_VERSION_GE(6000)
class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
public:
......@@ -184,6 +293,7 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
std::string name_space_;
std::string plugin_base_;
};
#endif
template <typename T>
class TrtPluginRegistrarV2 {
......@@ -203,8 +313,6 @@ class TrtPluginRegistrarV2 {
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -336,6 +336,17 @@ if '${WITH_XPU_BKCL}' == 'ON':
shutil.copy('${XPU_BKCL_LIB}', libs_path)
package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}']
# Only for lite xpu inference.
if '${WITH_XPU}' == 'OFF' and '${XPU_SDK_ROOT}' != '':
xpu_api_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/shlib/', 'libxpuapi.so')
xpu_rt_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/runtime/shlib/', 'libxpurt.so')
if os.path.exists(xpu_api_lib):
shutil.copy(xpu_api_lib, libs_path)
package_data['paddle.libs']+=['libxpuapi.so']
if os.path.exists(xpu_rt_lib):
shutil.copy(xpu_rt_lib, libs_path)
package_data['paddle.libs']+=['libxpurt.so']
### Old custom op extension mechanism related, will be removed in 2.1.0 ###
# copy libpaddle_framework.so to libs on linux
if sys.platform.startswith('linux'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册