Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9f0958fa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9f0958fa
编写于
1月 04, 2022
作者:
王
王明冬
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[infrt] add trt_graph_split_pass for infrt. test=develop (#38494)
上级
dfdc9960
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
238 addition
and
73 deletion
+238
-73
paddle/infrt/CMakeLists.txt
paddle/infrt/CMakeLists.txt
+8
-0
paddle/infrt/dialect/mlir_loader.cc
paddle/infrt/dialect/mlir_loader.cc
+3
-1
paddle/infrt/dialect/mlir_tests/paddle_ops.mlir
paddle/infrt/dialect/mlir_tests/paddle_ops.mlir
+2
-2
paddle/infrt/dialect/mlir_tests/rewrite.mlir
paddle/infrt/dialect/mlir_tests/rewrite.mlir
+7
-7
paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir
paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir
+1
-1
paddle/infrt/dialect/mlir_tests/trt_ops.mlir
paddle/infrt/dialect/mlir_tests/trt_ops.mlir
+6
-8
paddle/infrt/dialect/pd_ops.td
paddle/infrt/dialect/pd_ops.td
+2
-2
paddle/infrt/dialect/tensorrt/CMakeLists.txt
paddle/infrt/dialect/tensorrt/CMakeLists.txt
+4
-0
paddle/infrt/dialect/tensorrt/trt_exec.cc
paddle/infrt/dialect/tensorrt/trt_exec.cc
+48
-0
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
+37
-40
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
+2
-3
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
+50
-0
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
+60
-0
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
+0
-2
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
+2
-4
paddle/infrt/dialect/tensorrt/trt_ops.h
paddle/infrt/dialect/tensorrt/trt_ops.h
+0
-0
paddle/scripts/infrt_build.sh
paddle/scripts/infrt_build.sh
+6
-3
未找到文件。
paddle/infrt/CMakeLists.txt
浏览文件 @
9f0958fa
if
(
NOT WITH_INFRT
)
if
(
NOT WITH_INFRT
)
return
()
return
()
endif
()
endif
()
# compile flags
set
(
INFRT_FLAGS -Wno-comment
)
foreach
(
flag
${
INFRT_FLAGS
}
)
safe_set_cflag
(
CMAKE_C_FLAGS
${
flag
}
)
safe_set_cxxflag
(
CMAKE_CXX_FLAGS
${
flag
}
)
endforeach
()
set
(
INFRT_SOURCE_DIR
"
${
PADDLE_SOURCE_DIR
}
/paddle/infrt"
)
set
(
INFRT_SOURCE_DIR
"
${
PADDLE_SOURCE_DIR
}
/paddle/infrt"
)
set
(
INFRT_BINARY_DIR
"
${
PADDLE_BINARY_DIR
}
/paddle/infrt"
)
set
(
INFRT_BINARY_DIR
"
${
PADDLE_BINARY_DIR
}
/paddle/infrt"
)
set
(
INFRT_TEST_TARGETS CACHE INTERNAL
""
)
set
(
INFRT_TEST_TARGETS CACHE INTERNAL
""
)
...
...
paddle/infrt/dialect/mlir_loader.cc
浏览文件 @
9f0958fa
...
@@ -36,7 +36,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
...
@@ -36,7 +36,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const
std
::
string
&
mlir_source
)
{
const
std
::
string
&
mlir_source
)
{
// context->allowUnregisteredDialects();
// context->allowUnregisteredDialects();
RegisterCinnDialects
(
context
->
getDialectRegistry
());
RegisterCinnDialects
(
context
->
getDialectRegistry
());
context
->
getDialectRegistry
().
insert
<
mlir
::
StandardOpsDialect
>
();
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir
::
ScopedDiagnosticHandler
scope_handler
(
mlir
::
ScopedDiagnosticHandler
scope_handler
(
context
,
[](
mlir
::
Diagnostic
&
diag
)
{
context
,
[](
mlir
::
Diagnostic
&
diag
)
{
...
...
paddle/infrt/dialect/mlir_tests/paddle_ops.mlir
浏览文件 @
9f0958fa
func @ops() {
func @ops() {
%a = pd.feed() : tensor<?xf32>
%a = pd.feed()
{name="input0"}
: tensor<?xf32>
%b = pd.feed() : tensor<?xf32>
%b = pd.feed()
{name="input1"}
: tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
...
...
paddle/infrt/dialect/mlir_tests/rewrite.mlir
浏览文件 @
9f0958fa
// CHECK-LABEL: @main
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?xf32>
%a = "pd.feed"()
{name="input0"}
: () -> tensor<?xf32>
%b = "pd.feed"() : () -> tensor<?xf32>
%b = "pd.feed"()
{name="input1"}
: () -> tensor<?xf32>
%bias = "pd.feed"() : () -> tensor<?xf32>
%bias = "pd.feed"()
{name="input2"}
: () -> tensor<?xf32>
%b1 = "pd.feed"() : () -> tensor<?xf32>
%b1 = "pd.feed"()
{name="input3"}
: () -> tensor<?xf32>
%b2 = "pd.feed"() : () -> tensor<?xf32>
%b2 = "pd.feed"()
{name="input4"}
: () -> tensor<?xf32>
%bias1 = "pd.feed"() : () -> tensor<?xf32>
%bias1 = "pd.feed"()
{name="input5"}
: () -> tensor<?xf32>
%bias2 = "pd.feed"() : () -> tensor<?xf32>
%bias2 = "pd.feed"()
{name="input6"}
: () -> tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
...
...
paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir
浏览文件 @
9f0958fa
// CHECK-LABEL: @main
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?x3x256x256xf32>
%a = "pd.feed"()
{name="input0"}
: () -> tensor<?x3x256x256xf32>
%filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32>
%filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32>
%bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
...
...
paddle/infrt/dialect/mlir_tests/trt_ops.mlir
浏览文件 @
9f0958fa
// CHECK-LABEL: @main
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?xf32>
%bias = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%b = "pd.feed"() : () -> tensor<?xf32>
%c = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%bias = "pd.feed"() : () -> tensor<?xf32>
%b1 = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%c = "pd.feed"() : () -> tensor<?xf32>
%b2 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%b1 = "pd.feed"() : () -> tensor<?xf32>
%bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%b2 = "pd.feed"() : () -> tensor<?xf32>
%bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() : () -> tensor<?xf32>
%bias2 = "pd.feed"() : () -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
...
...
paddle/infrt/dialect/pd_ops.td
浏览文件 @
9f0958fa
...
@@ -6,14 +6,14 @@ include "mlir/Interfaces/LoopLikeInterface.td"
...
@@ -6,14 +6,14 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/pd_op_base.td"
include "paddle/infrt/dialect/pd_op_base.td"
def PD_FeedOp : PD_Op<"feed"
, [NoSideEffect]
> {
def PD_FeedOp : PD_Op<"feed"> {
let summary = "Feed Op";
let summary = "Feed Op";
let description = [{
let description = [{
Feed a tensor into the model.
Feed a tensor into the model.
}];
}];
let arguments = (ins);
let arguments = (ins
StrAttr:$name
);
let results = (outs PD_Tensor:$out);
let results = (outs PD_Tensor:$out);
let assemblyFormat = [{
let assemblyFormat = [{
...
...
paddle/infrt/dialect/tensorrt/CMakeLists.txt
浏览文件 @
9f0958fa
...
@@ -4,5 +4,9 @@ gather_srcs(infrt_src SRCS
...
@@ -4,5 +4,9 @@ gather_srcs(infrt_src SRCS
trt_ops.cc
trt_ops.cc
trt_op_teller_pass.cc
trt_op_teller_pass.cc
trt_graph_fuse_pass.cc
trt_graph_fuse_pass.cc
trt_graph_split_pass.cc
)
)
mlir_tablegen_on
(
trt_ops
)
mlir_tablegen_on
(
trt_ops
)
add_executable
(
trt-exec trt_exec.cc
)
target_link_libraries
(
trt-exec infrt
${
MLIR_IR_LIBS
}
)
paddle/infrt/dialect/tensorrt/trt_exec.cc
0 → 100644
浏览文件 @
9f0958fa
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
int
main
(
int
argc
,
char
**
argv
)
{
static
llvm
::
cl
::
opt
<
std
::
string
>
input_file
(
llvm
::
cl
::
Positional
,
llvm
::
cl
::
desc
(
"Specify input filename"
),
llvm
::
cl
::
init
(
"-"
));
llvm
::
cl
::
ParseCommandLineOptions
(
argc
,
argv
);
mlir
::
MLIRContext
*
context
=
infrt
::
Global
::
getMLIRContext
();
auto
module
=
infrt
::
dialect
::
LoadMlirFile
(
input_file
.
c_str
(),
context
);
module
->
dump
();
mlir
::
PassManager
pm
(
context
);
mlir
::
OpPassManager
&
trt_pass_manager
=
pm
.
nest
<
mlir
::
FuncOp
>
();
trt_pass_manager
.
addPass
(
std
::
make_unique
<
infrt
::
trt
::
trtOpTellerPass
>
());
trt_pass_manager
.
addPass
(
std
::
make_unique
<
infrt
::
trt
::
trtGraphFusePass
>
());
trt_pass_manager
.
addPass
(
std
::
make_unique
<
infrt
::
trt
::
trtGraphSplitPass
>
(
10
));
if
(
mlir
::
failed
(
pm
.
run
(
*
module
)))
{
std
::
cout
<<
"
\n
pass failed!
\n
"
<<
std
::
endl
;
return
4
;
}
module
->
dump
();
return
0
;
}
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
浏览文件 @
9f0958fa
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
...
@@ -25,42 +26,31 @@
...
@@ -25,42 +26,31 @@
namespace
infrt
{
namespace
infrt
{
namespace
trt
{
namespace
trt
{
namespace
{
namespace
{
// ReverseDfs
//
FlexibleDFS
//
do reverse dfs. calls "func" to search when visit a node.
//
do reverse dfs. calls leave(node) after visiting all parents of node
.
//
The elements in 'source' can't be nullptr
.
// Reference the function
with the same name
but defined in:
// Reference the function
nameed "FlexibleDFS"
but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc.
// paddle/fluid/framework/ir/subgraph_detector.cc.
void
FlexibleDFS
(
const
std
::
vector
<::
mlir
::
Operation
*>
&
source
,
const
std
::
function
<
bool
(
const
::
mlir
::
Operation
*
)
>
&
leave
)
{
typedef
struct
{
::
mlir
::
Operation
*
node
;
bool
leave
;
}
FNode
;
std
::
vector
<
FNode
>
stack
;
bool
reverseDfs
(
std
::
vector
<::
mlir
::
Operation
*>
source
,
for
(
auto
&
node
:
source
)
{
const
std
::
function
<
bool
(
const
::
mlir
::
Operation
*
)
>
&
func
)
{
stack
.
push_back
(
FNode
{
node
,
false
});
}
std
::
unordered_set
<
const
::
mlir
::
Operation
*>
visited
;
std
::
unordered_set
<
const
::
mlir
::
Operation
*>
visited
;
while
(
!
stack
.
empty
())
{
while
(
!
source
.
empty
())
{
auto
fnode
=
stack
.
back
();
auto
node
=
source
.
back
();
stack
.
pop_back
();
source
.
pop_back
();
if
(
visited
.
count
(
node
))
continue
;
if
(
fnode
.
leave
)
{
visited
.
insert
(
node
);
if
(
leave
&&
!
leave
(
fnode
.
node
))
return
;
if
(
func
(
node
))
return
true
;
}
auto
values
=
node
->
getOperands
();
if
(
visited
.
count
(
fnode
.
node
))
continue
;
visited
.
insert
(
fnode
.
node
);
if
(
leave
)
stack
.
push_back
(
FNode
{
fnode
.
node
,
true
});
auto
values
=
fnode
.
node
->
getOperands
();
for
(
auto
value
:
values
)
{
for
(
auto
value
:
values
)
{
// if the value is a block argument, the node is nullptr.
::
mlir
::
Operation
*
node
=
value
.
getDefiningOp
();
::
mlir
::
Operation
*
node
=
value
.
getDefiningOp
();
if
(
!
visited
.
count
(
node
))
{
if
(
node
!=
nullptr
&&
!
visited
.
count
(
node
))
{
s
tack
.
push_back
(
FNode
{
node
,
false
}
);
s
ource
.
emplace_back
(
node
);
}
}
}
}
}
}
return
false
;
}
}
// merge the first&second graph op to a new graph op.
// merge the first&second graph op to a new graph op.
...
@@ -136,6 +126,20 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
...
@@ -136,6 +126,20 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
second
.
erase
();
second
.
erase
();
}
}
// Topological sort the function op.
void
topoSortBlock
(
mlir
::
Block
&
body
)
{
// NOLINT
llvm
::
SetVector
<
Operation
*>
toSort
;
if
(
body
.
empty
())
return
;
for
(
auto
it
=
body
.
rbegin
();
it
!=
body
.
rend
();
++
it
)
{
toSort
.
insert
(
&*
it
);
}
llvm
::
SetVector
<
Operation
*>
result
=
::
mlir
::
topologicalSort
(
std
::
move
(
toSort
));
for
(
auto
*
op
:
result
)
{
op
->
moveBefore
(
body
.
getTerminator
());
}
}
}
// namespace
}
// namespace
// Implementation of the trtGraphFusePass.
// Implementation of the trtGraphFusePass.
...
@@ -158,21 +162,14 @@ void trtGraphFusePass::runOnFunction() {
...
@@ -158,21 +162,14 @@ void trtGraphFusePass::runOnFunction() {
std
::
vector
<::
mlir
::
Operation
*>
source_nodes
;
std
::
vector
<::
mlir
::
Operation
*>
source_nodes
;
for
(
auto
operand
:
user_op
->
getOperands
())
{
for
(
auto
operand
:
user_op
->
getOperands
())
{
auto
input
=
operand
.
getDefiningOp
();
auto
input
=
operand
.
getDefiningOp
();
if
(
input
!=
&
op
)
{
if
(
input
!=
&
op
&&
input
!=
nullptr
)
{
source_nodes
.
push_back
(
input
);
source_nodes
.
push_back
(
input
);
}
}
}
}
// Reverse DFS from the source_nodes.
// Reverse DFS from the source_nodes.
bool
have_excess_path
=
false
;
if
(
!
reverseDfs
(
source_nodes
,
[
&
op
](
const
::
mlir
::
Operation
*
n
)
{
FlexibleDFS
(
source_nodes
,
return
n
==
&
op
;
[
&
have_excess_path
,
&
op
](
const
::
mlir
::
Operation
*
n
)
{
}))
{
if
(
n
==
&
op
)
{
have_excess_path
=
true
;
return
false
;
}
return
true
;
});
if
(
!
have_excess_path
)
{
mergeTwoAdjacentGraphOp
(
builder
,
graph_op
,
user_graph_op
);
mergeTwoAdjacentGraphOp
(
builder
,
graph_op
,
user_graph_op
);
changed
=
true
;
changed
=
true
;
break
;
break
;
...
@@ -181,7 +178,7 @@ void trtGraphFusePass::runOnFunction() {
...
@@ -181,7 +178,7 @@ void trtGraphFusePass::runOnFunction() {
if
(
changed
)
break
;
if
(
changed
)
break
;
}
}
}
while
(
changed
);
}
while
(
changed
);
topoSortBlock
(
body
);
}
}
}
// namespace trt
}
// namespace trt
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
浏览文件 @
9f0958fa
...
@@ -25,7 +25,7 @@ namespace trt {
...
@@ -25,7 +25,7 @@ namespace trt {
* source func:
* source func:
*
*
* func @main() -> tensor<?xf32> {
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()
...
* %c = "pd.graph"(%a) {
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.fetch" %m
...
@@ -42,7 +42,7 @@ namespace trt {
...
@@ -42,7 +42,7 @@ namespace trt {
*
*
* destination func:
* destination func:
* func @main() -> tensor<?xf32> {
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()
...
* %d, %f = "pd.graph"(%a) {
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %n = "pd.conv3d"(%m)...
...
@@ -58,6 +58,5 @@ class trtGraphFusePass
...
@@ -58,6 +58,5 @@ class trtGraphFusePass
::
llvm
::
StringRef
getName
()
const
override
{
return
"trtGraphFusePass"
;
}
::
llvm
::
StringRef
getName
()
const
override
{
return
"trtGraphFusePass"
;
}
void
runOnFunction
()
override
;
void
runOnFunction
()
override
;
};
};
}
// namespace trt
}
// namespace trt
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
0 → 100644
浏览文件 @
9f0958fa
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace
infrt
{
namespace
trt
{
// Implementation of the trtGraphSplitPass。
void
trtGraphSplitPass
::
runOnFunction
()
{
std
::
vector
<::
mlir
::
pd
::
GraphOp
>
worklist
;
::
mlir
::
Block
&
block
=
getFunction
().
front
();
for
(
auto
&
op
:
block
)
{
::
mlir
::
pd
::
GraphOp
graph_op
=
::
llvm
::
dyn_cast_or_null
<::
mlir
::
pd
::
GraphOp
>
(
&
op
);
if
(
nullptr
!=
graph_op
&&
graph_op
.
getBody
()
->
getOperations
().
size
()
<=
min_subgraph_size_
)
{
worklist
.
push_back
(
graph_op
);
}
}
while
(
!
worklist
.
empty
())
{
::
mlir
::
pd
::
GraphOp
graph_op
=
worklist
.
back
();
worklist
.
pop_back
();
::
mlir
::
Block
*
body
=
graph_op
.
getBody
();
auto
fetch_op
=
body
->
getTerminator
();
graph_op
.
replaceAllUsesWith
(
fetch_op
->
getOperands
());
auto
copy_range
=
body
->
without_terminator
();
block
.
getOperations
().
splice
(
::
mlir
::
Block
::
iterator
(
graph_op
),
body
->
getOperations
(),
copy_range
.
begin
(),
copy_range
.
end
());
graph_op
.
erase
();
}
}
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
0 → 100644
浏览文件 @
9f0958fa
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
namespace
infrt
{
namespace
trt
{
/*
* trtGraphSplitPass.
*
* Splite the graph op when the number of operations is too small.
* The feature is the opposite of 'trtOpTellerPass'.
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* }
*/
class
trtGraphSplitPass
:
public
::
mlir
::
PassWrapper
<
trtGraphSplitPass
,
::
mlir
::
FunctionPass
>
{
public:
::
llvm
::
StringRef
getName
()
const
override
{
return
"trtGraphSplitPass"
;
}
void
runOnFunction
()
override
;
explicit
trtGraphSplitPass
(
size_t
min_subgraph_size
=
3
)
:
min_subgraph_size_
(
min_subgraph_size
)
{}
private:
size_t
min_subgraph_size_
;
};
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
浏览文件 @
9f0958fa
...
@@ -20,7 +20,6 @@
...
@@ -20,7 +20,6 @@
namespace
infrt
{
namespace
infrt
{
namespace
trt
{
namespace
trt
{
// Implementation of the trtOpTellerPass。
// Implementation of the trtOpTellerPass。
void
trtOpTellerPass
::
runOnFunction
()
{
void
trtOpTellerPass
::
runOnFunction
()
{
::
mlir
::
Block
&
body
=
getFunction
().
front
();
::
mlir
::
Block
&
body
=
getFunction
().
front
();
...
@@ -60,6 +59,5 @@ void trtOpTellerPass::runOnFunction() {
...
@@ -60,6 +59,5 @@ void trtOpTellerPass::runOnFunction() {
builder
.
create
<
mlir
::
pd
::
FetchOp
>
(
loc
,
op
->
getResults
());
builder
.
create
<
mlir
::
pd
::
FetchOp
>
(
loc
,
op
->
getResults
());
}
}
}
}
}
// namespace trt
}
// namespace trt
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
浏览文件 @
9f0958fa
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
namespace
infrt
{
namespace
infrt
{
namespace
trt
{
namespace
trt
{
/*
/*
* trtOpTellerPass.
* trtOpTellerPass.
*
*
...
@@ -26,7 +25,7 @@ namespace trt {
...
@@ -26,7 +25,7 @@ namespace trt {
* source func:
* source func:
*
*
* func @main() -> tensor<?xf32> {
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()
...
* %c = "pd.conv2d"(%a) ...
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* %f = "pd.conv2d"(%a) ...
...
@@ -35,7 +34,7 @@ namespace trt {
...
@@ -35,7 +34,7 @@ namespace trt {
*
*
* destination func:
* destination func:
* func @main() -> tensor<?xf32> {
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()
...
* %c = "pd.graph"(%a) {
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.fetch" %m
...
@@ -59,6 +58,5 @@ class trtOpTellerPass
...
@@ -59,6 +58,5 @@ class trtOpTellerPass
::
llvm
::
StringRef
getName
()
const
override
{
return
"trtOpTellerPass"
;
}
::
llvm
::
StringRef
getName
()
const
override
{
return
"trtOpTellerPass"
;
}
void
runOnFunction
()
override
;
void
runOnFunction
()
override
;
};
};
}
// namespace trt
}
// namespace trt
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_ops.h
100755 → 100644
浏览文件 @
9f0958fa
文件模式从 100755 更改为 100644
paddle/scripts/infrt_build.sh
浏览文件 @
9f0958fa
...
@@ -65,12 +65,12 @@ function infrt_gen_and_build() {
...
@@ -65,12 +65,12 @@ function infrt_gen_and_build() {
mkdir
-p
${
PADDLE_ROOT
}
/build
mkdir
-p
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
rm
-f
infrt_summary.txt
rm
-f
infrt_summary.txt
cmake ..
-DWITH_MKL
=
OFF
-DWITH_GPU
=
OFF
-DCMAKE_BUILD_TYPE
=
Release
-DWITH_INFRT
=
ON
-DWITH_PYTHON
=
OFF
-DWITH_TESTING
==
${
WITH_TESTING
:-
ON
}
;
build_error
=
$?
cmake ..
-DWITH_MKL
=
OFF
-DWITH_GPU
=
OFF
-D
WITH_CRYPTO
=
OFF
-D
CMAKE_BUILD_TYPE
=
Release
-DWITH_INFRT
=
ON
-DWITH_PYTHON
=
OFF
-DWITH_TESTING
==
${
WITH_TESTING
:-
ON
}
;
build_error
=
$?
if
[
"
$build_error
"
!=
0
]
;
then
if
[
"
$build_error
"
!=
0
]
;
then
exit
7
;
exit
7
;
fi
fi
make
-j
${
parallel_number
}
infrt infrtopt infrt-exec test_infrt_exec infrt_lib_dist
;
build_error
=
$?
make
-j
${
parallel_number
}
infrt infrtopt infrt-exec test_infrt_exec
trt-exec
infrt_lib_dist
;
build_error
=
$?
if
[
"
$build_error
"
!=
0
]
;
then
if
[
"
$build_error
"
!=
0
]
;
then
exit
7
;
exit
7
;
fi
fi
...
@@ -115,6 +115,9 @@ function main() {
...
@@ -115,6 +115,9 @@ function main() {
build_only
)
build_only
)
infrt_gen_and_build
${
parallel_number
}
infrt_gen_and_build
${
parallel_number
}
;;
;;
test_only
)
test_infrt
;;
*
)
*
)
print_usage
print_usage
exit
1
exit
1
...
@@ -126,7 +129,7 @@ function main() {
...
@@ -126,7 +129,7 @@ function main() {
cat
${
PADDLE_ROOT
}
/build/infrt_summary.txt
cat
${
PADDLE_ROOT
}
/build/infrt_summary.txt
echo
"========================================================"
echo
"========================================================"
fi
fi
echo
"paddle_build script finished as expected"
echo
"paddle_build script finished as expected
!
"
}
}
main
$@
main
$@
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录