diff --git a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc index 233bfd6a42b7f123813d4ef5cecf353f7e88d208..38e9b1c5e7c19c89f94ce55324507b02da0c5160 100644 --- a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc @@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", - "elementwise_add", "dropout", "split"}); + "elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"}); if (!node->IsOp()) return false; if (teller_set.count(node->Op()->Type())) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 76d205b737aeb456f242037f2b375d9c537b39f3..d19505877bbc1110fcf5787fffc1436d242a7cdc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(split); +USE_TRT_CONVERTER(prelu); +USE_TRT_CONVERTER(conv2d_transpose); #endif diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index ed4c398cee518af3211cab4e982082c46ebb36c2..aa4126392bf1f4b16a123403b68f506d6f76b5a2 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -2,7 +2,7 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc -pad_op.cc split_op.cc +pad_op.cc split_op.cc prelu_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -16,7 +16,7 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL) nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc - DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) + DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op conv_transpose_op SERIAL) nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc @@ -33,4 +33,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin -split_op concat_op SERIAL) + split_op concat_op SERIAL) +nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin + prelu_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 43950b8c048b4e1b8974956948caa639812b2f78..7900f56c9ce17ffc7c62c85a42c62ba326dea16e 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -18,92 +18,139 @@ namespace paddle { namespace inference { namespace tensorrt { -bool to_skip_merging_optimize(TensorRTEngine* engine_, +bool to_skip_merging_optimize(TensorRTEngine* engine, const std::vector& filters, const std::vector& strides, const std::vector& paddings, std::string input_name) { - if (engine_->itensor_quote_num[input_name] > 0) { + if (engine->itensor_quote_num[input_name] > 0) { return true; } if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) - engine_->itensor_quote_num[input_name] += 1; + engine->itensor_quote_num[input_name] += 1; return false; } +template +void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode, + RegistFunc fadd_layer, SetDilationFunc fset_dilation, + const std::string& name) { + VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight + PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); + + PADDLE_ENFORCE(engine != nullptr); + auto* X = engine->GetITensor(op_desc.Input("Input").front()); + + // Declare weights + auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); + PADDLE_ENFORCE_NOT_NULL(Y_v); + auto* Y_t = Y_v->GetMutable(); + + platform::CPUPlace cpu_place; + std::unique_ptr weight_tensor( + new framework::LoDTensor()); + weight_tensor->Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); + + auto* weight_data = weight_tensor->mutable_data(platform::CPUPlace()); + + PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); + const int n_output = weight_tensor->dims()[0]; + const int n_input = weight_tensor->dims()[1]; + const int filter_h = weight_tensor->dims()[2]; + const int filter_w = weight_tensor->dims()[3]; + const int groups = boost::get(op_desc.GetAttr("groups")); + const std::vector dilations = + boost::get>(op_desc.GetAttr("dilations")); + const std::vector strides = + boost::get>(op_desc.GetAttr("strides")); + const std::vector paddings = + boost::get>(op_desc.GetAttr("paddings")); + + nvinfer1::DimsHW nv_ksize(filter_h, filter_w); + nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); + nvinfer1::DimsHW nv_strides(strides[0], strides[1]); + nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_tensor->numel())}; + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + auto* layer = fadd_layer(const_cast(X), n_output, n_input, + nv_ksize, weight, bias); + PADDLE_ENFORCE(layer != nullptr); + layer->setStride(nv_strides); + layer->setPadding(nv_paddings); + layer->setNbGroups(groups); + // set dilations + fset_dilation(layer, nv_dilations); + + auto output_name = op_desc.Output("Output").front(); + layer->setName((name + " (Output: " + output_name + ")").c_str()); + engine->weight_map[op_desc.Input("Filter").front()] = + std::move(weight_tensor); + layer->getOutput(0)->setName(output_name.c_str()); + engine->SetITensor(output_name, layer->getOutput(0)); + + if (test_mode || + to_skip_merging_optimize(engine, {filter_h, filter_w}, strides, paddings, + op_desc.Input("Input").front())) { + engine->DeclareOutput(output_name); + } +} + class Conv2dOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a fluid conv2d op to tensorrt conv layer without bias"; - - framework::OpDesc op_desc(op, nullptr); - PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); - PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight - PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); - - auto* X = engine_->GetITensor(op_desc.Input("Input").front()); - - // Declare weights - auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); - PADDLE_ENFORCE_NOT_NULL(Y_v); - auto* Y_t = Y_v->GetMutable(); - - platform::CPUPlace cpu_place; - std::unique_ptr weight_tensor( - new framework::LoDTensor()); - weight_tensor->Resize(Y_t->dims()); - TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); - - auto* weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - - PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); - const int n_output = weight_tensor->dims()[0]; - const int filter_h = weight_tensor->dims()[2]; - const int filter_w = weight_tensor->dims()[3]; - - const int groups = boost::get(op_desc.GetAttr("groups")); - const std::vector dilations = - boost::get>(op_desc.GetAttr("dilations")); - const std::vector strides = - boost::get>(op_desc.GetAttr("strides")); - const std::vector paddings = - boost::get>(op_desc.GetAttr("paddings")); - - nvinfer1::DimsHW nv_ksize(filter_h, filter_w); - nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); - nvinfer1::DimsHW nv_strides(strides[0], strides[1]); - nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); - - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - weight_tensor->memory_size() / sizeof(float)}; - - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; - auto* layer = TRT_ENGINE_ADD_LAYER( - engine_, Convolution, *const_cast(X), n_output, - nv_ksize, weight.get(), bias.get()); - PADDLE_ENFORCE(layer != nullptr); - layer->setStride(nv_strides); - layer->setPadding(nv_paddings); - layer->setDilation(nv_dilations); - layer->setNbGroups(groups); - - auto output_name = op_desc.Output("Output").front(); - layer->setName(("conv2d (Output: " + output_name + ")").c_str()); - engine_->weight_map[op_desc.Input("Filter").front()] = - std::move(weight_tensor); - layer->getOutput(0)->setName(output_name.c_str()); - engine_->SetITensor(output_name, layer->getOutput(0)); - - if (test_mode || - to_skip_merging_optimize(engine_, {filter_h, filter_w}, strides, - paddings, op_desc.Input("Input").front())) { - engine_->DeclareOutput(output_name); - } + ConvertConv2d( + engine_, op, scope, test_mode, + [&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */ + int n_input, /* Conv input maps */ + nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* { + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, + ksize, weight.get(), bias.get()); + return layer; + }, + [](nvinfer1::IConvolutionLayer* layer, nvinfer1::DimsHW& dilations) { + layer->setDilation(dilations); + }, + "conv2d"); + } +}; + +class Deconv2dOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + ConvertConv2d( + engine_, op, scope, test_mode, + [&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */ + int n_input, /* Deconv output maps */ + nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* { + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input, + ksize, weight.get(), bias.get()); + return layer; + }, + [](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::DimsHW& dilations) { + PADDLE_ENFORCE( + dilations.d[0] == 1 && dilations.d[1] == 1, + "Dilations must be (1, 1) for tensorRT, but given (%d, %d)", + dilations.d[0], dilations.d[1]); + }, + "conv2d_transpose"); } }; @@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); +REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 671bcd8aa9a9fff34644a056499961cf6ab81287..1af091fabd2aea03a85b2d19fd556b18cdd65e3b 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* X = engine_->GetITensor(op_desc.Input("X").front()); nvinfer1::Dims dims_x = X->getDimensions(); - PADDLE_ENFORCE(dims_x.nbDims >= 3); + PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.", + dims_x.nbDims); auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc7cf7d8095043236373d19c3623a279be3fd453 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * PRelu converter from fluid to tensorRT. + */ +class PReluOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(40) << "convert fluid prelu op to tensorrt prelu layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE(input_num == 1); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE(output_num == 1); + // Get attrs + std::string mode = boost::get(op_desc.GetAttr("mode")); + // + auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]); + PADDLE_ENFORCE_NOT_NULL(alpha_var); + auto* alpha_tensor = alpha_var->GetMutable(); + + platform::CPUPlace place; + std::unique_ptr alpha_tensor_host( + new framework::LoDTensor()); + alpha_tensor_host->Resize(alpha_tensor->dims()); + TensorCopySync(*alpha_tensor, place, alpha_tensor_host.get()); + float* alpha_data = alpha_tensor_host->mutable_data(place); + + // Transform alpha to TensorRTEngine::Weight + TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT, + static_cast(alpha_data), + alpha_tensor_host->numel()); + engine_->weight_map[op_desc.Input("Alpha")[0]] = + std::move(alpha_tensor_host); + // + PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode); + nvinfer1::IPluginLayer* layer = + engine_->AddPlugin(&input, input_num, plugin); + + std::string layer_name = "prelu (Output: "; + for (size_t i = 0; i < output_num; i++) { + auto output_name = op_desc.Output("Out")[i]; + layer->getOutput(i)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(i)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } + layer->setName((layer_name + ")").c_str()); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc index f8711c6b60d74639529624c25429bc245de46479..6c002dcdc1cf3660c1a00e9e426d24f317e4883a 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" +USE_OP(conv2d); +USE_OP(conv2d_transpose); + namespace paddle { namespace inference { namespace tensorrt { @@ -51,7 +54,38 @@ TEST(conv2d_op, test) { validator.Execute(3); } +TEST(conv2d_transpose_op, test) { + std::unordered_set parameters({"deconv2d-Y"}); + framework::Scope scope; + TRTConvertValidation validator(5, parameters, scope, 1 << 15); + + validator.DeclInputVar("deconv2d-X", nvinfer1::Dims3(3, 5, 5)); + validator.DeclParamVar("deconv2d-Y", nvinfer1::Dims4(3, 2, 3, 3)); + validator.DeclOutputVar("deconv2d-Out", nvinfer1::Dims3(2, 5, 5)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("conv2d_transpose"); + desc.SetInput("Input", {"deconv2d-X"}); + desc.SetInput("Filter", {"deconv2d-Y"}); + desc.SetOutput("Output", {"deconv2d-Out"}); + + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + + desc.SetAttr("strides", strides); + desc.SetAttr("paddings", paddings); + desc.SetAttr("dilations", dilations); + desc.SetAttr("groups", groups); + + validator.SetOp(*desc.Proto()); + + validator.Execute(3); +} + } // namespace tensorrt } // namespace inference } // namespace paddle -USE_OP(conv2d); + diff --git a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..453f222f1f1e3f3b9ee8fa7bd49f4cab2286e7ea --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(prelu_op, test_channel_wise) { + std::unordered_set parameters({"prelu_alpha"}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1)); + validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("prelu"); + desc.SetInput("X", {"prelu_input"}); + desc.SetInput("Alpha", {"prelu_alpha"}); + desc.SetOutput("Out", {"prelu_out"}); + + desc.SetAttr("mode", std::string("channel")); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +TEST(prelu_op, test_element_wise) { + std::unordered_set parameters({"prelu_alpha"}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2)); + validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("prelu"); + desc.SetInput("X", {"prelu_input"}); + desc.SetInput("Alpha", {"prelu_alpha"}); + desc.SetOutput("Out", {"prelu_out"}); + + desc.SetAttr("mode", std::string("element")); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +TEST(prelu_op, test_scalar) { + std::unordered_set parameters({"prelu_alpha"}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1)); + validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("prelu"); + desc.SetInput("X", {"prelu_input"}); + desc.SetInput("Alpha", {"prelu_alpha"}); + desc.SetOutput("Out", {"prelu_out"}); + + desc.SetAttr("mode", std::string("all")); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +// USE_OP(prelu); +USE_CPU_ONLY_OP(prelu); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index fdd8b56b0ce5c9b5cb6395bcb437aae5ae27829b..208bd12b83aa19f01de9bcf4ada630c87defad5d 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, Buffer &TensorRTEngine::buffer(const std::string &name) { PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); auto it = buffer_sizes_.find(name); - PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s", + name); auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); return buffers_[slot_offset]; } diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 335acdf653e55cc7f3ceccdba88992851c8e0310..7a920ebd10f28e32bc3e7fc825195594806d4d44 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -40,12 +40,13 @@ class TensorRTEngine : public EngineBase { // Weight is model parameter. class Weight { public: + Weight() = default; Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) { w_.type = dtype; w_.values = value; w_.count = num_elem; } - const nvinfer1::Weights& get() { return w_; } + nvinfer1::Weights& get() { return w_; } std::vector dims; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 71b7a551619a43e5300ad3205418d1174c7019ff..6611e2e4b35ee51e58cb6b7b088ad160acb659ad 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1 +1 @@ -nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) +nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..c2fc8028e953fd49bfd89df592840531b881a1ae --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -0,0 +1,145 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +static const int CUDA_NUM_THREADS = 1024; +static const int CUDA_MAX_NUM_BLOCKS = 65535; +inline static int GET_NUM_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +__global__ void PReluChannelWiseKernel(const float *input, const float *alpha, + float *output, int channel, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const float *in = input + offset; + float *out = output + offset; + float scale = alpha[blockIdx.x % channel]; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + float x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +__global__ void PReluElementWiseKernel(const float *input, const float *alpha, + float *output, size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const float *in = input + offset; + const float *scale = alpha + offset; + float *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + float x = in[i]; + out[i] = (x > 0) ? x : scale[i] * x; + } +} + +__global__ void PReluScalarKernel(const float *input, const float *alpha, + float *output, size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const float *in = input + offset; + float scale = *alpha; + float *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + float x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +static inline void PReluChannelWise(cudaStream_t stream, const float *input, + const float *alpha, float *output, + int batch_size, + const nvinfer1::Dims &dims) { + size_t unroll = batch_size * dims.d[0]; + size_t spatial_size = dims.d[1] * dims.d[2]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluChannelWiseKernel<<>>( + input, alpha, output, dims.d[0], spatial_size); +} + +static inline void PReluElementWise(cudaStream_t stream, const float *input, + const float *alpha, float *output, + int batch_size, + const nvinfer1::Dims &dims) { + size_t unroll = batch_size * dims.d[0]; + size_t spatial_size = dims.d[1] * dims.d[2]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluElementWiseKernel<<>>( + input, alpha, output, spatial_size); +} + +static inline void PReluScalar(cudaStream_t stream, const float *input, + const float *alpha, float *output, + int batch_size, const nvinfer1::Dims &dims) { + size_t unroll = batch_size * dims.d[0]; + size_t spatial_size = dims.d[1] * dims.d[2]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluScalarKernel<<>>( + input, alpha, output, spatial_size); +} + +nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, + const nvinfer1::Dims *inputDims, + int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const &input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +int PReluPlugin::initialize() { + nvinfer1::Weights &alpha = cuda_alpha_.get(); + alpha.type = alpha_.get().type; + alpha.count = alpha_.get().count; + + CHECK_EQ(cudaMalloc(&alpha.values, alpha.count * sizeof(float)), cudaSuccess); + CHECK_EQ(cudaMemcpy(const_cast(alpha.values), alpha_.get().values, + alpha.count * sizeof(float), cudaMemcpyHostToDevice), + cudaSuccess); + return 0; +} + +int PReluPlugin::enqueue(int batchSize, const void *const *inputs, + void **outputs, void *workspace, cudaStream_t stream) { + // input dims is CHW. + const auto &input_dims = this->getInputDims(0); + const float *input = reinterpret_cast(inputs[0]); + const float *alpha = + reinterpret_cast(cuda_alpha_.get().values); + float *output = reinterpret_cast(outputs)[0]; + if (mode_ == "channel") { + PReluChannelWise(stream, input, alpha, output, batchSize, input_dims); + } else if (mode_ == "element") { + PReluElementWise(stream, input, alpha, output, batchSize, input_dims); + } else { + PReluScalar(stream, input, alpha, output, batchSize, input_dims); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..e2a97e042bb6acaa9dea074be70fc6563478c678 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -0,0 +1,71 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PReluPlugin : public PluginTensorRT { + TensorRTEngine::Weight alpha_; + TensorRTEngine::Weight cuda_alpha_; + std::string mode_; + + protected: + size_t getSerializationSize() override { + // return getBaseSerializationSize(alpha_) + SerializedSize(mode_); + return 0; + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) override { + // serializeBase(buffer); + // SerializeValue(&buffer, alpha_); + // SerializeValue(&buffer, mode_); + } + + public: + PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode) + : alpha_(alpha), mode_(mode) {} + + // It was used for tensorrt deserialization. + // It should not be called by users. + PReluPlugin(void const *serialData, size_t serialLength) { + // deserializeBase(serialData, serialLength); + // DeserializeValue(&serialData, &serialLength, &alpha_); + // DeserializeValue(&serialData, &serialLength, &mode_); + } + + PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); } + + const char *getPluginType() const override { return "prelu"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int initialize() override; + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle +