diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index a610687a5b11999a7cb7426dbe961e5972ee1746..e09705e3c69eb2b2370bd1ad2d9cf178ef041ee6 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,5 @@ nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) +add_subdirectory(plugin) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 0a35e10f6936313928ab21a6f17c40335e8fc882..e34d5db6b830861a40f4a5801d2ccaf8465335b2 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -2,7 +2,7 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc - DEPS tensorrt_engine operator scope framework_proto op_registry) + DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc index 60c16e35ed39a4cb14cbc16ebacaba2bb72bcc81..cd1bb892bdf0ccb69114ed51accad40d263581d5 100644 --- a/paddle/fluid/inference/tensorrt/convert/concat_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -19,7 +19,7 @@ namespace inference { namespace tensorrt { /* - * MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. + * ConcatOp */ class ConcatOpConverter : public OpConverter { public: diff --git a/paddle/fluid/inference/tensorrt/plugin/.trt_plugin_utils.h.swp b/paddle/fluid/inference/tensorrt/plugin/.trt_plugin_utils.h.swp new file mode 100644 index 0000000000000000000000000000000000000000..08d1434089f792131d0e6a545ad8675b3ba4892c Binary files /dev/null and b/paddle/fluid/inference/tensorrt/plugin/.trt_plugin_utils.h.swp differ diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b91c864c9e4205d1fb13a64ad8666d38dc26795 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -0,0 +1,2 @@ +nv_library(tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc +trt_plugin.cc split_op_plugin.cu DEPS enforce) diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc b/paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ebcd44611a0f79560d3639400277cb8193feebb --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, + const void* serial_data, + size_t serial_length) { + size_t parsed_byte = 0; + std::string encoded_op_name = + ExtractOpName(serial_data, serial_length, &parsed_byte); + + if (!IsPlugin(encoded_op_name)) { + return nullptr; + } + + auto plugin_ptr = + plugin_registry_[encoded_op_name].first(serial_data, serial_length); + owned_plugins_.emplace_back(plugin_ptr); + + return plugin_ptr; +} + +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( + const std::string& op_name) { + if (!IsPlugin(op_name)) return nullptr; + + auto plugin_ptr = plugin_registry_[op_name].second(); + owned_plugins_.emplace_back(plugin_ptr); + + return plugin_ptr; +} + +bool PluginFactoryTensorRT::RegisterPlugin( + const std::string& op_name, PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + if (IsPlugin(op_name)) return false; + + auto ret = plugin_registry_.emplace( + op_name, std::make_pair(deserialize_func, construct_func)); + + return ret.second; +} + +void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); } + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_factory.h b/paddle/fluid/inference/tensorrt/plugin/plugin_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..00435766f741f8e229dab4a50521818c0131585f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/plugin_factory.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "NvInfer.h" +#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { + public: + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; + } + + // Deserialization method + PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, + size_t serial_length) override; + + // Plugin construction, PluginFactoryTensorRT owns the plugin. + PluginTensorRT* CreatePlugin(const std::string& op_name); + + bool RegisterPlugin(const std::string& op_name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func); + + bool IsPlugin(const std::string& op_name) { + return plugin_registry_.find(op_name) != plugin_registry_.end(); + } + + size_t CountOwnedPlugins() { return owned_plugins_.size(); } + + void DestroyPlugins(); + + protected: + std::unordered_map> + plugin_registry_; + std::vector> owned_plugins_; +}; + +class TrtPluginRegistrar { + public: + TrtPluginRegistrar(const std::string& name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + auto factory = PluginFactoryTensorRT::GetInstance(); + // platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func, + // construct_func), "Falied to register plugin [%s]", name); + // platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func, + // construct_func)); + factory->RegisterPlugin(name, deserialize_func, construct_func); + } +}; + +#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \ + construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \ + construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ + static ::paddle::inference::tensorrt::TrtPluginRegistrar \ + trt_plugin_registrar##ctr __attribute__((unused)) = \ + ::paddle::inference::tensorrt::TrtPluginRegistrar( \ + name, deserialize_func, construct_func) + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc b/paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cc4162aa74254250b2ff2ce0e6d80dd2bfd3c99 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" +#include + +namespace paddle { +namespace inference { +namespace tensorrt { + +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { + size_t op_name_char_count = *static_cast(serial_data); + *incremental = sizeof(size_t) + op_name_char_count; + + assert(serial_length >= *incremental); + + const char* buffer = static_cast(serial_data) + sizeof(size_t); + std::string op_name(buffer, op_name_char_count); + + return op_name; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_utils.h b/paddle/fluid/inference/tensorrt/plugin/plugin_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..fb6608c12abc9a1576e63f4b29dec58b1b860d38 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/plugin_utils.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "NvInfer.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +typedef std::function + PluginDeserializeFunc; +typedef std::function PluginConstructFunc; + +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); + +} // namespace tensorrt +} // namespace inference +} // namespze paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/serialize.hpp b/paddle/fluid/inference/tensorrt/plugin/serialize.hpp new file mode 100644 index 0000000000000000000000000000000000000000..96df352feb5b1a85b0ff7adebb7baf5f30c115e6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/serialize.hpp @@ -0,0 +1,111 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +template +inline void serialize_value(void** buffer, T const& value); + +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, + T* value); + +namespace { + +template +struct Serializer {}; + +template +struct Serializer::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t serialized_size(T const& value) { return sizeof(T); } + static void serialize(void** buffer, T const& value) { + ::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + static void deserialize(void const** buffer, size_t* buffer_size, T* value) { + assert(*buffer_size >= sizeof(T)); + ::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } +}; + +template <> +struct Serializer { + static size_t serialized_size(const char* value) { return strlen(value) + 1; } + static void serialize(void** buffer, const char* value) { + ::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + static void deserialize(void const** buffer, size_t* buffer_size, + const char** value) { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } +}; + +template +struct Serializer, + typename std::enable_if::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t serialized_size(std::vector const& value) { + return sizeof(value.size()) + value.size() * sizeof(T); + } + static void serialize(void** buffer, std::vector const& value) { + serialize_value(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + ::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + static void deserialize(void const** buffer, size_t* buffer_size, + std::vector* value) { + size_t size; + deserialize_value(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + ::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } +}; + +} // namespace + +template +inline size_t serialized_size(T const& value) { + return Serializer::serialized_size(value); +} + +template +inline void serialize_value(void** buffer, T const& value) { + return Serializer::serialize(buffer, value); +} + +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, + T* value) { + return Serializer::deserialize(buffer, buffer_size, value); +} diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..044c229b55c5171e34d7441c62cc6c5ec13d6ff2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -0,0 +1,114 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +SplitPlugin* CreateSplitPlugin() { return new SplitPlugin(); }; + +nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, + const nvinfer1::Dims* inputDims, + int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const& input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + output_dims.d[axis_] = output_lenght_.at(index); + return output_dims; +} + +int SplitPlugin::initialize() { + std::vector segment_offsets(1, 0); + for (int i = 0; i < this->getNbOutputs(); ++i) { + segment_offsets.push_back(segment_offsets.back() + output_lenght_[i]); + } + d_segment_offsets_ = segment_offsets; + nvinfer1::Dims dims = this->getInputDims(0); + nx_ = 1; + for (int i = dims.nbDims - 1; i > axis_; --i) { + nx_ *= dims.d[i]; + } + ny_ = dims.d[axis_]; + nz_ = 1; + for (int i = axis_ - 1; i >= 0; --i) { + nz_ *= dims.d[i]; + } + return 0; +} + +template +__device__ int upper_bound(T const* vals, int n, T const& key) { + int i = 0; + while (n > 0) { + int m = n / 2; + int j = i + m; + if (!(key < vals[j])) { + i = j + 1; + n -= m + 1; + } else { + n = m; + } + } + return i; +} + +template +__global__ void split_kernel(int nsegment, + int const* __restrict__ segment_offsets, + T const* __restrict__ idata, T* const* odatas, + int nx, int srcny_, int nz) { + int x0 = threadIdx.x + blockIdx.x * blockDim.x; + int src_y0 = threadIdx.y + blockIdx.y * blockDim.y; + int z0 = threadIdx.z + blockIdx.z * blockDim.z; + for (int z = z0; z < nz; z += blockDim.z * gridDim.z) { + for (int src_y = src_y0; src_y < srcny_; src_y += blockDim.y * gridDim.y) { + for (int x = x0; x < nx; x += blockDim.x * gridDim.x) { + int segment = upper_bound(segment_offsets, nsegment, src_y) - 1; + int dst_y = src_y - segment_offsets[segment]; + int dstny_ = segment_offsets[segment + 1] - segment_offsets[segment]; + odatas[segment][x + nx * (dst_y + dstny_ * z)] = + idata[x + nx * (src_y + srcny_ * z)]; + } + } + } +} + +int SplitPlugin::enqueue(int batchSize, const void* const* inputs, + void** outputs, void* workspace, cudaStream_t stream) { + auto const& input_dims = this->getInputDims(0); + int const* d_segment_offsets_ptr = + thrust::raw_pointer_cast(&d_segment_offsets_[0]); + float const* idata = reinterpret_cast(inputs[0]); + float** odatas = reinterpret_cast(outputs); + + int nz = nz_ * batchSize; + dim3 block(32, 16); + dim3 grid(std::min((nx_ - 1) / block.x + 1, 65535u), + std::min((ny_ - 1) / block.y + 1, 65535u), + std::min((nz_ - 1) / block.z + 1, 65535u)); + + split_kernel<<>>(d_segment_offsets_.size(), + d_segment_offsets_ptr, idata, odatas, + nx_, ny_, nz); + + return cudaGetLastError() != cudaSuccess; +} + +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..406c822bb5eb090f54ec49511c9aaa21666126be --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -0,0 +1,62 @@ + +#pragma once + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SplitPlugin : public PluginTensorRT { + int axis_; + std::vector output_lenght_; + int nx_, ny_, nz_; + thrust::device_vector d_segment_offsets_; + + protected: + virtual size_t getSerializationSize() override { + return serialized_size(axis_) + serialized_size(output_lenght_) + + getBaseSerializationSize(); + } + + virtual void serialize(void *buffer) override { + serializeBase(buffer); + serialize_value(&buffer, axis_); + serialize_value(&buffer, output_lenght_); + } + + public: + Split() {} + SplitPlugin(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + deserialize_value(&serialData, &serialLength, &axis_); + deserialize_value(&serialData, &serialLength, &output_lenght_); + } + + SplitPlugin* clone() const override { + return new SplitPlugin(axis_, output_lenght_); + } + + virtual const char* getPluginType() const override { return "split"; } + virtual int getNbOutputs() const override { return output_lenght_.size(); } + virtual nvinfer1::Dims getOutputDimensions(int index, + const nvinfer1::Dims *inputs, int nbInputDims) override; + virtual int initialize() override; + virtual int enqueue(int batchSize, + const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; + + void setAxis(int axis) { + axis_ = axis; + } + + void setOutputLengths(const std::vector & output_lengths) { + output_length_ = output_lengths; + } + +}; + +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eff6665d42755bfe5e28d1ef8d6c6d3df62fed3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +void PluginTensorRT::serializeBase(void*& buffer) { + serialize_value(&buffer, input_dims_); + serialize_value(&buffer, max_batch_size_); + serialize_value(&buffer, data_type_); + serialize_value(&buffer, data_format_); +} + +void PluginTensorRT::deserializeBase(void const*& serialData, + size_t& serialLength) { + deserialize_value(&serialData, &serialLength, &input_dims_); + deserialize_value(&serialData, &serialLength, &max_batch_size_); + deserialize_value(&serialData, &serialLength, &data_type_); + deserialize_value(&serialData, &serialLength, &data_format_); +} + +size_t PluginTensorRT::getBaseSerializationSize() { + return (serialized_size(input_dims_) + serialized_size(max_batch_size_) + + serialized_size(data_type_) + serialized_size(data_format_)); +} + +bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const { + return ((type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF) && + (format == nvinfer1::PluginFormat::kNCHW)); +} + +void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims, + int nbInputs, + const nvinfer1::Dims* outputDims, + int nbOutputs, nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int maxBatchSize) { + data_type_ = type; + data_format_ = format; + input_dims_.assign(inputDims, inputDims + nbInputs); + max_batch_size_ = maxBatchSize; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..8168646bdec3fcd19be2332dc3379bfd1e549d9e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,72 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/plugin/serialize.hpp" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PluginTensorRT : public nvinfer1::IPluginExt { + public: + PluginTensorRT() {} + PluginTensorRT(const void* serialized_data, size_t length) {} + nvinfer1::Dims const& getInputDims(int index) const { + return input_dims_.at(index); + } + size_t getMaxBatchSize() const { return max_batch_size_; } + nvinfer1::DataType getDataType() const { return data_type_; } + nvinfer1::PluginFormat getDataFormat() const { return data_format_; } + virtual const char* getPluginVersion() const { return "1"; } + size_t getWorkspaceSize(int) const override { return 0; } + void terminate() override {} + virtual ~PluginTensorRT() {} + + // The following functions need to be overrided in the subclass. + virtual nvinfer1::IPluginExt* clone() const = 0; + virtual const char* getPluginType() const = 0; + int initialize() override { return 0; } + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const override; + void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, + const nvinfer1::Dims* outputDims, int nbOutputs, + nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int maxBatchSize) override; + virtual void serialize(void* buffer) override; + virtual size_t getSerializationSize() override; + + protected: + void deserializeBase(void const*& serialData, size_t& serialLength); + size_t getBaseSerializationSize(); + void serializeBase(void*& buffer); + + std::vector input_dims_; + size_t max_batch_size_; + nvinfer1::DataType data_type_; + nvinfer1::PluginFormat data_format_; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle