提交 0b962680 编写于 作者: N nhzlx

fix comments

test=develop
上级 e5bf8616
......@@ -35,6 +35,7 @@ class SplitOpConverter : public OpConverter {
int input_num = op_desc.Input("X").size();
size_t output_num = op_desc.Output("Out").size();
// Get Attrs
PADDLE_ENFORCE(input_num == 1);
int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths =
......@@ -48,9 +49,10 @@ class SplitOpConverter : public OpConverter {
PADDLE_ENFORCE(output_lengths.size() == output_num);
//
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer =
engine_->addPlugin(&input, input_num, plugin);
engine_->AddPlugin(&input, input_num, plugin);
std::string layer_name = "split (Output: ";
for (size_t i = 0; i < output_num; i++) {
......
......@@ -254,7 +254,7 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice(device_);
}
nvinfer1::IPluginLayer *TensorRTEngine::addPlugin(
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);
......
......@@ -126,7 +126,7 @@ class TensorRTEngine : public EngineBase {
void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch();
int GetDevice() { return device_; }
nvinfer1::IPluginLayer* addPlugin(nvinfer1::ITensor* const* inputs,
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int nbInputs, PluginTensorRT*);
// A pointer to CPU memory is needed of the TRT weight.
......
......@@ -20,11 +20,11 @@
#include <vector>
template <typename T>
inline void serialize_value(void** buffer, T const& value);
inline void SerializeValue(void** buffer, T const& value);
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size,
T* value);
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value);
namespace {
......@@ -35,14 +35,14 @@ template <typename T>
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t serialized_size(T const& value) { return sizeof(T); }
static void serialize(void** buffer, T const& value) {
::memcpy(*buffer, &value, sizeof(T));
static size_t SerializedSize(T const& value) { return sizeof(T); }
static void Serialize(void** buffer, T const& value) {
std::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T));
::memcpy(value, *buffer, sizeof(T));
std::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
......@@ -50,12 +50,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
template <>
struct Serializer<const char*> {
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
static void serialize(void** buffer, const char* value) {
::strcpy(static_cast<char*>(*buffer), value);
static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
static void Serialize(void** buffer, const char* value) {
std::strcpy(static_cast<char*>(*buffer), value);
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
}
static void deserialize(void const** buffer, size_t* buffer_size,
static void Deserialize(void const** buffer, size_t* buffer_size,
const char** value) {
*value = static_cast<char const*>(*buffer);
size_t data_size = strnlen(*value, *buffer_size) + 1;
......@@ -70,23 +70,23 @@ struct Serializer<std::vector<T>,
typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t serialized_size(std::vector<T> const& value) {
static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void serialize(void** buffer, std::vector<T> const& value) {
serialize_value(buffer, value.size());
static void Serialize(void** buffer, std::vector<T> const& value) {
SerializeValue(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
::memcpy(*buffer, value.data(), nbyte);
std::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void deserialize(void const** buffer, size_t* buffer_size,
static void Deserialize(void const** buffer, size_t* buffer_size,
std::vector<T>* value) {
size_t size;
deserialize_value(buffer, buffer_size, &size);
DeserializeValue(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte);
::memcpy(value->data(), *buffer, nbyte);
std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
......@@ -95,17 +95,17 @@ struct Serializer<std::vector<T>,
} // namespace
template <typename T>
inline size_t serialized_size(T const& value) {
return Serializer<T>::serialized_size(value);
inline size_t SerializedSize(T const& value) {
return Serializer<T>::SerializedSize(value);
}
template <typename T>
inline void serialize_value(void** buffer, T const& value) {
return Serializer<T>::serialize(buffer, value);
inline void SerializeValue(void** buffer, T const& value) {
return Serializer<T>::Serialize(buffer, value);
}
template <typename T>
inline void deserialize_value(void const** buffer, size_t* buffer_size,
T* value) {
return Serializer<T>::deserialize(buffer, buffer_size, value);
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value) {
return Serializer<T>::Deserialize(buffer, buffer_size, value);
}
......@@ -37,7 +37,6 @@ int SplitPlugin::initialize() {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
}
segment_offsets_ = segment_offsets;
d_segment_offsets_ = segment_offsets;
nvinfer1::Dims dims = this->getInputDims(0);
nx_ = 1;
for (int i = dims.nbDims - 1; i > axis_; --i) {
......@@ -55,8 +54,6 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0);
int input_size = 0;
int const* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets_[0]);
float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs);
......
......@@ -14,7 +14,6 @@
#pragma once
#include <thrust/device_vector.h>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
......@@ -25,19 +24,21 @@ class SplitPlugin : public PluginTensorRT {
int axis_;
std::vector<int> output_length_;
int nx_, ny_, nz_;
thrust::device_vector<int> d_segment_offsets_;
std::vector<int> segment_offsets_;
protected:
virtual size_t getSerializationSize() override {
return serialized_size(axis_) + serialized_size(output_length_) +
return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual void serialize(void *buffer) override {
serializeBase(buffer);
serialize_value(&buffer, axis_);
serialize_value(&buffer, output_length_);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_);
}
public:
......@@ -46,10 +47,12 @@ class SplitPlugin : public PluginTensorRT {
assert(axis <= nvinfer1::Dims::MAX_DIMS);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
deserialize_value(&serialData, &serialLength, &axis_);
deserialize_value(&serialData, &serialLength, &output_length_);
DeserializeValue(&serialData, &serialLength, &axis_);
DeserializeValue(&serialData, &serialLength, &output_length_);
}
SplitPlugin *clone() const override {
......@@ -64,12 +67,6 @@ class SplitPlugin : public PluginTensorRT {
virtual int initialize() override;
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
void setAxis(int axis) { axis_ = axis; }
void setOutputLengths(const std::vector<int> &output_lengths) {
output_length_ = output_lengths;
}
};
} // tensorrt
......
......@@ -19,23 +19,23 @@ namespace inference {
namespace tensorrt {
void PluginTensorRT::serializeBase(void*& buffer) {
serialize_value(&buffer, input_dims_);
serialize_value(&buffer, max_batch_size_);
serialize_value(&buffer, data_type_);
serialize_value(&buffer, data_format_);
SerializeValue(&buffer, input_dims_);
SerializeValue(&buffer, max_batch_size_);
SerializeValue(&buffer, data_type_);
SerializeValue(&buffer, data_format_);
}
void PluginTensorRT::deserializeBase(void const*& serialData,
size_t& serialLength) {
deserialize_value(&serialData, &serialLength, &input_dims_);
deserialize_value(&serialData, &serialLength, &max_batch_size_);
deserialize_value(&serialData, &serialLength, &data_type_);
deserialize_value(&serialData, &serialLength, &data_format_);
DeserializeValue(&serialData, &serialLength, &input_dims_);
DeserializeValue(&serialData, &serialLength, &max_batch_size_);
DeserializeValue(&serialData, &serialLength, &data_type_);
DeserializeValue(&serialData, &serialLength, &data_format_);
}
size_t PluginTensorRT::getBaseSerializationSize() {
return (serialized_size(input_dims_) + serialized_size(max_batch_size_) +
serialized_size(data_type_) + serialized_size(data_format_));
return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
SerializedSize(data_type_) + SerializedSize(data_format_));
}
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
......
......@@ -41,11 +41,7 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
size_t getWorkspaceSize(int) const override { return 0; }
void terminate() override {}
virtual ~PluginTensorRT() {}
// The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
int initialize() override { return 0; }
// Check format support. The default is FLOAT32 and NCHW.
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs,
......@@ -53,12 +49,24 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int maxBatchSize) override;
// *NOTE* The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
// Initialize the layer for execution. This is called when the engine is
// created.
int initialize() override { return 0; }
// Serialize the layer config to buffer.
virtual void serialize(void* buffer) = 0;
virtual size_t getSerializationSize() = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void deserializeBase(void const*& serialData, size_t& serialLength);
size_t getBaseSerializationSize();
// Serialize input_dims, max_batch_size, data_type, data_format
void serializeBase(void*& buffer);
std::vector<nvinfer1::Dims> input_dims_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册