未验证 提交 8bc1c5d2 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement the Tensorrt plugin for elementwise op (#14487)

* Initialize the elementwise plugin.

* Implement the basic CUDA kernel of elementwise plugin.
test=develop
上级 7aa3aff3
...@@ -114,7 +114,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -114,7 +114,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// it is either an OP's input or an OP's output. // it is either an OP's input or an OP's output.
auto &subgraph_nodes = *Agent(node).subgraph(); auto &subgraph_nodes = *Agent(node).subgraph();
for (size_t index = 0; index < block_desc.OpSize(); index++) { for (size_t index = 0; index < block_desc.OpSize(); ++index) {
framework::proto::OpDesc *op = block_desc.Op(index)->Proto(); framework::proto::OpDesc *op = block_desc.Op(index)->Proto();
auto correspond_node = subgraph_nodes[index]; auto correspond_node = subgraph_nodes[index];
PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type()); PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type());
......
...@@ -45,7 +45,8 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { ...@@ -45,7 +45,8 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"}); "elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
"conv2d_transpose"});
if (!node->IsOp()) return false; if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) { if (teller_set.count(node->Op()->Type())) {
......
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc pad_op.cc split_op.cc prelu_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter) ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter)
...@@ -20,7 +20,8 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc ...@@ -20,7 +20,8 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine elementwise_add_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
elementwise_add_op elementwise_mul_op SERIAL)
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine softmax_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
...@@ -33,7 +34,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc ...@@ -33,7 +34,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pad_op SERIAL) DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL) split_op concat_op SERIAL)
nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
prelu_op SERIAL) prelu_op SERIAL)
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -13,11 +13,25 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
static bool CheckDims(const nvinfer1::Dims& dims_x,
const nvinfer1::Dims& dims_y) {
if (dims_x.nbDims != dims_y.nbDims) {
return false;
}
for (int i = 0; i < dims_x.nbDims; i++) {
if (dims_x.d[i] != dims_y.d[i]) {
return false;
}
}
return true;
}
class ElementwiseWeightOpConverter : public OpConverter { class ElementwiseWeightOpConverter : public OpConverter {
public: public:
ElementwiseWeightOpConverter() {} ElementwiseWeightOpConverter() {}
...@@ -26,7 +40,7 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -26,7 +40,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert a fluid elementwise op to tensorrt IScaleLayer"; VLOG(3) << "Convert a fluid elementwise op to TensorRT IScaleLayer";
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
...@@ -106,10 +120,12 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -106,10 +120,12 @@ class ElementwiseTensorOpConverter : public OpConverter {
ElementwiseTensorOpConverter() {} ElementwiseTensorOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
auto op_pair = ops.find(op_type_);
PADDLE_ENFORCE(op_pair != ops.end(), "Wrong elementwise op type!");
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
VLOG(3) << "convert a fluid elementwise op to tensorrt IScaleLayer";
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
...@@ -120,29 +136,35 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -120,29 +136,35 @@ class ElementwiseTensorOpConverter : public OpConverter {
nvinfer1::Dims dims_x = X->getDimensions(); nvinfer1::Dims dims_x = X->getDimensions();
nvinfer1::Dims dims_y = Y->getDimensions(); nvinfer1::Dims dims_y = Y->getDimensions();
// The two input tensor should have the same dims int axis = boost::get<int>(op_desc.GetAttr("axis"));
PADDLE_ENFORCE(dims_x.nbDims >= 3); auto output_name = op_desc.Output("Out")[0];
if (dims_x.nbDims == dims_y.nbDims) { if (CheckDims(dims_x, dims_y)) {
for (int i = 0; i < dims_x.nbDims; i++) { // The two input tensor should have the same dims
if (dims_x.d[i] != dims_y.d[i]) VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
}
} else {
PADDLE_THROW("TensorRT unsupported tensor shape for Elementwise op!");
}
auto op_pair = ops.find(op_type_); nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
if (op_pair == ops.end()) { engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
PADDLE_THROW("Wrong elementwise op type!"); *const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
}
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
auto output_name = op_desc.Output("Out")[0]; layer->setName(("elementwise (Output: " + output_name + ")").c_str());
layer->setName(("elementwise (Output: " + output_name + ")").c_str()); layer->getOutput(0)->setName(output_name.c_str());
layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0));
engine_->SetITensor(output_name, layer->getOutput(0)); } else {
VLOG(3) << "Convert a fluid elementwise op to TensorRT "
"ElementWisePluginLayer";
plugin::ElementWisePlugin* plugin =
new plugin::ElementWisePlugin(op_pair->second, dims_x, dims_y, axis);
plugin->AddInput(X);
plugin->AddInput(Y);
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(
const_cast<nvinfer1::ITensor* const*>(plugin->GetInputs().data()), 2,
reinterpret_cast<plugin::PluginTensorRT*>(plugin));
layer->setName(("elementwise (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
}
if (test_mode) { // the test framework can not determine which is the if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside. // output, so place the declaration inside.
engine_->DeclareOutput(output_name); engine_->DeclareOutput(output_name);
......
...@@ -61,7 +61,7 @@ class OpConverter { ...@@ -61,7 +61,7 @@ class OpConverter {
// TODO(xingzhaolong): all mul, sub, div // TODO(xingzhaolong): all mul, sub, div
// static std::unordered_set<std::string> add_weight_op_set {"add", "mul", // static std::unordered_set<std::string> add_weight_op_set {"add", "mul",
// "sub", "div"}; // "sub", "div"};
static std::unordered_set<std::string> add_weight_op_set{"add"}; static std::unordered_set<std::string> add_weight_op_set{"add", "mul"};
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
int op_type_len = op_desc.Type().size(); int op_type_len = op_desc.Type().size();
std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len); std::string op_type = op_desc.Type().substr(op_type_len - 3, op_type_len);
......
...@@ -54,7 +54,7 @@ class PReluOpConverter : public OpConverter { ...@@ -54,7 +54,7 @@ class PReluOpConverter : public OpConverter {
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data), static_cast<void*>(alpha_data),
alpha_tensor_device->numel()); alpha_tensor_device->numel());
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode); plugin::PReluPlugin* plugin = new plugin::PReluPlugin(alpha_rt, mode);
nvinfer1::IPluginLayer* layer = nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin); engine_->AddPlugin(&input, input_num, plugin);
// keep alpha tensor to avoid release it's memory // keep alpha tensor to avoid release it's memory
......
...@@ -50,7 +50,7 @@ class SplitOpConverter : public OpConverter { ...@@ -50,7 +50,7 @@ class SplitOpConverter : public OpConverter {
PADDLE_ENFORCE(output_lengths.size() == output_num); PADDLE_ENFORCE(output_lengths.size() == output_num);
// //
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths); plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer = nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin); engine_->AddPlugin(&input, input_num, plugin);
......
...@@ -20,13 +20,12 @@ namespace paddle { ...@@ -20,13 +20,12 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(elementwise_op, add_weight_test) { TEST(elementwise_op, add_weight) {
std::unordered_set<std::string> parameters({"elementwise_add-Y"}); std::unordered_set<std::string> parameters({"elementwise_add-Y"});
framework::Scope scope; framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1 << 15); TRTConvertValidation validator(10, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3)); validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1)); validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3)); validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
// Prepare Op description // Prepare Op description
...@@ -44,30 +43,65 @@ TEST(elementwise_op, add_weight_test) { ...@@ -44,30 +43,65 @@ TEST(elementwise_op, add_weight_test) {
validator.Execute(8); validator.Execute(8);
} }
TEST(elementwise_op, add_tensor_test) { TEST(elementwise_op, native) {
std::unordered_set<std::string> parameters; for (std::string type : {"add", "mul"}) {
framework::Scope scope; int batch_size = 8;
TRTConvertValidation validator(8, parameters, scope, 1 << 15); std::unordered_set<std::string> parameters;
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3)); framework::Scope scope;
validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3)); TRTConvertValidation validator(batch_size, parameters, scope, 1 << 15);
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); validator.DeclInputVar("elementwise_" + type + "-X",
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3)); nvinfer1::DimsCHW(10, 3, 3));
validator.DeclInputVar("elementwise_" + type + "-Y",
// Prepare Op description nvinfer1::Dims3(10, 3, 3));
framework::OpDesc desc; validator.DeclOutputVar("elementwise_" + type + "-Out",
desc.SetType("elementwise_add"); nvinfer1::DimsCHW(10, 3, 3));
desc.SetInput("X", {"elementwise_add-X"});
desc.SetInput("Y", {"elementwise_add-Y"}); // Prepare Op description
desc.SetOutput("Out", {"elementwise_add-Out"}); framework::OpDesc desc;
desc.SetType("elementwise_" + type);
// the defalut axis of elementwise op is -1 desc.SetInput("X", {"elementwise_" + type + "-X"});
desc.SetInput("Y", {"elementwise_" + type + "-Y"});
validator.SetOp(*desc.Proto()); desc.SetOutput("Out", {"elementwise_" + type + "-Out"});
int axis = -1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(batch_size);
}
}
validator.Execute(8); TEST(elementwise_op, plugin) {
for (std::string type : {"add", "mul"}) {
int batch_size = 8;
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(batch_size, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_" + type + "-X",
nvinfer1::DimsCHW(10, 3, 3));
validator.DeclInputVar("elementwise_" + type + "-Y",
nvinfer1::Dims3(10, 1, 1));
validator.DeclOutputVar("elementwise_" + type + "-Out",
nvinfer1::DimsCHW(10, 3, 3));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("elementwise_" + type);
desc.SetInput("X", {"elementwise_" + type + "-X"});
desc.SetInput("Y", {"elementwise_" + type + "-Y"});
desc.SetOutput("Out", {"elementwise_" + type + "-Out"});
int axis = -1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(batch_size);
}
} }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(elementwise_add); USE_OP(elementwise_add);
USE_OP(elementwise_mul);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
......
...@@ -257,9 +257,10 @@ void TensorRTEngine::freshDeviceId() { ...@@ -257,9 +257,10 @@ void TensorRTEngine::freshDeviceId() {
} }
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) { nvinfer1::ITensor *const *inputs, int num_inputs,
plugin::PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin); owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin); return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin);
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -128,7 +128,7 @@ class TensorRTEngine : public EngineBase { ...@@ -128,7 +128,7 @@ class TensorRTEngine : public EngineBase {
int GetRuntimeBatch(); int GetRuntimeBatch();
int GetDevice() { return device_; } int GetDevice() { return device_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int nbInputs, PluginTensorRT*); int num_inputs, plugin::PluginTensorRT*);
// A pointer to CPU memory is needed of the TRT weight. // A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage. // Before TRT runs, fluid loads weight into GPU storage.
...@@ -171,7 +171,7 @@ class TensorRTEngine : public EngineBase { ...@@ -171,7 +171,7 @@ class TensorRTEngine : public EngineBase {
// The specific GPU id that the TensorRTEngine bounded to. // The specific GPU id that the TensorRTEngine bounded to.
int device_; int device_;
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugin_; std::vector<std::unique_ptr<plugin::PluginTensorRT>> owned_plugin_;
// TensorRT related internal members // TensorRT related internal members
template <typename T> template <typename T>
......
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce device_context) nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
DEPS enforce device_context)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
namespace details {
template <typename T>
struct Add {
__device__ T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T>
struct Mul {
__device__ T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T, typename Operator>
__global__ void ColumnWiseKernel(Operator op, const T* x, const T* y, T* out,
int batch_size, int num_rows, int num_cols) {
for (int batch_id = 0; batch_id < batch_size; ++batch_id) {
int row = blockIdx.x;
for (; row < num_rows; row += gridDim.x) {
T value_y = y[batch_id * num_rows + row];
int col = threadIdx.x;
int offset = (batch_id * num_rows + row) * num_cols;
for (; col < num_cols; col += blockDim.x) {
T value_x = x[offset + col];
out[offset + col] = op(value_x, value_y);
}
}
}
}
template <typename T, typename Operator>
static void ElementWise(Operator op, const T* x, const T* y, T* out,
int batch_size, int prev, int midd, int post,
cudaStream_t stream) {
const int kThreadsPerBlock = 1024;
const int kMaximumBlocks = 65535;
if (prev == 1) {
int num_threads = (post > kThreadsPerBlock) ? kThreadsPerBlock
: (((post + 31) >> 5) << 5);
int num_blocks = (midd < kMaximumBlocks) ? midd : kMaximumBlocks;
ColumnWiseKernel<<<num_blocks, num_threads, 0, stream>>>(
op, x, y, out, batch_size, midd, post);
} else if (post == 1) {
PADDLE_THROW("Not implemented.");
} else {
PADDLE_THROW("Not implemented.");
}
}
} // namespace details
nvinfer1::Dims ElementWisePlugin::getOutputDimensions(
int index, const nvinfer1::Dims* input_dims, int num_inputs) {
PADDLE_ENFORCE_EQ(index, 0);
PADDLE_ENFORCE_EQ(num_inputs, 2);
PADDLE_ENFORCE_NOT_NULL(input_dims);
return input_dims[0];
}
int ElementWisePlugin::initialize() {
PADDLE_ENFORCE_GT(dims_y_.nbDims, 0);
axis_ = (axis_ == -1) ? dims_x_.nbDims - dims_y_.nbDims : axis_;
int trimed_nb_dims = dims_y_.nbDims;
for (; trimed_nb_dims > 0; --trimed_nb_dims) {
if (dims_y_.d[trimed_nb_dims - 1] != 1) {
break;
}
}
dims_y_.nbDims = trimed_nb_dims;
PADDLE_ENFORCE_GE(dims_x_.nbDims, dims_y_.nbDims + axis_);
PADDLE_ENFORCE_LT(axis_, dims_x_.nbDims);
prev_size_ = 1;
midd_size_ = 1;
post_size_ = 1;
for (int i = 0; i < axis_; ++i) {
prev_size_ *= dims_x_.d[i];
}
for (int i = 0; i < dims_y_.nbDims; ++i) {
PADDLE_ENFORCE_EQ(dims_x_.d[i + axis_], dims_y_.d[i],
"Broadcast dimension mismatch.");
midd_size_ *= dims_y_.d[i];
}
for (int i = axis_ + dims_y_.nbDims; i < dims_x_.nbDims; ++i) {
post_size_ *= dims_x_.d[i];
}
return 0;
}
int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs,
void** outputs, void* workspace,
cudaStream_t stream) {
const float* x = reinterpret_cast<const float*>(inputs[0]);
const float* y = reinterpret_cast<const float*>(inputs[1]);
float* out = reinterpret_cast<float*>(outputs[0]);
if (type_ == nvinfer1::ElementWiseOperation::kSUM) {
details::ElementWise(details::Add<float>(), x, y, out, batch_size,
prev_size_, midd_size_, post_size_, stream);
} else if (type_ == nvinfer1::ElementWiseOperation::kPROD) {
details::ElementWise(details::Mul<float>(), x, y, out, batch_size,
prev_size_, midd_size_, post_size_, stream);
} else {
PADDLE_THROW("Not implemented.");
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class ElementWisePlugin : public PluginTensorRT {
public:
ElementWisePlugin(nvinfer1::ElementWiseOperation type,
nvinfer1::Dims const &dims_x, nvinfer1::Dims const &dims_y,
int axis)
: type_(type),
dims_x_(dims_x),
dims_y_(dims_y),
axis_(axis),
prev_size_(1),
midd_size_(1),
post_size_(1) {}
ElementWisePlugin(void const *serial_data, size_t serial_length) {
deserializeBase(serial_data, serial_length);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &dims_x_);
DeserializeValue(&serial_data, &serial_length, &dims_y_);
}
ElementWisePlugin *clone() const override {
// return new ElementWisePlugin(dims_x_, dims_y_, axis_);
return nullptr;
}
const char *getPluginType() const override { return "elementwise"; }
nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims,
int num_inputs) override;
int initialize() override;
// execute the layer
int enqueue(int batch_size, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream);
protected:
size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(dims_x_) +
SerializedSize(dims_y_) + getBaseSerializationSize();
}
void serialize(void *buffer) override {
serializeBase(buffer);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, dims_x_);
SerializeValue(&buffer, dims_y_);
}
nvinfer1::ElementWiseOperation type_;
nvinfer1::Dims dims_x_;
nvinfer1::Dims dims_y_;
int axis_;
int prev_size_;
int midd_size_;
int post_size_;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin {
static const int CUDA_NUM_THREADS = 1024; static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535; static const int CUDA_MAX_NUM_BLOCKS = 65535;
...@@ -126,6 +127,7 @@ int PReluPlugin::enqueue(int batchSize, const void *const *inputs, ...@@ -126,6 +127,7 @@ int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
} // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin {
class PReluPlugin : public PluginTensorRT { class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_; TensorRTEngine::Weight alpha_;
...@@ -63,6 +64,7 @@ class PReluPlugin : public PluginTensorRT { ...@@ -63,6 +64,7 @@ class PReluPlugin : public PluginTensorRT {
void *workspace, cudaStream_t stream) override; void *workspace, cudaStream_t stream) override;
}; };
} // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -14,10 +14,15 @@ ...@@ -14,10 +14,15 @@
#pragma once #pragma once
#include <cassert>
#include <cstring> #include <cstring>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
template <typename T> template <typename T>
inline void SerializeValue(void** buffer, T const& value); inline void SerializeValue(void** buffer, T const& value);
...@@ -26,7 +31,7 @@ template <typename T> ...@@ -26,7 +31,7 @@ template <typename T>
inline void DeserializeValue(void const** buffer, size_t* buffer_size, inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value); T* value);
namespace { namespace details {
template <typename T, class Enable = void> template <typename T, class Enable = void>
struct Serializer {}; struct Serializer {};
...@@ -36,10 +41,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value || ...@@ -36,10 +41,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value || std::is_enum<T>::value ||
std::is_pod<T>::value>::type> { std::is_pod<T>::value>::type> {
static size_t SerializedSize(T const& value) { return sizeof(T); } static size_t SerializedSize(T const& value) { return sizeof(T); }
static void Serialize(void** buffer, T const& value) { static void Serialize(void** buffer, T const& value) {
std::memcpy(*buffer, &value, sizeof(T)); std::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T); reinterpret_cast<char*&>(*buffer) += sizeof(T);
} }
static void Deserialize(void const** buffer, size_t* buffer_size, T* value) { static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T)); assert(*buffer_size >= sizeof(T));
std::memcpy(value, *buffer, sizeof(T)); std::memcpy(value, *buffer, sizeof(T));
...@@ -51,10 +58,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value || ...@@ -51,10 +58,12 @@ struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
template <> template <>
struct Serializer<const char*> { struct Serializer<const char*> {
static size_t SerializedSize(const char* value) { return strlen(value) + 1; } static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
static void Serialize(void** buffer, const char* value) { static void Serialize(void** buffer, const char* value) {
std::strcpy(static_cast<char*>(*buffer), value); std::strcpy(static_cast<char*>(*buffer), value); // NOLINT
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1; reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
} }
static void Deserialize(void const** buffer, size_t* buffer_size, static void Deserialize(void const** buffer, size_t* buffer_size,
const char** value) { const char** value) {
*value = static_cast<char const*>(*buffer); *value = static_cast<char const*>(*buffer);
...@@ -73,39 +82,46 @@ struct Serializer<std::vector<T>, ...@@ -73,39 +82,46 @@ struct Serializer<std::vector<T>,
static size_t SerializedSize(std::vector<T> const& value) { static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T); return sizeof(value.size()) + value.size() * sizeof(T);
} }
static void Serialize(void** buffer, std::vector<T> const& value) { static void Serialize(void** buffer, std::vector<T> const& value) {
SerializeValue(buffer, value.size()); SerializeValue(buffer, value.size());
size_t nbyte = value.size() * sizeof(T); size_t nbyte = value.size() * sizeof(T);
std::memcpy(*buffer, value.data(), nbyte); std::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte; reinterpret_cast<char*&>(*buffer) += nbyte;
} }
static void Deserialize(void const** buffer, size_t* buffer_size, static void Deserialize(void const** buffer, size_t* buffer_size,
std::vector<T>* value) { std::vector<T>* value) {
size_t size; size_t size;
DeserializeValue(buffer, buffer_size, &size); DeserializeValue(buffer, buffer_size, &size);
value->resize(size); value->resize(size);
size_t nbyte = value->size() * sizeof(T); size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte); PADDLE_ENFORCE_GE(*buffer_size, nbyte);
std::memcpy(value->data(), *buffer, nbyte); std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte; reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte; *buffer_size -= nbyte;
} }
}; };
} // namespace } // namespace details
template <typename T> template <typename T>
inline size_t SerializedSize(T const& value) { inline size_t SerializedSize(T const& value) {
return Serializer<T>::SerializedSize(value); return details::Serializer<T>::SerializedSize(value);
} }
template <typename T> template <typename T>
inline void SerializeValue(void** buffer, T const& value) { inline void SerializeValue(void** buffer, T const& value) {
return Serializer<T>::Serialize(buffer, value); return details::Serializer<T>::Serialize(buffer, value);
} }
template <typename T> template <typename T>
inline void DeserializeValue(void const** buffer, size_t* buffer_size, inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value) { T* value) {
return Serializer<T>::Deserialize(buffer, buffer_size, value); return details::Serializer<T>::Deserialize(buffer, buffer_size, value);
} }
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -12,26 +12,26 @@ ...@@ -12,26 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <stdio.h>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin {
nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, nvinfer1::Dims SplitPlugin::getOutputDimensions(
const nvinfer1::Dims* inputDims, int index, const nvinfer1::Dims* input_dims, int num_inputs) {
int nbInputs) { PADDLE_ENFORCE_EQ(num_inputs, 1);
assert(nbInputs == 1); PADDLE_ENFORCE_LT(index, this->getNbOutputs());
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = inputDims[0]; nvinfer1::Dims output_dims = input_dims[0];
nvinfer1::Dims output_dims = input_dims;
output_dims.d[axis_] = output_length_.at(index); output_dims.d[axis_] = output_length_.at(index);
return output_dims; return output_dims;
} }
int SplitPlugin::initialize() { int SplitPlugin::initialize() {
PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS);
std::vector<int> segment_offsets(1, 0); std::vector<int> segment_offsets(1, 0);
for (int i = 0; i < this->getNbOutputs(); ++i) { for (int i = 0; i < this->getNbOutputs(); ++i) {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]); segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
...@@ -76,6 +76,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, ...@@ -76,6 +76,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
} // tensorrt } // namespace plugin
} // inference } // namespace tensorrt
} // paddle } // namespace inference
} // namespace paddle
...@@ -14,61 +14,58 @@ ...@@ -14,61 +14,58 @@
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin {
class SplitPlugin : public PluginTensorRT { class SplitPlugin : public PluginTensorRT {
int axis_; public:
std::vector<int> output_length_; SplitPlugin(int axis, std::vector<int> const &output_lengths)
int nx_, ny_, nz_; : axis_(axis), output_length_(output_lengths) {}
std::vector<int> segment_offsets_;
SplitPlugin(void const *serial_data, size_t serial_length) {
deserializeBase(serial_data, serial_length);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &output_length_);
}
SplitPlugin *clone() const override {
return new SplitPlugin(axis_, output_length_);
}
const char *getPluginType() const override { return "split"; }
int getNbOutputs() const override { return output_length_.size(); }
nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims,
int num_inputs) override;
int initialize() override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
protected: protected:
virtual size_t getSerializationSize() override { size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(output_length_) + return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize(); getBaseSerializationSize();
} }
// TRT will call this func when we need to serialize the configuration of void serialize(void *buffer) override {
// tensorrt.
// It should not be called by users.
virtual void serialize(void *buffer) override {
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, axis_); SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_); SerializeValue(&buffer, output_length_);
} }
public: int axis_;
SplitPlugin(int axis, std::vector<int> const &output_lengths) std::vector<int> output_length_;
: axis_(axis), output_length_(output_lengths) { int nx_, ny_, nz_;
assert(axis <= nvinfer1::Dims::MAX_DIMS); std::vector<int> segment_offsets_;
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &axis_);
DeserializeValue(&serialData, &serialLength, &output_length_);
}
SplitPlugin *clone() const override {
return new SplitPlugin(axis_, output_length_);
}
virtual const char *getPluginType() const override { return "split"; }
virtual int getNbOutputs() const override { return output_length_.size(); }
virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *inputs,
int nbInputDims) override;
virtual int initialize() override;
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
}; };
} // tensorrt } // namespace plugin
} // inference } // namespace tensorrt
} // paddle } // namespace inference
} // namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin {
void PluginTensorRT::serializeBase(void*& buffer) { void PluginTensorRT::serializeBase(void*& buffer) {
SerializeValue(&buffer, input_dims_); SerializeValue(&buffer, input_dims_);
...@@ -25,12 +26,12 @@ void PluginTensorRT::serializeBase(void*& buffer) { ...@@ -25,12 +26,12 @@ void PluginTensorRT::serializeBase(void*& buffer) {
SerializeValue(&buffer, data_format_); SerializeValue(&buffer, data_format_);
} }
void PluginTensorRT::deserializeBase(void const*& serialData, void PluginTensorRT::deserializeBase(void const*& serial_data,
size_t& serialLength) { size_t& serial_length) {
DeserializeValue(&serialData, &serialLength, &input_dims_); DeserializeValue(&serial_data, &serial_length, &input_dims_);
DeserializeValue(&serialData, &serialLength, &max_batch_size_); DeserializeValue(&serial_data, &serial_length, &max_batch_size_);
DeserializeValue(&serialData, &serialLength, &data_type_); DeserializeValue(&serial_data, &serial_length, &data_type_);
DeserializeValue(&serialData, &serialLength, &data_format_); DeserializeValue(&serial_data, &serial_length, &data_format_);
} }
size_t PluginTensorRT::getBaseSerializationSize() { size_t PluginTensorRT::getBaseSerializationSize() {
...@@ -44,18 +45,17 @@ bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, ...@@ -44,18 +45,17 @@ bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
(format == nvinfer1::PluginFormat::kNCHW)); (format == nvinfer1::PluginFormat::kNCHW));
} }
void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims, void PluginTensorRT::configureWithFormat(
int nbInputs, const nvinfer1::Dims* input_dims, int num_inputs,
const nvinfer1::Dims* outputDims, const nvinfer1::Dims* output_dims, int num_outputs, nvinfer1::DataType type,
int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int max_batch_size) {
nvinfer1::PluginFormat format,
int maxBatchSize) {
data_type_ = type; data_type_ = type;
data_format_ = format; data_format_ = format;
input_dims_.assign(inputDims, inputDims + nbInputs); input_dims_.assign(input_dims, input_dims + num_inputs);
max_batch_size_ = maxBatchSize; max_batch_size_ = max_batch_size;
} }
} // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -14,23 +14,30 @@ ...@@ -14,23 +14,30 @@
#pragma once #pragma once
#include <cassert> #include <NvInfer.h>
#include <cstring> #include <cstring>
#include <iostream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h" #include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(profile);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin {
class PluginTensorRT : public nvinfer1::IPluginExt { class PluginTensorRT : public nvinfer1::IPluginExt {
public: public:
PluginTensorRT() {} PluginTensorRT() {}
// It was used for TensorRT deserialization.
// It should not be called by users.
PluginTensorRT(const void* serialized_data, size_t length) {} PluginTensorRT(const void* serialized_data, size_t length) {}
virtual ~PluginTensorRT() {}
nvinfer1::Dims const& getInputDims(int index) const { nvinfer1::Dims const& getInputDims(int index) const {
return input_dims_.at(index); return input_dims_.at(index);
} }
...@@ -38,43 +45,66 @@ class PluginTensorRT : public nvinfer1::IPluginExt { ...@@ -38,43 +45,66 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
nvinfer1::DataType getDataType() const { return data_type_; } nvinfer1::DataType getDataType() const { return data_type_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; } nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
virtual const char* getPluginVersion() const { return "1"; } virtual const char* getPluginVersion() const { return "1"; }
void AddInput(nvinfer1::ITensor* input) { inputs_.push_back(input); }
std::vector<nvinfer1::ITensor*>& GetInputs() { return inputs_; }
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
// Following functions are inherit from nvinfer1::IPluginExt
// Get the number of outputs from the layer
int getNbOutputs() const { return 1; }
// Get the dimension of an output tensor
virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims* input_dims,
int num_inputs) = 0;
// Find the workspace size required by the layer
size_t getWorkspaceSize(int) const override { return 0; } size_t getWorkspaceSize(int) const override { return 0; }
// Initialize the layer for execution.
// This is called when the engine is created.
int initialize() override { return 0; }
// Shutdown the layer. This is called when the engine is destroyed
void terminate() override {} void terminate() override {}
virtual ~PluginTensorRT() {} // Execute the layer
virtual int enqueue(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
// Find the size of the serialization buffer required
virtual size_t getSerializationSize() = 0;
// Serialize the layer config to buffer.
// TensorRT will call this func to serialize the configuration of TensorRT
// engine. It should not be called by users.
virtual void serialize(void* buffer) = 0;
// Check format support. The default is FLOAT32 and NCHW. // Check format support. The default is FLOAT32 and NCHW.
bool supportsFormat(nvinfer1::DataType type, bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override; nvinfer1::PluginFormat format) const override;
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, // Configure the layer
const nvinfer1::Dims* outputDims, int nbOutputs, void configureWithFormat(const nvinfer1::Dims* input_dims, int num_inputs,
const nvinfer1::Dims* output_dims, int num_outputs,
nvinfer1::DataType type, nvinfer1::DataType type,
nvinfer1::PluginFormat format, nvinfer1::PluginFormat format,
int maxBatchSize) override; int max_batch_size) override;
// *NOTE* The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
// Initialize the layer for execution. This is called when the engine is
// created.
int initialize() override { return 0; }
// Serialize the layer config to buffer.
virtual void serialize(void* buffer) = 0;
virtual size_t getSerializationSize() = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
protected: protected:
// Deserialize input_dims, max_batch_size, data_type, data_format // Deserialize input_dims, max_batch_size, data_type, data_format
void deserializeBase(void const*& serialData, size_t& serialLength); void deserializeBase(void const*& serial_data, // NOLINT
size_t& serial_length); // NOLINT
size_t getBaseSerializationSize(); size_t getBaseSerializationSize();
// Serialize input_dims, max_batch_size, data_type, data_format // Serialize input_dims, max_batch_size, data_type, data_format
void serializeBase(void*& buffer); void serializeBase(void*& buffer); // NOLINT
std::vector<nvinfer1::Dims> input_dims_; std::vector<nvinfer1::Dims> input_dims_;
size_t max_batch_size_; size_t max_batch_size_;
nvinfer1::DataType data_type_; nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_; nvinfer1::PluginFormat data_format_;
std::vector<nvinfer1::ITensor*> inputs_;
}; };
} // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -51,7 +51,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { ...@@ -51,7 +51,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
LOG(INFO) << *reinterpret_cast<const contrib::AnalysisConfig *>(config); LOG(INFO) << *reinterpret_cast<const contrib::AnalysisConfig *>(config);
return; return;
} }
LOG(INFO) << *config; LOG(INFO) << *reinterpret_cast<const NativeConfig *>(config);
} }
void CompareResult(const std::vector<PaddleTensor> &outputs, void CompareResult(const std::vector<PaddleTensor> &outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册