diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc index 768c6efaa6bd40529a509698e186fa66c2e8e711..5d494c2093b2a969d521c778b582b3b9f51dd259 100644 --- a/paddle/fluid/inference/tensorrt/convert/split_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -101,7 +101,7 @@ class SplitOpConverter : public OpConverter { engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths, with_fp16); - layer = engine_->AddPlugin(&input, input_num, plugin); + layer = engine_->AddPluginV2Ext(&input, input_num, plugin); } std::string layer_name = "split (Output: "; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 0bba4581ff90f931eba9399cb3b0b274342f4f16..99549fd6b5cbf96cf803e7f44b28c948daf0763d 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include #include -#include "cuda_runtime_api.h" +#include "cuda_runtime_api.h" // NOLINT #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" @@ -353,6 +353,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( return network()->addPluginExt(inputs, num_inputs, *plugin); } +nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext( + nvinfer1::ITensor *const *inputs, int num_inputs, + plugin::PluginTensorRTV2Ext *plugin) { + owned_plugin_v2ext_.emplace_back(plugin); + return network()->addPluginV2(inputs, num_inputs, *plugin); +} + void TensorRTEngine::freshDeviceId() { int count; cudaGetDeviceCount(&count); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 0e399578fa446793756a23e76013c3ed9a8bb9c4..de2924824f09de45727ecc30ec52904f90e6adb3 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -305,8 +305,14 @@ class TensorRTEngine { } int GetDeviceId() { return device_id_; } + nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, int num_inputs, plugin::PluginTensorRT*); + + nvinfer1::IPluginV2Layer* AddPluginV2Ext(nvinfer1::ITensor* const* inputs, + int num_inputs, + plugin::PluginTensorRTV2Ext* plugin); + void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) { quant_dynamic_range_[tensor] = range; } @@ -414,6 +420,7 @@ class TensorRTEngine { itensor_map_; std::vector> owned_plugin_; + std::vector> owned_plugin_v2ext_; // TensorRT related internal members template diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index e37beb3b8e5c3680eda481009699091dcc1ee7a3..7ee16a598d2d018cc56d301bae826024784cb51a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -6,3 +6,6 @@ nv_library(tensorrt_plugin qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) + +nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS + paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin) diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index 256aa28206ad1c21dc3245a6e78f7cdc59b29156..1b5c39f8fff855fac4ef8f2ee54faa872023ad05 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -22,11 +22,6 @@ namespace inference { namespace tensorrt { namespace plugin { -SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) { - return new SplitPlugin(buffer, length); -} -REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize); - template __device__ int upper_bound(T const* vals, int n, T const& key) { int i = 0; diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index 5c47ec3a990f584fd02b3515dbc642ffcd921709..e43b57357fb64f9c1528c1bf73b098ed4aaed8f2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -25,7 +25,7 @@ namespace inference { namespace tensorrt { namespace plugin { -class SplitPlugin : public PluginTensorRT { +class SplitPlugin : public PluginTensorRTV2Ext { public: SplitPlugin() {} SplitPlugin(int axis, std::vector const& output_lengths, bool with_fp16) @@ -39,13 +39,20 @@ class SplitPlugin : public PluginTensorRT { DeserializeValue(&serial_data, &serial_length, &output_length_); } - SplitPlugin* clone() const override { - auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_); + nvinfer1::IPluginV2Ext* clone() const override { + SplitPlugin* ptr = new SplitPlugin(axis_, output_length_, with_fp16_); + ptr->setPluginNamespace(this->getPluginNamespace()); ptr->shareData(this); return ptr; } - const char* getPluginType() const override { return "split_plugin"; } + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const override { + return input_types[0]; + } + + const char* getPluginType() const override { return "split_plugin_v2ext"; } int getNbOutputs() const override { return output_length_.size(); } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* input_dims, @@ -53,17 +60,18 @@ class SplitPlugin : public PluginTensorRT { int initialize() override; void terminate() override; - int enqueue(int batchSize, const void* const* inputs, void** outputs, + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override; + void destroy() override { delete this; } + protected: - size_t getSerializationSize() override { - return SerializedSize(getPluginType()) + SerializedSize(axis_) + - SerializedSize(output_length_) + getBaseSerializationSize(); + size_t getSerializationSize() const override { + return SerializedSize(axis_) + SerializedSize(output_length_) + + getBaseSerializationSize(); } - void serialize(void* buffer) override { - SerializeValue(&buffer, getPluginType()); + void serialize(void* buffer) const override { serializeBase(buffer); SerializeValue(&buffer, axis_); SerializeValue(&buffer, output_length_); @@ -83,6 +91,47 @@ class SplitPlugin : public PluginTensorRT { void shareData(const SplitPlugin* another); }; +class SplitPluginCreator : public nvinfer1::IPluginCreator { + public: + SplitPluginCreator() {} + const char* getPluginName() const override { return "split_plugin_v2ext"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + // not implemented + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new SplitPlugin(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; + +REGISTER_TRT_PLUGIN_V2(SplitPluginCreator); + #if IS_TRT_VERSION_GE(6000) class SplitPluginDynamic : public DynamicPluginTensorRT { public: diff --git a/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc new file mode 100644 index 0000000000000000000000000000000000000000..6636513a555f9e638e1dfdb54986010c76785e2a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +TEST(split_op_plugin, test_plugin) { + int axis = 1; + std::vector output_lengths{1, 1}; + bool with_fp16 = false; + std::vector input_types{nvinfer1::DataType::kFLOAT}; + std::vector input_dims; + + SplitPlugin sp_plugin(axis, output_lengths, with_fp16); + nvinfer1::Dims in_dims; + in_dims.nbDims = 4; + input_dims.push_back(in_dims); + sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2, + input_types.data(), nullptr, nullptr, nullptr, + nvinfer1::PluginFormat::kNCHW, 4); + sp_plugin.initialize(); + sp_plugin.getPluginType(); + sp_plugin.canBroadcastInputAcrossBatch(0); + sp_plugin.getNbOutputs(); + auto clone_plugin = sp_plugin.clone(); + clone_plugin->setPluginNamespace("test"); + clone_plugin->destroy(); + sp_plugin.getOutputDataType(0, input_types.data(), 1); + sp_plugin.terminate(); +} + +TEST(split_op_plugin, test_plugin_creater) { + SplitPluginCreator creator; + creator.getFieldNames(); + creator.createPlugin("test", nullptr); + creator.setPluginNamespace("test"); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc index fd721b161450d7a8d4660ca09ea3a1093d754664..55bc786746beafcf7b2df98d54e9391e6a59ba24 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -19,27 +19,50 @@ namespace inference { namespace tensorrt { namespace plugin { +inline void Seria(void*& buffer, // NOLINT + const std::vector& input_dims, + size_t max_batch_size, nvinfer1::DataType data_type, + nvinfer1::PluginFormat data_format, bool with_fp16) { + SerializeValue(&buffer, input_dims); + SerializeValue(&buffer, max_batch_size); + SerializeValue(&buffer, data_type); + SerializeValue(&buffer, data_format); + SerializeValue(&buffer, with_fp16); +} + +inline void Deseria(void const*& serial_data, size_t& serial_length, // NOLINT + std::vector* input_dims, + size_t* max_batch_size, nvinfer1::DataType* data_type, + nvinfer1::PluginFormat* data_format, bool* with_fp16) { + DeserializeValue(&serial_data, &serial_length, input_dims); + DeserializeValue(&serial_data, &serial_length, max_batch_size); + DeserializeValue(&serial_data, &serial_length, data_type); + DeserializeValue(&serial_data, &serial_length, data_format); + DeserializeValue(&serial_data, &serial_length, with_fp16); +} + +inline size_t SeriaSize(const std::vector& input_dims, + size_t max_batch_size, nvinfer1::DataType data_type, + nvinfer1::PluginFormat data_format, bool with_fp16) { + return (SerializedSize(input_dims) + SerializedSize(max_batch_size) + + SerializedSize(data_type) + SerializedSize(data_format) + + SerializedSize(with_fp16)); +} + void PluginTensorRT::serializeBase(void*& buffer) { - SerializeValue(&buffer, input_dims_); - SerializeValue(&buffer, max_batch_size_); - SerializeValue(&buffer, data_type_); - SerializeValue(&buffer, data_format_); - SerializeValue(&buffer, with_fp16_); + Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_, + with_fp16_); } void PluginTensorRT::deserializeBase(void const*& serial_data, size_t& serial_length) { - DeserializeValue(&serial_data, &serial_length, &input_dims_); - DeserializeValue(&serial_data, &serial_length, &max_batch_size_); - DeserializeValue(&serial_data, &serial_length, &data_type_); - DeserializeValue(&serial_data, &serial_length, &data_format_); - DeserializeValue(&serial_data, &serial_length, &with_fp16_); + Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_, + &data_type_, &data_format_, &with_fp16_); } size_t PluginTensorRT::getBaseSerializationSize() { - return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) + - SerializedSize(data_type_) + SerializedSize(data_format_) + - SerializedSize(with_fp16_)); + return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_, + with_fp16_); } bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, @@ -58,6 +81,35 @@ void PluginTensorRT::configureWithFormat( max_batch_size_ = max_batch_size; } +void PluginTensorRTV2Ext::serializeBase(void*& buffer) const { + Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_, + with_fp16_); +} + +void PluginTensorRTV2Ext::deserializeBase(void const*& serial_data, + size_t& serial_length) { + Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_, + &data_type_, &data_format_, &with_fp16_); +} + +size_t PluginTensorRTV2Ext::getBaseSerializationSize() const { + return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_, + with_fp16_); +} + +void PluginTensorRTV2Ext::configurePlugin( + const nvinfer1::Dims* input_dims, int32_t nb_inputs, + const nvinfer1::Dims* output_dims, int32_t nb_outputs, + const nvinfer1::DataType* input_types, + const nvinfer1::DataType* output_types, const bool* input_is_broadcast, + const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, + int32_t max_batch_size) { + input_dims_.assign(input_dims, input_dims + nb_inputs); + max_batch_size_ = max_batch_size; + data_format_ = float_format; + data_type_ = input_types[0]; +} + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index b3a3abe5d01fc53e2ef3da7722df0e372d605af4..ce3133ae99e94c62c0c8e958065700373d270037 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -44,6 +44,7 @@ typedef std::function typedef std::function PluginConstructFunc; +// Deprecated. Do not inherit this class, please refer to PluginTensorRTV2Ext class PluginTensorRT : public nvinfer1::IPluginExt { public: PluginTensorRT() : with_fp16_(false) {} @@ -119,6 +120,114 @@ class PluginTensorRT : public nvinfer1::IPluginExt { bool with_fp16_; }; +// TensorRT introduced IPluginV2Ext after 5.1, Paddle no longer supports +// versions before 5.1 +class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext { + public: + PluginTensorRTV2Ext() : with_fp16_(false) {} + PluginTensorRTV2Ext(const void* serialized_data, size_t length) {} + + nvinfer1::Dims const& getInputDims(int index) const { + return input_dims_.at(index); + } + size_t getMaxBatchSize() const { return max_batch_size_; } + nvinfer1::DataType getDataType() const { return data_type_; } + nvinfer1::PluginFormat getDataFormat() const { return data_format_; } + + // The Func in IPluginV2Ext + virtual nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const = 0; + + virtual bool isOutputBroadcastAcrossBatch(int32_t output_index, + const bool* input_is_broadcasted, + int32_t nb_inputs) const { + return false; + } + + virtual bool canBroadcastInputAcrossBatch(int32_t input_index) const { + return false; + } + + void configurePlugin(const nvinfer1::Dims* input_dims, int32_t nb_inputs, + const nvinfer1::Dims* output_dims, int32_t nb_outputs, + const nvinfer1::DataType* input_types, + const nvinfer1::DataType* output_types, + const bool* input_is_broadcast, + const bool* output_is_broadcast, + nvinfer1::PluginFormat float_format, + int32_t max_batch_size) override; + + virtual IPluginV2Ext* clone() const = 0; + + void attachToContext(cudnnContext*, cublasContext*, + nvinfer1::IGpuAllocator*) override {} + + void detachFromContext() override {} + + // The Func in IPluginV2 + virtual const char* getPluginType() const = 0; + const char* getPluginVersion() const override { return "1"; } + virtual int32_t getNbOutputs() const { return 1; } + virtual nvinfer1::Dims getOutputDimensions(int32_t index, + const nvinfer1::Dims* inputs, + int32_t nb_input) = 0; + // Check format support. The default is FLOAT32 and NCHW. + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const override { + return ((type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::PluginFormat::kNCHW)); + } + // Initialize the layer for execution. + // This is called when the engine is created. + int initialize() override { return 0; } + + // Shutdown the layer. This is called when the engine is destroyed + void terminate() override {} + + // Find the workspace size required by the layer + size_t getWorkspaceSize(int) const override { return 0; } + + // Execute the layer + virtual int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) = 0; + + // Find the size of the serialization buffer required + virtual size_t getSerializationSize() const = 0; + + // Serialize the layer config to buffer. + // TensorRT will call this func to serialize the configuration of TensorRT + // engine. It should not be called by users. + virtual void serialize(void* buffer) const = 0; + + virtual void destroy() = 0; + + void setPluginNamespace(const char* plugin_namespace) override { + name_space_ = plugin_namespace; + } + + const char* getPluginNamespace() const override { + return name_space_.c_str(); + } + + protected: + void deserializeBase(void const*& serial_data, // NOLINT + size_t& serial_length); // NOLINT + size_t getBaseSerializationSize() const; + void serializeBase(void*& buffer) const; // NOLINT + + protected: + std::vector input_dims_; + size_t max_batch_size_; + nvinfer1::DataType data_type_; + nvinfer1::PluginFormat data_format_; + std::vector inputs_; + bool with_fp16_; + + private: + std::string name_space_; +}; + #if IS_TRT_VERSION_GE(6000) class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { public: @@ -184,6 +293,7 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { std::string name_space_; std::string plugin_base_; }; +#endif template class TrtPluginRegistrarV2 { @@ -203,8 +313,6 @@ class TrtPluginRegistrarV2 { static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2 \ plugin_registrar_##name {} -#endif - } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/python/setup.py.in b/python/setup.py.in index 64cfe6e9ccff74aa95f1517c9c666a6c1ac0953d..69a8bc771aefb024872a4185e5723d553745e3d3 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -336,6 +336,17 @@ if '${WITH_XPU_BKCL}' == 'ON': shutil.copy('${XPU_BKCL_LIB}', libs_path) package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}'] +# Only for lite xpu inference. +if '${WITH_XPU}' == 'OFF' and '${XPU_SDK_ROOT}' != '': + xpu_api_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/shlib/', 'libxpuapi.so') + xpu_rt_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/runtime/shlib/', 'libxpurt.so') + if os.path.exists(xpu_api_lib): + shutil.copy(xpu_api_lib, libs_path) + package_data['paddle.libs']+=['libxpuapi.so'] + if os.path.exists(xpu_rt_lib): + shutil.copy(xpu_rt_lib, libs_path) + package_data['paddle.libs']+=['libxpurt.so'] + ### Old custom op extension mechanism related, will be removed in 2.1.0 ### # copy libpaddle_framework.so to libs on linux if sys.platform.startswith('linux'):