Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5914b18a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5914b18a
编写于
5月 12, 2022
作者:
W
Wangzheee
提交者:
GitHub
5月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-Inference] support transformer generation: some passes (#42664)
* [Paddle-Inference] support transformer generation: some passes
上级
a7926ef2
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
812 addition
and
18 deletion
+812
-18
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+4
-0
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc
...ramework/ir/delete_remove_padding_recover_padding_pass.cc
+100
-0
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h
...framework/ir/delete_remove_padding_recover_padding_pass.h
+59
-0
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+298
-0
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+94
-0
paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc
.../fluid/framework/ir/set_transformer_input_convert_pass.cc
+161
-0
paddle/fluid/framework/ir/set_transformer_input_convert_pass.h
...e/fluid/framework/ir/set_transformer_input_convert_pass.h
+80
-0
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+1
-6
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+15
-12
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
5914b18a
...
...
@@ -107,6 +107,9 @@ if(WITH_TENSORRT)
pass_library
(
trt_map_matmul_to_mul_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_skip_layernorm_fuse_pass inference
)
pass_library
(
set_transformer_input_convert_pass inference
)
pass_library
(
remove_padding_recover_padding_pass inference
)
pass_library
(
delete_remove_padding_recover_padding_pass inference
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
...
...
@@ -161,6 +164,7 @@ if(WITH_IPU)
pass_library
(
infer_shape_pass base DIR ipu
)
pass_library
(
delete_scale_op_pass base DIR ipu
)
pass_library
(
avg_shard_pass base DIR ipu
)
pass_library
(
transfer_cast_op_pass base DIR ipu
)
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
...
...
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
void
RecoverPadding
::
operator
()()
{
// Create nodes for recover_padding.
auto
*
recover_padding_input
=
pattern
->
NewNode
(
recover_padding_input_repr
())
->
assert_is_op_input
(
"recover_padding"
,
"Input"
);
auto
*
recover_padding_op
=
pattern
->
NewNode
(
recover_padding_op_repr
())
->
assert_is_op
(
"recover_padding"
);
auto
*
recover_padding_out
=
pattern
->
NewNode
(
recover_padding_out_repr
())
->
assert_is_op_output
(
"recover_padding"
,
"Out"
);
// Add links for recover_padding op.
recover_padding_op
->
LinksFrom
({
recover_padding_input
})
.
LinksTo
({
recover_padding_out
});
}
}
// namespace patterns
void
DeleteRemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
//
GraphPatternDetector
gpd
;
patterns
::
RecoverPadding
recover_padding
(
gpd
.
mutable_pattern
(),
"delete_remove_padding_recover_padding_pass"
);
recover_padding
();
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"delete_remove_padding_recover_padding_pass"
;
GET_IR_NODE_FROM_SUBGRAPH
(
recover_padding_input
,
recover_padding_input
,
recover_padding
);
GET_IR_NODE_FROM_SUBGRAPH
(
recover_padding_op
,
recover_padding_op
,
recover_padding
);
GET_IR_NODE_FROM_SUBGRAPH
(
recover_padding_out
,
recover_padding_out
,
recover_padding
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
bool
delete_recover_padding
=
true
;
for
(
size_t
i
=
0
;
i
<
recover_padding_out
->
outputs
.
size
();
++
i
)
{
if
(
recover_padding_out
->
outputs
[
i
]
->
Name
()
==
"remove_padding"
)
{
// op_node
auto
*
remove_padding_out_node
=
recover_padding_out
->
outputs
[
i
]
->
outputs
[
0
];
// var_node
auto
*
out_op_node
=
remove_padding_out_node
->
outputs
[
0
];
// op_node
IR_NODE_LINK_TO
(
recover_padding_input
,
out_op_node
);
del_node_set
.
insert
(
recover_padding_out
->
outputs
[
i
]);
del_node_set
.
insert
(
remove_padding_out_node
);
out_op_node
->
Op
()
->
RenameInput
(
remove_padding_out_node
->
Name
(),
recover_padding_input
->
Name
());
found_subgraph_count
++
;
}
else
{
delete_recover_padding
=
false
;
}
}
if
(
delete_recover_padding
)
{
del_node_set
.
insert
(
recover_padding_op
);
del_node_set
.
insert
(
recover_padding_out
);
}
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_remove_padding_recover_padding_pass
,
paddle
::
framework
::
ir
::
DeleteRemovePaddingRecoverPaddingPass
);
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
RecoverPadding
:
public
PatternBase
{
RecoverPadding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"recover_padding"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
recover_padding_input
);
PATTERN_DECL_NODE
(
recover_padding_op
);
PATTERN_DECL_NODE
(
recover_padding_out
);
};
}
// namespace patterns
class
DeleteRemovePaddingRecoverPaddingPass
:
public
FusePassBase
{
public:
DeleteRemovePaddingRecoverPaddingPass
()
{}
virtual
~
DeleteRemovePaddingRecoverPaddingPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"delete_remove_padding_recover_padding_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
void
SkipLayernorm
::
operator
()()
{
// Create nodes for skip_layernorm.
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
->
assert_is_op_input
(
"skip_layernorm"
,
"X"
);
auto
*
skip_layernorm_y
=
pattern
->
NewNode
(
skip_layernorm_y_repr
())
->
assert_is_op_input
(
"skip_layernorm"
,
"Y"
);
auto
*
skip_layernorm_op
=
pattern
->
NewNode
(
skip_layernorm_op_repr
())
->
assert_is_op
(
"skip_layernorm"
);
auto
*
skip_layernorm_out
=
pattern
->
NewNode
(
skip_layernorm_out_repr
())
->
assert_is_op_output
(
"skip_layernorm"
,
"Out"
);
// Add links for skip_layernorm op.
skip_layernorm_op
->
LinksFrom
({
skip_layernorm_x
,
skip_layernorm_y
})
.
LinksTo
({
skip_layernorm_out
});
}
void
MultiheadMatmul
::
operator
()()
{
// Create nodes for multihead_matmul.
auto
*
multihead_matmul_input
=
pattern
->
NewNode
(
multihead_matmul_input_repr
())
->
assert_is_op_input
(
"multihead_matmul"
,
"Input"
);
auto
*
multihead_matmul_op
=
pattern
->
NewNode
(
multihead_matmul_op_repr
())
->
assert_is_op
(
"multihead_matmul"
);
auto
*
multihead_matmul_out
=
pattern
->
NewNode
(
multihead_matmul_out_repr
())
->
assert_is_op_output
(
"multihead_matmul"
,
"Out"
);
// Add links for multihead_matmul op.
multihead_matmul_op
->
LinksFrom
({
multihead_matmul_input
})
.
LinksTo
({
multihead_matmul_out
});
}
void
Fc
::
operator
()()
{
// Create nodes for fc.
auto
*
fc_input
=
pattern
->
NewNode
(
fc_input_repr
())
->
assert_is_op_input
(
"fc"
,
"Input"
);
auto
*
fc_op
=
pattern
->
NewNode
(
fc_op_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
);
// Add links for fc op.
fc_op
->
LinksFrom
({
fc_input
}).
LinksTo
({
fc_out
});
}
void
Activation
::
operator
()()
{
// Create nodes for activation.
std
::
unordered_set
<
std
::
string
>
activation_ops
{
"relu"
,
"sigmoid"
,
"tanh"
};
auto
*
activation_input
=
pattern
->
NewNode
(
activation_input_repr
())
->
assert_is_ops_input
(
activation_ops
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_op_repr
())
->
assert_is_ops
(
activation_ops
);
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
->
assert_is_ops_output
(
activation_ops
);
// Add links for activation op.
activation_op
->
LinksFrom
({
activation_input
}).
LinksTo
({
activation_out
});
}
}
// namespace patterns
void
RemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
int
found_subgraph_count
=
0
;
// Create an remove_padding op node
auto
insert_remove_padding_op
=
[
&
](
Node
*
input_node
,
Node
*
op_node
)
{
// create op, var in graph
OpDesc
remove_padding
;
std
::
string
remove_padding_out_name
=
input_node
->
Name
()
+
".remove_padding"
;
VarDesc
remove_padding_out
(
remove_padding_out_name
);
remove_padding_out
.
SetDataType
(
input_node
->
Var
()
->
GetDataType
());
remove_padding_out
.
SetShape
(
input_node
->
Var
()
->
GetShape
());
remove_padding_out
.
SetPersistable
(
false
);
// remove_padding_op
remove_padding
.
SetType
(
"remove_padding"
);
// input
remove_padding
.
SetInput
(
"Input"
,
{
input_node
->
Name
()});
// output
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
auto
remove_padding_out_node
=
graph
->
CreateVarNode
(
&
remove_padding_out
);
// replace link
for
(
size_t
i
=
0
;
i
<
input_node
->
outputs
.
size
();
++
i
)
{
if
(
input_node
->
outputs
[
i
]
==
op_node
)
{
input_node
->
outputs
[
i
]
=
remove_padding_op_node
;
remove_padding_op_node
->
inputs
.
push_back
(
input_node
);
}
}
// link node
IR_NODE_LINK_TO
(
remove_padding_op_node
,
remove_padding_out_node
);
// replace link
for
(
size_t
i
=
0
;
i
<
op_node
->
inputs
.
size
();
++
i
)
{
if
(
op_node
->
inputs
[
i
]
==
input_node
)
{
op_node
->
inputs
[
i
]
=
remove_padding_out_node
;
remove_padding_out_node
->
outputs
.
push_back
(
op_node
);
}
}
// create variable in scope
scope
->
Var
(
remove_padding_out_name
);
auto
*
remove_padding_out_tensor
=
scope
->
FindVar
(
remove_padding_out_name
)
->
GetMutable
<
LoDTensor
>
();
remove_padding_out_tensor
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
());
// rename
op_node
->
Op
()
->
RenameInput
(
input_node
->
Name
(),
remove_padding_out_node
->
Name
());
};
// create an remove_padding op node
auto
insert_recover_padding_op
=
[
&
](
Node
*
op_node
,
Node
*
out_node
)
{
// create op, var in graph
OpDesc
recover_padding
;
std
::
string
recover_padding_input_name
=
out_node
->
Name
()
+
".recover_padding"
;
VarDesc
recover_padding_input
(
recover_padding_input_name
);
recover_padding_input
.
SetDataType
(
out_node
->
Var
()
->
GetDataType
());
recover_padding_input
.
SetShape
(
out_node
->
Var
()
->
GetShape
());
recover_padding_input
.
SetPersistable
(
false
);
// recover_padding_op
recover_padding
.
SetType
(
"recover_padding"
);
// input
recover_padding
.
SetInput
(
"Input"
,
{
recover_padding_input_name
});
// output
recover_padding
.
SetOutput
(
"Out"
,
{
out_node
->
Name
()});
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
auto
recover_padding_input_node
=
graph
->
CreateVarNode
(
&
recover_padding_input
);
// replace link
for
(
size_t
i
=
0
;
i
<
op_node
->
outputs
.
size
();
++
i
)
{
if
(
op_node
->
outputs
[
i
]
==
out_node
)
{
op_node
->
outputs
[
i
]
=
recover_padding_input_node
;
recover_padding_input_node
->
inputs
.
push_back
(
op_node
);
}
}
// link node
IR_NODE_LINK_TO
(
recover_padding_input_node
,
recover_padding_op_node
);
// replace link
for
(
size_t
i
=
0
;
i
<
out_node
->
inputs
.
size
();
++
i
)
{
if
(
out_node
->
inputs
[
i
]
==
op_node
)
{
out_node
->
inputs
[
i
]
=
recover_padding_op_node
;
recover_padding_op_node
->
outputs
.
push_back
(
out_node
);
}
}
// create variable in scope
scope
->
Var
(
recover_padding_input_name
);
auto
*
recover_padding_input_tensor
=
scope
->
FindVar
(
recover_padding_input_name
)
->
GetMutable
<
LoDTensor
>
();
recover_padding_input_tensor
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
());
// rename
op_node
->
Op
()
->
RenameOutput
(
out_node
->
Name
(),
recover_padding_input_name
);
};
GraphPatternDetector
gpd1
;
patterns
::
SkipLayernorm
skip_layernorm
(
gpd1
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
skip_layernorm
();
auto
handler1
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
insert_remove_padding_op
(
skip_layernorm_x
,
skip_layernorm_op
);
insert_remove_padding_op
(
skip_layernorm_y
,
skip_layernorm_op
);
insert_recover_padding_op
(
skip_layernorm_op
,
skip_layernorm_out
);
found_subgraph_count
++
;
};
gpd1
(
graph
,
handler1
);
GraphPatternDetector
gpd2
;
patterns
::
MultiheadMatmul
multihead_matmul
(
gpd2
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
multihead_matmul
();
auto
handler2
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
insert_remove_padding_op
(
multihead_matmul_input
,
multihead_matmul_op
);
insert_recover_padding_op
(
multihead_matmul_op
,
multihead_matmul_out
);
found_subgraph_count
++
;
};
gpd2
(
graph
,
handler2
);
GraphPatternDetector
gpd3
;
patterns
::
Fc
fc
(
gpd3
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fc
();
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: fc"
;
GET_IR_NODE_FROM_SUBGRAPH
(
fc_input
,
fc_input
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_op
,
fc_op
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc
);
insert_remove_padding_op
(
fc_input
,
fc_op
);
insert_recover_padding_op
(
fc_op
,
fc_out
);
found_subgraph_count
++
;
};
gpd3
(
graph
,
handler3
);
GraphPatternDetector
gpd4
;
patterns
::
Activation
activation
(
gpd4
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
activation
();
auto
handler4
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: activation"
;
GET_IR_NODE_FROM_SUBGRAPH
(
activation_input
,
activation_input
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_op
,
activation_op
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
activation
);
insert_remove_padding_op
(
activation_input
,
activation_op
);
insert_recover_padding_op
(
activation_op
,
activation_out
);
found_subgraph_count
++
;
};
gpd4
(
graph
,
handler4
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
remove_padding_recover_padding_pass
,
paddle
::
framework
::
ir
::
RemovePaddingRecoverPaddingPass
);
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
SkipLayernorm
:
public
PatternBase
{
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
skip_layernorm_x
);
PATTERN_DECL_NODE
(
skip_layernorm_y
);
PATTERN_DECL_NODE
(
skip_layernorm_op
);
PATTERN_DECL_NODE
(
skip_layernorm_out
);
};
struct
MultiheadMatmul
:
public
PatternBase
{
MultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
multihead_matmul_input
);
PATTERN_DECL_NODE
(
multihead_matmul_op
);
PATTERN_DECL_NODE
(
multihead_matmul_out
);
};
struct
Fc
:
public
PatternBase
{
Fc
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
fc_input
);
PATTERN_DECL_NODE
(
fc_op
);
PATTERN_DECL_NODE
(
fc_out
);
};
struct
Activation
:
public
PatternBase
{
Activation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"activation"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
activation_input
);
PATTERN_DECL_NODE
(
activation_op
);
PATTERN_DECL_NODE
(
activation_out
);
};
}
// namespace patterns
class
RemovePaddingRecoverPaddingPass
:
public
FusePassBase
{
public:
RemovePaddingRecoverPaddingPass
()
{}
virtual
~
RemovePaddingRecoverPaddingPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"remove_padding_recover_padding_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/set_transformer_input_convert_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
SetTransformerInputConvertPass
::
SetTransformerInputConvertPass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
End
();
}
namespace
patterns
{
void
SetTransformerInputConvert
::
operator
()()
{
std
::
unordered_set
<
std
::
string
>
lookup_table_ops
{
"lookup_table"
,
"lookup_table_v2"
};
// Create nodes for lookup_table1 op.
auto
*
lookup_table1_x
=
pattern
->
NewNode
(
lookup_table1_x_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
);
auto
*
lookup_table1_w
=
pattern
->
NewNode
(
lookup_table1_w_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"W"
);
auto
*
lookup_table1_op
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
lookup_table_ops
);
auto
*
lookup_table1_out
=
pattern
->
NewNode
(
lookup_table1_out_repr
())
->
assert_is_ops_output
(
lookup_table_ops
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
// Create nodes for lookup_table2 op.
auto
*
lookup_table2_x
=
pattern
->
NewNode
(
lookup_table2_x_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
);
auto
*
lookup_table2_w
=
pattern
->
NewNode
(
lookup_table2_w_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"W"
);
auto
*
lookup_table2_op
=
pattern
->
NewNode
(
lookup_table2_repr
())
->
assert_is_ops
(
lookup_table_ops
);
auto
*
lookup_table2_out
=
pattern
->
NewNode
(
lookup_table2_out_repr
())
->
assert_is_ops_output
(
lookup_table_ops
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
// Create nodes for elementwise_add op.
auto
*
elementwise_op
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_out
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsOutput
()
->
assert_is_only_output_of_op
(
"elementwise_add"
);
// links nodes.
lookup_table1_op
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
lookup_table2_op
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
elementwise_op
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
.
LinksTo
({
elementwise_out
});
}
}
// namespace patterns
void
SetTransformerInputConvertPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
patterns
::
SetTransformerInputConvert
fused_pattern
(
gpd
.
mutable_pattern
(),
"transformer_input_convert_pass"
);
fused_pattern
();
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"transformer_input_convert_pass in op compat failed."
;
return
;
}
VLOG
(
3
)
<<
"transformer_input_convert_pass for pos_id, max_seqlen"
;
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_x
,
lookup_table2_x
,
fused_pattern
);
// create op, var in graph
OpDesc
new_desc
;
new_desc
.
SetType
(
"transformer_input_convert"
);
// inputs
new_desc
.
SetInput
(
"X"
,
{
lookup_table2_x
->
Name
()});
// outputs
std
::
vector
<
std
::
string
>
output_0
=
{
"pos_id_tensor"
};
std
::
vector
<
std
::
string
>
output_1
=
{
"max_seqlen_tensor"
};
new_desc
.
SetOutput
(
"PosId"
,
output_0
);
new_desc
.
SetOutput
(
"MaxSeqlen"
,
output_1
);
std
::
string
transformer_input_convert_out0_name
=
"pos_id_tensor"
;
std
::
string
transformer_input_convert_out1_name
=
"max_seqlen_tensor"
;
VarDesc
transformer_input_convert_out0
(
transformer_input_convert_out0_name
);
VarDesc
transformer_input_convert_out1
(
transformer_input_convert_out1_name
);
transformer_input_convert_out0
.
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out1
.
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out0
.
SetShape
({
-
1
});
transformer_input_convert_out1
.
SetShape
({
-
1
});
transformer_input_convert_out0
.
SetPersistable
(
false
);
transformer_input_convert_out1
.
SetPersistable
(
false
);
auto
new_op_node
=
graph
->
CreateOpNode
(
&
new_desc
);
auto
transformer_input_convert_out0_node
=
graph
->
CreateVarNode
(
&
transformer_input_convert_out0
);
auto
transformer_input_convert_out1_node
=
graph
->
CreateVarNode
(
&
transformer_input_convert_out1
);
// needn't create variable in scope
IR_NODE_LINK_TO
(
lookup_table2_x
,
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out0_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out1_node
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
set_transformer_input_convert_pass
,
paddle
::
framework
::
ir
::
SetTransformerInputConvertPass
);
REGISTER_PASS_CAPABILITY
(
set_transformer_input_convert_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"lookup_table"
,
1
)
.
LE
(
"lookup_table_v2"
,
1
)
.
LE
(
"elementweise_add"
,
1
));
paddle/fluid/framework/ir/set_transformer_input_convert_pass.h
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct
SetTransformerInputConvert
:
public
PatternBase
{
SetTransformerInputConvert
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"transformer_input_convert"
)
{}
void
operator
()();
// declare operator node's name
PATTERN_DECL_NODE
(
lookup_table1
);
PATTERN_DECL_NODE
(
lookup_table2
);
PATTERN_DECL_NODE
(
elementwise
);
// declare variable node's name
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1_out
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table2_w
);
PATTERN_DECL_NODE
(
lookup_table2_out
);
PATTERN_DECL_NODE
(
elementwise_out
);
};
}
// namespace patterns
class
SetTransformerInputConvertPass
:
public
FusePassBase
{
public:
SetTransformerInputConvertPass
();
virtual
~
SetTransformerInputConvertPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"transformer_input_convert_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
5914b18a
...
...
@@ -377,12 +377,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine
->
SetUseDLA
(
Get
<
bool
>
(
"trt_use_dla"
));
trt_engine
->
SetDLACore
(
Get
<
int
>
(
"trt_dla_core"
));
trt_engine
->
SetUseInspector
(
Get
<
bool
>
(
"use_inspector"
));
trt_engine
->
SetWithErnie
(
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
||
(
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
)));
trt_engine
->
SetWithErnie
(
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
));
if
(
use_static_engine
)
{
trt_engine_serialized_data
=
GetTrtEngineSerializedData
(
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
5914b18a
...
...
@@ -98,6 +98,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"multihead_matmul_fuse_pass_v3"
,
//
"skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
...
...
@@ -108,6 +109,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
// "remove_padding_recover_padding_pass", //
// "delete_remove_padding_recover_padding_pass", //
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录