From 6fd96a0400e5e618795ad20f8e85a2e975ea4194 Mon Sep 17 00:00:00 2001 From: Wilber Date: Mon, 7 Mar 2022 15:41:27 +0800 Subject: [PATCH] Add mlir trt engine type. (#40197) * infrt add trt engine * update engine name --- .../backends/tensorrt/test_trt_engine.cc | 8 ++--- paddle/infrt/backends/tensorrt/trt_engine.cc | 26 ++++++++--------- paddle/infrt/backends/tensorrt/trt_engine.h | 11 +++++-- paddle/infrt/backends/tensorrt/trt_utils.h | 9 +++--- .../dialect/tensorrt/trt_dilaect_types.h | 29 +++++++++++++++++++ paddle/infrt/dialect/tensorrt/trt_op_base.td | 3 ++ paddle/infrt/dialect/tensorrt/trt_ops.cc | 25 ++++++++++++++++ paddle/infrt/dialect/tensorrt/trt_ops.h | 5 +++- 8 files changed, 91 insertions(+), 25 deletions(-) create mode 100644 paddle/infrt/dialect/tensorrt/trt_dilaect_types.h diff --git a/paddle/infrt/backends/tensorrt/test_trt_engine.cc b/paddle/infrt/backends/tensorrt/test_trt_engine.cc index 54b7bc3e8af..12cf14060e2 100644 --- a/paddle/infrt/backends/tensorrt/test_trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/test_trt_engine.cc @@ -17,8 +17,8 @@ #include #include #include -#include "glog/logging.h" -#include "gtest/gtest.h" +#include +#include #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" @@ -86,7 +86,7 @@ TrtUniquePtr ConstructNetwork( inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); } TEST(trt, run_static) { - TRTEngine static_trt_engine(0); + TrtEngine static_trt_engine(0); auto net = ConstructNetwork( static_trt_engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true); BuildOptions static_build_options; @@ -164,7 +164,7 @@ TEST(trt, run_static) { } TEST(trt, run_dynamic) { - TRTEngine engine(0); + TrtEngine engine(0); auto net = ConstructNetwork( engine.GetTrtBuilder(), nvinfer1::Dims4{-1, 3, -1, -1}, false); BuildOptions build_options; diff --git a/paddle/infrt/backends/tensorrt/trt_engine.cc b/paddle/infrt/backends/tensorrt/trt_engine.cc index a204fe42b45..232653e8c41 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/trt_engine.cc @@ -17,7 +17,7 @@ #include #include -#include "glog/logging.h" +#include #include "paddle/phi/backends/dynload/tensorrt.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/ddim.h" @@ -40,26 +40,26 @@ static nvinfer1::IRuntime* createInferRuntime( phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); } -TRTEngine::TRTEngine(int device_id) : device_id_(device_id) { +TrtEngine::TrtEngine(int device_id) : device_id_(device_id) { FreshDeviceId(); logger_.reset(new TrtLogger()); builder_.reset(createInferBuilder(logger_->GetTrtLogger())); phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), ""); } -nvinfer1::IBuilder* TRTEngine::GetTrtBuilder() { +nvinfer1::IBuilder* TrtEngine::GetTrtBuilder() { CHECK_NOTNULL(builder_); return builder_.get(); } -void TRTEngine::Build(TrtUniquePtr network, +void TrtEngine::Build(TrtUniquePtr network, const BuildOptions& build_options) { FreshDeviceId(); ModelToBuildEnv(std::move(network), build_options); CHECK_NOTNULL(engine_); } -bool TRTEngine::ModelToBuildEnv( +bool TrtEngine::ModelToBuildEnv( TrtUniquePtr network, const BuildOptions& build) { CHECK_NOTNULL(builder_); @@ -70,7 +70,7 @@ bool TRTEngine::ModelToBuildEnv( return true; } -bool TRTEngine::NetworkToEngine(const BuildOptions& build) { +bool TrtEngine::NetworkToEngine(const BuildOptions& build) { TrtUniquePtr config{builder_->createBuilderConfig()}; CHECK_NOTNULL(config); CHECK(SetupNetworkAndConfig(build, *network_, *config)); @@ -91,7 +91,7 @@ bool TRTEngine::NetworkToEngine(const BuildOptions& build) { return true; } -bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build, +bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build, INetworkDefinition& network, IBuilderConfig& config) { builder_->setMaxBatchSize(build.max_batch); @@ -235,7 +235,7 @@ bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build, return true; } -bool TRTEngine::SetUpInference( +bool TrtEngine::SetUpInference( const InferenceOptions& inference, const std::unordered_map& inputs, std::unordered_map* outputs) { @@ -261,7 +261,7 @@ bool TRTEngine::SetUpInference( return true; } -void TRTEngine::Run(const phi::GPUContext& ctx) { +void TrtEngine::Run(const phi::GPUContext& ctx) { if (is_dynamic_shape_) { DynamicRun(ctx); } else { @@ -269,7 +269,7 @@ void TRTEngine::Run(const phi::GPUContext& ctx) { } } -void TRTEngine::StaticRun(const phi::GPUContext& ctx) { +void TrtEngine::StaticRun(const phi::GPUContext& ctx) { const int num_bindings = engine_->getNbBindings(); std::vector buffers(num_bindings, nullptr); @@ -303,7 +303,7 @@ void TRTEngine::StaticRun(const phi::GPUContext& ctx) { runtime_batch, buffers.data(), ctx.stream(), nullptr); } -void TRTEngine::DynamicRun(const phi::GPUContext& ctx) { +void TrtEngine::DynamicRun(const phi::GPUContext& ctx) { const int num_bindings = engine_->getNbBindings(); std::vector buffers(num_bindings, nullptr); @@ -339,14 +339,14 @@ void TRTEngine::DynamicRun(const phi::GPUContext& ctx) { contexts_.front()->enqueueV2(buffers.data(), ctx.stream(), nullptr); } -void TRTEngine::FreshDeviceId() { +void TrtEngine::FreshDeviceId() { int count; cudaGetDeviceCount(&count); CHECK_LT(device_id_, count); phi::backends::gpu::SetDeviceId(device_id_); } -void TRTEngine::GetEngineInfo() { +void TrtEngine::GetEngineInfo() { #if IS_TRT_VERSION_GE(8200) LOG(INFO) << "====== engine info ======"; std::unique_ptr infer_inspector( diff --git a/paddle/infrt/backends/tensorrt/trt_engine.h b/paddle/infrt/backends/tensorrt/trt_engine.h index f72bdaf3ac0..3c8243e3c38 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.h +++ b/paddle/infrt/backends/tensorrt/trt_engine.h @@ -56,13 +56,18 @@ using namespace nvinfer1; // NOLINT // // We have encapsulated this logic, please use the following programming model. // -// TRTEngine trt_engine; +// TrtEngine trt_engine; // trt_engine.Build(...); // trt_engine.SetUpInference(...); // trt_engine.Run(...); -class TRTEngine { +class TrtEngine { public: - explicit TRTEngine(int device_id); + explicit TrtEngine(int device_id = 0); + + TrtEngine(const TrtEngine&) = delete; + TrtEngine& operator=(const TrtEngine&) = delete; + TrtEngine(TrtEngine&&) = default; + TrtEngine& operator=(TrtEngine&&) = default; nvinfer1::IBuilder* GetTrtBuilder(); diff --git a/paddle/infrt/backends/tensorrt/trt_utils.h b/paddle/infrt/backends/tensorrt/trt_utils.h index 4b129af1d53..c66a850ffb1 100644 --- a/paddle/infrt/backends/tensorrt/trt_utils.h +++ b/paddle/infrt/backends/tensorrt/trt_utils.h @@ -15,16 +15,17 @@ #pragma once +#include +#include +#include +#include + #include #include #include #include #include -#include -#include -#include -#include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" namespace infrt { diff --git a/paddle/infrt/dialect/tensorrt/trt_dilaect_types.h b/paddle/infrt/dialect/tensorrt/trt_dilaect_types.h new file mode 100644 index 00000000000..efcf7dd5be1 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_dilaect_types.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/Types.h" + +namespace infrt { +namespace trt { + +class EngineType + : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_base.td b/paddle/infrt/dialect/tensorrt/trt_op_base.td index 5722f17d597..128960ee03e 100755 --- a/paddle/infrt/dialect/tensorrt/trt_op_base.td +++ b/paddle/infrt/dialect/tensorrt/trt_op_base.td @@ -27,6 +27,9 @@ class TRT_PaddleAttr : Attr()">, "PaddlePaddle " # description # " attribute">; +def TRT_EngineType : + Type()">, "!trt.engine">, + BuildableType<"getType<::infrt::trt::EngineType>()">; //===----------------------------------------------------------------------===// // PaddlePaddle type definitions diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.cc b/paddle/infrt/dialect/tensorrt/trt_ops.cc index 35b7967892c..f179939e232 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.cc +++ b/paddle/infrt/dialect/tensorrt/trt_ops.cc @@ -13,23 +13,48 @@ // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_ops.h" +#include #include #include #include #include #include +#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h" namespace infrt { namespace trt { TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context) : mlir::Dialect("trt", context, mlir::TypeID::get()) { + addTypes(); addOperations< #define GET_OP_LIST #include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT >(); } +mlir::Type TensorRTDialect::parseType(mlir::DialectAsmParser &parser) const { + llvm::StringRef keyword; + if (parser.parseKeyword(&keyword)) return mlir::Type(); + // parse trt dilaect types, for example: !trt.engine + if (keyword == "engine") { + return infrt::trt::EngineType::get(getContext()); + } + parser.emitError(parser.getCurrentLocation(), "unknown infrt::trt type: ") + << keyword; + return mlir::Type(); +} + +void TensorRTDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // print trt dilaect types, for example: !trt.engien + if (type.isa()) { + printer << "engine"; + return; + } + llvm_unreachable("unknown infrt::trt type."); +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h index 95b2ed41fdf..978b9906e5f 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.h +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -35,8 +35,11 @@ namespace trt { class TensorRTDialect : public mlir::Dialect { public: - explicit TensorRTDialect(mlir::MLIRContext* context); + explicit TensorRTDialect(mlir::MLIRContext *context); static llvm::StringRef getDialectNamespace() { return "trt"; } + mlir::Type parseType(mlir::DialectAsmParser &parser) const; // NOLINT + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const; // NOLINT }; } // namespace trt -- GitLab