未验证 提交 6fd96a04 编写于 作者: W Wilber 提交者: GitHub

Add mlir trt engine type. (#40197)

* infrt add trt engine

* update engine name
上级 c52a664e
......@@ -17,8 +17,8 @@
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
......@@ -86,7 +86,7 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork(
inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); }
TEST(trt, run_static) {
TRTEngine static_trt_engine(0);
TrtEngine static_trt_engine(0);
auto net = ConstructNetwork(
static_trt_engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true);
BuildOptions static_build_options;
......@@ -164,7 +164,7 @@ TEST(trt, run_static) {
}
TEST(trt, run_dynamic) {
TRTEngine engine(0);
TrtEngine engine(0);
auto net = ConstructNetwork(
engine.GetTrtBuilder(), nvinfer1::Dims4{-1, 3, -1, -1}, false);
BuildOptions build_options;
......
......@@ -17,7 +17,7 @@
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include "glog/logging.h"
#include <glog/logging.h>
#include "paddle/phi/backends/dynload/tensorrt.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h"
......@@ -40,26 +40,26 @@ static nvinfer1::IRuntime* createInferRuntime(
phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
}
TRTEngine::TRTEngine(int device_id) : device_id_(device_id) {
TrtEngine::TrtEngine(int device_id) : device_id_(device_id) {
FreshDeviceId();
logger_.reset(new TrtLogger());
builder_.reset(createInferBuilder(logger_->GetTrtLogger()));
phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), "");
}
nvinfer1::IBuilder* TRTEngine::GetTrtBuilder() {
nvinfer1::IBuilder* TrtEngine::GetTrtBuilder() {
CHECK_NOTNULL(builder_);
return builder_.get();
}
void TRTEngine::Build(TrtUniquePtr<nvinfer1::INetworkDefinition> network,
void TrtEngine::Build(TrtUniquePtr<nvinfer1::INetworkDefinition> network,
const BuildOptions& build_options) {
FreshDeviceId();
ModelToBuildEnv(std::move(network), build_options);
CHECK_NOTNULL(engine_);
}
bool TRTEngine::ModelToBuildEnv(
bool TrtEngine::ModelToBuildEnv(
TrtUniquePtr<nvinfer1::INetworkDefinition> network,
const BuildOptions& build) {
CHECK_NOTNULL(builder_);
......@@ -70,7 +70,7 @@ bool TRTEngine::ModelToBuildEnv(
return true;
}
bool TRTEngine::NetworkToEngine(const BuildOptions& build) {
bool TrtEngine::NetworkToEngine(const BuildOptions& build) {
TrtUniquePtr<IBuilderConfig> config{builder_->createBuilderConfig()};
CHECK_NOTNULL(config);
CHECK(SetupNetworkAndConfig(build, *network_, *config));
......@@ -91,7 +91,7 @@ bool TRTEngine::NetworkToEngine(const BuildOptions& build) {
return true;
}
bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build,
bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build,
INetworkDefinition& network,
IBuilderConfig& config) {
builder_->setMaxBatchSize(build.max_batch);
......@@ -235,7 +235,7 @@ bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build,
return true;
}
bool TRTEngine::SetUpInference(
bool TrtEngine::SetUpInference(
const InferenceOptions& inference,
const std::unordered_map<std::string, phi::DenseTensor*>& inputs,
std::unordered_map<std::string, phi::DenseTensor*>* outputs) {
......@@ -261,7 +261,7 @@ bool TRTEngine::SetUpInference(
return true;
}
void TRTEngine::Run(const phi::GPUContext& ctx) {
void TrtEngine::Run(const phi::GPUContext& ctx) {
if (is_dynamic_shape_) {
DynamicRun(ctx);
} else {
......@@ -269,7 +269,7 @@ void TRTEngine::Run(const phi::GPUContext& ctx) {
}
}
void TRTEngine::StaticRun(const phi::GPUContext& ctx) {
void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
const int num_bindings = engine_->getNbBindings();
std::vector<void*> buffers(num_bindings, nullptr);
......@@ -303,7 +303,7 @@ void TRTEngine::StaticRun(const phi::GPUContext& ctx) {
runtime_batch, buffers.data(), ctx.stream(), nullptr);
}
void TRTEngine::DynamicRun(const phi::GPUContext& ctx) {
void TrtEngine::DynamicRun(const phi::GPUContext& ctx) {
const int num_bindings = engine_->getNbBindings();
std::vector<void*> buffers(num_bindings, nullptr);
......@@ -339,14 +339,14 @@ void TRTEngine::DynamicRun(const phi::GPUContext& ctx) {
contexts_.front()->enqueueV2(buffers.data(), ctx.stream(), nullptr);
}
void TRTEngine::FreshDeviceId() {
void TrtEngine::FreshDeviceId() {
int count;
cudaGetDeviceCount(&count);
CHECK_LT(device_id_, count);
phi::backends::gpu::SetDeviceId(device_id_);
}
void TRTEngine::GetEngineInfo() {
void TrtEngine::GetEngineInfo() {
#if IS_TRT_VERSION_GE(8200)
LOG(INFO) << "====== engine info ======";
std::unique_ptr<nvinfer1::IEngineInspector> infer_inspector(
......
......@@ -56,13 +56,18 @@ using namespace nvinfer1; // NOLINT
//
// We have encapsulated this logic, please use the following programming model.
//
// TRTEngine trt_engine;
// TrtEngine trt_engine;
// trt_engine.Build(...);
// trt_engine.SetUpInference(...);
// trt_engine.Run(...);
class TRTEngine {
class TrtEngine {
public:
explicit TRTEngine(int device_id);
explicit TrtEngine(int device_id = 0);
TrtEngine(const TrtEngine&) = delete;
TrtEngine& operator=(const TrtEngine&) = delete;
TrtEngine(TrtEngine&&) = default;
TrtEngine& operator=(TrtEngine&&) = default;
nvinfer1::IBuilder* GetTrtBuilder();
......
......@@ -15,16 +15,17 @@
#pragma once
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include <glog/logging.h>
#include <algorithm>
#include <cassert>
#include <functional>
#include <memory>
#include <unordered_map>
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/Types.h"
namespace infrt {
namespace trt {
class EngineType
: public mlir::Type::TypeBase<EngineType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
};
} // namespace trt
} // namespace infrt
......@@ -27,6 +27,9 @@ class TRT_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::trt::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">;
def TRT_EngineType :
Type<CPred<"$_self.isa<::infrt::trt::EngineType>()">, "!trt.engine">,
BuildableType<"getType<::infrt::trt::EngineType>()">;
//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
......
......@@ -13,23 +13,48 @@
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h"
namespace infrt {
namespace trt {
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addTypes<EngineType>();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>();
}
mlir::Type TensorRTDialect::parseType(mlir::DialectAsmParser &parser) const {
llvm::StringRef keyword;
if (parser.parseKeyword(&keyword)) return mlir::Type();
// parse trt dilaect types, for example: !trt.engine
if (keyword == "engine") {
return infrt::trt::EngineType::get(getContext());
}
parser.emitError(parser.getCurrentLocation(), "unknown infrt::trt type: ")
<< keyword;
return mlir::Type();
}
void TensorRTDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// print trt dilaect types, for example: !trt.engien
if (type.isa<infrt::trt::EngineType>()) {
printer << "engine";
return;
}
llvm_unreachable("unknown infrt::trt type.");
}
} // namespace trt
} // namespace infrt
......
......@@ -35,8 +35,11 @@ namespace trt {
class TensorRTDialect : public mlir::Dialect {
public:
explicit TensorRTDialect(mlir::MLIRContext* context);
explicit TensorRTDialect(mlir::MLIRContext *context);
static llvm::StringRef getDialectNamespace() { return "trt"; }
mlir::Type parseType(mlir::DialectAsmParser &parser) const; // NOLINT
void printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const; // NOLINT
};
} // namespace trt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册