未验证 提交 6fd96a04 编写于 作者: W Wilber 提交者: GitHub

Add mlir trt engine type. (#40197)

* infrt add trt engine

* update engine name
上级 c52a664e
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <NvInfer.h> #include <NvInfer.h>
#include <NvInferRuntime.h> #include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h> #include <NvInferRuntimeCommon.h>
#include "glog/logging.h" #include <glog/logging.h>
#include "gtest/gtest.h" #include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
...@@ -86,7 +86,7 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork( ...@@ -86,7 +86,7 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork(
inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); } inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); }
TEST(trt, run_static) { TEST(trt, run_static) {
TRTEngine static_trt_engine(0); TrtEngine static_trt_engine(0);
auto net = ConstructNetwork( auto net = ConstructNetwork(
static_trt_engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true); static_trt_engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true);
BuildOptions static_build_options; BuildOptions static_build_options;
...@@ -164,7 +164,7 @@ TEST(trt, run_static) { ...@@ -164,7 +164,7 @@ TEST(trt, run_static) {
} }
TEST(trt, run_dynamic) { TEST(trt, run_dynamic) {
TRTEngine engine(0); TrtEngine engine(0);
auto net = ConstructNetwork( auto net = ConstructNetwork(
engine.GetTrtBuilder(), nvinfer1::Dims4{-1, 3, -1, -1}, false); engine.GetTrtBuilder(), nvinfer1::Dims4{-1, 3, -1, -1}, false);
BuildOptions build_options; BuildOptions build_options;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <NvInferRuntime.h> #include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h> #include <NvInferRuntimeCommon.h>
#include "glog/logging.h" #include <glog/logging.h>
#include "paddle/phi/backends/dynload/tensorrt.h" #include "paddle/phi/backends/dynload/tensorrt.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -40,26 +40,26 @@ static nvinfer1::IRuntime* createInferRuntime( ...@@ -40,26 +40,26 @@ static nvinfer1::IRuntime* createInferRuntime(
phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
} }
TRTEngine::TRTEngine(int device_id) : device_id_(device_id) { TrtEngine::TrtEngine(int device_id) : device_id_(device_id) {
FreshDeviceId(); FreshDeviceId();
logger_.reset(new TrtLogger()); logger_.reset(new TrtLogger());
builder_.reset(createInferBuilder(logger_->GetTrtLogger())); builder_.reset(createInferBuilder(logger_->GetTrtLogger()));
phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), ""); phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), "");
} }
nvinfer1::IBuilder* TRTEngine::GetTrtBuilder() { nvinfer1::IBuilder* TrtEngine::GetTrtBuilder() {
CHECK_NOTNULL(builder_); CHECK_NOTNULL(builder_);
return builder_.get(); return builder_.get();
} }
void TRTEngine::Build(TrtUniquePtr<nvinfer1::INetworkDefinition> network, void TrtEngine::Build(TrtUniquePtr<nvinfer1::INetworkDefinition> network,
const BuildOptions& build_options) { const BuildOptions& build_options) {
FreshDeviceId(); FreshDeviceId();
ModelToBuildEnv(std::move(network), build_options); ModelToBuildEnv(std::move(network), build_options);
CHECK_NOTNULL(engine_); CHECK_NOTNULL(engine_);
} }
bool TRTEngine::ModelToBuildEnv( bool TrtEngine::ModelToBuildEnv(
TrtUniquePtr<nvinfer1::INetworkDefinition> network, TrtUniquePtr<nvinfer1::INetworkDefinition> network,
const BuildOptions& build) { const BuildOptions& build) {
CHECK_NOTNULL(builder_); CHECK_NOTNULL(builder_);
...@@ -70,7 +70,7 @@ bool TRTEngine::ModelToBuildEnv( ...@@ -70,7 +70,7 @@ bool TRTEngine::ModelToBuildEnv(
return true; return true;
} }
bool TRTEngine::NetworkToEngine(const BuildOptions& build) { bool TrtEngine::NetworkToEngine(const BuildOptions& build) {
TrtUniquePtr<IBuilderConfig> config{builder_->createBuilderConfig()}; TrtUniquePtr<IBuilderConfig> config{builder_->createBuilderConfig()};
CHECK_NOTNULL(config); CHECK_NOTNULL(config);
CHECK(SetupNetworkAndConfig(build, *network_, *config)); CHECK(SetupNetworkAndConfig(build, *network_, *config));
...@@ -91,7 +91,7 @@ bool TRTEngine::NetworkToEngine(const BuildOptions& build) { ...@@ -91,7 +91,7 @@ bool TRTEngine::NetworkToEngine(const BuildOptions& build) {
return true; return true;
} }
bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build, bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build,
INetworkDefinition& network, INetworkDefinition& network,
IBuilderConfig& config) { IBuilderConfig& config) {
builder_->setMaxBatchSize(build.max_batch); builder_->setMaxBatchSize(build.max_batch);
...@@ -235,7 +235,7 @@ bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build, ...@@ -235,7 +235,7 @@ bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build,
return true; return true;
} }
bool TRTEngine::SetUpInference( bool TrtEngine::SetUpInference(
const InferenceOptions& inference, const InferenceOptions& inference,
const std::unordered_map<std::string, phi::DenseTensor*>& inputs, const std::unordered_map<std::string, phi::DenseTensor*>& inputs,
std::unordered_map<std::string, phi::DenseTensor*>* outputs) { std::unordered_map<std::string, phi::DenseTensor*>* outputs) {
...@@ -261,7 +261,7 @@ bool TRTEngine::SetUpInference( ...@@ -261,7 +261,7 @@ bool TRTEngine::SetUpInference(
return true; return true;
} }
void TRTEngine::Run(const phi::GPUContext& ctx) { void TrtEngine::Run(const phi::GPUContext& ctx) {
if (is_dynamic_shape_) { if (is_dynamic_shape_) {
DynamicRun(ctx); DynamicRun(ctx);
} else { } else {
...@@ -269,7 +269,7 @@ void TRTEngine::Run(const phi::GPUContext& ctx) { ...@@ -269,7 +269,7 @@ void TRTEngine::Run(const phi::GPUContext& ctx) {
} }
} }
void TRTEngine::StaticRun(const phi::GPUContext& ctx) { void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
const int num_bindings = engine_->getNbBindings(); const int num_bindings = engine_->getNbBindings();
std::vector<void*> buffers(num_bindings, nullptr); std::vector<void*> buffers(num_bindings, nullptr);
...@@ -303,7 +303,7 @@ void TRTEngine::StaticRun(const phi::GPUContext& ctx) { ...@@ -303,7 +303,7 @@ void TRTEngine::StaticRun(const phi::GPUContext& ctx) {
runtime_batch, buffers.data(), ctx.stream(), nullptr); runtime_batch, buffers.data(), ctx.stream(), nullptr);
} }
void TRTEngine::DynamicRun(const phi::GPUContext& ctx) { void TrtEngine::DynamicRun(const phi::GPUContext& ctx) {
const int num_bindings = engine_->getNbBindings(); const int num_bindings = engine_->getNbBindings();
std::vector<void*> buffers(num_bindings, nullptr); std::vector<void*> buffers(num_bindings, nullptr);
...@@ -339,14 +339,14 @@ void TRTEngine::DynamicRun(const phi::GPUContext& ctx) { ...@@ -339,14 +339,14 @@ void TRTEngine::DynamicRun(const phi::GPUContext& ctx) {
contexts_.front()->enqueueV2(buffers.data(), ctx.stream(), nullptr); contexts_.front()->enqueueV2(buffers.data(), ctx.stream(), nullptr);
} }
void TRTEngine::FreshDeviceId() { void TrtEngine::FreshDeviceId() {
int count; int count;
cudaGetDeviceCount(&count); cudaGetDeviceCount(&count);
CHECK_LT(device_id_, count); CHECK_LT(device_id_, count);
phi::backends::gpu::SetDeviceId(device_id_); phi::backends::gpu::SetDeviceId(device_id_);
} }
void TRTEngine::GetEngineInfo() { void TrtEngine::GetEngineInfo() {
#if IS_TRT_VERSION_GE(8200) #if IS_TRT_VERSION_GE(8200)
LOG(INFO) << "====== engine info ======"; LOG(INFO) << "====== engine info ======";
std::unique_ptr<nvinfer1::IEngineInspector> infer_inspector( std::unique_ptr<nvinfer1::IEngineInspector> infer_inspector(
......
...@@ -56,13 +56,18 @@ using namespace nvinfer1; // NOLINT ...@@ -56,13 +56,18 @@ using namespace nvinfer1; // NOLINT
// //
// We have encapsulated this logic, please use the following programming model. // We have encapsulated this logic, please use the following programming model.
// //
// TRTEngine trt_engine; // TrtEngine trt_engine;
// trt_engine.Build(...); // trt_engine.Build(...);
// trt_engine.SetUpInference(...); // trt_engine.SetUpInference(...);
// trt_engine.Run(...); // trt_engine.Run(...);
class TRTEngine { class TrtEngine {
public: public:
explicit TRTEngine(int device_id); explicit TrtEngine(int device_id = 0);
TrtEngine(const TrtEngine&) = delete;
TrtEngine& operator=(const TrtEngine&) = delete;
TrtEngine(TrtEngine&&) = default;
TrtEngine& operator=(TrtEngine&&) = default;
nvinfer1::IBuilder* GetTrtBuilder(); nvinfer1::IBuilder* GetTrtBuilder();
......
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
#pragma once #pragma once
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace infrt { namespace infrt {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/Types.h"
namespace infrt {
namespace trt {
class EngineType
: public mlir::Type::TypeBase<EngineType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
};
} // namespace trt
} // namespace infrt
...@@ -27,6 +27,9 @@ class TRT_PaddleAttr <string name, string description> : ...@@ -27,6 +27,9 @@ class TRT_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::trt::" # name # "Attr>()">, Attr<CPred<"$_self.isa<mlir::trt::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">; "PaddlePaddle " # description # " attribute">;
def TRT_EngineType :
Type<CPred<"$_self.isa<::infrt::trt::EngineType>()">, "!trt.engine">,
BuildableType<"getType<::infrt::trt::EngineType>()">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PaddlePaddle type definitions // PaddlePaddle type definitions
......
...@@ -13,23 +13,48 @@ ...@@ -13,23 +13,48 @@
// limitations under the License. // limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Matchers.h> #include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/PatternMatch.h> #include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h> #include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context) TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) { : mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addTypes<EngineType>();
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT #include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>(); >();
} }
mlir::Type TensorRTDialect::parseType(mlir::DialectAsmParser &parser) const {
llvm::StringRef keyword;
if (parser.parseKeyword(&keyword)) return mlir::Type();
// parse trt dilaect types, for example: !trt.engine
if (keyword == "engine") {
return infrt::trt::EngineType::get(getContext());
}
parser.emitError(parser.getCurrentLocation(), "unknown infrt::trt type: ")
<< keyword;
return mlir::Type();
}
void TensorRTDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// print trt dilaect types, for example: !trt.engien
if (type.isa<infrt::trt::EngineType>()) {
printer << "engine";
return;
}
llvm_unreachable("unknown infrt::trt type.");
}
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
......
...@@ -35,8 +35,11 @@ namespace trt { ...@@ -35,8 +35,11 @@ namespace trt {
class TensorRTDialect : public mlir::Dialect { class TensorRTDialect : public mlir::Dialect {
public: public:
explicit TensorRTDialect(mlir::MLIRContext* context); explicit TensorRTDialect(mlir::MLIRContext *context);
static llvm::StringRef getDialectNamespace() { return "trt"; } static llvm::StringRef getDialectNamespace() { return "trt"; }
mlir::Type parseType(mlir::DialectAsmParser &parser) const; // NOLINT
void printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const; // NOLINT
}; };
} // namespace trt } // namespace trt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册