Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
aa0e84e3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aa0e84e3
编写于
9月 21, 2022
作者:
W
wenbin
提交者:
GitHub
9月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
residual_no_bias (#46129)
* residual_no_bias * comments * more ut * fix input
上级
3d59fee5
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
347 addition
and
69 deletion
+347
-69
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc
+98
-44
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h
+11
-0
paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc
...e/fluid/inference/tensorrt/convert/preln_residual_bias.cc
+20
-13
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
...d/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
+15
-12
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+4
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py
...s/ir/inference/test_trt_convert_preln_residual_no_bias.py
+166
-0
python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py
...sts/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py
+32
-0
未找到文件。
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc
浏览文件 @
aa0e84e3
...
...
@@ -33,11 +33,16 @@ namespace ir {
namespace
patterns
{
struct
PrelnResidualBias
:
public
PatternBase
{
PrelnResidualBias
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_residual_bias"
)
{}
PrelnResidualBias
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_bias
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_residual_bias"
)
{
with_bias_
=
with_bias
;
}
void
operator
()(
PDNode
*
x
,
PDNode
*
y
);
bool
with_bias_
;
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise_bias
);
PATTERN_DECL_NODE
(
elementwise0
);
...
...
@@ -55,15 +60,17 @@ struct PrelnResidualBias : public PatternBase {
};
void
PrelnResidualBias
::
operator
()(
PDNode
*
x
,
PDNode
*
y
)
{
PDNode
*
elementwise0
=
nullptr
;
PDNode
*
elementwise_bias_var
=
nullptr
;
PDNode
*
elementwise0_out_var
=
nullptr
;
// Create nodes for elementwise add op.
x
->
assert_is_op_input
(
"elementwise_add"
);
y
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
elementwise0
=
if
(
with_bias_
)
{
elementwise0
=
pattern
->
NewNode
(
elementwise0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_bias_var
=
pattern
->
NewNode
(
elementwise_bias_repr
())
elementwise_bias_var
=
pattern
->
NewNode
(
elementwise_bias_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_is_persistable_var
();
auto
*
elementwise0_out_var
=
pattern
->
NewNode
(
elementwise0_out_repr
())
elementwise0_out_var
=
pattern
->
NewNode
(
elementwise0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_more
([](
Node
*
x
)
{
...
...
@@ -73,14 +80,21 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
return
false
;
}
});
}
else
{
elementwise0_out_var
=
y
;
}
auto
*
elementwise1
=
pattern
->
NewNode
(
elementwise1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise1_out_var
=
pattern
->
NewNode
(
elementwise1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// Add links for elementwise_add op.
if
(
with_bias_
)
{
elementwise0
->
LinksFrom
({
y
,
elementwise_bias_var
})
.
LinksTo
({
elementwise0_out_var
});
elementwise1_out_var
->
assert_is_op_output
(
"elementwise_add"
);
}
elementwise1
->
LinksFrom
({
x
,
elementwise0_out_var
})
.
LinksTo
({
elementwise1_out_var
});
// Create nodes for layer_norm op.
...
...
@@ -115,7 +129,8 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) {
}
// namespace patterns
void
PrelnResidualBiasFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
int
PrelnResidualBiasFusePass
::
ApplyPattern
(
ir
::
Graph
*
graph
,
bool
with_bias
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"preln_residual_bias_fuse"
,
graph
);
...
...
@@ -123,18 +138,32 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
PDNode
*
x
=
nullptr
;
PDNode
*
y
=
nullptr
;
if
(
with_bias
)
{
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_residual_bias_fuse/x"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_var_not_persistable
();
auto
*
y
=
gpd
.
mutable_pattern
()
y
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_residual_bias_fuse/y"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_var_not_persistable
();
patterns
::
PrelnResidualBias
fused_pattern
(
gpd
.
mutable_pattern
(),
"preln_residual_bias_fuse"
);
}
else
{
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_residual_bias_fuse/x"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
y
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_residual_bias_fuse/y"
)
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
}
patterns
::
PrelnResidualBias
fused_pattern
(
gpd
.
mutable_pattern
(),
"preln_residual_bias_fuse"
,
with_bias
);
fused_pattern
(
x
,
y
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
...
@@ -145,11 +174,19 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
}
VLOG
(
4
)
<<
"handle PrelnResidualBias fuse"
;
Node
*
elementwise_bias
=
nullptr
;
Node
*
elementwise0
=
nullptr
;
Node
*
elementwise0_out
=
nullptr
;
if
(
with_bias
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_bias
,
elementwise_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise0
,
elementwise0
,
fused_pattern
);
tmp_
elementwise_bias
,
elementwise_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
tmp_
elementwise0
,
elementwise0
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise0_out
,
elementwise0_out
,
fused_pattern
);
tmp_elementwise0_out
,
elementwise0_out
,
fused_pattern
);
elementwise_bias
=
tmp_elementwise_bias
;
elementwise0
=
tmp_elementwise0
;
elementwise0_out
=
tmp_elementwise0_out
;
}
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise1
,
elementwise1
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise1_out
,
elementwise1_out
,
fused_pattern
);
...
...
@@ -185,7 +222,9 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
new_desc
.
SetInput
(
"Y"
,
{
subgraph
.
at
(
y
)
->
Name
()});
new_desc
.
SetInput
(
"Scale"
,
{
layer_norm_scale
->
Name
()});
new_desc
.
SetInput
(
"Bias"
,
{
layer_norm_bias
->
Name
()});
if
(
with_bias
)
{
new_desc
.
SetInput
(
"EleBias"
,
{
elementwise_bias
->
Name
()});
}
// outputs
new_desc
.
SetOutput
(
"Out_0"
,
{
layer_norm_out
->
Name
()});
new_desc
.
SetOutput
(
"Out_1"
,
{
elementwise1_out
->
Name
()});
...
...
@@ -194,16 +233,20 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
new_desc
.
SetAttr
(
"begin_norm_axis"
,
layer_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
if
(
with_bias
)
{
del_node_set
.
insert
(
elementwise0
);
del_node_set
.
insert
(
elementwise1
);
del_node_set
.
insert
(
elementwise0_out
);
}
del_node_set
.
insert
(
elementwise1
);
del_node_set
.
insert
(
layer_norm
);
del_node_set
.
insert
(
layer_norm_mean
);
del_node_set
.
insert
(
layer_norm_variance
);
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
y
),
fused_node
);
if
(
with_bias
)
{
IR_NODE_LINK_TO
(
elementwise_bias
,
fused_node
);
}
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
layer_norm_out
);
...
...
@@ -212,6 +255,17 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const {
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
void
PrelnResidualBiasFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"preln_residual_bias_fuse"
,
graph
);
int
found_subgraph_count
=
0
;
found_subgraph_count
=
ApplyPattern
(
graph
,
true
);
found_subgraph_count
+=
ApplyPattern
(
graph
,
false
);
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h
浏览文件 @
aa0e84e3
...
...
@@ -29,6 +29,16 @@ namespace ir {
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
// or
//
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_residual_bias
// | | | |
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
class
Graph
;
class
PrelnResidualBiasFusePass
:
public
FusePassBase
{
...
...
@@ -80,6 +90,7 @@ class PrelnResidualBiasFusePass : public FusePassBase {
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
int
ApplyPattern
(
ir
::
Graph
*
graph
,
bool
with_bias
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc
浏览文件 @
aa0e84e3
...
...
@@ -51,12 +51,15 @@ class PrelnResidualBiasOpConverter : public OpConverter {
framework
::
DDim
bias_dims
,
scale_dims
,
ele_bias_dims
;
auto
*
bias
=
get_persistable_data
(
"Bias"
,
&
bias_dims
);
auto
*
scale
=
get_persistable_data
(
"Scale"
,
&
scale_dims
);
auto
*
ele_bias
=
get_persistable_data
(
"EleBias"
,
&
ele_bias_dims
);
auto
const
&
vars
=
op_desc
.
Inputs
(
false
);
bool
has_bias
=
vars
.
find
(
"EleBias"
)
!=
vars
.
end
();
float
*
ele_bias
=
has_bias
?
get_persistable_data
(
"EleBias"
,
&
ele_bias_dims
)
:
nullptr
;
int
bias_size
=
phi
::
product
(
bias_dims
);
int
scale_size
=
phi
::
product
(
scale_dims
);
int
ele_bias_size
=
phi
::
product
(
ele_bias_dims
)
;
int
ele_bias_size
=
has_bias
?
phi
::
product
(
ele_bias_dims
)
:
0
;
float
epsilon
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
)
{
...
...
@@ -66,13 +69,17 @@ class PrelnResidualBiasOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
if
(
with_fp16
)
{
auto
half_ele_bias_data
=
new
half
[
ele_bias_size
];
half
*
half_ele_bias_data
=
nullptr
;
if
(
ele_bias_size
>
0
)
{
half_ele_bias_data
=
new
half
[
ele_bias_size
];
for
(
int
i
=
0
;
i
<
ele_bias_size
;
i
++
)
{
half_ele_bias_data
[
i
]
=
static_cast
<
half
>
(
ele_bias
[
i
]);
}
plugin
=
new
plugin
::
PrelnResidualBiasPluginDynamic
(
bias
,
}
plugin
=
new
plugin
::
PrelnResidualBiasPluginDynamic
(
bias
,
scale
,
half_ele_bias_data
,
ele_bias_size
>
0
?
half_ele_bias_data
:
nullptr
,
bias_size
,
scale_size
,
ele_bias_size
,
...
...
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
浏览文件 @
aa0e84e3
...
...
@@ -44,7 +44,7 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT {
scale_
.
data
(),
scale_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
if
(
ele_bias_size_
>
0
)
{
if
(
with_fp16_
)
{
cudaMalloc
(
&
ele_bias_gpu_
,
sizeof
(
half
)
*
ele_bias_size_
);
cudaMemcpy
(
ele_bias_gpu_
,
...
...
@@ -58,6 +58,9 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT {
ele_bias_size_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
}
else
{
ele_bias_gpu_
=
nullptr
;
}
return
0
;
}
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
aa0e84e3
...
...
@@ -142,6 +142,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_complex_matmul
)
list
(
REMOVE_ITEM TEST_OPS test_ops_nms
)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias
)
list
(
REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_no_bias
)
list
(
REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_checkpoint_saver
)
...
...
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
aa0e84e3
...
...
@@ -22,6 +22,10 @@ if(NOT WITH_DISTRIBUTE)
"test_trt_convert_preln_residual_bias"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_preln_residual_bias"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_preln_residual_bias"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_preln_residual_no_bias"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_preln_residual_no_bias"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_preln_residual_no_bias"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_c_allreduce"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_c_allreduce"
)
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py
0 → 100644
浏览文件 @
aa0e84e3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
trt_layer_auto_scan_test
import
TrtLayerAutoScanTest
,
SkipReasons
from
program_config
import
TensorConfig
,
ProgramConfig
import
numpy
as
np
import
paddle.inference
as
paddle_infer
from
functools
import
partial
from
typing
import
Optional
,
List
,
Callable
,
Dict
,
Any
,
Set
import
unittest
class
TrtConvertSkipLayernormTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
inputs
=
program_config
.
inputs
weights
=
program_config
.
weights
outputs
=
program_config
.
outputs
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
#The input dimension should be less than or equal to the set axis.
if
'begin_norm_axis'
in
attrs
[
0
]
and
attrs
[
0
][
'begin_norm_axis'
]
>=
0
:
if
len
(
inputs
[
'inputX_data'
].
shape
)
<=
attrs
[
0
][
'begin_norm_axis'
]:
return
False
return
True
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]],
batch
):
return
np
.
ones
([
batch
,
128
,
768
]).
astype
(
np
.
float32
)
def
generate_input2
(
attrs
:
List
[
Dict
[
str
,
Any
]],
batch
):
return
np
.
ones
([
batch
,
128
,
768
]).
astype
(
np
.
float32
)
def
generate_weight1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
768
]).
astype
(
np
.
float32
)
def
generate_weight2
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
768
]).
astype
(
np
.
float32
)
for
batch
in
[
4
]:
for
epsilon
in
[
1e-5
]:
for
begin_norm_axis
in
[
2
]:
for
enable_int8
in
[
False
,
True
]:
dics
=
[{
"epsilon"
:
epsilon
,
"begin_norm_axis"
:
begin_norm_axis
,
},
{}]
ops_config
=
[{
"op_type"
:
"elementwise_add"
,
"op_inputs"
:
{
"X"
:
[
"inputX_data"
],
"Y"
:
[
"inputY_data"
]
},
"op_outputs"
:
{
"Out"
:
[
"ele_out"
]
},
"op_attrs"
:
{
"axis"
:
-
1
}
},
{
"op_type"
:
"layer_norm"
,
"op_inputs"
:
{
"X"
:
[
"ele_out"
],
"Bias"
:
[
"Bias"
],
"Scale"
:
[
"Scale"
]
},
"op_outputs"
:
{
"Y"
:
[
"layernorm_out"
],
"Mean"
:
[
"Mean"
],
"Variance"
:
[
"Variance"
]
},
"op_attrs"
:
dics
[
0
]
}]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{
"Bias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight1
,
dics
)),
"Scale"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight2
,
dics
))
},
inputs
=
{
"inputX_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input1
,
dics
,
batch
)),
"inputY_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input2
,
dics
,
batch
))
},
outputs
=
[
"ele_out"
,
"layernorm_out"
])
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
self
.
dynamic_shape
.
min_input_shape
=
{
"inputX_data"
:
[
4
,
128
,
768
],
"inputY_data"
:
[
4
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
]
}
self
.
dynamic_shape
.
max_input_shape
=
{
"inputX_data"
:
[
4
,
128
,
768
],
"inputY_data"
:
[
4
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
]
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"inputX_data"
:
[
4
,
128
,
768
],
"inputY_data"
:
[
4
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
]
}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
return
1
,
4
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# just support dynamic_shape
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-2
# atol=1e-2 while rtol is 1e-8
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-2
# atol=1e-2 while rtol is 1e-8
def
add_skip_trt_case
(
self
):
pass
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py
浏览文件 @
aa0e84e3
...
...
@@ -57,5 +57,37 @@ class PrelnResidualBiasFusePassTest(PassTest):
self
.
check_program
(
opt_program
)
class
PrelnResidualBiasFusePassNoBiasTest
(
PassTest
):
def
setUp
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
128
,
768
],
dtype
=
"float32"
,
lod_level
=
0
)
y
=
paddle
.
static
.
data
(
name
=
"y"
,
shape
=
[
128
,
768
],
dtype
=
"float32"
,
lod_level
=
0
)
elementwise_out
=
x
+
y
out
=
paddle
.
static
.
nn
.
layer_norm
(
input
=
elementwise_out
)
self
.
fetch_list
=
[
out
,
elementwise_out
]
self
.
pass_names
=
"preln_residual_bias_fuse_pass"
self
.
fused_op_type
=
"preln_residual_bias"
self
.
num_fused_ops
=
1
def
test_check_program
(
self
):
use_gpu_set
=
[
False
]
if
paddle
.
device
.
is_compiled_with_cuda
():
use_gpu_set
.
append
(
True
)
for
use_gpu
in
use_gpu_set
:
place
=
paddle
.
CUDAPlace
(
0
)
if
use_gpu
else
paddle
.
CPUPlace
()
opt_program
=
self
.
_apply_ir_passes
()
self
.
check_program
(
opt_program
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录