diff --git a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc index dc4d0906c4f260c8f7a11832fc52eba7191c54e8..233bfd6a42b7f123813d4ef5cecf353f7e88d208 100644 --- a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc @@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", - "elementwise_add", "dropout"}); + "elementwise_add", "dropout", "split"}); if (!node->IsOp()) return false; if (teller_set.count(node->Op()->Type())) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7407a1ba2f63bfe31a9d3a6f33395575c5809dee..76d205b737aeb456f242037f2b375d9c537b39f3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); +USE_TRT_CONVERTER(split); #endif diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index a610687a5b11999a7cb7426dbe961e5972ee1746..e09705e3c69eb2b2370bd1ad2d9cf178ef041ee6 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,5 @@ nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) +add_subdirectory(plugin) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 0a35e10f6936313928ab21a6f17c40335e8fc882..ed4c398cee518af3211cab4e982082c46ebb36c2 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,8 +1,9 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc - DEPS tensorrt_engine operator scope framework_proto op_registry) +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc +pad_op.cc split_op.cc + DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) @@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) - nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) +nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin +split_op concat_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc index b2e7c593e85974898012f8a353817a27ca212f4d..525ba9dc341c8c1343553ac9523611f79ac3aa2d 100644 --- a/paddle/fluid/inference/tensorrt/convert/concat_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -19,7 +19,7 @@ namespace inference { namespace tensorrt { /* - * MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. + * ConcatOp */ class ConcatOpConverter : public OpConverter { public: diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..12179cccc76f8b0f595f41c135290dc0f3b50ad7 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * SplitOp. + */ +class SplitOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(40) << "convert a fluid split op to tensorrt split layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + int input_num = op_desc.Input("X").size(); + size_t output_num = op_desc.Output("Out").size(); + + // Get Attrs + PADDLE_ENFORCE(input_num == 1); + int axis = boost::get(op_desc.GetAttr("axis")); + std::vector output_lengths = + boost::get>(op_desc.GetAttr("sections")); + PADDLE_ENFORCE(axis != 0); + if (axis < 0) { + axis += input_dims.nbDims; + } else { + axis -= 1; + } + + PADDLE_ENFORCE(output_lengths.size() == output_num); + + // + SplitPlugin* plugin = new SplitPlugin(axis, output_lengths); + nvinfer1::IPluginLayer* layer = + engine_->AddPlugin(&input, input_num, plugin); + + std::string layer_name = "split (Output: "; + for (size_t i = 0; i < output_num; i++) { + auto output_name = op_desc.Output("Out")[i]; + layer->getOutput(i)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(i)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } + layer->setName((layer_name + ")").c_str()); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_split_op.cc b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f81d011552c152c2df79e1a272f34b954ae2a3a1 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(split_op, test) { + std::unordered_set parameters({""}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2)); + validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("split"); + desc.SetInput("X", {"split_input"}); + desc.SetOutput("Out", {"split_out1", "split_out2"}); + + int num = 0; + int axis = 1; + std::vector output_lengths = {2, 1}; + desc.SetAttr("axis", axis); + desc.SetAttr("num", num); + desc.SetAttr("sections", output_lengths); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(split); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 8adc3baca64845f596477a0abe61be31e7377d9f..fdd8b56b0ce5c9b5cb6395bcb437aae5ae27829b 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() { cudaSetDevice(device_); } +nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( + nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) { + owned_plugin_.emplace_back(plugin); + return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 828181200e300c370bbfa234c3c23ae44810878c..335acdf653e55cc7f3ceccdba88992851c8e0310 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { @@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase { void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); int GetDevice() { return device_; } + nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, + int nbInputs, PluginTensorRT*); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. @@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase { std::unordered_map buffer_sizes_; std::unordered_map itensor_map_; + // The specific GPU id that the TensorRTEngine bounded to. int device_; + std::vector> owned_plugin_; // TensorRT related internal members template diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..71b7a551619a43e5300ad3205418d1174c7019ff --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -0,0 +1 @@ +nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) diff --git a/paddle/fluid/inference/tensorrt/plugin/serialize.h b/paddle/fluid/inference/tensorrt/plugin/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..50c0b17d78327e22b0aa81fdac6958e80a30dfe8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/serialize.h @@ -0,0 +1,111 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +template +inline void SerializeValue(void** buffer, T const& value); + +template +inline void DeserializeValue(void const** buffer, size_t* buffer_size, + T* value); + +namespace { + +template +struct Serializer {}; + +template +struct Serializer::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t SerializedSize(T const& value) { return sizeof(T); } + static void Serialize(void** buffer, T const& value) { + std::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + static void Deserialize(void const** buffer, size_t* buffer_size, T* value) { + assert(*buffer_size >= sizeof(T)); + std::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } +}; + +template <> +struct Serializer { + static size_t SerializedSize(const char* value) { return strlen(value) + 1; } + static void Serialize(void** buffer, const char* value) { + std::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + static void Deserialize(void const** buffer, size_t* buffer_size, + const char** value) { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } +}; + +template +struct Serializer, + typename std::enable_if::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t SerializedSize(std::vector const& value) { + return sizeof(value.size()) + value.size() * sizeof(T); + } + static void Serialize(void** buffer, std::vector const& value) { + SerializeValue(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + std::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + static void Deserialize(void const** buffer, size_t* buffer_size, + std::vector* value) { + size_t size; + DeserializeValue(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + std::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } +}; + +} // namespace + +template +inline size_t SerializedSize(T const& value) { + return Serializer::SerializedSize(value); +} + +template +inline void SerializeValue(void** buffer, T const& value) { + return Serializer::Serialize(buffer, value); +} + +template +inline void DeserializeValue(void const** buffer, size_t* buffer_size, + T* value) { + return Serializer::Deserialize(buffer, buffer_size, value); +} diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd6a44dcc14d50cddb879763a93abf4297494ec9 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -0,0 +1,81 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, + const nvinfer1::Dims* inputDims, + int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const& input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + output_dims.d[axis_] = output_length_.at(index); + return output_dims; +} + +int SplitPlugin::initialize() { + std::vector segment_offsets(1, 0); + for (int i = 0; i < this->getNbOutputs(); ++i) { + segment_offsets.push_back(segment_offsets.back() + output_length_[i]); + } + segment_offsets_ = segment_offsets; + nvinfer1::Dims dims = this->getInputDims(0); + nx_ = 1; + for (int i = dims.nbDims - 1; i > axis_; --i) { + nx_ *= dims.d[i]; + } + ny_ = dims.d[axis_]; + nz_ = 1; + for (int i = axis_ - 1; i >= 0; --i) { + nz_ *= dims.d[i]; + } + return 0; +} + +int SplitPlugin::enqueue(int batchSize, const void* const* inputs, + void** outputs, void* workspace, cudaStream_t stream) { + auto const& input_dims = this->getInputDims(0); + int input_size = 0; + float const* idata = reinterpret_cast(inputs[0]); + float** odatas = reinterpret_cast(outputs); + + // kernel impl here. + int inputBatchOffset = nx_ * ny_ * nz_; + for (size_t i = 0; i < this->getNbOutputs(); i++) { + for (size_t j = 0; j < batchSize; j++) { + cudaMemcpyAsync( + odatas[i] + + j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * + sizeof(float), + inputs[0] + + (inputBatchOffset * j + segment_offsets_[i] * nx_) * + sizeof(float), + (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + } + } + + return cudaGetLastError() != cudaSuccess; +} + +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..7281e40c331550de472df49c57b1d9a5226842d5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -0,0 +1,74 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SplitPlugin : public PluginTensorRT { + int axis_; + std::vector output_length_; + int nx_, ny_, nz_; + std::vector segment_offsets_; + + protected: + virtual size_t getSerializationSize() override { + return SerializedSize(axis_) + SerializedSize(output_length_) + + getBaseSerializationSize(); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + virtual void serialize(void *buffer) override { + serializeBase(buffer); + SerializeValue(&buffer, axis_); + SerializeValue(&buffer, output_length_); + } + + public: + SplitPlugin(int axis, std::vector const &output_lengths) + : axis_(axis), output_length_(output_lengths) { + assert(axis <= nvinfer1::Dims::MAX_DIMS); + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + SplitPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &axis_); + DeserializeValue(&serialData, &serialLength, &output_length_); + } + + SplitPlugin *clone() const override { + return new SplitPlugin(axis_, output_length_); + } + + virtual const char *getPluginType() const override { return "split"; } + virtual int getNbOutputs() const override { return output_length_.size(); } + virtual nvinfer1::Dims getOutputDimensions(int index, + const nvinfer1::Dims *inputs, + int nbInputDims) override; + virtual int initialize() override; + virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000000000000000000000000000000000..08016d84b15bc750738f3183d8d61a5c90862288 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +void PluginTensorRT::serializeBase(void*& buffer) { + SerializeValue(&buffer, input_dims_); + SerializeValue(&buffer, max_batch_size_); + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, data_format_); +} + +void PluginTensorRT::deserializeBase(void const*& serialData, + size_t& serialLength) { + DeserializeValue(&serialData, &serialLength, &input_dims_); + DeserializeValue(&serialData, &serialLength, &max_batch_size_); + DeserializeValue(&serialData, &serialLength, &data_type_); + DeserializeValue(&serialData, &serialLength, &data_format_); +} + +size_t PluginTensorRT::getBaseSerializationSize() { + return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) + + SerializedSize(data_type_) + SerializedSize(data_format_)); +} + +bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const { + return ((type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::PluginFormat::kNCHW)); +} + +void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims, + int nbInputs, + const nvinfer1::Dims* outputDims, + int nbOutputs, nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int maxBatchSize) { + data_type_ = type; + data_format_ = format; + input_dims_.assign(inputDims, inputDims + nbInputs); + max_batch_size_ = maxBatchSize; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..4d85e955a49b7dcccae158ea06b76419419797cf --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,80 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "NvInfer.h" + +#include "paddle/fluid/inference/tensorrt/plugin/serialize.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PluginTensorRT : public nvinfer1::IPluginExt { + public: + PluginTensorRT() {} + PluginTensorRT(const void* serialized_data, size_t length) {} + nvinfer1::Dims const& getInputDims(int index) const { + return input_dims_.at(index); + } + size_t getMaxBatchSize() const { return max_batch_size_; } + nvinfer1::DataType getDataType() const { return data_type_; } + nvinfer1::PluginFormat getDataFormat() const { return data_format_; } + virtual const char* getPluginVersion() const { return "1"; } + size_t getWorkspaceSize(int) const override { return 0; } + void terminate() override {} + virtual ~PluginTensorRT() {} + // Check format support. The default is FLOAT32 and NCHW. + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const override; + void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, + const nvinfer1::Dims* outputDims, int nbOutputs, + nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int maxBatchSize) override; + + // *NOTE* The following functions need to be overrided in the subclass. + virtual nvinfer1::IPluginExt* clone() const = 0; + virtual const char* getPluginType() const = 0; + // Initialize the layer for execution. This is called when the engine is + // created. + int initialize() override { return 0; } + // Serialize the layer config to buffer. + virtual void serialize(void* buffer) = 0; + virtual size_t getSerializationSize() = 0; + virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) = 0; + + protected: + // Deserialize input_dims, max_batch_size, data_type, data_format + void deserializeBase(void const*& serialData, size_t& serialLength); + size_t getBaseSerializationSize(); + // Serialize input_dims, max_batch_size, data_type, data_format + void serializeBase(void*& buffer); + + std::vector input_dims_; + size_t max_batch_size_; + nvinfer1::DataType data_type_; + nvinfer1::PluginFormat data_format_; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle