提交 d38fd6a0 编写于 作者: N nhzlx

add plugin support and offer an simple split sample

上级 2d7134bc
...@@ -71,7 +71,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -71,7 +71,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"}); "elementwise_add", "dropout", "split"});
if (!node->IsFunction()) return false; if (!node->IsFunction()) return false;
const auto* func = static_cast<const Function*>(node); const auto* func = static_cast<const Function*>(node);
......
...@@ -186,3 +186,4 @@ USE_TRT_CONVERTER(batch_norm); ...@@ -186,3 +186,4 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(split);
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc ...@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL)
...@@ -12,53 +12,62 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,53 +12,62 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, /*
const void* serial_data, * SplitOp.
size_t serial_length) { */
size_t parsed_byte = 0; class SplitOpConverter : public OpConverter {
std::string encoded_op_name = public:
ExtractOpName(serial_data, serial_length, &parsed_byte); void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
if (!IsPlugin(encoded_op_name)) { VLOG(40) << "convert a fluid split op to tensorrt split layer";
return nullptr;
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
int input_num = op_desc.Input("X").size();
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE(input_num == 1);
int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
PADDLE_ENFORCE(axis != 0);
if (axis < 0) {
axis += input_dims.nbDims;
} else {
axis -= 1;
}
PADDLE_ENFORCE(output_lengths.size() == output_num);
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer =
engine_->addPlugin(&input, input_num, plugin);
std::string layer_name = "split (Output: ";
for (size_t i = 0; i < output_num; i++) {
auto output_name = op_desc.Output("Out")[i];
layer->getOutput(i)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(i));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
layer->setName((layer_name + ")").c_str());
} }
};
auto plugin_ptr =
plugin_registry_[encoded_op_name].first(serial_data, serial_length);
owned_plugins_.emplace_back(plugin_ptr);
return plugin_ptr;
}
PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(
const std::string& op_name) {
if (!IsPlugin(op_name)) return nullptr;
auto plugin_ptr = plugin_registry_[op_name].second();
owned_plugins_.emplace_back(plugin_ptr);
return plugin_ptr;
}
bool PluginFactoryTensorRT::RegisterPlugin(
const std::string& op_name, PluginDeserializeFunc deserialize_func,
PluginConstructFunc construct_func) {
if (IsPlugin(op_name)) return false;
auto ret = plugin_registry_.emplace(
op_name, std::make_pair(deserialize_func, construct_func));
return ret.second;
}
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter);
...@@ -12,26 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,26 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" #include <gtest/gtest.h>
#include <cassert> #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
std::string ExtractOpName(const void* serial_data, size_t serial_length, TEST(split_op, test) {
size_t* incremental) { std::unordered_set<std::string> parameters({""});
size_t op_name_char_count = *static_cast<const size_t*>(serial_data); framework::Scope scope;
*incremental = sizeof(size_t) + op_name_char_count; TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
assert(serial_length >= *incremental); validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
const char* buffer = static_cast<const char*>(serial_data) + sizeof(size_t);
std::string op_name(buffer, op_name_char_count); // Prepare Op description
framework::OpDesc desc;
return op_name; desc.SetType("split");
desc.SetInput("X", {"split_input"});
desc.SetOutput("Out", {"split_out1", "split_out2"});
int num = 0;
int axis = 1;
std::vector<int> output_lengths = {2, 1};
desc.SetAttr("axis", axis);
desc.SetAttr("num", num);
desc.SetAttr("sections", output_lengths);
validator.SetOp(*desc.Proto());
validator.Execute(1);
} }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(split);
...@@ -254,6 +254,12 @@ void TensorRTEngine::freshDeviceId() { ...@@ -254,6 +254,12 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice(device_); cudaSetDevice(device_);
} }
nvinfer1::IPluginLayer *TensorRTEngine::addPlugin(
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
...@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase { ...@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch(); int GetRuntimeBatch();
int GetDevice() { return device_; } int GetDevice() { return device_; }
nvinfer1::IPluginLayer* addPlugin(nvinfer1::ITensor* const* inputs,
int nbInputs, PluginTensorRT*);
// A pointer to CPU memory is needed of the TRT weight. // A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage. // Before TRT runs, fluid loads weight into GPU storage.
...@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase { ...@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_; std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/> std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_; itensor_map_;
// The specific GPU id that the TensorRTEngine bounded to. // The specific GPU id that the TensorRTEngine bounded to.
int device_; int device_;
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugin_;
// TensorRT related internal members // TensorRT related internal members
template <typename T> template <typename T>
......
nv_library(tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)
trt_plugin.cc split_op_plugin.cu DEPS enforce)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <unordered_map>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
public:
static PluginFactoryTensorRT* GetInstance() {
static PluginFactoryTensorRT* factory_instance =
new PluginFactoryTensorRT();
return factory_instance;
}
// Deserialization method
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
size_t serial_length) override;
// Plugin construction, PluginFactoryTensorRT owns the plugin.
PluginTensorRT* CreatePlugin(const std::string& op_name);
bool RegisterPlugin(const std::string& op_name,
PluginDeserializeFunc deserialize_func,
PluginConstructFunc construct_func);
bool IsPlugin(const std::string& op_name) {
return plugin_registry_.find(op_name) != plugin_registry_.end();
}
size_t CountOwnedPlugins() { return owned_plugins_.size(); }
void DestroyPlugins();
protected:
std::unordered_map<std::string,
std::pair<PluginDeserializeFunc, PluginConstructFunc>>
plugin_registry_;
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
};
class TrtPluginRegistrar {
public:
TrtPluginRegistrar(const std::string& name,
PluginDeserializeFunc deserialize_func,
PluginConstructFunc construct_func) {
auto factory = PluginFactoryTensorRT::GetInstance();
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
// construct_func), "Falied to register plugin [%s]", name);
// platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func,
// construct_func));
factory->RegisterPlugin(name, deserialize_func, construct_func);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \
REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
construct_func)
#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
construct_func) \
REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
static ::paddle::inference::tensorrt::TrtPluginRegistrar \
trt_plugin_registrar##ctr __attribute__((unused)) = \
::paddle::inference::tensorrt::TrtPluginRegistrar( \
name, deserialize_func, construct_func)
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
typedef std::function<PluginTensorRT*(const void*, size_t)>
PluginDeserializeFunc;
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
std::string ExtractOpName(const void* serial_data, size_t serial_length,
size_t* incremental);
} // namespace tensorrt
} // namespace inference
} // namespze paddle
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <stdio.h>
#include <cassert> #include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
...@@ -19,8 +20,6 @@ namespace paddle { ...@@ -19,8 +20,6 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
SplitPlugin* CreateSplitPlugin() { return new SplitPlugin(); };
nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
const nvinfer1::Dims* inputDims, const nvinfer1::Dims* inputDims,
int nbInputs) { int nbInputs) {
...@@ -28,15 +27,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, ...@@ -28,15 +27,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
assert(index < this->getNbOutputs()); assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = inputDims[0]; nvinfer1::Dims const& input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims; nvinfer1::Dims output_dims = input_dims;
output_dims.d[axis_] = output_lenght_.at(index); output_dims.d[axis_] = output_length_.at(index);
return output_dims; return output_dims;
} }
int SplitPlugin::initialize() { int SplitPlugin::initialize() {
std::vector<int> segment_offsets(1, 0); std::vector<int> segment_offsets(1, 0);
for (int i = 0; i < this->getNbOutputs(); ++i) { for (int i = 0; i < this->getNbOutputs(); ++i) {
segment_offsets.push_back(segment_offsets.back() + output_lenght_[i]); segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
} }
segment_offsets_ = segment_offsets;
d_segment_offsets_ = segment_offsets; d_segment_offsets_ = segment_offsets;
nvinfer1::Dims dims = this->getInputDims(0); nvinfer1::Dims dims = this->getInputDims(0);
nx_ = 1; nx_ = 1;
...@@ -51,60 +51,30 @@ int SplitPlugin::initialize() { ...@@ -51,60 +51,30 @@ int SplitPlugin::initialize() {
return 0; return 0;
} }
template <typename T>
__device__ int upper_bound(T const* vals, int n, T const& key) {
int i = 0;
while (n > 0) {
int m = n / 2;
int j = i + m;
if (!(key < vals[j])) {
i = j + 1;
n -= m + 1;
} else {
n = m;
}
}
return i;
}
template <typename T>
__global__ void split_kernel(int nsegment,
int const* __restrict__ segment_offsets,
T const* __restrict__ idata, T* const* odatas,
int nx, int srcny_, int nz) {
int x0 = threadIdx.x + blockIdx.x * blockDim.x;
int src_y0 = threadIdx.y + blockIdx.y * blockDim.y;
int z0 = threadIdx.z + blockIdx.z * blockDim.z;
for (int z = z0; z < nz; z += blockDim.z * gridDim.z) {
for (int src_y = src_y0; src_y < srcny_; src_y += blockDim.y * gridDim.y) {
for (int x = x0; x < nx; x += blockDim.x * gridDim.x) {
int segment = upper_bound(segment_offsets, nsegment, src_y) - 1;
int dst_y = src_y - segment_offsets[segment];
int dstny_ = segment_offsets[segment + 1] - segment_offsets[segment];
odatas[segment][x + nx * (dst_y + dstny_ * z)] =
idata[x + nx * (src_y + srcny_ * z)];
}
}
}
}
int SplitPlugin::enqueue(int batchSize, const void* const* inputs, int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) { void** outputs, void* workspace, cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0); auto const& input_dims = this->getInputDims(0);
int input_size = 0;
int const* d_segment_offsets_ptr = int const* d_segment_offsets_ptr =
thrust::raw_pointer_cast(&d_segment_offsets_[0]); thrust::raw_pointer_cast(&d_segment_offsets_[0]);
float const* idata = reinterpret_cast<float const*>(inputs[0]); float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs); float** odatas = reinterpret_cast<float**>(outputs);
int nz = nz_ * batchSize; // kernel impl here.
dim3 block(32, 16); int inputBatchOffset = nx_ * ny_ * nz_;
dim3 grid(std::min((nx_ - 1) / block.x + 1, 65535u), for (size_t i = 0; i < this->getNbOutputs(); i++) {
std::min((ny_ - 1) / block.y + 1, 65535u), for (size_t j = 0; j < batchSize; j++) {
std::min((nz_ - 1) / block.z + 1, 65535u)); cudaMemcpyAsync(
odatas[i] +
split_kernel<<<grid, block, 0, stream>>>(d_segment_offsets_.size(), j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ *
d_segment_offsets_ptr, idata, odatas, sizeof(float),
nx_, ny_, nz); inputs[0] +
(inputBatchOffset * j + segment_offsets_[i] * nx_) *
sizeof(float),
(segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
}
}
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -10,53 +23,55 @@ namespace tensorrt { ...@@ -10,53 +23,55 @@ namespace tensorrt {
class SplitPlugin : public PluginTensorRT { class SplitPlugin : public PluginTensorRT {
int axis_; int axis_;
std::vector<int> output_lenght_; std::vector<int> output_length_;
int nx_, ny_, nz_; int nx_, ny_, nz_;
thrust::device_vector<int> d_segment_offsets_; thrust::device_vector<int> d_segment_offsets_;
std::vector<int> segment_offsets_;
protected: protected:
virtual size_t getSerializationSize() override { virtual size_t getSerializationSize() override {
return serialized_size(axis_) + serialized_size(output_lenght_) return serialized_size(axis_) + serialized_size(output_length_) +
+ getBaseSerializationSize(); getBaseSerializationSize();
} }
virtual void serialize(void *buffer) override { virtual void serialize(void *buffer) override {
serializeBase(buffer); serializeBase(buffer);
serialize_value(&buffer, axis_); serialize_value(&buffer, axis_);
serialize_value(&buffer, output_lenght_); serialize_value(&buffer, output_length_);
} }
public: public:
Split() {} SplitPlugin(int axis, std::vector<int> const &output_lengths)
SplitPlugin(void const* serialData, size_t serialLength) { : axis_(axis), output_length_(output_lengths) {
assert(axis <= nvinfer1::Dims::MAX_DIMS);
}
SplitPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength); deserializeBase(serialData, serialLength);
deserialize_value(&serialData, &serialLength, &axis_); deserialize_value(&serialData, &serialLength, &axis_);
deserialize_value(&serialData, &serialLength, &output_lenght_); deserialize_value(&serialData, &serialLength, &output_length_);
} }
SplitPlugin* clone() const override { SplitPlugin *clone() const override {
return new SplitPlugin(axis_, output_lenght_); return new SplitPlugin(axis_, output_length_);
} }
virtual const char* getPluginType() const override { return "split"; } virtual const char *getPluginType() const override { return "split"; }
virtual int getNbOutputs() const override { return output_lenght_.size(); } virtual int getNbOutputs() const override { return output_length_.size(); }
virtual nvinfer1::Dims getOutputDimensions(int index, virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *inputs, int nbInputDims) override; const nvinfer1::Dims *inputs,
int nbInputDims) override;
virtual int initialize() override; virtual int initialize() override;
virtual int enqueue(int batchSize, virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override; void *workspace, cudaStream_t stream) override;
void setAxis(int axis) { void setAxis(int axis) { axis_ = axis; }
axis_ = axis;
}
void setOutputLengths(const std::vector<int> & output_lengths) { void setOutputLengths(const std::vector<int> &output_lengths) {
output_length_ = output_lengths; output_length_ = output_lengths;
} }
}; };
} // tensorrt } // tensorrt
} // inference } // inference
} // paddle } // paddle
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -41,8 +40,7 @@ size_t PluginTensorRT::getBaseSerializationSize() { ...@@ -41,8 +40,7 @@ size_t PluginTensorRT::getBaseSerializationSize() {
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const { nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT || return ((type == nvinfer1::DataType::kFLOAT) &&
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kNCHW));
} }
......
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
#pragma once #pragma once
#include <NvInfer.h>
#include <cassert> #include <cassert>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.hpp" #include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -53,8 +53,8 @@ class PluginTensorRT : public nvinfer1::IPluginExt { ...@@ -53,8 +53,8 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1::DataType type, nvinfer1::DataType type,
nvinfer1::PluginFormat format, nvinfer1::PluginFormat format,
int maxBatchSize) override; int maxBatchSize) override;
virtual void serialize(void* buffer) override; virtual void serialize(void* buffer) = 0;
virtual size_t getSerializationSize() override; virtual size_t getSerializationSize() = 0;
protected: protected:
void deserializeBase(void const*& serialData, size_t& serialLength); void deserializeBase(void const*& serialData, size_t& serialLength);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册