Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4f066e31
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4f066e31
编写于
2月 03, 2021
作者:
A
Adam Osewski
提交者:
GitHub
2月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Layer normalization fuse pass. (#30721)
上级
b1026f64
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
824 addition
and
7 deletion
+824
-7
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+3
-1
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+116
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+35
-0
paddle/fluid/framework/ir/layer_norm_fuse_pass.cc
paddle/fluid/framework/ir/layer_norm_fuse_pass.cc
+231
-0
paddle/fluid/framework/ir/layer_norm_fuse_pass.h
paddle/fluid/framework/ir/layer_norm_fuse_pass.h
+84
-0
paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc
paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc
+199
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
...id/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
+1
-1
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+1
-1
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
...uid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
+1
-1
paddle/fluid/framework/ir/pass_test_util.cc
paddle/fluid/framework/ir/pass_test_util.cc
+51
-2
paddle/fluid/framework/ir/pass_test_util.h
paddle/fluid/framework/ir/pass_test_util.h
+35
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-1
python/paddle/fluid/tests/unittests/ir/inference/test_layer_norm_fuse_pass.py
...tests/unittests/ir/inference/test_layer_norm_fuse_pass.py
+64
-0
tools/static_mode_white_list.py
tools/static_mode_white_list.py
+1
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
4f066e31
...
...
@@ -92,6 +92,7 @@ pass_library(skip_layernorm_fuse_pass base)
pass_library
(
multihead_matmul_fuse_pass inference
)
pass_library
(
adaptive_pool2d_convert_global_pass inference
)
pass_library
(
unsqueeze2_eltwise_fuse_pass inference
)
pass_library
(
layer_norm_fuse_pass inference
)
if
(
WITH_GPU
)
pass_library
(
cudnn_placement_pass base DEPS placement_pass_base
)
pass_library
(
embedding_eltwise_layernorm_fuse_pass inference
)
...
...
@@ -129,6 +130,7 @@ cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc D
set
(
GLOB_PASS_LIB
${
PASS_LIBRARY
}
CACHE INTERNAL
"Global PASS library"
)
cc_library
(
pass_builder SRCS pass_builder.cc DEPS pass
)
cc_library
(
pass_test_util SRCS pass_test_util.cc DEPS graph pass
)
cc_test
(
node_test SRCS node_test.cc DEPS node
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
...
...
@@ -150,6 +152,7 @@ cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.c
cc_test
(
test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass
)
cc_test
(
test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass
)
cc_test
(
test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass
)
cc_test
(
test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor
)
if
(
WITH_GPU
)
cc_test
(
test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass
)
cc_test
(
test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass
)
...
...
@@ -158,7 +161,6 @@ if(NOT WIN32)
cc_test
(
test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass
)
endif
()
if
(
WITH_MKLDNN
)
cc_library
(
pass_test_util SRCS mkldnn/pass_test_util.cc DEPS graph pass
)
cc_test
(
test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass
)
cc_test
(
test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor
)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
4f066e31
...
...
@@ -2796,6 +2796,122 @@ PDNode *patterns::MultiGru::operator()() {
return
h
;
}
PDNode
*
patterns
::
LayerNorm
::
operator
()()
{
auto
*
x
=
pattern
->
NewNode
(
x_repr
())
->
AsInput
()
->
assert_is_ops_input
(
{
"reduce_mean"
,
"elementwise_sub"
});
auto
*
x_mean
=
pattern
->
NewNode
(
x_mean_repr
())
->
assert_is_op
(
"reduce_mean"
);
auto
*
x_mean_out
=
pattern
->
NewNode
(
x_mean_out_repr
())
->
assert_is_op_output
(
"reduce_mean"
,
"Out"
)
->
assert_is_op_input
(
"elementwise_sub"
,
"Y"
)
->
AsIntermediate
();
auto
*
x_sub_mean
=
pattern
->
NewNode
(
x_sub_mean_repr
())
->
assert_is_op
(
"elementwise_sub"
);
auto
*
x_sub_mean_out
=
pattern
->
NewNode
(
x_sub_mean_out_repr
())
->
assert_is_op_output
(
"elementwise_sub"
)
->
assert_is_ops_input
({
"elementwise_pow"
,
"elementwise_div"
},
"X"
)
->
AsIntermediate
();
auto
*
sqr_pow
=
pattern
->
NewNode
(
sqr_pow_repr
())
->
assert_is_op_input
(
"elementwise_pow"
,
"Y"
)
->
assert_is_persistable_var
()
->
AsInput
();
auto
*
x_sub_mean_sqr
=
pattern
->
NewNode
(
x_sub_mean_sqr_repr
())
->
assert_is_op
(
"elementwise_pow"
);
auto
*
x_sub_mean_sqr_out
=
pattern
->
NewNode
(
x_sub_mean_sqr_out_repr
())
->
assert_is_op_output
(
"elementwise_pow"
)
->
assert_is_op_input
(
"reduce_mean"
)
->
AsIntermediate
();
auto
*
std_dev
=
pattern
->
NewNode
(
std_dev_repr
())
->
assert_is_op
(
"reduce_mean"
);
auto
*
std_dev_out
=
pattern
->
NewNode
(
std_dev_out_repr
())
->
assert_is_op_output
(
"reduce_mean"
)
->
assert_is_op_input
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
eps
=
pattern
->
NewNode
(
eps_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_is_persistable_var
()
->
AsInput
();
auto
*
std_dev_eps
=
pattern
->
NewNode
(
std_dev_eps_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
std_dev_eps_out
=
pattern
->
NewNode
(
std_dev_eps_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"sqrt"
)
->
AsIntermediate
();
auto
*
std_dev_eps_sqrt
=
pattern
->
NewNode
(
std_dev_eps_sqrt_repr
())
->
assert_is_op
(
"sqrt"
);
auto
*
std_dev_eps_sqrt_out
=
pattern
->
NewNode
(
std_dev_eps_sqrt_out_repr
())
->
assert_is_op_output
(
"sqrt"
)
->
assert_is_op_input
(
"elementwise_div"
,
"Y"
)
->
AsIntermediate
();
auto
*
division
=
pattern
->
NewNode
(
division_repr
())
->
assert_is_op
(
"elementwise_div"
);
auto
*
division_out
=
pattern
->
NewNode
(
division_out_repr
())
->
assert_is_op_output
(
"elementwise_div"
)
->
assert_is_op_input
(
"elementwise_mul"
)
->
AsIntermediate
();
auto
*
gamma
=
pattern
->
NewNode
(
gamma_repr
())
->
assert_is_op_input
(
"elementwise_mul"
,
"Y"
)
->
assert_is_persistable_var
()
->
AsInput
();
auto
*
scale
=
pattern
->
NewNode
(
scale_repr
())
->
assert_is_op
(
"elementwise_mul"
);
auto
*
scale_out
=
pattern
->
NewNode
(
scale_out_repr
())
->
assert_is_op_output
(
"elementwise_mul"
)
->
assert_is_op_input
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
beta
=
pattern
->
NewNode
(
beta_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_is_persistable_var
()
->
AsInput
();
auto
*
shift
=
pattern
->
NewNode
(
shift_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
shift_out
=
pattern
->
NewNode
(
shift_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
/*
* X
* / \
* / reduce_mean "u(x)"
* \ /
* elementwise_sub "x - u(x)"
* / \ 2
* | \ /
* | elementwise_pow "(x - u(x))^2"
* | |
* | reduce_mean "sigma^2 = 1/C*Sum{(x - u(x))^2}"
* | | eps
* | | /
* | elementwise_add "sigma^2 + epsilon"
* \ |
* \ sqrt "sqrt(sigma^2 + epsilon)"
* \ /
* \ /
* elementwise_div "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}"
* |
* gamma |
* \ |
* elementwise_mul "scale: gamma(C) * lnorm"
* |
* beta |
* \ |
* elementwise_add "shift: gamma(C) * lnorm + beta(C)"
*/
x_mean
->
LinksFrom
({
x
}).
LinksTo
({
x_mean_out
});
x_sub_mean
->
LinksFrom
({
x
,
x_mean_out
}).
LinksTo
({
x_sub_mean_out
});
x_sub_mean_sqr
->
LinksFrom
({
x_sub_mean_out
,
sqr_pow
})
.
LinksTo
({
x_sub_mean_sqr_out
});
std_dev
->
LinksFrom
({
x_sub_mean_sqr_out
}).
LinksTo
({
std_dev_out
});
std_dev_eps
->
LinksFrom
({
std_dev_out
,
eps
}).
LinksTo
({
std_dev_eps_out
});
std_dev_eps_sqrt
->
LinksFrom
({
std_dev_eps_out
})
.
LinksTo
({
std_dev_eps_sqrt_out
});
division
->
LinksFrom
({
x_sub_mean_out
,
std_dev_eps_sqrt_out
})
.
LinksTo
({
division_out
});
scale
->
LinksFrom
({
division_out
,
gamma
}).
LinksTo
({
scale_out
});
shift
->
LinksFrom
({
scale_out
,
beta
}).
LinksTo
({
shift_out
});
return
shift_out
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
4f066e31
...
...
@@ -1598,6 +1598,41 @@ struct MultiGru : public PatternBase {
PATTERN_DECL_NODE
(
h
);
};
//
// \brief Pattern looking for subgraph representing layer normalization
// operation.
//
struct
LayerNorm
:
public
PatternBase
{
LayerNorm
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"layer_norm"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
x_mean
);
PATTERN_DECL_NODE
(
x_mean_out
);
PATTERN_DECL_NODE
(
x_sub_mean
);
PATTERN_DECL_NODE
(
x_sub_mean_out
);
PATTERN_DECL_NODE
(
sqr_pow
);
PATTERN_DECL_NODE
(
x_sub_mean_sqr
);
PATTERN_DECL_NODE
(
x_sub_mean_sqr_out
);
PATTERN_DECL_NODE
(
std_dev
);
PATTERN_DECL_NODE
(
std_dev_out
);
PATTERN_DECL_NODE
(
eps
);
PATTERN_DECL_NODE
(
std_dev_eps
);
PATTERN_DECL_NODE
(
std_dev_eps_out
);
PATTERN_DECL_NODE
(
std_dev_eps_sqrt
);
PATTERN_DECL_NODE
(
std_dev_eps_sqrt_out
);
PATTERN_DECL_NODE
(
division
);
PATTERN_DECL_NODE
(
division_out
);
PATTERN_DECL_NODE
(
gamma
);
PATTERN_DECL_NODE
(
scale
);
PATTERN_DECL_NODE
(
scale_out
);
PATTERN_DECL_NODE
(
beta
);
PATTERN_DECL_NODE
(
shift
);
PATTERN_DECL_NODE
(
shift_out
);
};
}
// namespace patterns
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/layer_norm_fuse_pass.cc
0 → 100644
浏览文件 @
4f066e31
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// cpplint complaints (wrong!) for not included <string> header in below line.
using
string
::
PrettyLogDetail
;
// NOLINT
namespace
{
void
validateReduceOpAttrs
(
const
Node
*
node
,
const
std
::
string
&
name
)
{
const
auto
*
op
=
node
->
Op
();
if
(
op
->
HasAttr
(
"dim"
))
{
auto
dims
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"dim"
));
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The LayerNorm fusion "
,
name
,
" reduction must happen only over "
"single dimension."
));
PADDLE_ENFORCE_EQ
(
dims
.
front
(),
-
1
,
platform
::
errors
::
PreconditionNotMet
(
"The LayerNorm fusion "
,
name
,
" reduction must happen over last "
"dimension."
));
}
if
(
op
->
HasAttr
(
"reduce_all"
))
{
PADDLE_ENFORCE
(
!
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"reduce_all"
)),
platform
::
errors
::
PreconditionNotMet
(
"The LayerNorm fusion "
,
name
,
" reduction must have "
"
\'
reduce_all
\'
attribute set to false."
));
}
if
(
op
->
HasAttr
(
"keep_dim"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"keep_dim"
)),
platform
::
errors
::
PreconditionNotMet
(
"The LayerNorm fusion "
,
name
,
" reduction must have "
"
\'
keep_dim
\'
attribute set to true."
));
}
}
void
setIntermediateOut
(
OpDesc
*
desc
,
const
std
::
string
&
out_name
,
const
std
::
string
&
scope_name
)
{
std
::
string
new_name
=
scope_name
+
"/at."
+
out_name
+
".new"
;
desc
->
SetOutput
(
out_name
,
{
new_name
});
}
void
addIntermediateOut
(
Node
*
op_node
,
const
std
::
string
&
out_name
,
const
std
::
string
&
scope_name
,
Graph
*
graph
)
{
std
::
string
new_name
=
scope_name
+
"/at."
+
out_name
+
".new"
;
VarDesc
out_var
(
new_name
);
out_var
.
SetPersistable
(
false
);
auto
*
node_var
=
graph
->
CreateVarNode
(
&
out_var
);
IR_NODE_LINK_TO
(
op_node
,
node_var
);
}
}
// namespace
void
LayerNormFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"The input graph of "
"LayerNormFusePass should not be nullptr."
));
FusePassBase
::
Init
(
scope_name_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
GraphPatternDetector
gpd
;
patterns
::
LayerNorm
layer_norm_pattern
(
gpd
.
mutable_pattern
(),
scope_name_
);
layer_norm_pattern
();
int
found_layer_norm_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse LayerNorm from subgraph."
;
GET_IR_NODE_FROM_SUBGRAPH
(
x
,
x
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
x_mean
,
x_mean
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
x_mean_out
,
x_mean_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
x_sub_mean
,
x_sub_mean
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
x_sub_mean_out
,
x_sub_mean_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
sqr_pow
,
sqr_pow
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
x_sub_mean_sqr
,
x_sub_mean_sqr
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
x_sub_mean_sqr_out
,
x_sub_mean_sqr_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
std_dev
,
std_dev
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
std_dev_out
,
std_dev_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eps
,
eps
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
std_dev_eps
,
std_dev_eps
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
std_dev_eps_out
,
std_dev_eps_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
std_dev_eps_sqrt
,
std_dev_eps_sqrt
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
std_dev_eps_sqrt_out
,
std_dev_eps_sqrt_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
division
,
division
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
division_out
,
division_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
gamma
,
gamma
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale
,
scale
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
scale_out
,
scale_out
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
beta
,
beta
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
shift
,
shift
,
layer_norm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
shift_out
,
shift_out
,
layer_norm_pattern
);
auto
*
eps_tensor
=
scope
->
FindVar
(
eps
->
Name
())
->
GetMutable
<
LoDTensor
>
();
// ------------------ subgraph node's validation ---------------------------
PADDLE_ENFORCE_EQ
(
eps_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The LayerNorm divisor "
"epsilon value must be one-element tensor, but has %s "
"elements."
,
eps_tensor
->
numel
()));
PADDLE_ENFORCE_EQ
(
eps_tensor
->
type
(),
proto
::
VarType
::
FP32
,
platform
::
errors
::
InvalidArgument
(
"The LayerNorm divisor "
"epsilon value must be of FP32 data type, but is %s."
,
eps_tensor
->
type
()));
const
auto
&
gamma_shape
=
gamma
->
Var
()
->
GetShape
();
const
auto
&
beta_shape
=
beta
->
Var
()
->
GetShape
();
const
auto
&
x_shape
=
x
->
Var
()
->
GetShape
();
int64_t
x_last_dim
=
x_shape
.
back
();
PADDLE_ENFORCE_EQ
(
gamma_shape
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The LayerNorm gamma "
"(scale) tensor shape must be one-dimensional, "
"but is %s."
,
gamma_shape
.
size
()));
PADDLE_ENFORCE_EQ
(
beta_shape
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The LayerNorm beta "
"(shift) tensor shape must be one-dimensional, "
"but is %s."
,
beta_shape
.
size
()));
PADDLE_ENFORCE_EQ
(
beta_shape
,
gamma_shape
,
platform
::
errors
::
InvalidArgument
(
"The LayerNorm beta "
"and gamma tensors shapes' must be equal."
));
PADDLE_ENFORCE_EQ
(
gamma_shape
.
front
(),
x_last_dim
,
platform
::
errors
::
InvalidArgument
(
"The LayerNorm beta "
"and gamma tensors shapes' must be equal to the last "
"input's dimension size."
));
validateReduceOpAttrs
(
x_mean
,
"input mean"
);
validateReduceOpAttrs
(
std_dev
,
"std_dev mean"
);
// ------------------ op creation and placement ---------------------------
OpDesc
ln_op_desc
;
ln_op_desc
.
SetType
(
"layer_norm"
);
ln_op_desc
.
SetInput
(
"X"
,
{
x
->
Name
()});
ln_op_desc
.
SetInput
(
"Scale"
,
{
gamma
->
Name
()});
ln_op_desc
.
SetInput
(
"Bias"
,
{
beta
->
Name
()});
ln_op_desc
.
SetOutput
(
"Y"
,
{
shift_out
->
Name
()});
setIntermediateOut
(
&
ln_op_desc
,
"Mean"
,
scope_name_
);
setIntermediateOut
(
&
ln_op_desc
,
"Variance"
,
scope_name_
);
ln_op_desc
.
SetAttr
(
"begin_norm_axis"
,
static_cast
<
int
>
(
x_shape
.
size
()
-
1
));
ln_op_desc
.
SetAttr
(
"epsilon"
,
*
(
eps_tensor
->
data
<
float
>
()));
ln_op_desc
.
SetAttr
(
"is_test"
,
true
);
Node
*
ln_op
=
g
->
CreateOpNode
(
&
ln_op_desc
);
addIntermediateOut
(
ln_op
,
"Mean"
,
scope_name_
,
g
);
addIntermediateOut
(
ln_op
,
"Variance"
,
scope_name_
,
g
);
IR_NODE_LINK_TO
(
x
,
ln_op
);
IR_NODE_LINK_TO
(
gamma
,
ln_op
);
IR_NODE_LINK_TO
(
beta
,
ln_op
);
IR_OP_VAR_LINK
(
ln_op
,
shift_out
);
GraphSafeRemoveNodes
(
g
,
{
x_mean
,
x_mean_out
,
x_sub_mean
,
x_sub_mean_out
,
sqr_pow
,
x_sub_mean_sqr
,
x_sub_mean_sqr_out
,
std_dev
,
std_dev_out
,
eps
,
std_dev_eps
,
std_dev_eps_out
,
std_dev_eps_sqrt
,
std_dev_eps_sqrt_out
,
division
,
division_out
,
scale
,
scale_out
,
shift
});
found_layer_norm_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_layer_norm_count
);
PrettyLogDetail
(
"--- Fused %d subgraphs into layer_norm op."
,
found_layer_norm_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
layer_norm_fuse_pass
,
paddle
::
framework
::
ir
::
LayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
layer_norm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
GE
(
"elementwise_add"
,
0
)
.
LE
(
"elementwise_add"
,
1
)
.
GE
(
"elementwise_div"
,
0
)
.
LE
(
"elementwise_div"
,
1
)
.
GE
(
"elementwise_mul"
,
0
)
.
LE
(
"elementwise_mul"
,
1
)
.
GE
(
"elementwise_pow"
,
0
)
.
LE
(
"elementwise_pow"
,
1
)
.
GE
(
"elementwise_sub"
,
0
)
.
LE
(
"elementwise_sub"
,
1
)
.
EQ
(
"reduce_mean"
,
0
)
.
EQ
(
"sqrt"
,
0
));
paddle/fluid/framework/ir/layer_norm_fuse_pass.h
0 → 100644
浏览文件 @
4f066e31
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* \brief Fuse the subgraph representing layer normalization into
* layer_norm op.
*
* \note The following graph represents this equation:
*
* x - u(x)
* y(c) * ------------------- + b(c)
* sqrt(sigma^2 + eps)
*
* x - input data
* u(x) - mean
* sigma^2 - standard deviation
* eps - epsilon
* y(c) - gamma (scale) channelwise
* b(c) - beta (shift) channelwise
*
*
* X
* / \
* / reduce_mean "u(x)"
* \ /
* elementwise_sub "x - u(x)"
* / \ 2
* | \ /
* | elementwise_pow "(x - u(x))^2"
* | |
* | reduce_mean "sigma^2 = 1/C*Sum{(x - u(x))^2}"
* | | eps
* | | /
* | elementwise_add "sigma^2 + epsilon"
* \ |
* \ sqrt "sqrt(sigma^2 + epsilon)"
* \ /
* \ /
* elementwise_div "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}"
* |
* gamma |
* \ |
* elementwise_mul "scale: gamma(C) * lnorm"
* |
* beta |
* \ |
* elementwise_add "shift: gamma(C) * lnorm + beta(C)"
*/
class
LayerNormFusePass
:
public
FusePassBase
{
public:
virtual
~
LayerNormFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
const
std
::
string
scope_name_
{
"layer_norm_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc
0 → 100644
浏览文件 @
4f066e31
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
ProgramDesc
BuildGraphProgram
()
{
auto
prog
=
test
::
BuildProgramDesc
(
{
"x"
,
"x_mean_out"
,
"x_sub_mean_out"
,
"x_sub_mean_sqr_out"
,
"std_dev_out"
,
"std_dev_eps_out"
,
"std_dev_eps_sqrt_out"
,
"division_out"
,
"scale_out"
,
"shift_out"
},
{
"sqr_pow"
,
"eps"
,
"gamma"
,
"beta"
});
const
auto
&
block_desc
=
prog
.
Block
(
0
);
auto
*
x_var_desc
=
block_desc
.
FindVar
(
"x"
);
x_var_desc
->
SetDataType
(
proto
::
VarType
::
FP32
);
x_var_desc
->
SetShape
({
3
,
32
,
48
});
auto
*
eps_var_desc
=
block_desc
.
FindVar
(
"eps"
);
eps_var_desc
->
SetDataType
(
proto
::
VarType
::
FP32
);
eps_var_desc
->
SetShape
({
1
});
auto
*
gamma_var_desc
=
block_desc
.
FindVar
(
"gamma"
);
gamma_var_desc
->
SetDataType
(
proto
::
VarType
::
FP32
);
gamma_var_desc
->
SetShape
({
48
});
auto
*
beta_var_desc
=
block_desc
.
FindVar
(
"beta"
);
beta_var_desc
->
SetDataType
(
proto
::
VarType
::
FP32
);
beta_var_desc
->
SetShape
({
48
});
auto
*
x_mean
=
test
::
CreateOp
(
&
prog
,
"reduce_mean"
,
{{
"X"
,
"x"
}},
{{
"Out"
,
"x_mean_out"
}},
false
);
x_mean
->
SetAttr
(
"dim"
,
std
::
vector
<
int
>
{
-
1
});
x_mean
->
SetAttr
(
"keep_dim"
,
true
);
x_mean
->
SetAttr
(
"reduce_all"
,
false
);
test
::
CreateOp
(
&
prog
,
"elementwise_sub"
,
{{
"X"
,
"x"
},
{
"Y"
,
"x_mean_out"
}},
{{
"Out"
,
"x_sub_mean_out"
}},
false
);
test
::
CreateOp
(
&
prog
,
"elementwise_pow"
,
{{
"X"
,
"x_sub_mean_out"
},
{
"Y"
,
"sqr_pow"
}},
{{
"Out"
,
"x_sub_mean_sqr_out"
}},
false
);
auto
*
std_dev
=
test
::
CreateOp
(
&
prog
,
"reduce_mean"
,
{{
"X"
,
"x_sub_mean_sqr_out"
}},
{{
"Out"
,
"std_dev_out"
}},
false
);
std_dev
->
SetAttr
(
"dim"
,
std
::
vector
<
int
>
{
-
1
});
std_dev
->
SetAttr
(
"keep_dim"
,
true
);
std_dev
->
SetAttr
(
"reduce_all"
,
false
);
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"std_dev_out"
},
{
"Y"
,
"eps"
}},
{{
"Out"
,
"std_dev_eps_out"
}},
false
);
test
::
CreateOp
(
&
prog
,
"sqrt"
,
{{
"X"
,
"std_dev_eps_out"
}},
{{
"Out"
,
"std_dev_eps_sqrt_out"
}},
false
);
test
::
CreateOp
(
&
prog
,
"elementwise_div"
,
{{
"X"
,
"x_sub_mean_out"
},
{
"Y"
,
"std_dev_eps_sqrt_out"
}},
{{
"Out"
,
"division_out"
}},
false
);
test
::
CreateOp
(
&
prog
,
"elementwise_mul"
,
{{
"X"
,
"division_out"
},
{
"Y"
,
"gamma"
}},
{{
"Out"
,
"scale_out"
}},
false
);
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"scale_out"
},
{
"Y"
,
"beta"
}},
{{
"Out"
,
"shift_out"
}},
false
);
return
prog
;
}
bool
CheckFusedSubgraphOpsCount
(
const
Graph
&
graph
)
{
return
test
::
AssertOpsCount
(
graph
,
{{
"reduce_mean"
,
0
},
{
"elementwise_sub"
,
0
},
{
"elementwise_pow"
,
0
},
{
"elementwise_add"
,
0
},
{
"sqrt"
,
0
},
{
"elementwise_div"
,
0
},
{
"elementwise_mul"
,
0
},
{
"layer_norm"
,
1
}});
}
}
// namespace
// ------------------------------ Test cases -----------------------------------
TEST
(
FuseLayerNormPass
,
TestFuse
)
{
ProgramDesc
prog
=
BuildGraphProgram
();
Graph
graph
(
prog
);
constexpr
int
removed_nodes
=
19
;
// LayerNorm + outputs: {Mean, Variance}
constexpr
int
added_nodes
=
3
;
auto
place
=
paddle
::
platform
::
CPUPlace
();
NaiveExecutor
exe
{
place
};
Scope
scope
;
float
eps_value
=
1e-5
f
;
// Init scope, as it is used in pass
exe
.
CreateVariables
(
prog
,
0
,
true
,
&
scope
);
test
::
InitLoDTensorHolder
<
float
>
(
&
scope
,
place
,
"eps"
,
{
1
},
&
eps_value
);
graph
.
SetNotOwned
(
kParamScopeAttr
,
&
scope
);
EXPECT_TRUE
(
test
::
RunPassAndAssert
(
&
graph
,
"layer_norm_fuse_pass"
,
"x"
,
"shift_out"
,
removed_nodes
,
added_nodes
));
EXPECT_TRUE
(
CheckFusedSubgraphOpsCount
(
graph
));
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"layer_norm"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"is_test"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"is_test"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"begin_norm_axis"
));
ASSERT_TRUE
(
op
->
HasAttr
(
"epsilon"
));
}
}
}
TEST
(
FuseLayerNormPass
,
TestInvalidEpsNumel
)
{
ProgramDesc
prog
=
BuildGraphProgram
();
auto
*
eps_var_desc
=
prog
.
Block
(
0
).
FindVar
(
"eps"
);
eps_var_desc
->
SetDataType
(
proto
::
VarType
::
FP32
);
eps_var_desc
->
SetShape
({
2
});
Graph
graph
(
prog
);
constexpr
int
removed_nodes
=
19
;
constexpr
int
added_nodes
=
3
;
auto
place
=
paddle
::
platform
::
CPUPlace
();
NaiveExecutor
exe
{
place
};
Scope
scope
;
auto
eps_values
=
std
::
vector
<
float
>
{
1e-5
f
,
1e-5
f
};
// Init scope, as it is used in pass
exe
.
CreateVariables
(
prog
,
0
,
true
,
&
scope
);
test
::
InitLoDTensorHolder
<
float
>
(
&
scope
,
place
,
"eps"
,
{
2
},
eps_values
.
data
());
graph
.
SetNotOwned
(
kParamScopeAttr
,
&
scope
);
EXPECT_THROW
(
test
::
RunPassAndAssert
(
&
graph
,
"layer_norm_fuse_pass"
,
"x"
,
"shift_out"
,
removed_nodes
,
added_nodes
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseLayerNormPass
,
TestInvalidEpsDataType
)
{
ProgramDesc
prog
=
BuildGraphProgram
();
auto
*
eps_var_desc
=
prog
.
Block
(
0
).
FindVar
(
"eps"
);
eps_var_desc
->
SetDataType
(
proto
::
VarType
::
FP64
);
eps_var_desc
->
SetShape
({
1
});
Graph
graph
(
prog
);
constexpr
int
removed_nodes
=
19
;
constexpr
int
added_nodes
=
3
;
auto
place
=
paddle
::
platform
::
CPUPlace
();
NaiveExecutor
exe
{
place
};
Scope
scope
;
double
eps_value
=
1e-5
;
// Init scope, as it is used in pass
exe
.
CreateVariables
(
prog
,
0
,
true
,
&
scope
);
test
::
InitLoDTensorHolder
<
double
>
(
&
scope
,
place
,
"eps"
,
{
1
},
&
eps_value
);
graph
.
SetNotOwned
(
kParamScopeAttr
,
&
scope
);
EXPECT_THROW
(
test
::
RunPassAndAssert
(
&
graph
,
"layer_norm_fuse_pass"
,
"x"
,
"shift_out"
,
removed_nodes
,
added_nodes
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseLayerNormPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"layer_norm_fuse_pass"
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
layer_norm_fuse_pass
);
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
浏览文件 @
4f066e31
...
...
@@ -15,7 +15,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/
mkldnn/
pass_test_util.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
4f066e31
...
...
@@ -15,7 +15,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/
mkldnn/
pass_test_util.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
...
...
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
浏览文件 @
4f066e31
...
...
@@ -15,7 +15,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/
mkldnn/
pass_test_util.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
paddle/fluid/framework/ir/
mkldnn/
pass_test_util.cc
→
paddle/fluid/framework/ir/pass_test_util.cc
浏览文件 @
4f066e31
...
...
@@ -13,15 +13,19 @@
// limitations under the License.
#include <algorithm>
#include <cstring>
#include <exception>
#include <functional>
#include <iterator>
#include <list>
#include <map>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -32,7 +36,7 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
const
std
::
vector
<
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
InOutVarNamePair
>&
outputs
,
bool
use_mkldnn
)
{
auto
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
...
...
@@ -43,6 +47,8 @@ OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
op
;
}
...
...
@@ -168,6 +174,49 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
return
expected_nodes_num
==
current_nodes_num
;
}
template
<
typename
T
>
void
InitLoDTensorHolder
(
Scope
*
scope
,
const
paddle
::
platform
::
Place
&
place
,
const
std
::
string
&
var_name
,
const
std
::
vector
<
int64_t
>&
dims
,
const
T
*
data
)
{
auto
var
=
scope
->
Var
(
var_name
);
auto
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
auto
*
tensor_mem_ptr
=
tensor
->
mutable_data
<
T
>
(
make_ddim
(
dims
),
place
);
if
(
data
!=
nullptr
)
{
std
::
memcpy
(
tensor_mem_ptr
,
data
,
tensor
->
memory_size
());
}
else
{
std
::
memset
(
tensor_mem_ptr
,
0
,
tensor
->
memory_size
());
}
}
// Instantiate for below data types.
template
void
InitLoDTensorHolder
<
float
>(
Scope
*
,
const
paddle
::
platform
::
Place
&
,
const
std
::
string
&
,
const
std
::
vector
<
int64_t
>&
,
const
float
*
);
template
void
InitLoDTensorHolder
<
int
>(
Scope
*
,
const
paddle
::
platform
::
Place
&
,
const
std
::
string
&
,
const
std
::
vector
<
int64_t
>&
,
const
int
*
);
template
void
InitLoDTensorHolder
<
double
>(
Scope
*
,
const
paddle
::
platform
::
Place
&
,
const
std
::
string
&
,
const
std
::
vector
<
int64_t
>&
,
const
double
*
);
OpDesc
*
GetOp
(
const
ProgramDesc
&
prog
,
const
std
::
string
&
op_type
,
const
std
::
string
&
output_name
,
const
std
::
string
&
output_arg_name
)
{
auto
all_ops
=
prog
.
Block
(
0
).
AllOps
();
for
(
auto
*
op_desc
:
all_ops
)
{
if
(
op_desc
->
Type
()
==
op_type
&&
op_desc
->
HasOutput
(
output_name
))
{
const
auto
&
arg_names
=
op_desc
->
Outputs
().
at
(
output_name
);
for
(
const
auto
&
name
:
arg_names
)
{
if
(
name
==
output_arg_name
)
return
op_desc
;
}
}
}
return
nullptr
;
}
}
// namespace test
}
// namespace ir
}
// namespace framework
...
...
paddle/fluid/framework/ir/
mkldnn/
pass_test_util.h
→
paddle/fluid/framework/ir/pass_test_util.h
浏览文件 @
4f066e31
...
...
@@ -18,9 +18,13 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -113,6 +117,37 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
const
std
::
string
&
from
,
const
std
::
string
&
to
,
int
removed_nodes_count
,
int
added_nodes_count
=
0
);
///
/// @brief Initializes the tensor memory holder.
///
/// @param[in] scope The scope that manages the variable.
/// @param[in] place The place where memory will be allocated.
/// @param[in] var_name The variable name.
/// @param[in] dims The dimensions of allocated tensor.
///
/// @tparam T Tensor data type.
///
template
<
typename
T
>
void
InitLoDTensorHolder
(
Scope
*
scope
,
const
paddle
::
platform
::
Place
&
place
,
const
std
::
string
&
var_name
,
const
std
::
vector
<
int64_t
>&
dims
,
const
T
*
data
=
nullptr
);
///
/// @brief Retrieve operator descriptor from program.
///
/// @param[in] prog The program descriptor containing the op we
/// search for.
/// @param[in] op_type The wanted operator type name.
/// @param[in] output_name The wanted operator output name.
/// @param[in] output_arg_name The wanted operator output argument name.
///
/// @return The operator descriptor.
///
OpDesc
*
GetOp
(
const
ProgramDesc
&
prog
,
const
std
::
string
&
op_type
,
const
std
::
string
&
output_name
,
const
std
::
string
&
output_arg_name
);
}
// namespace test
}
// namespace ir
}
// namespace framework
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
4f066e31
...
...
@@ -162,6 +162,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones.
passes_
.
assign
({
"simplify_with_basic_ops_pass"
,
//
"layer_norm_fuse_pass"
,
"attention_lstm_fuse_pass"
,
//
"seqconv_eltadd_relu_fuse_pass"
,
//
// "seqpool_concat_fuse_pass", //
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_layer_norm_fuse_pass.py
0 → 100644
浏览文件 @
4f066e31
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of subgraph expressing layer normalization."""
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
from
paddle
import
enable_static
from
paddle.fluid.core
import
PassVersionChecker
class
LayerNormFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
3
,
64
,
120
],
dtype
=
"float32"
)
sqr_pow
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
2
,
dtype
=
"float32"
)
eps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
1e-5
,
dtype
=
"float32"
)
gamma
=
fluid
.
layers
.
create_parameter
(
shape
=
[
120
],
dtype
=
"float32"
,
is_bias
=
True
)
beta
=
fluid
.
layers
.
create_parameter
(
shape
=
[
120
],
dtype
=
"float32"
,
is_bias
=
True
)
x_mean_out
=
fluid
.
layers
.
reduce_mean
(
data
,
dim
=-
1
,
keep_dim
=
True
)
x_sub_mean_out
=
fluid
.
layers
.
elementwise_sub
(
data
,
x_mean_out
)
x_sub_mean_sqr_out
=
fluid
.
layers
.
elementwise_pow
(
x_sub_mean_out
,
sqr_pow
)
std_dev_out
=
fluid
.
layers
.
reduce_mean
(
x_sub_mean_sqr_out
,
dim
=-
1
,
keep_dim
=
True
)
std_dev_eps_out
=
fluid
.
layers
.
elementwise_add
(
std_dev_out
,
eps
)
std_dev_eps_sqrt_out
=
fluid
.
layers
.
sqrt
(
std_dev_eps_out
)
division_out
=
fluid
.
layers
.
elementwise_div
(
x_sub_mean_out
,
std_dev_eps_sqrt_out
)
scale_out
=
fluid
.
layers
.
elementwise_mul
(
division_out
,
gamma
)
shift_out
=
fluid
.
layers
.
elementwise_add
(
scale_out
,
beta
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
3
,
64
,
120
)).
astype
(
"float32"
),
}
self
.
fetch_list
=
[
shift_out
]
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
"layer_norm_fuse_pass"
))
if
__name__
==
"__main__"
:
enable_static
()
unittest
.
main
()
tools/static_mode_white_list.py
浏览文件 @
4f066e31
...
...
@@ -296,6 +296,7 @@ STATIC_MODE_TESTING_LIST = [
'test_layer_norm_mkldnn_op'
,
'test_layer_norm_bf16_mkldnn_op'
,
'test_layer_norm_op_v2'
,
'test_layer_norm_fuse_pass'
,
'test_learning_rate_scheduler'
,
'test_linear_interp_op'
,
'test_linear_interp_v2_op'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录