Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6fd96a04
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6fd96a04
编写于
3月 07, 2022
作者:
W
Wilber
提交者:
GitHub
3月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add mlir trt engine type. (#40197)
* infrt add trt engine * update engine name
上级
c52a664e
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
91 addition
and
25 deletion
+91
-25
paddle/infrt/backends/tensorrt/test_trt_engine.cc
paddle/infrt/backends/tensorrt/test_trt_engine.cc
+4
-4
paddle/infrt/backends/tensorrt/trt_engine.cc
paddle/infrt/backends/tensorrt/trt_engine.cc
+13
-13
paddle/infrt/backends/tensorrt/trt_engine.h
paddle/infrt/backends/tensorrt/trt_engine.h
+8
-3
paddle/infrt/backends/tensorrt/trt_utils.h
paddle/infrt/backends/tensorrt/trt_utils.h
+5
-4
paddle/infrt/dialect/tensorrt/trt_dilaect_types.h
paddle/infrt/dialect/tensorrt/trt_dilaect_types.h
+29
-0
paddle/infrt/dialect/tensorrt/trt_op_base.td
paddle/infrt/dialect/tensorrt/trt_op_base.td
+3
-0
paddle/infrt/dialect/tensorrt/trt_ops.cc
paddle/infrt/dialect/tensorrt/trt_ops.cc
+25
-0
paddle/infrt/dialect/tensorrt/trt_ops.h
paddle/infrt/dialect/tensorrt/trt_ops.h
+4
-1
未找到文件。
paddle/infrt/backends/tensorrt/test_trt_engine.cc
浏览文件 @
6fd96a04
...
@@ -17,8 +17,8 @@
...
@@ -17,8 +17,8 @@
#include <NvInfer.h>
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include <NvInferRuntimeCommon.h>
#include
"glog/logging.h"
#include
<glog/logging.h>
#include
"gtest/gtest.h"
#include
<gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
...
@@ -86,7 +86,7 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork(
...
@@ -86,7 +86,7 @@ TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork(
inline
float
sigmoid
(
float
x
)
{
return
1.
f
/
(
1.
f
+
exp
(
-
1
*
x
));
}
inline
float
sigmoid
(
float
x
)
{
return
1.
f
/
(
1.
f
+
exp
(
-
1
*
x
));
}
TEST
(
trt
,
run_static
)
{
TEST
(
trt
,
run_static
)
{
T
RT
Engine
static_trt_engine
(
0
);
T
rt
Engine
static_trt_engine
(
0
);
auto
net
=
ConstructNetwork
(
auto
net
=
ConstructNetwork
(
static_trt_engine
.
GetTrtBuilder
(),
nvinfer1
::
Dims3
{
3
,
28
,
28
},
true
);
static_trt_engine
.
GetTrtBuilder
(),
nvinfer1
::
Dims3
{
3
,
28
,
28
},
true
);
BuildOptions
static_build_options
;
BuildOptions
static_build_options
;
...
@@ -164,7 +164,7 @@ TEST(trt, run_static) {
...
@@ -164,7 +164,7 @@ TEST(trt, run_static) {
}
}
TEST
(
trt
,
run_dynamic
)
{
TEST
(
trt
,
run_dynamic
)
{
T
RT
Engine
engine
(
0
);
T
rt
Engine
engine
(
0
);
auto
net
=
ConstructNetwork
(
auto
net
=
ConstructNetwork
(
engine
.
GetTrtBuilder
(),
nvinfer1
::
Dims4
{
-
1
,
3
,
-
1
,
-
1
},
false
);
engine
.
GetTrtBuilder
(),
nvinfer1
::
Dims4
{
-
1
,
3
,
-
1
,
-
1
},
false
);
BuildOptions
build_options
;
BuildOptions
build_options
;
...
...
paddle/infrt/backends/tensorrt/trt_engine.cc
浏览文件 @
6fd96a04
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include <NvInferRuntime.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include <NvInferRuntimeCommon.h>
#include
"glog/logging.h"
#include
<glog/logging.h>
#include "paddle/phi/backends/dynload/tensorrt.h"
#include "paddle/phi/backends/dynload/tensorrt.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
...
@@ -40,26 +40,26 @@ static nvinfer1::IRuntime* createInferRuntime(
...
@@ -40,26 +40,26 @@ static nvinfer1::IRuntime* createInferRuntime(
phi
::
dynload
::
createInferRuntime_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
phi
::
dynload
::
createInferRuntime_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
}
}
T
RTEngine
::
TRT
Engine
(
int
device_id
)
:
device_id_
(
device_id
)
{
T
rtEngine
::
Trt
Engine
(
int
device_id
)
:
device_id_
(
device_id
)
{
FreshDeviceId
();
FreshDeviceId
();
logger_
.
reset
(
new
TrtLogger
());
logger_
.
reset
(
new
TrtLogger
());
builder_
.
reset
(
createInferBuilder
(
logger_
->
GetTrtLogger
()));
builder_
.
reset
(
createInferBuilder
(
logger_
->
GetTrtLogger
()));
phi
::
dynload
::
initLibNvInferPlugins
(
&
logger_
->
GetTrtLogger
(),
""
);
phi
::
dynload
::
initLibNvInferPlugins
(
&
logger_
->
GetTrtLogger
(),
""
);
}
}
nvinfer1
::
IBuilder
*
T
RT
Engine
::
GetTrtBuilder
()
{
nvinfer1
::
IBuilder
*
T
rt
Engine
::
GetTrtBuilder
()
{
CHECK_NOTNULL
(
builder_
);
CHECK_NOTNULL
(
builder_
);
return
builder_
.
get
();
return
builder_
.
get
();
}
}
void
T
RT
Engine
::
Build
(
TrtUniquePtr
<
nvinfer1
::
INetworkDefinition
>
network
,
void
T
rt
Engine
::
Build
(
TrtUniquePtr
<
nvinfer1
::
INetworkDefinition
>
network
,
const
BuildOptions
&
build_options
)
{
const
BuildOptions
&
build_options
)
{
FreshDeviceId
();
FreshDeviceId
();
ModelToBuildEnv
(
std
::
move
(
network
),
build_options
);
ModelToBuildEnv
(
std
::
move
(
network
),
build_options
);
CHECK_NOTNULL
(
engine_
);
CHECK_NOTNULL
(
engine_
);
}
}
bool
T
RT
Engine
::
ModelToBuildEnv
(
bool
T
rt
Engine
::
ModelToBuildEnv
(
TrtUniquePtr
<
nvinfer1
::
INetworkDefinition
>
network
,
TrtUniquePtr
<
nvinfer1
::
INetworkDefinition
>
network
,
const
BuildOptions
&
build
)
{
const
BuildOptions
&
build
)
{
CHECK_NOTNULL
(
builder_
);
CHECK_NOTNULL
(
builder_
);
...
@@ -70,7 +70,7 @@ bool TRTEngine::ModelToBuildEnv(
...
@@ -70,7 +70,7 @@ bool TRTEngine::ModelToBuildEnv(
return
true
;
return
true
;
}
}
bool
T
RT
Engine
::
NetworkToEngine
(
const
BuildOptions
&
build
)
{
bool
T
rt
Engine
::
NetworkToEngine
(
const
BuildOptions
&
build
)
{
TrtUniquePtr
<
IBuilderConfig
>
config
{
builder_
->
createBuilderConfig
()};
TrtUniquePtr
<
IBuilderConfig
>
config
{
builder_
->
createBuilderConfig
()};
CHECK_NOTNULL
(
config
);
CHECK_NOTNULL
(
config
);
CHECK
(
SetupNetworkAndConfig
(
build
,
*
network_
,
*
config
));
CHECK
(
SetupNetworkAndConfig
(
build
,
*
network_
,
*
config
));
...
@@ -91,7 +91,7 @@ bool TRTEngine::NetworkToEngine(const BuildOptions& build) {
...
@@ -91,7 +91,7 @@ bool TRTEngine::NetworkToEngine(const BuildOptions& build) {
return
true
;
return
true
;
}
}
bool
T
RT
Engine
::
SetupNetworkAndConfig
(
const
BuildOptions
&
build
,
bool
T
rt
Engine
::
SetupNetworkAndConfig
(
const
BuildOptions
&
build
,
INetworkDefinition
&
network
,
INetworkDefinition
&
network
,
IBuilderConfig
&
config
)
{
IBuilderConfig
&
config
)
{
builder_
->
setMaxBatchSize
(
build
.
max_batch
);
builder_
->
setMaxBatchSize
(
build
.
max_batch
);
...
@@ -235,7 +235,7 @@ bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build,
...
@@ -235,7 +235,7 @@ bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build,
return
true
;
return
true
;
}
}
bool
T
RT
Engine
::
SetUpInference
(
bool
T
rt
Engine
::
SetUpInference
(
const
InferenceOptions
&
inference
,
const
InferenceOptions
&
inference
,
const
std
::
unordered_map
<
std
::
string
,
phi
::
DenseTensor
*>&
inputs
,
const
std
::
unordered_map
<
std
::
string
,
phi
::
DenseTensor
*>&
inputs
,
std
::
unordered_map
<
std
::
string
,
phi
::
DenseTensor
*>*
outputs
)
{
std
::
unordered_map
<
std
::
string
,
phi
::
DenseTensor
*>*
outputs
)
{
...
@@ -261,7 +261,7 @@ bool TRTEngine::SetUpInference(
...
@@ -261,7 +261,7 @@ bool TRTEngine::SetUpInference(
return
true
;
return
true
;
}
}
void
T
RT
Engine
::
Run
(
const
phi
::
GPUContext
&
ctx
)
{
void
T
rt
Engine
::
Run
(
const
phi
::
GPUContext
&
ctx
)
{
if
(
is_dynamic_shape_
)
{
if
(
is_dynamic_shape_
)
{
DynamicRun
(
ctx
);
DynamicRun
(
ctx
);
}
else
{
}
else
{
...
@@ -269,7 +269,7 @@ void TRTEngine::Run(const phi::GPUContext& ctx) {
...
@@ -269,7 +269,7 @@ void TRTEngine::Run(const phi::GPUContext& ctx) {
}
}
}
}
void
T
RT
Engine
::
StaticRun
(
const
phi
::
GPUContext
&
ctx
)
{
void
T
rt
Engine
::
StaticRun
(
const
phi
::
GPUContext
&
ctx
)
{
const
int
num_bindings
=
engine_
->
getNbBindings
();
const
int
num_bindings
=
engine_
->
getNbBindings
();
std
::
vector
<
void
*>
buffers
(
num_bindings
,
nullptr
);
std
::
vector
<
void
*>
buffers
(
num_bindings
,
nullptr
);
...
@@ -303,7 +303,7 @@ void TRTEngine::StaticRun(const phi::GPUContext& ctx) {
...
@@ -303,7 +303,7 @@ void TRTEngine::StaticRun(const phi::GPUContext& ctx) {
runtime_batch
,
buffers
.
data
(),
ctx
.
stream
(),
nullptr
);
runtime_batch
,
buffers
.
data
(),
ctx
.
stream
(),
nullptr
);
}
}
void
T
RT
Engine
::
DynamicRun
(
const
phi
::
GPUContext
&
ctx
)
{
void
T
rt
Engine
::
DynamicRun
(
const
phi
::
GPUContext
&
ctx
)
{
const
int
num_bindings
=
engine_
->
getNbBindings
();
const
int
num_bindings
=
engine_
->
getNbBindings
();
std
::
vector
<
void
*>
buffers
(
num_bindings
,
nullptr
);
std
::
vector
<
void
*>
buffers
(
num_bindings
,
nullptr
);
...
@@ -339,14 +339,14 @@ void TRTEngine::DynamicRun(const phi::GPUContext& ctx) {
...
@@ -339,14 +339,14 @@ void TRTEngine::DynamicRun(const phi::GPUContext& ctx) {
contexts_
.
front
()
->
enqueueV2
(
buffers
.
data
(),
ctx
.
stream
(),
nullptr
);
contexts_
.
front
()
->
enqueueV2
(
buffers
.
data
(),
ctx
.
stream
(),
nullptr
);
}
}
void
T
RT
Engine
::
FreshDeviceId
()
{
void
T
rt
Engine
::
FreshDeviceId
()
{
int
count
;
int
count
;
cudaGetDeviceCount
(
&
count
);
cudaGetDeviceCount
(
&
count
);
CHECK_LT
(
device_id_
,
count
);
CHECK_LT
(
device_id_
,
count
);
phi
::
backends
::
gpu
::
SetDeviceId
(
device_id_
);
phi
::
backends
::
gpu
::
SetDeviceId
(
device_id_
);
}
}
void
T
RT
Engine
::
GetEngineInfo
()
{
void
T
rt
Engine
::
GetEngineInfo
()
{
#if IS_TRT_VERSION_GE(8200)
#if IS_TRT_VERSION_GE(8200)
LOG
(
INFO
)
<<
"====== engine info ======"
;
LOG
(
INFO
)
<<
"====== engine info ======"
;
std
::
unique_ptr
<
nvinfer1
::
IEngineInspector
>
infer_inspector
(
std
::
unique_ptr
<
nvinfer1
::
IEngineInspector
>
infer_inspector
(
...
...
paddle/infrt/backends/tensorrt/trt_engine.h
浏览文件 @
6fd96a04
...
@@ -56,13 +56,18 @@ using namespace nvinfer1; // NOLINT
...
@@ -56,13 +56,18 @@ using namespace nvinfer1; // NOLINT
//
//
// We have encapsulated this logic, please use the following programming model.
// We have encapsulated this logic, please use the following programming model.
//
//
// T
RT
Engine trt_engine;
// T
rt
Engine trt_engine;
// trt_engine.Build(...);
// trt_engine.Build(...);
// trt_engine.SetUpInference(...);
// trt_engine.SetUpInference(...);
// trt_engine.Run(...);
// trt_engine.Run(...);
class
T
RT
Engine
{
class
T
rt
Engine
{
public:
public:
explicit
TRTEngine
(
int
device_id
);
explicit
TrtEngine
(
int
device_id
=
0
);
TrtEngine
(
const
TrtEngine
&
)
=
delete
;
TrtEngine
&
operator
=
(
const
TrtEngine
&
)
=
delete
;
TrtEngine
(
TrtEngine
&&
)
=
default
;
TrtEngine
&
operator
=
(
TrtEngine
&&
)
=
default
;
nvinfer1
::
IBuilder
*
GetTrtBuilder
();
nvinfer1
::
IBuilder
*
GetTrtBuilder
();
...
...
paddle/infrt/backends/tensorrt/trt_utils.h
浏览文件 @
6fd96a04
...
@@ -15,16 +15,17 @@
...
@@ -15,16 +15,17 @@
#pragma once
#pragma once
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include <glog/logging.h>
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <functional>
#include <functional>
#include <memory>
#include <memory>
#include <unordered_map>
#include <unordered_map>
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
infrt
{
namespace
infrt
{
...
...
paddle/infrt/dialect/tensorrt/trt_dilaect_types.h
0 → 100644
浏览文件 @
6fd96a04
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/Types.h"
namespace
infrt
{
namespace
trt
{
class
EngineType
:
public
mlir
::
Type
::
TypeBase
<
EngineType
,
mlir
::
Type
,
mlir
::
TypeStorage
>
{
public:
using
Base
::
Base
;
};
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_op_base.td
浏览文件 @
6fd96a04
...
@@ -27,6 +27,9 @@ class TRT_PaddleAttr <string name, string description> :
...
@@ -27,6 +27,9 @@ class TRT_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::trt::" # name # "Attr>()">,
Attr<CPred<"$_self.isa<mlir::trt::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">;
"PaddlePaddle " # description # " attribute">;
def TRT_EngineType :
Type<CPred<"$_self.isa<::infrt::trt::EngineType>()">, "!trt.engine">,
BuildableType<"getType<::infrt::trt::EngineType>()">;
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
// PaddlePaddle type definitions
...
...
paddle/infrt/dialect/tensorrt/trt_ops.cc
浏览文件 @
6fd96a04
...
@@ -13,23 +13,48 @@
...
@@ -13,23 +13,48 @@
// limitations under the License.
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h"
namespace
infrt
{
namespace
infrt
{
namespace
trt
{
namespace
trt
{
TensorRTDialect
::
TensorRTDialect
(
mlir
::
MLIRContext
*
context
)
TensorRTDialect
::
TensorRTDialect
(
mlir
::
MLIRContext
*
context
)
:
mlir
::
Dialect
(
"trt"
,
context
,
mlir
::
TypeID
::
get
<
TensorRTDialect
>
())
{
:
mlir
::
Dialect
(
"trt"
,
context
,
mlir
::
TypeID
::
get
<
TensorRTDialect
>
())
{
addTypes
<
EngineType
>
();
addOperations
<
addOperations
<
#define GET_OP_LIST
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>
();
>
();
}
}
mlir
::
Type
TensorRTDialect
::
parseType
(
mlir
::
DialectAsmParser
&
parser
)
const
{
llvm
::
StringRef
keyword
;
if
(
parser
.
parseKeyword
(
&
keyword
))
return
mlir
::
Type
();
// parse trt dilaect types, for example: !trt.engine
if
(
keyword
==
"engine"
)
{
return
infrt
::
trt
::
EngineType
::
get
(
getContext
());
}
parser
.
emitError
(
parser
.
getCurrentLocation
(),
"unknown infrt::trt type: "
)
<<
keyword
;
return
mlir
::
Type
();
}
void
TensorRTDialect
::
printType
(
mlir
::
Type
type
,
mlir
::
DialectAsmPrinter
&
printer
)
const
{
// print trt dilaect types, for example: !trt.engien
if
(
type
.
isa
<
infrt
::
trt
::
EngineType
>
())
{
printer
<<
"engine"
;
return
;
}
llvm_unreachable
(
"unknown infrt::trt type."
);
}
}
// namespace trt
}
// namespace trt
}
// namespace infrt
}
// namespace infrt
...
...
paddle/infrt/dialect/tensorrt/trt_ops.h
浏览文件 @
6fd96a04
...
@@ -35,8 +35,11 @@ namespace trt {
...
@@ -35,8 +35,11 @@ namespace trt {
class
TensorRTDialect
:
public
mlir
::
Dialect
{
class
TensorRTDialect
:
public
mlir
::
Dialect
{
public:
public:
explicit
TensorRTDialect
(
mlir
::
MLIRContext
*
context
);
explicit
TensorRTDialect
(
mlir
::
MLIRContext
*
context
);
static
llvm
::
StringRef
getDialectNamespace
()
{
return
"trt"
;
}
static
llvm
::
StringRef
getDialectNamespace
()
{
return
"trt"
;
}
mlir
::
Type
parseType
(
mlir
::
DialectAsmParser
&
parser
)
const
;
// NOLINT
void
printType
(
mlir
::
Type
type
,
mlir
::
DialectAsmPrinter
&
printer
)
const
;
// NOLINT
};
};
}
// namespace trt
}
// namespace trt
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录