Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
bf93050c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bf93050c
编写于
3月 28, 2022
作者:
王
王明冬
提交者:
GitHub
3月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[infrt] move graph op from pd dialect to infrt dialect. (#41003)
上级
29d2e949
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
38 addition
and
37 deletion
+38
-37
paddle/infrt/dialect/infrt/ir/infrt_ops.td
paddle/infrt/dialect/infrt/ir/infrt_ops.td
+10
-0
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
+7
-7
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
+4
-4
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
+3
-4
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
+1
-1
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
+7
-5
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h
+1
-1
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
+2
-2
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
+3
-3
tools/infrt/custom_pdop.td
tools/infrt/custom_pdop.td
+0
-10
未找到文件。
paddle/infrt/dialect/infrt/ir/infrt_ops.td
浏览文件 @
bf93050c
...
...
@@ -9,6 +9,16 @@ class Infrt_Op<string mnemonic, list<OpTrait> traits = []> : Op<Infrt_Dialect, m
// let parser = [{ return infrt::parse$cppClass(parser, result); }];
}
def PD_GraphOp : Infrt_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs Variadic<AnyType>:$outputs);
}
def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> {
let summary = "kernel op";
let description = [{kernel op!}];
...
...
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
浏览文件 @
bf93050c
...
...
@@ -55,8 +55,8 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
// merge the first&second graph op to a new graph op.
void
mergeTwoAdjacentGraphOp
(
mlir
::
OpBuilder
&
builder
,
// NOLINT
infrt
::
pd
::
GraphOp
first
,
infrt
::
pd
::
GraphOp
second
)
{
::
infrt
::
GraphOp
first
,
::
infrt
::
GraphOp
second
)
{
// comput inputs and outputs
::
llvm
::
SmallVector
<
mlir
::
Value
,
4
>
inputs
(
first
.
getOperands
()),
outputs
;
for
(
mlir
::
Value
input
:
second
.
getOperands
())
{
...
...
@@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op
builder
.
setInsertionPoint
(
first
);
auto
loc
=
first
.
getLoc
();
auto
graph_op
=
builder
.
create
<
infrt
::
pd
::
GraphOp
>
(
loc
,
return_types
,
inputs
);
auto
graph_op
=
builder
.
create
<
::
infrt
::
GraphOp
>
(
loc
,
return_types
,
inputs
);
mlir
::
Block
*
block
=
new
mlir
::
Block
;
auto
copy_range
=
second
.
getBody
()
->
without_terminator
();
block
->
getOperations
().
splice
(
block
->
begin
(),
...
...
@@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() {
do
{
changed
=
false
;
for
(
auto
&
op
:
body
)
{
infrt
::
pd
::
GraphOp
graph_op
=
::
llvm
::
dyn_cast_or_null
<
infrt
::
pd
::
GraphOp
>
(
&
op
);
::
infrt
::
GraphOp
graph_op
=
::
llvm
::
dyn_cast_or_null
<
::
infrt
::
GraphOp
>
(
&
op
);
if
(
nullptr
==
graph_op
)
continue
;
for
(
auto
user_op
:
op
.
getUsers
())
{
infrt
::
pd
::
GraphOp
user_graph_op
=
::
llvm
::
dyn_cast_or_null
<
infrt
::
pd
::
GraphOp
>
(
user_op
);
::
infrt
::
GraphOp
user_graph_op
=
::
llvm
::
dyn_cast_or_null
<
::
infrt
::
GraphOp
>
(
user_op
);
if
(
nullptr
==
user_graph_op
)
continue
;
// get all dst input nodes except src.
std
::
vector
<
mlir
::
Operation
*>
source_nodes
;
...
...
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
浏览文件 @
bf93050c
...
...
@@ -25,15 +25,15 @@ namespace trt {
* source func:
*
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "
pd
.graph"(%a) {
* %c = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m...
* } ...
* %d = "
pd
.graph"(%c) {
* %d = "
infrt
.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* infrt.return %m...
* } ...
* %f = "
pd
.graph"(%a) {
* %f = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m...
* } ...
...
...
@@ -42,7 +42,7 @@ namespace trt {
*
* destination func:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "
pd
.graph"(%a) {
* %d, %f = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
...
...
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
浏览文件 @
bf93050c
...
...
@@ -21,18 +21,17 @@ namespace infrt {
namespace
trt
{
// Implementation of the trtGraphSplitPass。
void
TRTGraphSplitPass
::
runOnFunction
()
{
std
::
vector
<
infrt
::
pd
::
GraphOp
>
worklist
;
std
::
vector
<
::
infrt
::
GraphOp
>
worklist
;
mlir
::
Block
&
block
=
getFunction
().
front
();
for
(
auto
&
op
:
block
)
{
infrt
::
pd
::
GraphOp
graph_op
=
::
llvm
::
dyn_cast_or_null
<
infrt
::
pd
::
GraphOp
>
(
&
op
);
::
infrt
::
GraphOp
graph_op
=
::
llvm
::
dyn_cast_or_null
<::
infrt
::
GraphOp
>
(
&
op
);
if
(
nullptr
!=
graph_op
&&
graph_op
.
getBody
()
->
getOperations
().
size
()
<=
min_subgraph_size_
)
{
worklist
.
push_back
(
graph_op
);
}
}
while
(
!
worklist
.
empty
())
{
infrt
::
pd
::
GraphOp
graph_op
=
worklist
.
back
();
::
infrt
::
GraphOp
graph_op
=
worklist
.
back
();
worklist
.
pop_back
();
mlir
::
Block
*
body
=
graph_op
.
getBody
();
auto
return_op
=
body
->
getTerminator
();
...
...
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
浏览文件 @
bf93050c
...
...
@@ -26,7 +26,7 @@ namespace trt {
* source func:
*
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "
pd
.graph"(%a) {
* %d, %f = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
...
...
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
浏览文件 @
bf93050c
...
...
@@ -41,14 +41,15 @@ namespace trt {
#endif // INFRT_WITH_TRT
template
<
typename
T
>
::
mlir
::
IntegerAttr
createNvinferEnumAttr
(
::
mlir
::
PatternRewriter
&
rewriter
,
T
enum_value
)
{
::
mlir
::
IntegerAttr
createNvinferEnumAttr
(
::
mlir
::
PatternRewriter
&
rewriter
,
// NOLINT
T
enum_value
)
{
return
rewriter
.
getSI32IntegerAttr
((
int32_t
)
enum_value
);
}
template
<
>
::
mlir
::
IntegerAttr
createNvinferEnumAttr
<
std
::
string
>
(
::
mlir
::
PatternRewriter
&
rewriter
,
std
::
string
enum_value
)
{
::
mlir
::
PatternRewriter
&
rewriter
,
std
::
string
enum_value
)
{
// NOLINT
(
void
)
enum_value
;
return
rewriter
.
getSI32IntegerAttr
(
-
1
);
}
...
...
@@ -57,10 +58,11 @@ template <>
struct
PD2TRT_GraphLower
:
public
::
mlir
::
RewritePattern
{
explicit
PD2TRT_GraphLower
(
::
mlir
::
MLIRContext
*
context
)
:
::
mlir
::
RewritePattern
(
"pd.graph"
,
1
,
context
,
{
"trt.create_engine"
})
{}
:
::
mlir
::
RewritePattern
(
"infrt.graph"
,
1
,
context
,
{
"trt.create_engine"
})
{}
::
mlir
::
LogicalResult
matchAndRewrite
(
::
mlir
::
Operation
*
op
,
::
mlir
::
PatternRewriter
&
rewriter
)
const
override
{
auto
casted_op
=
::
llvm
::
dyn_cast
<
infrt
::
pd
::
GraphOp
>
(
op
);
auto
casted_op
=
::
llvm
::
dyn_cast
<
::
infrt
::
GraphOp
>
(
op
);
::
mlir
::
Operation
::
operand_range
inputs
=
casted_op
.
inputs
();
auto
ods_loc
=
rewriter
.
getFusedLoc
(
op
->
getLoc
());
CreateEngineOp
create_engine_op
;
...
...
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h
浏览文件 @
bf93050c
...
...
@@ -25,7 +25,7 @@ namespace trt {
*
* source ir:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %d, %f = "
pd
.graph"(%a) {
* %d, %f = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
...
...
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
浏览文件 @
bf93050c
...
...
@@ -40,12 +40,12 @@ void TRTOpTellerPass::runOnFunction() {
if
(
op
->
getName
().
getStringRef
().
substr
(
0
,
3
)
!=
"pd."
)
continue
;
if
(
::
llvm
::
dyn_cast_or_null
<
infrt
::
pd
::
FeedOp
>
(
op
))
continue
;
if
(
::
llvm
::
dyn_cast_or_null
<
infrt
::
pd
::
FetchOp
>
(
op
))
continue
;
if
(
::
llvm
::
dyn_cast_or_null
<
infrt
::
pd
::
GraphOp
>
(
op
))
continue
;
if
(
::
llvm
::
dyn_cast_or_null
<
::
infrt
::
GraphOp
>
(
op
))
continue
;
if
(
::
llvm
::
dyn_cast_or_null
<::
infrt
::
ReturnOp
>
(
op
))
continue
;
builder
.
setInsertionPoint
(
op
);
auto
loc
=
getFunction
().
getLoc
();
auto
graph_op
=
builder
.
create
<
infrt
::
pd
::
GraphOp
>
(
auto
graph_op
=
builder
.
create
<
::
infrt
::
GraphOp
>
(
loc
,
op
->
getResultTypes
(),
op
->
getOperands
());
::
llvm
::
SmallVector
<
mlir
::
Value
,
4
>
tblgen_repl_values
;
...
...
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
浏览文件 @
bf93050c
...
...
@@ -33,15 +33,15 @@ namespace trt {
*
* destination func:
* func @main(%a : tensor<?xf32>) -> tensor<?xf32> {
* %c = "
pd
.graph"(%a) {
* %c = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m:...
* } ...
* %d = "
pd
.graph"(%c) {
* %d = "
infrt
.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* infrt.return %m:...
* } ...
* %f = "
pd
.graph"(%a) {
* %f = "
infrt
.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* infrt.return %m:...
* } ...
...
...
tools/infrt/custom_pdop.td
浏览文件 @
bf93050c
...
...
@@ -23,16 +23,6 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<PD_Tensor>:$inputs);
let results = (outs Variadic<PD_Tensor>:$outputs);
}
def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> {
let summary = "constant Op";
let description = [{}];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录