Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
8d2351c3
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8d2351c3
编写于
9月 20, 2020
作者:
W
weihaoji
提交者:
GitHub
9月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] add resnet50-D fusion (#4276)
上级
7fb2261d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
897 addition
and
17 deletion
+897
-17
lite/api/paddle_use_passes.h
lite/api/paddle_use_passes.h
+1
-0
lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc
lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc
+835
-17
lite/core/optimizer.h
lite/core/optimizer.h
+1
-0
lite/kernels/xpu/__xpu__resnet50_compute.cc
lite/kernels/xpu/__xpu__resnet50_compute.cc
+44
-0
lite/kernels/xpu/__xpu__resnet50_compute.h
lite/kernels/xpu/__xpu__resnet50_compute.h
+15
-0
lite/operators/__xpu__resnet50_op.cc
lite/operators/__xpu__resnet50_op.cc
+1
-0
未找到文件。
lite/api/paddle_use_passes.h
浏览文件 @
8d2351c3
...
...
@@ -62,6 +62,7 @@ USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS
(
control_flow_op_unused_inputs_and_outputs_eliminate_pass
)
USE_MIR_PASS
(
lite_scale_activation_fuse_pass
);
USE_MIR_PASS
(
__xpu__resnet_fuse_pass
);
USE_MIR_PASS
(
__xpu__resnet_d_fuse_pass
);
USE_MIR_PASS
(
__xpu__resnet_cbam_fuse_pass
);
USE_MIR_PASS
(
__xpu__multi_encoder_fuse_pass
);
USE_MIR_PASS
(
__xpu__embedding_with_eltwise_add_fuse_pass
);
...
...
lite/core/mir/fusion/__xpu__resnet_fuse_pass.cc
浏览文件 @
8d2351c3
...
...
@@ -307,7 +307,7 @@ class XPUResNetBlock0Fuser : public FuseBase {
matched
.
at
(
"right_bn1_variance"
)
->
arg
()
->
name
,
});
op_desc
.
SetOutput
(
"Outputs"
,
{
matched
.
at
(
"relu_out"
)
->
arg
()
->
name
});
//
XXX:
keep these to fool SubgraphOp::AttachImpl()
// keep these to fool SubgraphOp::AttachImpl()
op_desc
.
SetAttr
<
int
>
(
"sub_block"
,
0
);
op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"input_data_names"
,
{});
op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"output_data_names"
,
{});
...
...
@@ -570,7 +570,7 @@ class XPUResNetBlock1Fuser : public FuseBase {
matched
.
at
(
"right_bn3_variance"
)
->
arg
()
->
name
,
});
op_desc
.
SetOutput
(
"Outputs"
,
{
matched
.
at
(
"relu_out"
)
->
arg
()
->
name
});
//
XXX:
keep these to fool SubgraphOp::AttachImpl()
// keep these to fool SubgraphOp::AttachImpl()
op_desc
.
SetAttr
<
int
>
(
"sub_block"
,
0
);
op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"input_data_names"
,
{});
op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"output_data_names"
,
{});
...
...
@@ -599,9 +599,658 @@ class XPUResNetBlock1Fuser : public FuseBase {
}
};
class
XPUResNetDtypeBlock0Fuser
:
public
FuseBase
{
public:
XPUResNetDtypeBlock0Fuser
()
{}
void
BuildPattern
()
override
{
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
assert_is_op_input
(
"pool2d"
,
"X"
)
->
AsInput
();
auto
*
left_conv1_weight
=
VarNode
(
"left_conv1_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
left_conv1
=
OpNode
(
"left_conv1"
,
"conv2d"
);
auto
*
left_conv1_out
=
VarNode
(
"left_conv1_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
left_bn1_scale
=
VarNode
(
"left_bn1_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
left_bn1_bias
=
VarNode
(
"left_bn1_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
left_bn1_mean
=
VarNode
(
"left_bn1_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
left_bn1_var
=
VarNode
(
"left_bn1_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
left_bn1
=
OpNode
(
"left_bn1"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
left_bn1_out
=
VarNode
(
"left_bn1_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
auto
*
left_bn1_mean_out
=
VarNode
(
"left_bn1_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
left_bn1_var_out
=
VarNode
(
"left_bn1_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
left_bn1_saved_mean
=
VarNode
(
"left_bn1_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
left_bn1_saved_var
=
VarNode
(
"left_bn1_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
left_relu1
=
OpNode
(
"left_relu1"
,
"relu"
)
->
AsIntermediate
();
auto
*
left_relu1_out
=
VarNode
(
"left_relu1_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsIntermediate
();
auto
*
left_conv2_weight
=
VarNode
(
"left_conv2_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
left_conv2
=
OpNode
(
"left_conv2"
,
"conv2d"
)
->
AsIntermediate
();
auto
*
left_conv2_out
=
VarNode
(
"left_conv2_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
left_bn2_scale
=
VarNode
(
"left_bn2_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
left_bn2_bias
=
VarNode
(
"left_bn2_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
left_bn2_mean
=
VarNode
(
"left_bn2_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
left_bn2_var
=
VarNode
(
"left_bn2_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
left_bn2
=
OpNode
(
"left_bn2"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
left_bn2_out
=
VarNode
(
"left_bn2_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
auto
*
left_bn2_mean_out
=
VarNode
(
"left_bn2_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
left_bn2_var_out
=
VarNode
(
"left_bn2_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
left_bn2_saved_mean
=
VarNode
(
"left_bn2_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
left_bn2_saved_var
=
VarNode
(
"left_bn2_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
left_relu2
=
OpNode
(
"left_relu2"
,
"relu"
)
->
AsIntermediate
();
auto
*
left_relu2_out
=
VarNode
(
"left_relu2_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsIntermediate
();
auto
*
left_conv3_weight
=
VarNode
(
"left_conv3_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
left_conv3
=
OpNode
(
"left_conv3"
,
"conv2d"
)
->
AsIntermediate
();
auto
*
left_conv3_out
=
VarNode
(
"left_conv3_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
left_bn3_scale
=
VarNode
(
"left_bn3_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
left_bn3_bias
=
VarNode
(
"left_bn3_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
left_bn3_mean
=
VarNode
(
"left_bn3_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
left_bn3_var
=
VarNode
(
"left_bn3_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
left_bn3
=
OpNode
(
"left_bn3"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
left_bn3_out
=
VarNode
(
"left_bn3_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsIntermediate
();
auto
*
left_bn3_mean_out
=
VarNode
(
"left_bn3_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
left_bn3_var_out
=
VarNode
(
"left_bn3_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
left_bn3_saved_mean
=
VarNode
(
"left_bn3_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
left_bn3_saved_var
=
VarNode
(
"left_bn3_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
right_pool
=
OpNode
(
"right_pool"
,
"pool2d"
)
->
AsIntermediate
();
auto
*
right_pool_out
=
VarNode
(
"right_pool_out"
)
->
assert_is_op_output
(
"pool2d"
,
"Out"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsIntermediate
();
auto
*
right_conv1_weight
=
VarNode
(
"right_conv1_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
right_conv1
=
OpNode
(
"right_conv1"
,
"conv2d"
)
->
AsIntermediate
();
auto
*
right_conv1_out
=
VarNode
(
"right_conv1_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
right_bn1_scale
=
VarNode
(
"right_bn1_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
right_bn1_bias
=
VarNode
(
"right_bn1_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
right_bn1_mean
=
VarNode
(
"right_bn1_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
right_bn1_var
=
VarNode
(
"right_bn1_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
right_bn1
=
OpNode
(
"right_bn1"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
right_bn1_out
=
VarNode
(
"right_bn1_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
*
right_bn1_mean_out
=
VarNode
(
"right_bn1_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
right_bn1_var_out
=
VarNode
(
"right_bn1_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
right_bn1_saved_mean
=
VarNode
(
"right_bn1_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
right_bn1_saved_var
=
VarNode
(
"right_bn1_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
AsIntermediate
();
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
)
->
AsIntermediate
();
auto
*
relu_out
=
VarNode
(
"relu_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
AsOutput
();
*
input
>>
*
left_conv1
>>
*
left_conv1_out
>>
*
left_bn1
>>
*
left_bn1_out
>>
*
left_relu1
>>
*
left_relu1_out
>>
*
left_conv2
>>
*
left_conv2_out
>>
*
left_bn2
>>
*
left_bn2_out
>>
*
left_relu2
>>
*
left_relu2_out
>>
*
left_conv3
>>
*
left_conv3_out
>>
*
left_bn3
>>
*
left_bn3_out
>>
*
add
;
*
left_conv1_weight
>>
*
left_conv1
;
*
left_bn1_scale
>>
*
left_bn1
;
*
left_bn1_bias
>>
*
left_bn1
;
*
left_bn1_mean
>>
*
left_bn1
;
*
left_bn1_var
>>
*
left_bn1
;
*
left_bn1
>>
*
left_bn1_mean_out
;
*
left_bn1
>>
*
left_bn1_var_out
;
*
left_bn1
>>
*
left_bn1_saved_mean
;
*
left_bn1
>>
*
left_bn1_saved_var
;
*
left_conv2_weight
>>
*
left_conv2
;
*
left_bn2_scale
>>
*
left_bn2
;
*
left_bn2_bias
>>
*
left_bn2
;
*
left_bn2_mean
>>
*
left_bn2
;
*
left_bn2_var
>>
*
left_bn2
;
*
left_bn2
>>
*
left_bn2_mean_out
;
*
left_bn2
>>
*
left_bn2_var_out
;
*
left_bn2
>>
*
left_bn2_saved_mean
;
*
left_bn2
>>
*
left_bn2_saved_var
;
*
left_conv3_weight
>>
*
left_conv3
;
*
left_bn3_scale
>>
*
left_bn3
;
*
left_bn3_bias
>>
*
left_bn3
;
*
left_bn3_mean
>>
*
left_bn3
;
*
left_bn3_var
>>
*
left_bn3
;
*
left_bn3
>>
*
left_bn3_mean_out
;
*
left_bn3
>>
*
left_bn3_var_out
;
*
left_bn3
>>
*
left_bn3_saved_mean
;
*
left_bn3
>>
*
left_bn3_saved_var
;
*
input
>>
*
right_pool
>>
*
right_pool_out
>>
*
right_conv1
>>
*
right_conv1_out
>>
*
right_bn1
>>
*
right_bn1_out
>>
*
add
;
*
right_conv1_weight
>>
*
right_conv1
;
*
right_bn1_scale
>>
*
right_bn1
;
*
right_bn1_bias
>>
*
right_bn1
;
*
right_bn1_mean
>>
*
right_bn1
;
*
right_bn1_var
>>
*
right_bn1
;
*
right_bn1
>>
*
right_bn1_mean_out
;
*
right_bn1
>>
*
right_bn1_var_out
;
*
right_bn1
>>
*
right_bn1_saved_mean
;
*
right_bn1
>>
*
right_bn1_saved_var
;
*
add
>>
*
add_out
>>
*
relu
>>
*
relu_out
;
}
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"resnet_block0_d"
);
op_desc
.
SetInput
(
"Inputs"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"left_conv1_weight"
)
->
arg
()
->
name
,
matched
.
at
(
"left_conv2_weight"
)
->
arg
()
->
name
,
matched
.
at
(
"left_conv3_weight"
)
->
arg
()
->
name
,
matched
.
at
(
"right_conv1_weight"
)
->
arg
()
->
name
,
});
op_desc
.
SetInput
(
"Scale"
,
{
matched
.
at
(
"left_bn1_scale"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn2_scale"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn3_scale"
)
->
arg
()
->
name
,
matched
.
at
(
"right_bn1_scale"
)
->
arg
()
->
name
,
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"left_bn1_bias"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn2_bias"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn3_bias"
)
->
arg
()
->
name
,
matched
.
at
(
"right_bn1_bias"
)
->
arg
()
->
name
,
});
op_desc
.
SetInput
(
"Mean"
,
{
matched
.
at
(
"left_bn1_mean"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn2_mean"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn3_mean"
)
->
arg
()
->
name
,
matched
.
at
(
"right_bn1_mean"
)
->
arg
()
->
name
,
});
op_desc
.
SetInput
(
"Var"
,
{
matched
.
at
(
"left_bn1_variance"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn2_variance"
)
->
arg
()
->
name
,
matched
.
at
(
"left_bn3_variance"
)
->
arg
()
->
name
,
matched
.
at
(
"right_bn1_variance"
)
->
arg
()
->
name
,
});
op_desc
.
SetOutput
(
"Outputs"
,
{
matched
.
at
(
"relu_out"
)
->
arg
()
->
name
});
// keep these to fool SubgraphOp::AttachImpl()
op_desc
.
SetAttr
<
int
>
(
"sub_block"
,
0
);
op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"input_data_names"
,
{});
op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"output_data_names"
,
{});
auto
block0_stmt
=
matched
.
at
(
"left_conv1"
)
->
stmt
();
// block0_stmt->ResetOp(op_desc, graph->valid_places());
auto
fake_subgraph_op
=
LiteOpRegistry
::
Global
().
Create
(
"subgraph"
);
auto
sub_program_desc
=
std
::
make_shared
<
cpp
::
ProgramDesc
>
();
sub_program_desc
->
AddBlock
<
cpp
::
BlockDesc
>
();
static_cast
<
operators
::
SubgraphOp
*>
(
fake_subgraph_op
.
get
())
->
SetProgramDesc
(
sub_program_desc
);
fake_subgraph_op
->
Attach
(
op_desc
,
block0_stmt
->
op
()
->
scope
());
fake_subgraph_op
->
SetValidPlaces
(
block0_stmt
->
op
()
->
valid_places
());
block0_stmt
->
SetOp
(
fake_subgraph_op
);
std
::
vector
<
std
::
string
>
froms
=
{
"left_conv2_weight"
,
"left_conv3_weight"
,
"right_conv1_weight"
,
"left_bn1_bias"
,
"left_bn2_bias"
,
"left_bn3_bias"
,
"right_bn1_bias"
,
};
for
(
auto
&
from
:
froms
)
{
IR_NODE_LINK_TO
(
matched
.
at
(
from
),
matched
.
at
(
"left_conv1"
));
}
IR_OP_VAR_LINK
(
matched
.
at
(
"left_conv1"
),
matched
.
at
(
"relu_out"
));
}
};
class
XPUResNet50Fuser
:
public
xpu
::
XPUFuseBase
{
public:
XPUResNet50Fuser
()
{}
XPUResNet50Fuser
()
{}
void
BuildPattern
()
override
{
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsInput
();
auto
*
top_conv_weight
=
VarNode
(
"top_conv_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
top_conv
=
OpNode
(
"top_conv"
,
"conv2d"
);
auto
*
top_conv_out
=
VarNode
(
"top_conv_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
top_bn_scale
=
VarNode
(
"top_bn_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
top_bn_bias
=
VarNode
(
"top_bn_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
top_bn_mean
=
VarNode
(
"top_bn_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
top_bn_var
=
VarNode
(
"top_bn_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
top_bn
=
OpNode
(
"top_bn"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
top_bn_out
=
VarNode
(
"top_bn_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
auto
*
top_bn_mean_out
=
VarNode
(
"top_bn_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
top_bn_var_out
=
VarNode
(
"top_bn_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
top_bn_saved_mean
=
VarNode
(
"top_bn_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
top_bn_saved_var
=
VarNode
(
"top_bn_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
top_relu
=
OpNode
(
"top_relu"
,
"relu"
)
->
AsIntermediate
();
auto
*
top_relu_out
=
VarNode
(
"top_relu_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
assert_is_op_input
(
"pool2d"
,
"X"
)
->
AsIntermediate
();
auto
*
top_pool
=
OpNode
(
"top_pool"
,
"pool2d"
)
->
AsIntermediate
();
auto
*
top_pool_out
=
VarNode
(
"top_pool_out"
)
->
assert_is_op_output
(
"pool2d"
,
"Out"
)
->
assert_is_op_input
(
"resnet_block0"
,
"Inputs"
)
->
AsIntermediate
();
// args are left out
auto
*
resnet_block0_1
=
OpNode
(
"resnet_block0_1"
,
"resnet_block0"
)
->
AsIntermediate
();
auto
*
resnet_block0_1_out
=
VarNode
(
"resnet_block0_1_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_1_1
=
OpNode
(
"resnet_block1_1_1"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_1_1_out
=
VarNode
(
"resnet_block1_1_1_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_1_2
=
OpNode
(
"resnet_block1_1_2"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_1_2_out
=
VarNode
(
"resnet_block1_1_2_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block0_2
=
OpNode
(
"resnet_block0_2"
,
"resnet_block0"
)
->
AsIntermediate
();
auto
*
resnet_block0_2_out
=
VarNode
(
"resnet_block0_2_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_1
=
OpNode
(
"resnet_block1_2_1"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_1_out
=
VarNode
(
"resnet_block1_2_1_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_2
=
OpNode
(
"resnet_block1_2_2"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_2_out
=
VarNode
(
"resnet_block1_2_2_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_3
=
OpNode
(
"resnet_block1_2_3"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_3_out
=
VarNode
(
"resnet_block1_2_3_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block0_3
=
OpNode
(
"resnet_block0_3"
,
"resnet_block0"
)
->
AsIntermediate
();
auto
*
resnet_block0_3_out
=
VarNode
(
"resnet_block0_3_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_1
=
OpNode
(
"resnet_block1_3_1"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_1_out
=
VarNode
(
"resnet_block1_3_1_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_2
=
OpNode
(
"resnet_block1_3_2"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_2_out
=
VarNode
(
"resnet_block1_3_2_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_3
=
OpNode
(
"resnet_block1_3_3"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_3_out
=
VarNode
(
"resnet_block1_3_3_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_4
=
OpNode
(
"resnet_block1_3_4"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_4_out
=
VarNode
(
"resnet_block1_3_4_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_5
=
OpNode
(
"resnet_block1_3_5"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_5_out
=
VarNode
(
"resnet_block1_3_5_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block0_4
=
OpNode
(
"resnet_block0_4"
,
"resnet_block0"
)
->
AsIntermediate
();
auto
*
resnet_block0_4_out
=
VarNode
(
"resnet_block0_4_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_4_1
=
OpNode
(
"resnet_block1_4_1"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_4_1_out
=
VarNode
(
"resnet_block1_4_1_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_4_2
=
OpNode
(
"resnet_block1_4_2"
,
"resnet_block1"
)
->
AsIntermediate
();
auto
*
resnet_block1_4_2_out
=
VarNode
(
"resnet_block1_4_2_out"
)
->
assert_is_op_output
(
"resnet_block1"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
bottom_pool
=
OpNode
(
"bottom_pool"
,
"pool2d"
)
->
AsIntermediate
();
auto
*
bottom_pool_out
=
VarNode
(
"bottom_pool_out"
)
->
assert_is_op_output
(
"pool2d"
,
"Out"
)
->
AsOutput
();
*
input
>>
*
top_conv
>>
*
top_conv_out
>>
*
top_bn
>>
*
top_bn_out
>>
*
top_relu
>>
*
top_relu_out
>>
*
top_pool
>>
*
top_pool_out
>>
*
resnet_block0_1
>>
*
resnet_block0_1_out
>>
*
resnet_block1_1_1
>>
*
resnet_block1_1_1_out
>>
*
resnet_block1_1_2
>>
*
resnet_block1_1_2_out
>>
*
resnet_block0_2
>>
*
resnet_block0_2_out
>>
*
resnet_block1_2_1
>>
*
resnet_block1_2_1_out
>>
*
resnet_block1_2_2
>>
*
resnet_block1_2_2_out
>>
*
resnet_block1_2_3
>>
*
resnet_block1_2_3_out
>>
*
resnet_block0_3
>>
*
resnet_block0_3_out
>>
*
resnet_block1_3_1
>>
*
resnet_block1_3_1_out
>>
*
resnet_block1_3_2
>>
*
resnet_block1_3_2_out
>>
*
resnet_block1_3_3
>>
*
resnet_block1_3_3_out
>>
*
resnet_block1_3_4
>>
*
resnet_block1_3_4_out
>>
*
resnet_block1_3_5
>>
*
resnet_block1_3_5_out
>>
*
resnet_block0_4
>>
*
resnet_block0_4_out
>>
*
resnet_block1_4_1
>>
*
resnet_block1_4_1_out
>>
*
resnet_block1_4_2
>>
*
resnet_block1_4_2_out
>>
*
bottom_pool
>>
*
bottom_pool_out
;
*
top_conv_weight
>>
*
top_conv
;
*
top_bn_scale
>>
*
top_bn
;
*
top_bn_bias
>>
*
top_bn
;
*
top_bn_mean
>>
*
top_bn
;
*
top_bn_var
>>
*
top_bn
;
*
top_bn
>>
*
top_bn_mean_out
;
*
top_bn
>>
*
top_bn_var_out
;
*
top_bn
>>
*
top_bn_saved_mean
;
*
top_bn
>>
*
top_bn_saved_var
;
}
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
,
const
std
::
vector
<
Node
*>&
extra_input_vars
)
override
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"__xpu__resnet50"
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
std
::
vector
<
std
::
string
>
filter_name
=
{
matched
.
at
(
"top_conv_weight"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
scale_name
=
{
matched
.
at
(
"top_bn_scale"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
bias_name
=
{
matched
.
at
(
"top_bn_bias"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
mean_name
=
{
matched
.
at
(
"top_bn_mean"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
var_name
=
{
matched
.
at
(
"top_bn_variance"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
max_filter_name
;
std
::
vector
<
std
::
string
>
resnet_block_vec
=
{
"resnet_block0_1"
,
"resnet_block1_1_1"
,
"resnet_block1_1_2"
,
"resnet_block0_2"
,
"resnet_block1_2_1"
,
"resnet_block1_2_2"
,
"resnet_block1_2_3"
,
"resnet_block0_3"
,
"resnet_block1_3_1"
,
"resnet_block1_3_2"
,
"resnet_block1_3_3"
,
"resnet_block1_3_4"
,
"resnet_block1_3_5"
,
"resnet_block0_4"
,
"resnet_block1_4_1"
,
"resnet_block1_4_2"
,
};
for
(
auto
&
block
:
resnet_block_vec
)
{
auto
*
block_op_info
=
matched
.
at
(
block
)
->
stmt
()
->
op_info
();
auto
block_filter_name
=
block_op_info
->
Input
(
"Filter"
);
std
::
copy
(
block_filter_name
.
begin
(),
block_filter_name
.
end
(),
std
::
back_inserter
(
filter_name
));
auto
block_scale_name
=
block_op_info
->
Input
(
"Scale"
);
std
::
copy
(
block_scale_name
.
begin
(),
block_scale_name
.
end
(),
std
::
back_inserter
(
scale_name
));
auto
block_bias_name
=
block_op_info
->
Input
(
"Bias"
);
std
::
copy
(
block_bias_name
.
begin
(),
block_bias_name
.
end
(),
std
::
back_inserter
(
bias_name
));
auto
block_mean_name
=
block_op_info
->
Input
(
"Mean"
);
std
::
copy
(
block_mean_name
.
begin
(),
block_mean_name
.
end
(),
std
::
back_inserter
(
mean_name
));
auto
block_var_name
=
block_op_info
->
Input
(
"Var"
);
std
::
copy
(
block_var_name
.
begin
(),
block_var_name
.
end
(),
std
::
back_inserter
(
var_name
));
}
op_desc
.
SetInput
(
"Filter"
,
filter_name
);
op_desc
.
SetInput
(
"Bias"
,
bias_name
);
op_desc
.
SetOutput
(
"Output"
,
{
matched
.
at
(
"bottom_pool_out"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
<
int
>
(
"xpu"
,
1
);
auto
*
resnet50_stmt
=
matched
.
at
(
"top_conv"
)
->
stmt
();
auto
*
scope
=
resnet50_stmt
->
op
()
->
scope
();
for
(
size_t
i
=
0
;
i
<
filter_name
.
size
();
++
i
)
{
auto
*
filter_t
=
scope
->
FindMutableTensor
(
filter_name
[
i
]);
auto
*
scale_t
=
scope
->
FindMutableTensor
(
scale_name
[
i
]);
auto
*
bias_t
=
scope
->
FindMutableTensor
(
bias_name
[
i
]);
auto
*
mean_t
=
scope
->
FindMutableTensor
(
mean_name
[
i
]);
auto
*
var_t
=
scope
->
FindMutableTensor
(
var_name
[
i
]);
int
mean_len
=
mean_t
->
numel
();
int
filter_len
=
filter_t
->
numel
();
int
filter_stride
=
filter_len
/
mean_len
;
float
*
filter_on_host
=
filter_t
->
mutable_data
<
float
>
();
float
*
scale_on_host
=
scale_t
->
mutable_data
<
float
>
();
float
*
bias_on_host
=
bias_t
->
mutable_data
<
float
>
();
float
*
mean_on_host
=
mean_t
->
mutable_data
<
float
>
();
float
*
var_on_host
=
var_t
->
mutable_data
<
float
>
();
// Perform preprocess
for
(
int
i
=
0
;
i
<
mean_len
;
++
i
)
{
scale_on_host
[
i
]
=
scale_on_host
[
i
]
/
sqrtf
(
var_on_host
[
i
]
+
0.00001
f
);
}
for
(
int
i
=
0
;
i
<
mean_len
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_stride
;
++
j
)
{
filter_on_host
[
i
*
filter_stride
+
j
]
*=
scale_on_host
[
i
];
}
}
for
(
int
i
=
0
;
i
<
mean_len
;
++
i
)
{
bias_on_host
[
i
]
+=
-
mean_on_host
[
i
]
*
scale_on_host
[
i
];
}
float
max_f
=
paddle
::
lite
::
xpu
::
math
::
FindMaxAbs
(
filter_on_host
,
filter_len
);
std
::
unique_ptr
<
int16_t
[]
>
filter_int16
(
new
int16_t
[
filter_len
]);
paddle
::
lite
::
xpu
::
math
::
ConvertFP32ToInt16
(
filter_on_host
,
filter_int16
.
get
(),
max_f
,
filter_len
);
memcpy
(
filter_on_host
,
filter_int16
.
get
(),
filter_len
*
sizeof
(
int16_t
));
// create new arg in graph and scope
std
::
string
max_name
=
filter_name
[
i
]
+
"_max"
;
max_filter_name
.
push_back
(
max_name
);
auto
*
max_filter_node
=
graph
->
NewArgumentNode
(
max_name
);
max_filter_node
->
arg
()
->
is_weight
=
true
;
max_filter_node
->
arg
()
->
type
=
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
DirectedLink
(
max_filter_node
,
matched
.
at
(
"top_conv"
));
auto
*
max_filter_t
=
scope
->
NewTensor
(
max_name
);
max_filter_t
->
Resize
({
4
});
float
*
max_ptr
=
max_filter_t
->
mutable_data
<
float
>
();
max_ptr
[
0
]
=
max_f
;
max_ptr
[
1
]
=
max_f
;
max_ptr
[
2
]
=
max_f
;
max_ptr
[
3
]
=
max_f
;
}
op_desc
.
SetInput
(
"MaxFilter"
,
max_filter_name
);
auto
resnet50_op
=
LiteOpRegistry
::
Global
().
Create
(
op_desc
.
Type
());
resnet50_op
->
Attach
(
op_desc
,
scope
);
resnet50_op
->
SetValidPlaces
(
resnet50_stmt
->
op
()
->
valid_places
());
auto
kernels
=
resnet50_op
->
CreateKernels
(
resnet50_op
->
valid_places
());
resnet50_stmt
->
SetOp
(
resnet50_op
);
resnet50_stmt
->
SetKernels
(
std
::
move
(
kernels
));
IR_NODE_LINK_TO
(
matched
.
at
(
"top_bn_bias"
),
matched
.
at
(
"top_conv"
));
for
(
auto
*
node
:
extra_input_vars
)
{
IR_NODE_LINK_TO
(
node
,
matched
.
at
(
"top_conv"
));
}
IR_OP_VAR_LINK
(
matched
.
at
(
"top_conv"
),
matched
.
at
(
"bottom_pool_out"
));
}
};
class
XPUResNet50DtypeFuser
:
public
xpu
::
XPUFuseBase
{
public:
XPUResNet50DtypeFuser
()
{}
void
BuildPattern
()
override
{
auto
*
input
=
...
...
@@ -650,8 +1299,102 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
auto
*
top_relu
=
OpNode
(
"top_relu"
,
"relu"
)
->
AsIntermediate
();
auto
*
top_relu_out
=
VarNode
(
"top_relu_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
assert_is_op_input
(
"pool2d"
,
"X"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsIntermediate
();
auto
*
second_conv_weight
=
VarNode
(
"second_conv_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
second_conv
=
OpNode
(
"second_conv"
,
"conv2d"
)
->
AsIntermediate
();
auto
*
second_conv_out
=
VarNode
(
"second_conv_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
second_bn_scale
=
VarNode
(
"second_bn_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
second_bn_bias
=
VarNode
(
"second_bn_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
second_bn_mean
=
VarNode
(
"second_bn_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
second_bn_var
=
VarNode
(
"second_bn_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
second_bn
=
OpNode
(
"second_bn"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
second_bn_out
=
VarNode
(
"second_bn_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
auto
*
second_bn_mean_out
=
VarNode
(
"second_bn_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
second_bn_var_out
=
VarNode
(
"second_bn_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
second_bn_saved_mean
=
VarNode
(
"second_bn_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
second_bn_saved_var
=
VarNode
(
"second_bn_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
second_relu
=
OpNode
(
"second_relu"
,
"relu"
)
->
AsIntermediate
();
auto
*
second_relu_out
=
VarNode
(
"second_relu_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsIntermediate
();
auto
*
third_conv_weight
=
VarNode
(
"third_conv_weight"
)
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
*
third_conv
=
OpNode
(
"third_conv"
,
"conv2d"
)
->
AsIntermediate
();
auto
*
third_conv_out
=
VarNode
(
"third_conv_out"
)
->
assert_is_op_output
(
"conv2d"
,
"Output"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
AsIntermediate
();
auto
*
third_bn_scale
=
VarNode
(
"third_bn_scale"
)
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
AsIntermediate
();
auto
*
third_bn_bias
=
VarNode
(
"third_bn_bias"
)
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
AsInput
();
auto
*
third_bn_mean
=
VarNode
(
"third_bn_mean"
)
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
AsIntermediate
();
auto
*
third_bn_var
=
VarNode
(
"third_bn_variance"
)
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
AsIntermediate
();
auto
*
third_bn
=
OpNode
(
"third_bn"
,
"batch_norm"
)
->
AsIntermediate
();
auto
*
third_bn_out
=
VarNode
(
"third_bn_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
auto
*
third_bn_mean_out
=
VarNode
(
"third_bn_mean_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
)
->
AsIntermediate
();
auto
*
third_bn_var_out
=
VarNode
(
"third_bn_var_out"
)
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
)
->
AsIntermediate
();
auto
*
third_bn_saved_mean
=
VarNode
(
"third_bn_saved_mean"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
)
->
AsIntermediate
();
auto
*
third_bn_saved_var
=
VarNode
(
"third_bn_saved_var"
)
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
)
->
AsIntermediate
();
auto
*
third_relu
=
OpNode
(
"third_relu"
,
"relu"
)
->
AsIntermediate
();
auto
*
third_relu_out
=
VarNode
(
"third_relu_out"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
assert_is_op_input
(
"pool2d"
,
"X"
)
->
AsIntermediate
();
auto
*
top_pool
=
OpNode
(
"top_pool"
,
"pool2d"
)
->
AsIntermediate
();
auto
*
top_pool_out
=
VarNode
(
"top_pool_out"
)
->
assert_is_op_output
(
"pool2d"
,
"Out"
)
...
...
@@ -679,10 +1422,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->
AsIntermediate
();
auto
*
resnet_block0_2
=
OpNode
(
"resnet_block0_2"
,
"resnet_block0"
)
->
AsIntermediate
();
OpNode
(
"resnet_block0_2"
,
"resnet_block0
_d
"
)
->
AsIntermediate
();
auto
*
resnet_block0_2_out
=
VarNode
(
"resnet_block0_2_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
assert_is_op_output
(
"resnet_block0
_d
"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_2_1
=
OpNode
(
"resnet_block1_2_1"
,
"resnet_block1"
)
->
AsIntermediate
();
...
...
@@ -704,10 +1447,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->
AsIntermediate
();
auto
*
resnet_block0_3
=
OpNode
(
"resnet_block0_3"
,
"resnet_block0"
)
->
AsIntermediate
();
OpNode
(
"resnet_block0_3"
,
"resnet_block0
_d
"
)
->
AsIntermediate
();
auto
*
resnet_block0_3_out
=
VarNode
(
"resnet_block0_3_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
assert_is_op_output
(
"resnet_block0
_d
"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_3_1
=
OpNode
(
"resnet_block1_3_1"
,
"resnet_block1"
)
->
AsIntermediate
();
...
...
@@ -741,10 +1484,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->
AsIntermediate
();
auto
*
resnet_block0_4
=
OpNode
(
"resnet_block0_4"
,
"resnet_block0"
)
->
AsIntermediate
();
OpNode
(
"resnet_block0_4"
,
"resnet_block0
_d
"
)
->
AsIntermediate
();
auto
*
resnet_block0_4_out
=
VarNode
(
"resnet_block0_4_out"
)
->
assert_is_op_output
(
"resnet_block0"
,
"Outputs"
)
->
assert_is_op_output
(
"resnet_block0
_d
"
,
"Outputs"
)
->
AsIntermediate
();
auto
*
resnet_block1_4_1
=
OpNode
(
"resnet_block1_4_1"
,
"resnet_block1"
)
->
AsIntermediate
();
...
...
@@ -765,7 +1508,10 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
->
AsOutput
();
*
input
>>
*
top_conv
>>
*
top_conv_out
>>
*
top_bn
>>
*
top_bn_out
>>
*
top_relu
>>
*
top_relu_out
>>
*
top_pool
>>
*
top_pool_out
>>
*
top_relu
>>
*
top_relu_out
>>
*
second_conv
>>
*
second_conv_out
>>
*
second_bn
>>
*
second_bn_out
>>
*
second_relu
>>
*
second_relu_out
>>
*
third_conv
>>
*
third_conv_out
>>
*
third_bn
>>
*
third_bn_out
>>
*
third_relu
>>
*
third_relu_out
>>
*
top_pool
>>
*
top_pool_out
>>
*
resnet_block0_1
>>
*
resnet_block0_1_out
>>
*
resnet_block1_1_1
>>
*
resnet_block1_1_1_out
>>
*
resnet_block1_1_2
>>
*
resnet_block1_1_2_out
>>
*
resnet_block0_2
>>
*
resnet_block0_2_out
>>
...
...
@@ -789,24 +1535,59 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
*
top_bn
>>
*
top_bn_var_out
;
*
top_bn
>>
*
top_bn_saved_mean
;
*
top_bn
>>
*
top_bn_saved_var
;
*
second_conv_weight
>>
*
second_conv
;
*
second_bn_scale
>>
*
second_bn
;
*
second_bn_bias
>>
*
second_bn
;
*
second_bn_mean
>>
*
second_bn
;
*
second_bn_var
>>
*
second_bn
;
*
second_bn
>>
*
second_bn_mean_out
;
*
second_bn
>>
*
second_bn_var_out
;
*
second_bn
>>
*
second_bn_saved_mean
;
*
second_bn
>>
*
second_bn_saved_var
;
*
third_conv_weight
>>
*
third_conv
;
*
third_bn_scale
>>
*
third_bn
;
*
third_bn_bias
>>
*
third_bn
;
*
third_bn_mean
>>
*
third_bn
;
*
third_bn_var
>>
*
third_bn
;
*
third_bn
>>
*
third_bn_mean_out
;
*
third_bn
>>
*
third_bn_var_out
;
*
third_bn
>>
*
third_bn_saved_mean
;
*
third_bn
>>
*
third_bn_saved_var
;
}
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
,
const
std
::
vector
<
Node
*>&
extra_input_vars
)
override
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"__xpu__resnet50"
);
op_desc
.
SetType
(
"__xpu__resnet50
_d
"
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
std
::
vector
<
std
::
string
>
filter_name
=
{
matched
.
at
(
"top_conv_weight"
)
->
arg
()
->
name
};
matched
.
at
(
"top_conv_weight"
)
->
arg
()
->
name
,
matched
.
at
(
"second_conv_weight"
)
->
arg
()
->
name
,
matched
.
at
(
"third_conv_weight"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
scale_name
=
{
matched
.
at
(
"top_bn_scale"
)
->
arg
()
->
name
};
matched
.
at
(
"top_bn_scale"
)
->
arg
()
->
name
,
matched
.
at
(
"second_bn_scale"
)
->
arg
()
->
name
,
matched
.
at
(
"third_bn_scale"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
bias_name
=
{
matched
.
at
(
"top_bn_bias"
)
->
arg
()
->
name
};
matched
.
at
(
"top_bn_bias"
)
->
arg
()
->
name
,
matched
.
at
(
"second_bn_bias"
)
->
arg
()
->
name
,
matched
.
at
(
"third_bn_bias"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
mean_name
=
{
matched
.
at
(
"top_bn_mean"
)
->
arg
()
->
name
};
matched
.
at
(
"top_bn_mean"
)
->
arg
()
->
name
,
matched
.
at
(
"second_bn_mean"
)
->
arg
()
->
name
,
matched
.
at
(
"third_bn_mean"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
var_name
=
{
matched
.
at
(
"top_bn_variance"
)
->
arg
()
->
name
};
matched
.
at
(
"top_bn_variance"
)
->
arg
()
->
name
,
matched
.
at
(
"second_bn_variance"
)
->
arg
()
->
name
,
matched
.
at
(
"third_bn_variance"
)
->
arg
()
->
name
};
std
::
vector
<
std
::
string
>
max_filter_name
;
std
::
vector
<
std
::
string
>
resnet_block_vec
=
{
"resnet_block0_1"
,
...
...
@@ -900,7 +1681,9 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
max_filter_node
->
arg
()
->
is_weight
=
true
;
max_filter_node
->
arg
()
->
type
=
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
DirectedLink
(
max_filter_node
,
matched
.
at
(
"top_conv"
));
auto
*
max_filter_t
=
scope
->
NewTensor
(
max_name
);
max_filter_t
->
Resize
({
4
});
float
*
max_ptr
=
max_filter_t
->
mutable_data
<
float
>
();
...
...
@@ -919,6 +1702,11 @@ class XPUResNet50Fuser : public xpu::XPUFuseBase {
resnet50_stmt
->
SetKernels
(
std
::
move
(
kernels
));
IR_NODE_LINK_TO
(
matched
.
at
(
"top_bn_bias"
),
matched
.
at
(
"top_conv"
));
IR_NODE_LINK_TO
(
matched
.
at
(
"second_conv_weight"
),
matched
.
at
(
"top_conv"
));
IR_NODE_LINK_TO
(
matched
.
at
(
"second_bn_bias"
),
matched
.
at
(
"top_conv"
));
IR_NODE_LINK_TO
(
matched
.
at
(
"third_conv_weight"
),
matched
.
at
(
"top_conv"
));
IR_NODE_LINK_TO
(
matched
.
at
(
"third_bn_bias"
),
matched
.
at
(
"top_conv"
));
for
(
auto
*
node
:
extra_input_vars
)
{
IR_NODE_LINK_TO
(
node
,
matched
.
at
(
"top_conv"
));
}
...
...
@@ -951,6 +1739,31 @@ class XPUResNet50FusePass : public ProgramPass {
}
};
class
XPUResNet50DtypeFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
{
if
(
GetBoolFromEnv
(
"XPU_ENABLE_XTCL"
))
return
;
bool
changed
=
false
;
SSAGraph
backup
;
backup
.
CloneFrom
(
*
graph
);
fusion
::
XPUResNetBlock0Fuser
block0_fuser
;
changed
|=
block0_fuser
(
graph
.
get
());
fusion
::
XPUResNetDtypeBlock0Fuser
d_type_block0_fuser
;
changed
|=
d_type_block0_fuser
(
graph
.
get
());
fusion
::
XPUResNetBlock1Fuser
block1_fuser
;
changed
|=
block1_fuser
(
graph
.
get
());
fusion
::
XPUResNet50DtypeFuser
resnet50_d_fuser
;
size_t
n_matches
=
resnet50_d_fuser
(
graph
.
get
());
if
(
changed
&&
!
n_matches
)
{
// Restore graph from backuped one if no whole ResNet50 graph was found
graph
->
CloneFrom
(
backup
);
}
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
...
...
@@ -959,3 +1772,8 @@ REGISTER_MIR_PASS(__xpu__resnet_fuse_pass,
paddle
::
lite
::
mir
::
XPUResNet50FusePass
)
.
BindTargets
({
TARGET
(
kXPU
)})
.
BindKernel
(
"__xpu__resnet50"
);
REGISTER_MIR_PASS
(
__xpu__resnet_d_fuse_pass
,
paddle
::
lite
::
mir
::
XPUResNet50DtypeFusePass
)
.
BindTargets
({
TARGET
(
kXPU
)})
.
BindKernel
(
"__xpu__resnet50_d"
);
lite/core/optimizer.h
浏览文件 @
8d2351c3
...
...
@@ -108,6 +108,7 @@ class Optimizer {
#endif
"identity_dropout_eliminate_pass"
,
"__xpu__resnet_fuse_pass"
,
"__xpu__resnet_d_fuse_pass"
,
"__xpu__resnet_cbam_fuse_pass"
,
"__xpu__conv2d_fuse_pass"
,
"__xpu__conv2d_link_previous_out_max_pass"
,
...
...
lite/kernels/xpu/__xpu__resnet50_compute.cc
浏览文件 @
8d2351c3
...
...
@@ -34,6 +34,21 @@ void XPUResNet50Compute::PrepareForRun() {
}
}
void
XPUResNet50DtypeCompute
::
PrepareForRun
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
for
(
auto
*
filter
:
param
.
filter
)
{
arg_filter_
.
push_back
(
reinterpret_cast
<
const
int16_t
*>
(
filter
->
data
<
float
>
()));
}
for
(
auto
*
bias
:
param
.
bias
)
{
arg_bias_
.
push_back
(
bias
->
data
<
float
>
());
}
for
(
auto
*
max_filter
:
param
.
max_filter
)
{
arg_max_filter_
.
push_back
(
max_filter
->
data
<
float
>
());
}
}
void
XPUResNet50Compute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
As
<
XPUContext
>
();
...
...
@@ -50,6 +65,22 @@ void XPUResNet50Compute::Run() {
CHECK_EQ
(
r
,
0
);
}
void
XPUResNet50DtypeCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
As
<
XPUContext
>
();
int
batch_size
=
param
.
input
->
dims
()[
0
];
int
r
=
xdnn
::
conv2d_int16_resnet_d
<
float
,
int16_t
>
(
ctx
.
GetRawContext
(),
/* context */
batch_size
,
/* num */
param
.
input
->
data
<
float
>
(),
/* bottom */
&
arg_filter_
[
0
],
/* weight_list */
param
.
output
->
mutable_data
<
float
>
(
TARGET
(
kXPU
)),
/* top */
&
arg_bias_
[
0
],
/* bias_list */
&
arg_max_filter_
[
0
]
/* max_filter_list */
);
CHECK_EQ
(
r
,
0
);
}
}
// namespace xpu
}
// namespace kernels
}
// namespace lite
...
...
@@ -67,3 +98,16 @@ REGISTER_LITE_KERNEL(__xpu__resnet50,
.
BindInput
(
"MaxFilter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
__xpu__resnet50_d
,
kXPU
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
xpu
::
XPUResNet50DtypeCompute
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
BindInput
(
"MaxFilter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kXPU
))})
.
Finalize
();
lite/kernels/xpu/__xpu__resnet50_compute.h
浏览文件 @
8d2351c3
...
...
@@ -38,6 +38,21 @@ class XPUResNet50Compute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
std
::
vector
<
const
float
*>
arg_bias_
;
};
class
XPUResNet50DtypeCompute
:
public
KernelLite
<
TARGET
(
kXPU
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
XPUResNet50Param
;
virtual
void
PrepareForRun
();
virtual
void
Run
();
private:
std
::
vector
<
const
int16_t
*>
arg_filter_
;
std
::
vector
<
const
float
*>
arg_max_filter_
;
std
::
vector
<
const
float
*>
arg_bias_
;
};
}
// namespace xpu
}
// namespace kernels
}
// namespace lite
...
...
lite/operators/__xpu__resnet50_op.cc
浏览文件 @
8d2351c3
...
...
@@ -62,3 +62,4 @@ bool XPUResNet50Op::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
}
// namespace paddle
REGISTER_LITE_OP
(
__xpu__resnet50
,
paddle
::
lite
::
operators
::
XPUResNet50Op
);
REGISTER_LITE_OP
(
__xpu__resnet50_d
,
paddle
::
lite
::
operators
::
XPUResNet50Op
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录