Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5914b18a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5914b18a
编写于
5月 12, 2022
作者:
W
Wangzheee
提交者:
GitHub
5月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle-Inference] support transformer generation: some passes (#42664)
* [Paddle-Inference] support transformer generation: some passes
上级
a7926ef2
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
812 addition
and
18 deletion
+812
-18
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+4
-0
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc
...ramework/ir/delete_remove_padding_recover_padding_pass.cc
+100
-0
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h
...framework/ir/delete_remove_padding_recover_padding_pass.h
+59
-0
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+298
-0
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+94
-0
paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc
.../fluid/framework/ir/set_transformer_input_convert_pass.cc
+161
-0
paddle/fluid/framework/ir/set_transformer_input_convert_pass.h
...e/fluid/framework/ir/set_transformer_input_convert_pass.h
+80
-0
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+1
-6
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+15
-12
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
5914b18a
...
...
@@ -107,6 +107,9 @@ if(WITH_TENSORRT)
pass_library
(
trt_map_matmul_to_mul_pass inference
)
pass_library
(
preln_embedding_eltwise_layernorm_fuse_pass inference
)
pass_library
(
preln_skip_layernorm_fuse_pass inference
)
pass_library
(
set_transformer_input_convert_pass inference
)
pass_library
(
remove_padding_recover_padding_pass inference
)
pass_library
(
delete_remove_padding_recover_padding_pass inference
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
...
...
@@ -161,6 +164,7 @@ if(WITH_IPU)
pass_library
(
infer_shape_pass base DIR ipu
)
pass_library
(
delete_scale_op_pass base DIR ipu
)
pass_library
(
avg_shard_pass base DIR ipu
)
pass_library
(
transfer_cast_op_pass base DIR ipu
)
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
...
...
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
void
RecoverPadding
::
operator
()()
{
// Create nodes for recover_padding.
auto
*
recover_padding_input
=
pattern
->
NewNode
(
recover_padding_input_repr
())
->
assert_is_op_input
(
"recover_padding"
,
"Input"
);
auto
*
recover_padding_op
=
pattern
->
NewNode
(
recover_padding_op_repr
())
->
assert_is_op
(
"recover_padding"
);
auto
*
recover_padding_out
=
pattern
->
NewNode
(
recover_padding_out_repr
())
->
assert_is_op_output
(
"recover_padding"
,
"Out"
);
// Add links for recover_padding op.
recover_padding_op
->
LinksFrom
({
recover_padding_input
})
.
LinksTo
({
recover_padding_out
});
}
}
// namespace patterns
void
DeleteRemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
//
GraphPatternDetector
gpd
;
patterns
::
RecoverPadding
recover_padding
(
gpd
.
mutable_pattern
(),
"delete_remove_padding_recover_padding_pass"
);
recover_padding
();
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"delete_remove_padding_recover_padding_pass"
;
GET_IR_NODE_FROM_SUBGRAPH
(
recover_padding_input
,
recover_padding_input
,
recover_padding
);
GET_IR_NODE_FROM_SUBGRAPH
(
recover_padding_op
,
recover_padding_op
,
recover_padding
);
GET_IR_NODE_FROM_SUBGRAPH
(
recover_padding_out
,
recover_padding_out
,
recover_padding
);
std
::
unordered_set
<
const
Node
*>
del_node_set
;
bool
delete_recover_padding
=
true
;
for
(
size_t
i
=
0
;
i
<
recover_padding_out
->
outputs
.
size
();
++
i
)
{
if
(
recover_padding_out
->
outputs
[
i
]
->
Name
()
==
"remove_padding"
)
{
// op_node
auto
*
remove_padding_out_node
=
recover_padding_out
->
outputs
[
i
]
->
outputs
[
0
];
// var_node
auto
*
out_op_node
=
remove_padding_out_node
->
outputs
[
0
];
// op_node
IR_NODE_LINK_TO
(
recover_padding_input
,
out_op_node
);
del_node_set
.
insert
(
recover_padding_out
->
outputs
[
i
]);
del_node_set
.
insert
(
remove_padding_out_node
);
out_op_node
->
Op
()
->
RenameInput
(
remove_padding_out_node
->
Name
(),
recover_padding_input
->
Name
());
found_subgraph_count
++
;
}
else
{
delete_recover_padding
=
false
;
}
}
if
(
delete_recover_padding
)
{
del_node_set
.
insert
(
recover_padding_op
);
del_node_set
.
insert
(
recover_padding_out
);
}
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_remove_padding_recover_padding_pass
,
paddle
::
framework
::
ir
::
DeleteRemovePaddingRecoverPaddingPass
);
paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
RecoverPadding
:
public
PatternBase
{
RecoverPadding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"recover_padding"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
recover_padding_input
);
PATTERN_DECL_NODE
(
recover_padding_op
);
PATTERN_DECL_NODE
(
recover_padding_out
);
};
}
// namespace patterns
class
DeleteRemovePaddingRecoverPaddingPass
:
public
FusePassBase
{
public:
DeleteRemovePaddingRecoverPaddingPass
()
{}
virtual
~
DeleteRemovePaddingRecoverPaddingPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"delete_remove_padding_recover_padding_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
void
SkipLayernorm
::
operator
()()
{
// Create nodes for skip_layernorm.
auto
*
skip_layernorm_x
=
pattern
->
NewNode
(
skip_layernorm_x_repr
())
->
assert_is_op_input
(
"skip_layernorm"
,
"X"
);
auto
*
skip_layernorm_y
=
pattern
->
NewNode
(
skip_layernorm_y_repr
())
->
assert_is_op_input
(
"skip_layernorm"
,
"Y"
);
auto
*
skip_layernorm_op
=
pattern
->
NewNode
(
skip_layernorm_op_repr
())
->
assert_is_op
(
"skip_layernorm"
);
auto
*
skip_layernorm_out
=
pattern
->
NewNode
(
skip_layernorm_out_repr
())
->
assert_is_op_output
(
"skip_layernorm"
,
"Out"
);
// Add links for skip_layernorm op.
skip_layernorm_op
->
LinksFrom
({
skip_layernorm_x
,
skip_layernorm_y
})
.
LinksTo
({
skip_layernorm_out
});
}
void
MultiheadMatmul
::
operator
()()
{
// Create nodes for multihead_matmul.
auto
*
multihead_matmul_input
=
pattern
->
NewNode
(
multihead_matmul_input_repr
())
->
assert_is_op_input
(
"multihead_matmul"
,
"Input"
);
auto
*
multihead_matmul_op
=
pattern
->
NewNode
(
multihead_matmul_op_repr
())
->
assert_is_op
(
"multihead_matmul"
);
auto
*
multihead_matmul_out
=
pattern
->
NewNode
(
multihead_matmul_out_repr
())
->
assert_is_op_output
(
"multihead_matmul"
,
"Out"
);
// Add links for multihead_matmul op.
multihead_matmul_op
->
LinksFrom
({
multihead_matmul_input
})
.
LinksTo
({
multihead_matmul_out
});
}
void
Fc
::
operator
()()
{
// Create nodes for fc.
auto
*
fc_input
=
pattern
->
NewNode
(
fc_input_repr
())
->
assert_is_op_input
(
"fc"
,
"Input"
);
auto
*
fc_op
=
pattern
->
NewNode
(
fc_op_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
);
// Add links for fc op.
fc_op
->
LinksFrom
({
fc_input
}).
LinksTo
({
fc_out
});
}
void
Activation
::
operator
()()
{
// Create nodes for activation.
std
::
unordered_set
<
std
::
string
>
activation_ops
{
"relu"
,
"sigmoid"
,
"tanh"
};
auto
*
activation_input
=
pattern
->
NewNode
(
activation_input_repr
())
->
assert_is_ops_input
(
activation_ops
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_op_repr
())
->
assert_is_ops
(
activation_ops
);
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
->
assert_is_ops_output
(
activation_ops
);
// Add links for activation op.
activation_op
->
LinksFrom
({
activation_input
}).
LinksTo
({
activation_out
});
}
}
// namespace patterns
void
RemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
int
found_subgraph_count
=
0
;
// Create an remove_padding op node
auto
insert_remove_padding_op
=
[
&
](
Node
*
input_node
,
Node
*
op_node
)
{
// create op, var in graph
OpDesc
remove_padding
;
std
::
string
remove_padding_out_name
=
input_node
->
Name
()
+
".remove_padding"
;
VarDesc
remove_padding_out
(
remove_padding_out_name
);
remove_padding_out
.
SetDataType
(
input_node
->
Var
()
->
GetDataType
());
remove_padding_out
.
SetShape
(
input_node
->
Var
()
->
GetShape
());
remove_padding_out
.
SetPersistable
(
false
);
// remove_padding_op
remove_padding
.
SetType
(
"remove_padding"
);
// input
remove_padding
.
SetInput
(
"Input"
,
{
input_node
->
Name
()});
// output
remove_padding
.
SetOutput
(
"Out"
,
{
remove_padding_out_name
});
auto
remove_padding_op_node
=
graph
->
CreateOpNode
(
&
remove_padding
);
auto
remove_padding_out_node
=
graph
->
CreateVarNode
(
&
remove_padding_out
);
// replace link
for
(
size_t
i
=
0
;
i
<
input_node
->
outputs
.
size
();
++
i
)
{
if
(
input_node
->
outputs
[
i
]
==
op_node
)
{
input_node
->
outputs
[
i
]
=
remove_padding_op_node
;
remove_padding_op_node
->
inputs
.
push_back
(
input_node
);
}
}
// link node
IR_NODE_LINK_TO
(
remove_padding_op_node
,
remove_padding_out_node
);
// replace link
for
(
size_t
i
=
0
;
i
<
op_node
->
inputs
.
size
();
++
i
)
{
if
(
op_node
->
inputs
[
i
]
==
input_node
)
{
op_node
->
inputs
[
i
]
=
remove_padding_out_node
;
remove_padding_out_node
->
outputs
.
push_back
(
op_node
);
}
}
// create variable in scope
scope
->
Var
(
remove_padding_out_name
);
auto
*
remove_padding_out_tensor
=
scope
->
FindVar
(
remove_padding_out_name
)
->
GetMutable
<
LoDTensor
>
();
remove_padding_out_tensor
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
());
// rename
op_node
->
Op
()
->
RenameInput
(
input_node
->
Name
(),
remove_padding_out_node
->
Name
());
};
// create an remove_padding op node
auto
insert_recover_padding_op
=
[
&
](
Node
*
op_node
,
Node
*
out_node
)
{
// create op, var in graph
OpDesc
recover_padding
;
std
::
string
recover_padding_input_name
=
out_node
->
Name
()
+
".recover_padding"
;
VarDesc
recover_padding_input
(
recover_padding_input_name
);
recover_padding_input
.
SetDataType
(
out_node
->
Var
()
->
GetDataType
());
recover_padding_input
.
SetShape
(
out_node
->
Var
()
->
GetShape
());
recover_padding_input
.
SetPersistable
(
false
);
// recover_padding_op
recover_padding
.
SetType
(
"recover_padding"
);
// input
recover_padding
.
SetInput
(
"Input"
,
{
recover_padding_input_name
});
// output
recover_padding
.
SetOutput
(
"Out"
,
{
out_node
->
Name
()});
auto
recover_padding_op_node
=
graph
->
CreateOpNode
(
&
recover_padding
);
auto
recover_padding_input_node
=
graph
->
CreateVarNode
(
&
recover_padding_input
);
// replace link
for
(
size_t
i
=
0
;
i
<
op_node
->
outputs
.
size
();
++
i
)
{
if
(
op_node
->
outputs
[
i
]
==
out_node
)
{
op_node
->
outputs
[
i
]
=
recover_padding_input_node
;
recover_padding_input_node
->
inputs
.
push_back
(
op_node
);
}
}
// link node
IR_NODE_LINK_TO
(
recover_padding_input_node
,
recover_padding_op_node
);
// replace link
for
(
size_t
i
=
0
;
i
<
out_node
->
inputs
.
size
();
++
i
)
{
if
(
out_node
->
inputs
[
i
]
==
op_node
)
{
out_node
->
inputs
[
i
]
=
recover_padding_op_node
;
recover_padding_op_node
->
outputs
.
push_back
(
out_node
);
}
}
// create variable in scope
scope
->
Var
(
recover_padding_input_name
);
auto
*
recover_padding_input_tensor
=
scope
->
FindVar
(
recover_padding_input_name
)
->
GetMutable
<
LoDTensor
>
();
recover_padding_input_tensor
->
mutable_data
<
float
>
(
platform
::
CUDAPlace
());
// rename
op_node
->
Op
()
->
RenameOutput
(
out_node
->
Name
(),
recover_padding_input_name
);
};
GraphPatternDetector
gpd1
;
patterns
::
SkipLayernorm
skip_layernorm
(
gpd1
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
skip_layernorm
();
auto
handler1
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"
;
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_x
,
skip_layernorm_x
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_y
,
skip_layernorm_y
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_op
,
skip_layernorm_op
,
skip_layernorm
);
GET_IR_NODE_FROM_SUBGRAPH
(
skip_layernorm_out
,
skip_layernorm_out
,
skip_layernorm
);
insert_remove_padding_op
(
skip_layernorm_x
,
skip_layernorm_op
);
insert_remove_padding_op
(
skip_layernorm_y
,
skip_layernorm_op
);
insert_recover_padding_op
(
skip_layernorm_op
,
skip_layernorm_out
);
found_subgraph_count
++
;
};
gpd1
(
graph
,
handler1
);
GraphPatternDetector
gpd2
;
patterns
::
MultiheadMatmul
multihead_matmul
(
gpd2
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
multihead_matmul
();
auto
handler2
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_input
,
multihead_matmul_input
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_op
,
multihead_matmul_op
,
multihead_matmul
);
GET_IR_NODE_FROM_SUBGRAPH
(
multihead_matmul_out
,
multihead_matmul_out
,
multihead_matmul
);
insert_remove_padding_op
(
multihead_matmul_input
,
multihead_matmul_op
);
insert_recover_padding_op
(
multihead_matmul_op
,
multihead_matmul_out
);
found_subgraph_count
++
;
};
gpd2
(
graph
,
handler2
);
GraphPatternDetector
gpd3
;
patterns
::
Fc
fc
(
gpd3
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fc
();
auto
handler3
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: fc"
;
GET_IR_NODE_FROM_SUBGRAPH
(
fc_input
,
fc_input
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_op
,
fc_op
,
fc
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc
);
insert_remove_padding_op
(
fc_input
,
fc_op
);
insert_recover_padding_op
(
fc_op
,
fc_out
);
found_subgraph_count
++
;
};
gpd3
(
graph
,
handler3
);
GraphPatternDetector
gpd4
;
patterns
::
Activation
activation
(
gpd4
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
activation
();
auto
handler4
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: activation"
;
GET_IR_NODE_FROM_SUBGRAPH
(
activation_input
,
activation_input
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_op
,
activation_op
,
activation
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
activation
);
insert_remove_padding_op
(
activation_input
,
activation_op
);
insert_recover_padding_op
(
activation_op
,
activation_out
);
found_subgraph_count
++
;
};
gpd4
(
graph
,
handler4
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
remove_padding_recover_padding_pass
,
paddle
::
framework
::
ir
::
RemovePaddingRecoverPaddingPass
);
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
SkipLayernorm
:
public
PatternBase
{
SkipLayernorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_layernorm"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
skip_layernorm_x
);
PATTERN_DECL_NODE
(
skip_layernorm_y
);
PATTERN_DECL_NODE
(
skip_layernorm_op
);
PATTERN_DECL_NODE
(
skip_layernorm_out
);
};
struct
MultiheadMatmul
:
public
PatternBase
{
MultiheadMatmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multihead_matmul"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
multihead_matmul_input
);
PATTERN_DECL_NODE
(
multihead_matmul_op
);
PATTERN_DECL_NODE
(
multihead_matmul_out
);
};
struct
Fc
:
public
PatternBase
{
Fc
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
fc_input
);
PATTERN_DECL_NODE
(
fc_op
);
PATTERN_DECL_NODE
(
fc_out
);
};
struct
Activation
:
public
PatternBase
{
Activation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"activation"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
activation_input
);
PATTERN_DECL_NODE
(
activation_op
);
PATTERN_DECL_NODE
(
activation_out
);
};
}
// namespace patterns
class
RemovePaddingRecoverPaddingPass
:
public
FusePassBase
{
public:
RemovePaddingRecoverPaddingPass
()
{}
virtual
~
RemovePaddingRecoverPaddingPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"remove_padding_recover_padding_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/set_transformer_input_convert_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
SetTransformerInputConvertPass
::
SetTransformerInputConvertPass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
End
();
}
namespace
patterns
{
void
SetTransformerInputConvert
::
operator
()()
{
std
::
unordered_set
<
std
::
string
>
lookup_table_ops
{
"lookup_table"
,
"lookup_table_v2"
};
// Create nodes for lookup_table1 op.
auto
*
lookup_table1_x
=
pattern
->
NewNode
(
lookup_table1_x_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
);
auto
*
lookup_table1_w
=
pattern
->
NewNode
(
lookup_table1_w_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"W"
);
auto
*
lookup_table1_op
=
pattern
->
NewNode
(
lookup_table1_repr
())
->
assert_is_ops
(
lookup_table_ops
);
auto
*
lookup_table1_out
=
pattern
->
NewNode
(
lookup_table1_out_repr
())
->
assert_is_ops_output
(
lookup_table_ops
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
// Create nodes for lookup_table2 op.
auto
*
lookup_table2_x
=
pattern
->
NewNode
(
lookup_table2_x_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"Ids"
);
auto
*
lookup_table2_w
=
pattern
->
NewNode
(
lookup_table2_w_repr
())
->
assert_is_ops_input
(
lookup_table_ops
,
"W"
);
auto
*
lookup_table2_op
=
pattern
->
NewNode
(
lookup_table2_repr
())
->
assert_is_ops
(
lookup_table_ops
);
auto
*
lookup_table2_out
=
pattern
->
NewNode
(
lookup_table2_out_repr
())
->
assert_is_ops_output
(
lookup_table_ops
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
// Create nodes for elementwise_add op.
auto
*
elementwise_op
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_out
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsOutput
()
->
assert_is_only_output_of_op
(
"elementwise_add"
);
// links nodes.
lookup_table1_op
->
LinksFrom
({
lookup_table1_x
,
lookup_table1_w
})
.
LinksTo
({
lookup_table1_out
});
lookup_table2_op
->
LinksFrom
({
lookup_table2_x
,
lookup_table2_w
})
.
LinksTo
({
lookup_table2_out
});
elementwise_op
->
LinksFrom
({
lookup_table1_out
,
lookup_table2_out
})
.
LinksTo
({
elementwise_out
});
}
}
// namespace patterns
void
SetTransformerInputConvertPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
patterns
::
SetTransformerInputConvert
fused_pattern
(
gpd
.
mutable_pattern
(),
"transformer_input_convert_pass"
);
fused_pattern
();
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"transformer_input_convert_pass in op compat failed."
;
return
;
}
VLOG
(
3
)
<<
"transformer_input_convert_pass for pos_id, max_seqlen"
;
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table2_x
,
lookup_table2_x
,
fused_pattern
);
// create op, var in graph
OpDesc
new_desc
;
new_desc
.
SetType
(
"transformer_input_convert"
);
// inputs
new_desc
.
SetInput
(
"X"
,
{
lookup_table2_x
->
Name
()});
// outputs
std
::
vector
<
std
::
string
>
output_0
=
{
"pos_id_tensor"
};
std
::
vector
<
std
::
string
>
output_1
=
{
"max_seqlen_tensor"
};
new_desc
.
SetOutput
(
"PosId"
,
output_0
);
new_desc
.
SetOutput
(
"MaxSeqlen"
,
output_1
);
std
::
string
transformer_input_convert_out0_name
=
"pos_id_tensor"
;
std
::
string
transformer_input_convert_out1_name
=
"max_seqlen_tensor"
;
VarDesc
transformer_input_convert_out0
(
transformer_input_convert_out0_name
);
VarDesc
transformer_input_convert_out1
(
transformer_input_convert_out1_name
);
transformer_input_convert_out0
.
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out1
.
SetDataType
(
proto
::
VarType
::
INT32
);
transformer_input_convert_out0
.
SetShape
({
-
1
});
transformer_input_convert_out1
.
SetShape
({
-
1
});
transformer_input_convert_out0
.
SetPersistable
(
false
);
transformer_input_convert_out1
.
SetPersistable
(
false
);
auto
new_op_node
=
graph
->
CreateOpNode
(
&
new_desc
);
auto
transformer_input_convert_out0_node
=
graph
->
CreateVarNode
(
&
transformer_input_convert_out0
);
auto
transformer_input_convert_out1_node
=
graph
->
CreateVarNode
(
&
transformer_input_convert_out1
);
// needn't create variable in scope
IR_NODE_LINK_TO
(
lookup_table2_x
,
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out0_node
);
IR_NODE_LINK_TO
(
new_op_node
,
transformer_input_convert_out1_node
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
set_transformer_input_convert_pass
,
paddle
::
framework
::
ir
::
SetTransformerInputConvertPass
);
REGISTER_PASS_CAPABILITY
(
set_transformer_input_convert_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"lookup_table"
,
1
)
.
LE
(
"lookup_table_v2"
,
1
)
.
LE
(
"elementweise_add"
,
1
));
paddle/fluid/framework/ir/set_transformer_input_convert_pass.h
0 → 100644
浏览文件 @
5914b18a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct
SetTransformerInputConvert
:
public
PatternBase
{
SetTransformerInputConvert
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"transformer_input_convert"
)
{}
void
operator
()();
// declare operator node's name
PATTERN_DECL_NODE
(
lookup_table1
);
PATTERN_DECL_NODE
(
lookup_table2
);
PATTERN_DECL_NODE
(
elementwise
);
// declare variable node's name
PATTERN_DECL_NODE
(
lookup_table1_x
);
PATTERN_DECL_NODE
(
lookup_table1_w
);
PATTERN_DECL_NODE
(
lookup_table1_out
);
PATTERN_DECL_NODE
(
lookup_table2_x
);
PATTERN_DECL_NODE
(
lookup_table2_w
);
PATTERN_DECL_NODE
(
lookup_table2_out
);
PATTERN_DECL_NODE
(
elementwise_out
);
};
}
// namespace patterns
class
SetTransformerInputConvertPass
:
public
FusePassBase
{
public:
SetTransformerInputConvertPass
();
virtual
~
SetTransformerInputConvertPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"transformer_input_convert_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
5914b18a
...
...
@@ -377,12 +377,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine
->
SetUseDLA
(
Get
<
bool
>
(
"trt_use_dla"
));
trt_engine
->
SetDLACore
(
Get
<
int
>
(
"trt_dla_core"
));
trt_engine
->
SetUseInspector
(
Get
<
bool
>
(
"use_inspector"
));
trt_engine
->
SetWithErnie
(
(
graph
->
Has
(
framework
::
ir
::
kEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
))
||
(
graph
->
Has
(
framework
::
ir
::
kPrelnEmbEltwiseLayernormPass
)
&&
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
)));
trt_engine
->
SetWithErnie
(
graph
->
Has
(
framework
::
ir
::
kMultiheadMatmulPass
));
if
(
use_static_engine
)
{
trt_engine_serialized_data
=
GetTrtEngineSerializedData
(
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
5914b18a
...
...
@@ -98,6 +98,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"multihead_matmul_fuse_pass_v3"
,
//
"skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
...
...
@@ -108,6 +109,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
// "remove_padding_recover_padding_pass", //
// "delete_remove_padding_recover_padding_pass", //
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录