未验证 提交 77ac30e5 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #14386 from NHZlX/add_trt_plugin

add plugin support for paddle-trt
...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { ...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"}); "elementwise_add", "dropout", "split"});
if (!node->IsOp()) return false; if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) { if (teller_set.count(node->Op()->Type())) {
......
...@@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm); ...@@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(split);
#endif #endif
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(plugin)
add_subdirectory(convert) add_subdirectory(convert)
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) pad_op.cc split_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
...@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc ...@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL)
...@@ -19,7 +19,7 @@ namespace inference { ...@@ -19,7 +19,7 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
/* /*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. * ConcatOp
*/ */
class ConcatOpConverter : public OpConverter { class ConcatOpConverter : public OpConverter {
public: public:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* SplitOp.
*/
class SplitOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert a fluid split op to tensorrt split layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
int input_num = op_desc.Input("X").size();
size_t output_num = op_desc.Output("Out").size();
// Get Attrs
PADDLE_ENFORCE(input_num == 1);
int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
PADDLE_ENFORCE(axis != 0);
if (axis < 0) {
axis += input_dims.nbDims;
} else {
axis -= 1;
}
PADDLE_ENFORCE(output_lengths.size() == output_num);
//
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
std::string layer_name = "split (Output: ";
for (size_t i = 0; i < output_num; i++) {
auto output_name = op_desc.Output("Out")[i];
layer->getOutput(i)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(i));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
layer->setName((layer_name + ")").c_str());
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(split_op, test) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("split");
desc.SetInput("X", {"split_input"});
desc.SetOutput("Out", {"split_out1", "split_out2"});
int num = 0;
int axis = 1;
std::vector<int> output_lengths = {2, 1};
desc.SetAttr("axis", axis);
desc.SetAttr("num", num);
desc.SetAttr("sections", output_lengths);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(split);
...@@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() { ...@@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice(device_); cudaSetDevice(device_);
} }
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
...@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase { ...@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch(); int GetRuntimeBatch();
int GetDevice() { return device_; } int GetDevice() { return device_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int nbInputs, PluginTensorRT*);
// A pointer to CPU memory is needed of the TRT weight. // A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage. // Before TRT runs, fluid loads weight into GPU storage.
...@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase { ...@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_; std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/> std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_; itensor_map_;
// The specific GPU id that the TensorRTEngine bounded to. // The specific GPU id that the TensorRTEngine bounded to.
int device_; int device_;
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugin_;
// TensorRT related internal members // TensorRT related internal members
template <typename T> template <typename T>
......
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <type_traits>
#include <vector>
template <typename T>
inline void SerializeValue(void** buffer, T const& value);
template <typename T>
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value);
namespace {
template <typename T, class Enable = void>
struct Serializer {};
template <typename T>
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t SerializedSize(T const& value) { return sizeof(T); }
static void Serialize(void** buffer, T const& value) {
std::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T));
std::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
};
template <>
struct Serializer<const char*> {
static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
static void Serialize(void** buffer, const char* value) {
std::strcpy(static_cast<char*>(*buffer), value);
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
}
static void Deserialize(void const** buffer, size_t* buffer_size,
const char** value) {
*value = static_cast<char const*>(*buffer);
size_t data_size = strnlen(*value, *buffer_size) + 1;
assert(*buffer_size >= data_size);
reinterpret_cast<char const*&>(*buffer) += data_size;
*buffer_size -= data_size;
}
};
template <typename T>
struct Serializer<std::vector<T>,
typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void Serialize(void** buffer, std::vector<T> const& value) {
SerializeValue(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
std::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void Deserialize(void const** buffer, size_t* buffer_size,
std::vector<T>* value) {
size_t size;
DeserializeValue(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte);
std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
} // namespace
template <typename T>
inline size_t SerializedSize(T const& value) {
return Serializer<T>::SerializedSize(value);
}
template <typename T>
inline void SerializeValue(void** buffer, T const& value) {
return Serializer<T>::Serialize(buffer, value);
}
template <typename T>
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value) {
return Serializer<T>::Deserialize(buffer, buffer_size, value);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
const nvinfer1::Dims* inputDims,
int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
output_dims.d[axis_] = output_length_.at(index);
return output_dims;
}
int SplitPlugin::initialize() {
std::vector<int> segment_offsets(1, 0);
for (int i = 0; i < this->getNbOutputs(); ++i) {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
}
segment_offsets_ = segment_offsets;
nvinfer1::Dims dims = this->getInputDims(0);
nx_ = 1;
for (int i = dims.nbDims - 1; i > axis_; --i) {
nx_ *= dims.d[i];
}
ny_ = dims.d[axis_];
nz_ = 1;
for (int i = axis_ - 1; i >= 0; --i) {
nz_ *= dims.d[i];
}
return 0;
}
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0);
int input_size = 0;
float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs);
// kernel impl here.
int inputBatchOffset = nx_ * ny_ * nz_;
for (size_t i = 0; i < this->getNbOutputs(); i++) {
for (size_t j = 0; j < batchSize; j++) {
cudaMemcpyAsync(
odatas[i] +
j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ *
sizeof(float),
inputs[0] +
(inputBatchOffset * j + segment_offsets_[i] * nx_) *
sizeof(float),
(segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
}
}
return cudaGetLastError() != cudaSuccess;
}
} // tensorrt
} // inference
} // paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class SplitPlugin : public PluginTensorRT {
int axis_;
std::vector<int> output_length_;
int nx_, ny_, nz_;
std::vector<int> segment_offsets_;
protected:
virtual size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual void serialize(void *buffer) override {
serializeBase(buffer);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_);
}
public:
SplitPlugin(int axis, std::vector<int> const &output_lengths)
: axis_(axis), output_length_(output_lengths) {
assert(axis <= nvinfer1::Dims::MAX_DIMS);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &axis_);
DeserializeValue(&serialData, &serialLength, &output_length_);
}
SplitPlugin *clone() const override {
return new SplitPlugin(axis_, output_length_);
}
virtual const char *getPluginType() const override { return "split"; }
virtual int getNbOutputs() const override { return output_length_.size(); }
virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *inputs,
int nbInputDims) override;
virtual int initialize() override;
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // tensorrt
} // inference
} // paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
void PluginTensorRT::serializeBase(void*& buffer) {
SerializeValue(&buffer, input_dims_);
SerializeValue(&buffer, max_batch_size_);
SerializeValue(&buffer, data_type_);
SerializeValue(&buffer, data_format_);
}
void PluginTensorRT::deserializeBase(void const*& serialData,
size_t& serialLength) {
DeserializeValue(&serialData, &serialLength, &input_dims_);
DeserializeValue(&serialData, &serialLength, &max_batch_size_);
DeserializeValue(&serialData, &serialLength, &data_type_);
DeserializeValue(&serialData, &serialLength, &data_format_);
}
size_t PluginTensorRT::getBaseSerializationSize() {
return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
SerializedSize(data_type_) + SerializedSize(data_format_));
}
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
}
void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims,
int nbInputs,
const nvinfer1::Dims* outputDims,
int nbOutputs, nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int maxBatchSize) {
data_type_ = type;
data_format_ = format;
input_dims_.assign(inputDims, inputDims + nbInputs);
max_batch_size_ = maxBatchSize;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <iostream>
#include <unordered_map>
#include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PluginTensorRT : public nvinfer1::IPluginExt {
public:
PluginTensorRT() {}
PluginTensorRT(const void* serialized_data, size_t length) {}
nvinfer1::Dims const& getInputDims(int index) const {
return input_dims_.at(index);
}
size_t getMaxBatchSize() const { return max_batch_size_; }
nvinfer1::DataType getDataType() const { return data_type_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
virtual const char* getPluginVersion() const { return "1"; }
size_t getWorkspaceSize(int) const override { return 0; }
void terminate() override {}
virtual ~PluginTensorRT() {}
// Check format support. The default is FLOAT32 and NCHW.
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs,
const nvinfer1::Dims* outputDims, int nbOutputs,
nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int maxBatchSize) override;
// *NOTE* The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
// Initialize the layer for execution. This is called when the engine is
// created.
int initialize() override { return 0; }
// Serialize the layer config to buffer.
virtual void serialize(void* buffer) = 0;
virtual size_t getSerializationSize() = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void deserializeBase(void const*& serialData, size_t& serialLength);
size_t getBaseSerializationSize();
// Serialize input_dims, max_batch_size, data_type, data_format
void serializeBase(void*& buffer);
std::vector<nvinfer1::Dims> input_dims_;
size_t max_batch_size_;
nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册