From 2942cf2ed136cd02da311415f1d8209cabdb6532 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Tue, 18 Dec 2018 06:37:11 +0000 Subject: [PATCH] cherry-pick to release 1.2 1. fix pool2d converter bug. 2. add paddle-trt multi thread support. 3. copy trt header and lib to fluid_inference_install_dir/thrid_party/install/tensorrt/ test=release/1.2 --- cmake/inference_lib.cmake | 7 + .../ir_passes/tensorrt_subgraph_pass.cc | 3 - .../inference/tensorrt/convert/op_converter.h | 2 + .../inference/tensorrt/convert/pool2d_op.cc | 8 +- .../fluid/operators/tensorrt/CMakeLists.txt | 2 +- .../operators/tensorrt/tensorrt_engine_op.cc | 5 +- .../tensorrt/tensorrt_engine_op.cu.cc | 24 --- .../operators/tensorrt/tensorrt_engine_op.h | 170 ++++++++---------- .../tensorrt/tensorrt_engine_op_test.cc | 3 +- 9 files changed, 93 insertions(+), 131 deletions(-) delete mode 100644 paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 0b95a7807..3e19b8e7e 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -182,6 +182,13 @@ if (WITH_ANAKIN AND WITH_MKL) list(APPEND inference_deps anakin_inference_lib) endif () +if (TENSORRT_FOUND) + copy(tensorrt_lib DEPS ${inference_deps} + SRCS ${TENSORRT_ROOT}/include/Nv*.h ${TENSORRT_ROOT}/lib/libnvinfer* + DSTS ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/include ${FLUID_INSTALL_DIR}/third_party/install/tensorrt/lib) +endif () + + set(module "inference") copy(inference_lib DEPS ${inference_deps} SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index c6b7c05f7..6a010f75f 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -63,7 +63,6 @@ std::unique_ptr analysis::TensorRtSubgraphPass::ApplyImpl( void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, Graph *graph) const { auto *op_desc = node->Op(); - static int counter{0}; auto &subgraph = *Agent(node).subgraph(); PADDLE_ENFORCE(!subgraph.empty()); @@ -191,8 +190,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, block_desc.Proto()->SerializeAsString()); SetAttr(op_desc->Proto(), "max_batch_size", Get("max_batch_size")); SetAttr(op_desc->Proto(), "workspace_size", Get("workspace_size")); - SetAttr(op_desc->Proto(), "engine_uniq_key", - "trt-" + std::to_string(counter++)); SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes())); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); } diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index d61d635ed..91670ba8a 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -103,6 +103,7 @@ class OpConverter { void ConvertBlock(const framework::proto::BlockDesc& block, const std::unordered_set& parameters, const framework::Scope& scope, TensorRTEngine* engine) { + std::unique_lock lk(mut_); for (int i = 0; i < block.ops_size(); i++) { const auto& op = block.ops(i); ConvertOp(op, parameters, scope, engine); @@ -125,6 +126,7 @@ class OpConverter { std::unordered_map converters_; // fluid inference scope framework::Scope* scope_{nullptr}; + std::mutex mut_; }; } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 343fd3f7c..1d0d83d1f 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -109,8 +109,12 @@ class Pool2dOpConverter : public OpConverter { } if (pool_type == "max") { - nvinfer1::DimsHW pre_pad(paddings[0], paddings[1]); - nvinfer1::DimsHW post_pad(paddings[0], paddings[1]); + // Under ceil mode, the pre_pad and post_pad are used to + // record the the padding size. In some ceil mode cases, + // we do not need padding, so we initialize the two vars to 0. + + nvinfer1::DimsHW pre_pad(0, 0); + nvinfer1::DimsHW post_pad(0, 0); if (ceil_mode) { // If ceil mode is true, we will pad the appropriate size to the input. DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad, &post_pad, diff --git a/paddle/fluid/operators/tensorrt/CMakeLists.txt b/paddle/fluid/operators/tensorrt/CMakeLists.txt index eee0b90fb..6b551d13f 100644 --- a/paddle/fluid/operators/tensorrt/CMakeLists.txt +++ b/paddle/fluid/operators/tensorrt/CMakeLists.txt @@ -1,5 +1,5 @@ op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) -file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n") +file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(tensorrt_engine);\n") nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc DEPS tensorrt_engine_op analysis) diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc index 3cf2ce3c7..b993c55fa 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc @@ -21,8 +21,6 @@ namespace paddle { -DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); - namespace operators { class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { @@ -31,7 +29,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Xs", "A list of inputs.").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); - AddAttr("engine_uniq_key", "unique key for the TRT engine."); AddAttr("max_batch_size", "the maximum batch size."); AddAttr("workspace_size", "the workspace size."); AddComment("TensorRT engine operator."); @@ -50,6 +47,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, - ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); + ops::TensorRTEngineOpMaker); #endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc deleted file mode 100644 index cbe1b426f..000000000 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cu.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - tensorrt_engine, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel); diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 6eef4c98c..88c4f5084 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -27,8 +27,6 @@ namespace paddle { -DECLARE_int32(tensorrt_engine_batch_size); - namespace operators { using FluidDT = framework::proto::VarType_Type; @@ -49,7 +47,7 @@ TRT_DT FluidDataType2TRT(FluidDT type) { return TRT_DT::kINT32; } -nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { +nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { PADDLE_ENFORCE_GT(shape.size(), 1UL, "TensorRT' tensor input requires at least 2 dimensions"); PADDLE_ENFORCE_LE(shape.size(), 4UL, @@ -63,131 +61,119 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { } // namespace // NOLINT using inference::Singleton; -using inference::tensorrt::TRT_EngineManager; +using inference::tensorrt::TensorRTEngine; + +class TensorRTEngineOp : public framework::OperatorBase { + private: + std::vector input_names_; + std::unordered_set param_names_; + mutable std::unique_ptr trt_engine_; + int max_batch_size_; + int workspace_size_; -class TensorRTEngineOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; + TensorRTEngineOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) { + input_names_ = Inputs("Xs"); + max_batch_size_ = Attr("max_batch_size"); + workspace_size_ = Attr("workspace_size"); + + auto params = Attr>("parameters"); + for (const auto ¶m : params) { + param_names_.insert(param); + } + } protected: - void InferShape(framework::InferShapeContext* ctx) const override {} - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input0 = ctx.Inputs("Xs").front(); - framework::OpKernelType kt = framework::OpKernelType( - framework::ToDataType(ctx.scope() - .FindVar(input0) - ->GetMutable() - ->type()), - ctx.GetPlace()); - return kt; + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + RunTrt(scope, dev_place); } -}; -template -class TensorRTEngineKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto engine_name = context.Attr("engine_uniq_key"); - int max_batch_size = context.Attr("max_batch_size"); - if (!Singleton::Global().HasEngine(engine_name)) { - Prepare(context); + void RunTrt(const framework::Scope &scope, + const platform::Place &dev_place) const { + int runtime_batch = 1; + if (trt_engine_.get() == nullptr) { + trt_engine_.reset(new TensorRTEngine( + max_batch_size_, workspace_size_, nullptr, + boost::get(dev_place).device)); + Prepare(scope, dev_place, trt_engine_.get()); } - auto* engine = Singleton::Global().Get(engine_name); - auto input_names = context.op().Inputs("Xs"); - PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); - PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, max_batch_size); + + auto *engine = trt_engine_.get(); + PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); std::vector output_maps = - context.Attr>("output_name_mapping"); + Attr>("output_name_mapping"); - auto params = context.Attr>("parameters"); - std::unordered_set parameters; - for (const auto& param : params) { - parameters.insert(param); - } // Convert input tensor from fluid to engine. - for (const auto& x : context.Inputs("Xs")) { - if (parameters.count(x)) continue; + for (const auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; // convert input and copy to TRT engine's buffer - auto& t = inference::analysis::GetFromScope( - context.scope(), x); + auto &t = + inference::analysis::GetFromScope(scope, x); + auto t_shape = framework::vectorize(t.dims()); + runtime_batch = t_shape[0]; if (platform::is_cpu_place(t.place())) { - engine->SetInputFromCPU(x, static_cast(t.data()), + engine->SetInputFromCPU(x, static_cast(t.data()), t.memory_size()); } else { - engine->SetInputFromGPU(x, static_cast(t.data()), + engine->SetInputFromGPU(x, static_cast(t.data()), t.memory_size()); } } + + PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_); // Execute the engine. - PADDLE_ENFORCE_GT(FLAGS_tensorrt_engine_batch_size, 0); - engine->Execute(FLAGS_tensorrt_engine_batch_size); + engine->Execute(runtime_batch); // Convert output tensor from engine to fluid int output_index = 0; VLOG(4) << "TensorRT Engine Op Outputs:"; - for (const auto& y : context.Outputs("Ys")) { + for (const auto &y : Outputs("Ys")) { VLOG(4) << y; // convert output and copy to fluid. - nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]); + nvinfer1::ITensor *trt_t = engine->GetITensor(output_maps[output_index]); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. // The ITensor doesn't contain the batch size dim. std::vector ddim; - ddim.push_back(FLAGS_tensorrt_engine_batch_size); + ddim.push_back(runtime_batch); for (int i = 0; i < dims.nbDims; i++) { ddim.push_back(dims.d[i]); } - auto* fluid_v = context.scope().FindVar(y); + auto *fluid_v = scope.FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); - auto* fluid_t = fluid_v->GetMutable(); + auto *fluid_t = fluid_v->GetMutable(); fluid_t->Resize(framework::make_ddim(ddim)); - // TODO(Superjomn) find some way to determine which device to output the - // tensor. - // if (platform::is_cpu_place(fluid_t->place())) { // TODO(Superjomn) change this float to dtype size. - auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * - FLAGS_tensorrt_engine_batch_size; + auto size = + inference::analysis::AccuDims(dims.d, dims.nbDims) * runtime_batch; engine->GetOutputInGPU( output_maps[output_index], fluid_t->mutable_data(platform::CUDAPlace( - boost::get(context.GetPlace()).device)), + boost::get(dev_place).device)), size * sizeof(float)); - output_index += 1; } cudaStreamSynchronize(*engine->stream()); } - protected: - void Prepare(const framework::ExecutionContext& context) const { + void Prepare(const framework::Scope &scope, const platform::Place &dev_place, + TensorRTEngine *engine) const { VLOG(4) << "Prepare engine"; - // Get the ProgramDesc and pass to convert. framework::proto::BlockDesc block_desc; - block_desc.ParseFromString(context.Attr("subgraph")); - int max_batch_size = context.Attr("max_batch_size"); - int workspace_size = context.Attr("workspace_size"); - - auto params = context.Attr>("parameters"); - std::unordered_set parameters; - for (const auto& param : params) { - parameters.insert(param); - } + block_desc.ParseFromString(Attr("subgraph")); std::vector output_maps = - context.Attr>("output_name_mapping"); - - // TODO(Superjomn) replace this with a different stream - auto* engine = Singleton::Global().Create( - max_batch_size, workspace_size, nullptr /*engine hold its own stream*/, - context.Attr("engine_uniq_key"), - boost::get(context.GetPlace()).device); + Attr>("output_name_mapping"); engine->InitNetwork(); @@ -195,39 +181,33 @@ class TensorRTEngineKernel : public framework::OpKernel { VLOG(4) << "parsed var size " << block.AllVars().size(); // Add inputs VLOG(4) << "declare inputs"; - for (auto& input : context.Inputs("Xs")) { - if (parameters.count(input)) continue; + for (auto &input : Inputs("Xs")) { + if (param_names_.count(input)) continue; VLOG(4) << "declare input " << input; - auto* var = block.FindVar(input); + + auto &t = + inference::analysis::GetFromScope(scope, input); + auto t_shape = framework::vectorize(t.dims()); + + auto *var = block.FindVar(input); // TensorRT engine need to create parameters. The parameter's description // should be set in PADDLE_ENFORCE(var, "no variable called %s", input); PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, "TensorRT engine only takes LoDTensor as input"); - auto shape = var->GetShape(); - // For the special batch_size placeholder -1, drop it and pass the real - // shape of data. - // TODO(Superjomn) fix this with batch broadcast, or it can't handle - // variational batch size. - if (shape[0] == -1) { - shape[0] = FLAGS_tensorrt_engine_batch_size; - } + engine->DeclareInput( input, FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), - Vec2TRT_Dims(shape)); + Vec2TRT_Dims(t_shape)); } - inference::Singleton::Global() - .ConvertBlock(block_desc, parameters, context.scope(), engine); + .ConvertBlock(block_desc, param_names_, scope, engine); // Add outputs - for (auto& output : output_maps) { - if (!engine->HasDeclared(output)) { - engine->DeclareOutput(output); - } + for (auto &output : output_maps) { + engine->DeclareOutput(output); } - engine->FreezeNetwork(); } }; diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index 56bdd6c2f..287b0edc9 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -24,8 +24,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" -USE_CUDA_ONLY_OP(tensorrt_engine); - +USE_NO_KERNEL_OP(tensorrt_engine); namespace paddle { namespace operators { -- GitLab