diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc index 60d07859f3ab0a1eefe8bcc67e692c7e28ddae07..12179cccc76f8b0f595f41c135290dc0f3b50ad7 100644 --- a/paddle/fluid/inference/tensorrt/convert/split_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter { int input_num = op_desc.Input("X").size(); size_t output_num = op_desc.Output("Out").size(); + // Get Attrs PADDLE_ENFORCE(input_num == 1); int axis = boost::get(op_desc.GetAttr("axis")); std::vector output_lengths = @@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter { PADDLE_ENFORCE(output_lengths.size() == output_num); + // SplitPlugin* plugin = new SplitPlugin(axis, output_lengths); nvinfer1::IPluginLayer* layer = - engine_->addPlugin(&input, input_num, plugin); + engine_->AddPlugin(&input, input_num, plugin); std::string layer_name = "split (Output: "; for (size_t i = 0; i < output_num; i++) { diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 426bf169bbfd036e554a3695636866dcb761cb83..0e06a8f8041e415dcead5858bfa2262bc7b36c63 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() { cudaSetDevice(device_); } -nvinfer1::IPluginLayer *TensorRTEngine::addPlugin( +nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) { owned_plugin_.emplace_back(plugin); return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 216606a29118654f4bde72b7df21b6d1ab036128..335acdf653e55cc7f3ceccdba88992851c8e0310 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase { void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); int GetDevice() { return device_; } - nvinfer1::IPluginLayer* addPlugin(nvinfer1::ITensor* const* inputs, + nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, int nbInputs, PluginTensorRT*); // A pointer to CPU memory is needed of the TRT weight. diff --git a/paddle/fluid/inference/tensorrt/plugin/serialize.h b/paddle/fluid/inference/tensorrt/plugin/serialize.h index 96df352feb5b1a85b0ff7adebb7baf5f30c115e6..50c0b17d78327e22b0aa81fdac6958e80a30dfe8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/serialize.h +++ b/paddle/fluid/inference/tensorrt/plugin/serialize.h @@ -20,11 +20,11 @@ #include template -inline void serialize_value(void** buffer, T const& value); +inline void SerializeValue(void** buffer, T const& value); template -inline void deserialize_value(void const** buffer, size_t* buffer_size, - T* value); +inline void DeserializeValue(void const** buffer, size_t* buffer_size, + T* value); namespace { @@ -35,14 +35,14 @@ template struct Serializer::value || std::is_enum::value || std::is_pod::value>::type> { - static size_t serialized_size(T const& value) { return sizeof(T); } - static void serialize(void** buffer, T const& value) { - ::memcpy(*buffer, &value, sizeof(T)); + static size_t SerializedSize(T const& value) { return sizeof(T); } + static void Serialize(void** buffer, T const& value) { + std::memcpy(*buffer, &value, sizeof(T)); reinterpret_cast(*buffer) += sizeof(T); } - static void deserialize(void const** buffer, size_t* buffer_size, T* value) { + static void Deserialize(void const** buffer, size_t* buffer_size, T* value) { assert(*buffer_size >= sizeof(T)); - ::memcpy(value, *buffer, sizeof(T)); + std::memcpy(value, *buffer, sizeof(T)); reinterpret_cast(*buffer) += sizeof(T); *buffer_size -= sizeof(T); } @@ -50,12 +50,12 @@ struct Serializer::value || template <> struct Serializer { - static size_t serialized_size(const char* value) { return strlen(value) + 1; } - static void serialize(void** buffer, const char* value) { - ::strcpy(static_cast(*buffer), value); + static size_t SerializedSize(const char* value) { return strlen(value) + 1; } + static void Serialize(void** buffer, const char* value) { + std::strcpy(static_cast(*buffer), value); reinterpret_cast(*buffer) += strlen(value) + 1; } - static void deserialize(void const** buffer, size_t* buffer_size, + static void Deserialize(void const** buffer, size_t* buffer_size, const char** value) { *value = static_cast(*buffer); size_t data_size = strnlen(*value, *buffer_size) + 1; @@ -70,23 +70,23 @@ struct Serializer, typename std::enable_if::value || std::is_enum::value || std::is_pod::value>::type> { - static size_t serialized_size(std::vector const& value) { + static size_t SerializedSize(std::vector const& value) { return sizeof(value.size()) + value.size() * sizeof(T); } - static void serialize(void** buffer, std::vector const& value) { - serialize_value(buffer, value.size()); + static void Serialize(void** buffer, std::vector const& value) { + SerializeValue(buffer, value.size()); size_t nbyte = value.size() * sizeof(T); - ::memcpy(*buffer, value.data(), nbyte); + std::memcpy(*buffer, value.data(), nbyte); reinterpret_cast(*buffer) += nbyte; } - static void deserialize(void const** buffer, size_t* buffer_size, + static void Deserialize(void const** buffer, size_t* buffer_size, std::vector* value) { size_t size; - deserialize_value(buffer, buffer_size, &size); + DeserializeValue(buffer, buffer_size, &size); value->resize(size); size_t nbyte = value->size() * sizeof(T); assert(*buffer_size >= nbyte); - ::memcpy(value->data(), *buffer, nbyte); + std::memcpy(value->data(), *buffer, nbyte); reinterpret_cast(*buffer) += nbyte; *buffer_size -= nbyte; } @@ -95,17 +95,17 @@ struct Serializer, } // namespace template -inline size_t serialized_size(T const& value) { - return Serializer::serialized_size(value); +inline size_t SerializedSize(T const& value) { + return Serializer::SerializedSize(value); } template -inline void serialize_value(void** buffer, T const& value) { - return Serializer::serialize(buffer, value); +inline void SerializeValue(void** buffer, T const& value) { + return Serializer::Serialize(buffer, value); } template -inline void deserialize_value(void const** buffer, size_t* buffer_size, - T* value) { - return Serializer::deserialize(buffer, buffer_size, value); +inline void DeserializeValue(void const** buffer, size_t* buffer_size, + T* value) { + return Serializer::Deserialize(buffer, buffer_size, value); } diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index ed43c4d43547be5793b0e5979924ca67813185a9..bd6a44dcc14d50cddb879763a93abf4297494ec9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -37,7 +37,6 @@ int SplitPlugin::initialize() { segment_offsets.push_back(segment_offsets.back() + output_length_[i]); } segment_offsets_ = segment_offsets; - d_segment_offsets_ = segment_offsets; nvinfer1::Dims dims = this->getInputDims(0); nx_ = 1; for (int i = dims.nbDims - 1; i > axis_; --i) { @@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) { auto const& input_dims = this->getInputDims(0); int input_size = 0; - int const* d_segment_offsets_ptr = - thrust::raw_pointer_cast(&d_segment_offsets_[0]); float const* idata = reinterpret_cast(inputs[0]); float** odatas = reinterpret_cast(outputs); diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index 59be609111e173268d470b87ae02c2ee90121c0c..7281e40c331550de472df49c57b1d9a5226842d5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -14,7 +14,6 @@ #pragma once -#include #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" namespace paddle { @@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT { int axis_; std::vector output_length_; int nx_, ny_, nz_; - thrust::device_vector d_segment_offsets_; std::vector segment_offsets_; protected: virtual size_t getSerializationSize() override { - return serialized_size(axis_) + serialized_size(output_length_) + + return SerializedSize(axis_) + SerializedSize(output_length_) + getBaseSerializationSize(); } + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. virtual void serialize(void *buffer) override { serializeBase(buffer); - serialize_value(&buffer, axis_); - serialize_value(&buffer, output_length_); + SerializeValue(&buffer, axis_); + SerializeValue(&buffer, output_length_); } public: @@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT { assert(axis <= nvinfer1::Dims::MAX_DIMS); } + // It was used for tensorrt deserialization. + // It should not be called by users. SplitPlugin(void const *serialData, size_t serialLength) { deserializeBase(serialData, serialLength); - deserialize_value(&serialData, &serialLength, &axis_); - deserialize_value(&serialData, &serialLength, &output_length_); + DeserializeValue(&serialData, &serialLength, &axis_); + DeserializeValue(&serialData, &serialLength, &output_length_); } SplitPlugin *clone() const override { @@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT { virtual int initialize() override; virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override; - - void setAxis(int axis) { axis_ = axis; } - - void setOutputLengths(const std::vector &output_lengths) { - output_length_ = output_lengths; - } }; } // tensorrt diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc index 975a5ed162768728c9184174f17eedb2b5256169..08016d84b15bc750738f3183d8d61a5c90862288 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -19,23 +19,23 @@ namespace inference { namespace tensorrt { void PluginTensorRT::serializeBase(void*& buffer) { - serialize_value(&buffer, input_dims_); - serialize_value(&buffer, max_batch_size_); - serialize_value(&buffer, data_type_); - serialize_value(&buffer, data_format_); + SerializeValue(&buffer, input_dims_); + SerializeValue(&buffer, max_batch_size_); + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, data_format_); } void PluginTensorRT::deserializeBase(void const*& serialData, size_t& serialLength) { - deserialize_value(&serialData, &serialLength, &input_dims_); - deserialize_value(&serialData, &serialLength, &max_batch_size_); - deserialize_value(&serialData, &serialLength, &data_type_); - deserialize_value(&serialData, &serialLength, &data_format_); + DeserializeValue(&serialData, &serialLength, &input_dims_); + DeserializeValue(&serialData, &serialLength, &max_batch_size_); + DeserializeValue(&serialData, &serialLength, &data_type_); + DeserializeValue(&serialData, &serialLength, &data_format_); } size_t PluginTensorRT::getBaseSerializationSize() { - return (serialized_size(input_dims_) + serialized_size(max_batch_size_) + - serialized_size(data_type_) + serialized_size(data_format_)); + return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) + + SerializedSize(data_type_) + SerializedSize(data_format_)); } bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 44869b390fa8c64d98ef8595fb5bdd7bc752d168..4d85e955a49b7dcccae158ea06b76419419797cf 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt { size_t getWorkspaceSize(int) const override { return 0; } void terminate() override {} virtual ~PluginTensorRT() {} - - // The following functions need to be overrided in the subclass. - virtual nvinfer1::IPluginExt* clone() const = 0; - virtual const char* getPluginType() const = 0; - int initialize() override { return 0; } + // Check format support. The default is FLOAT32 and NCHW. bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const override; void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, @@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt { nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override; + + // *NOTE* The following functions need to be overrided in the subclass. + virtual nvinfer1::IPluginExt* clone() const = 0; + virtual const char* getPluginType() const = 0; + // Initialize the layer for execution. This is called when the engine is + // created. + int initialize() override { return 0; } + // Serialize the layer config to buffer. virtual void serialize(void* buffer) = 0; virtual size_t getSerializationSize() = 0; + virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) = 0; protected: + // Deserialize input_dims, max_batch_size, data_type, data_format void deserializeBase(void const*& serialData, size_t& serialLength); size_t getBaseSerializationSize(); + // Serialize input_dims, max_batch_size, data_type, data_format void serializeBase(void*& buffer); std::vector input_dims_;