提交 0b962680 编写于 作者: N nhzlx

fix comments

test=develop
上级 e5bf8616
...@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter { ...@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter {
int input_num = op_desc.Input("X").size(); int input_num = op_desc.Input("X").size();
size_t output_num = op_desc.Output("Out").size(); size_t output_num = op_desc.Output("Out").size();
// Get Attrs
PADDLE_ENFORCE(input_num == 1); PADDLE_ENFORCE(input_num == 1);
int axis = boost::get<int>(op_desc.GetAttr("axis")); int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths = std::vector<int> output_lengths =
...@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter { ...@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter {
PADDLE_ENFORCE(output_lengths.size() == output_num); PADDLE_ENFORCE(output_lengths.size() == output_num);
//
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths); SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer = nvinfer1::IPluginLayer* layer =
engine_->addPlugin(&input, input_num, plugin); engine_->AddPlugin(&input, input_num, plugin);
std::string layer_name = "split (Output: "; std::string layer_name = "split (Output: ";
for (size_t i = 0; i < output_num; i++) { for (size_t i = 0; i < output_num; i++) {
......
...@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() { ...@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice(device_); cudaSetDevice(device_);
} }
nvinfer1::IPluginLayer *TensorRTEngine::addPlugin( nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) { nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin); owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin); return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);
......
...@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase { ...@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase {
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch(); int GetRuntimeBatch();
int GetDevice() { return device_; } int GetDevice() { return device_; }
nvinfer1::IPluginLayer* addPlugin(nvinfer1::ITensor* const* inputs, nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int nbInputs, PluginTensorRT*); int nbInputs, PluginTensorRT*);
// A pointer to CPU memory is needed of the TRT weight. // A pointer to CPU memory is needed of the TRT weight.
......
...@@ -20,11 +20,11 @@ ...@@ -20,11 +20,11 @@
#include <vector> #include <vector>
template <typename T> template <typename T>
inline void serialize_value(void** buffer, T const& value); inline void SerializeValue(void** buffer, T const& value);
template <typename T> template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value); T* value);
namespace { namespace {
...@@ -35,14 +35,14 @@ template <typename T> ...@@ -35,14 +35,14 @@ template <typename T>
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value || struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value || std::is_enum<T>::value ||
std::is_pod<T>::value>::type> { std::is_pod<T>::value>::type> {
static size_t serialized_size(T const& value) { return sizeof(T); } static size_t SerializedSize(T const& value) { return sizeof(T); }
static void serialize(void** buffer, T const& value) { static void Serialize(void** buffer, T const& value) {
::memcpy(*buffer, &value, sizeof(T)); std::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T); reinterpret_cast<char*&>(*buffer) += sizeof(T);
} }
static void deserialize(void const** buffer, size_t* buffer_size, T* value) { static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T)); assert(*buffer_size >= sizeof(T));
::memcpy(value, *buffer, sizeof(T)); std::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T); reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T); *buffer_size -= sizeof(T);
} }
...@@ -50,12 +50,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value || ...@@ -50,12 +50,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
template <> template <>
struct Serializer<const char*> { struct Serializer<const char*> {
static size_t serialized_size(const char* value) { return strlen(value) + 1; } static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
static void serialize(void** buffer, const char* value) { static void Serialize(void** buffer, const char* value) {
::strcpy(static_cast<char*>(*buffer), value); std::strcpy(static_cast<char*>(*buffer), value);
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1; reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
} }
static void deserialize(void const** buffer, size_t* buffer_size, static void Deserialize(void const** buffer, size_t* buffer_size,
const char** value) { const char** value) {
*value = static_cast<char const*>(*buffer); *value = static_cast<char const*>(*buffer);
size_t data_size = strnlen(*value, *buffer_size) + 1; size_t data_size = strnlen(*value, *buffer_size) + 1;
...@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>, ...@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>,
typename std::enable_if<std::is_arithmetic<T>::value || typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value || std::is_enum<T>::value ||
std::is_pod<T>::value>::type> { std::is_pod<T>::value>::type> {
static size_t serialized_size(std::vector<T> const& value) { static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T); return sizeof(value.size()) + value.size() * sizeof(T);
} }
static void serialize(void** buffer, std::vector<T> const& value) { static void Serialize(void** buffer, std::vector<T> const& value) {
serialize_value(buffer, value.size()); SerializeValue(buffer, value.size());
size_t nbyte = value.size() * sizeof(T); size_t nbyte = value.size() * sizeof(T);
::memcpy(*buffer, value.data(), nbyte); std::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte; reinterpret_cast<char*&>(*buffer) += nbyte;
} }
static void deserialize(void const** buffer, size_t* buffer_size, static void Deserialize(void const** buffer, size_t* buffer_size,
std::vector<T>* value) { std::vector<T>* value) {
size_t size; size_t size;
deserialize_value(buffer, buffer_size, &size); DeserializeValue(buffer, buffer_size, &size);
value->resize(size); value->resize(size);
size_t nbyte = value->size() * sizeof(T); size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte); assert(*buffer_size >= nbyte);
::memcpy(value->data(), *buffer, nbyte); std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte; reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte; *buffer_size -= nbyte;
} }
...@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>, ...@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>,
} // namespace } // namespace
template <typename T> template <typename T>
inline size_t serialized_size(T const& value) { inline size_t SerializedSize(T const& value) {
return Serializer<T>::serialized_size(value); return Serializer<T>::SerializedSize(value);
} }
template <typename T> template <typename T>
inline void serialize_value(void** buffer, T const& value) { inline void SerializeValue(void** buffer, T const& value) {
return Serializer<T>::serialize(buffer, value); return Serializer<T>::Serialize(buffer, value);
} }
template <typename T> template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size, inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value) { T* value) {
return Serializer<T>::deserialize(buffer, buffer_size, value); return Serializer<T>::Deserialize(buffer, buffer_size, value);
} }
...@@ -37,7 +37,6 @@ int SplitPlugin::initialize() { ...@@ -37,7 +37,6 @@ int SplitPlugin::initialize() {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]); segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
} }
segment_offsets_ = segment_offsets; segment_offsets_ = segment_offsets;
d_segment_offsets_ = segment_offsets;
nvinfer1::Dims dims = this->getInputDims(0); nvinfer1::Dims dims = this->getInputDims(0);
nx_ = 1; nx_ = 1;
for (int i = dims.nbDims - 1; i > axis_; --i) { for (int i = dims.nbDims - 1; i > axis_; --i) {
...@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, ...@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) { void** outputs, void* workspace, cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0); auto const& input_dims = this->getInputDims(0);
int input_size = 0; int input_size = 0;
int const* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
float const* idata = reinterpret_cast<float const*>(inputs[0]); float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs); float** odatas = reinterpret_cast<float**>(outputs);
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#pragma once #pragma once
#include <thrust/device_vector.h>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle { namespace paddle {
...@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT { ...@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT {
int axis_; int axis_;
std::vector<int> output_length_; std::vector<int> output_length_;
int nx_, ny_, nz_; int nx_, ny_, nz_;
thrust::device_vector<int> d_segment_offsets_;
std::vector<int> segment_offsets_; std::vector<int> segment_offsets_;
protected: protected:
virtual size_t getSerializationSize() override { virtual size_t getSerializationSize() override {
return serialized_size(axis_) + serialized_size(output_length_) + return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize(); getBaseSerializationSize();
} }
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual void serialize(void *buffer) override { virtual void serialize(void *buffer) override {
serializeBase(buffer); serializeBase(buffer);
serialize_value(&buffer, axis_); SerializeValue(&buffer, axis_);
serialize_value(&buffer, output_length_); SerializeValue(&buffer, output_length_);
} }
public: public:
...@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT { ...@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT {
assert(axis <= nvinfer1::Dims::MAX_DIMS); assert(axis <= nvinfer1::Dims::MAX_DIMS);
} }
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin(void const *serialData, size_t serialLength) { SplitPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength); deserializeBase(serialData, serialLength);
deserialize_value(&serialData, &serialLength, &axis_); DeserializeValue(&serialData, &serialLength, &axis_);
deserialize_value(&serialData, &serialLength, &output_length_); DeserializeValue(&serialData, &serialLength, &output_length_);
} }
SplitPlugin *clone() const override { SplitPlugin *clone() const override {
...@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT { ...@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT {
virtual int initialize() override; virtual int initialize() override;
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override; void *workspace, cudaStream_t stream) override;
void setAxis(int axis) { axis_ = axis; }
void setOutputLengths(const std::vector<int> &output_lengths) {
output_length_ = output_lengths;
}
}; };
} // tensorrt } // tensorrt
......
...@@ -19,23 +19,23 @@ namespace inference { ...@@ -19,23 +19,23 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
void PluginTensorRT::serializeBase(void*& buffer) { void PluginTensorRT::serializeBase(void*& buffer) {
serialize_value(&buffer, input_dims_); SerializeValue(&buffer, input_dims_);
serialize_value(&buffer, max_batch_size_); SerializeValue(&buffer, max_batch_size_);
serialize_value(&buffer, data_type_); SerializeValue(&buffer, data_type_);
serialize_value(&buffer, data_format_); SerializeValue(&buffer, data_format_);
} }
void PluginTensorRT::deserializeBase(void const*& serialData, void PluginTensorRT::deserializeBase(void const*& serialData,
size_t& serialLength) { size_t& serialLength) {
deserialize_value(&serialData, &serialLength, &input_dims_); DeserializeValue(&serialData, &serialLength, &input_dims_);
deserialize_value(&serialData, &serialLength, &max_batch_size_); DeserializeValue(&serialData, &serialLength, &max_batch_size_);
deserialize_value(&serialData, &serialLength, &data_type_); DeserializeValue(&serialData, &serialLength, &data_type_);
deserialize_value(&serialData, &serialLength, &data_format_); DeserializeValue(&serialData, &serialLength, &data_format_);
} }
size_t PluginTensorRT::getBaseSerializationSize() { size_t PluginTensorRT::getBaseSerializationSize() {
return (serialized_size(input_dims_) + serialized_size(max_batch_size_) + return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
serialized_size(data_type_) + serialized_size(data_format_)); SerializedSize(data_type_) + SerializedSize(data_format_));
} }
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
......
...@@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt { ...@@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
size_t getWorkspaceSize(int) const override { return 0; } size_t getWorkspaceSize(int) const override { return 0; }
void terminate() override {} void terminate() override {}
virtual ~PluginTensorRT() {} virtual ~PluginTensorRT() {}
// Check format support. The default is FLOAT32 and NCHW.
// The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
int initialize() override { return 0; }
bool supportsFormat(nvinfer1::DataType type, bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override; nvinfer1::PluginFormat format) const override;
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs,
...@@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt { ...@@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1::DataType type, nvinfer1::DataType type,
nvinfer1::PluginFormat format, nvinfer1::PluginFormat format,
int maxBatchSize) override; int maxBatchSize) override;
// *NOTE* The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
// Initialize the layer for execution. This is called when the engine is
// created.
int initialize() override { return 0; }
// Serialize the layer config to buffer.
virtual void serialize(void* buffer) = 0; virtual void serialize(void* buffer) = 0;
virtual size_t getSerializationSize() = 0; virtual size_t getSerializationSize() = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
protected: protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void deserializeBase(void const*& serialData, size_t& serialLength); void deserializeBase(void const*& serialData, size_t& serialLength);
size_t getBaseSerializationSize(); size_t getBaseSerializationSize();
// Serialize input_dims, max_batch_size, data_type, data_format
void serializeBase(void*& buffer); void serializeBase(void*& buffer);
std::vector<nvinfer1::Dims> input_dims_; std::vector<nvinfer1::Dims> input_dims_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册