Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0fb9cc41
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0fb9cc41
编写于
5月 20, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(gopt): fix nchw64 opt pass
GitOrigin-RevId: dec18d1ab1b7bd0723395e490c215356f178e44a
上级
e661ae90
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
310 addition
and
123 deletion
+310
-123
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
+109
-54
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+138
-68
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+63
-1
未找到文件。
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
浏览文件 @
0fb9cc41
...
...
@@ -36,15 +36,23 @@ struct SubGraphMatcher {
Node
(
Typeinfo
*
in_op_type
)
:
op_type
(
in_op_type
){};
Node
(
Typeinfo
*
in_op_type
,
CallBack
func
)
:
op_type
(
in_op_type
),
cbk
(
func
){};
Node
(
Typeinfo
*
in_op_type
,
std
::
vector
<
Node
>
in_pre_node
)
Node
(
Typeinfo
*
in_op_type
,
std
::
vector
<
std
::
vector
<
Node
>
>
in_pre_node
)
:
op_type
(
in_op_type
),
pre_node
(
in_pre_node
){};
Node
(
Typeinfo
*
in_op_type
,
std
::
vector
<
Node
>
in_pre_node
,
CallBack
func
)
Node
(
Typeinfo
*
in_op_type
,
std
::
vector
<
std
::
vector
<
Node
>>
in_pre_node
,
CallBack
func
)
:
op_type
(
in_op_type
),
pre_node
(
in_pre_node
),
cbk
(
func
){};
Node
(
Typeinfo
*
in_op_type
,
std
::
vector
<
std
::
vector
<
Node
>>
in_pre_node
,
CallBack
func
,
std
::
string
in_msg
)
:
op_type
(
in_op_type
),
pre_node
(
in_pre_node
),
cbk
(
func
),
msg
(
in_msg
){};
Typeinfo
*
op_type
{
nullptr
};
std
::
vector
<
Node
>
pre_node
;
std
::
vector
<
std
::
vector
<
Node
>
>
pre_node
;
//! cbk used to check param and gather args for creating fusion op
CallBack
cbk
;
std
::
string
msg
{
""
};
};
bool
match
(
Node
&
root
,
OperatorNodeBase
*
opr
)
{
...
...
@@ -53,20 +61,34 @@ struct SubGraphMatcher {
}
//! match nullptr node always
if
(
root
.
op_type
==
nullptr
||
root
.
op_type
==
opr
->
dyn_typeinfo
())
{
bool
match_ok
=
true
;
bool
current_match
=
true
;
if
(
root
.
cbk
)
match_ok
&=
root
.
cbk
(
opr
);
RETURN_IF_FALSE
(
match_ok
);
current_match
&=
root
.
cbk
(
opr
);
RETURN_IF_FALSE
(
current_match
);
auto
&
inp
=
opr
->
input
();
for
(
size_t
node_idx
=
0
;
node_idx
<
root
.
pre_node
.
size
();
++
node_idx
)
{
bool
valid_node_idx
=
node_idx
<
inp
.
size
();
RETURN_IF_FALSE
(
valid_node_idx
);
match_ok
&=
match
(
root
.
pre_node
[
node_idx
],
inp
[
node_idx
]
->
owner_opr
());
RETURN_IF_FALSE
(
match_ok
);
bool
any_sub_patten_match
=
root
.
pre_node
.
size
()
==
0
?
true
:
false
;
for
(
auto
&
sub_patten
:
root
.
pre_node
)
{
bool
patten_ok
=
true
;
for
(
size_t
node_idx
=
0
;
node_idx
<
sub_patten
.
size
();
++
node_idx
)
{
bool
valid_node_idx
=
node_idx
<
inp
.
size
();
if
(
!
valid_node_idx
)
{
patten_ok
=
false
;
break
;
}
patten_ok
=
patten_ok
&&
match
(
sub_patten
[
node_idx
],
inp
[
node_idx
]
->
owner_opr
());
if
(
!
patten_ok
)
{
break
;
}
}
any_sub_patten_match
=
any_sub_patten_match
||
patten_ok
;
if
(
any_sub_patten_match
)
{
break
;
}
}
return
match_ok
;
return
current_match
&&
any_sub_patten_match
;
}
else
{
return
false
;
}
...
...
@@ -237,24 +259,26 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
return
false
;
}
};
SGM
::
Node
broadcast_or_immutable
{
nullptr
,
check_pad
};
SGM
::
Node
broadcast_or_immutable
{
nullptr
,
{},
check_pad
,
"broadcast_or_immutable"
};
SGM
::
Node
broadcast_concat
{
opr
::
Concat
::
typeinfo
(),
{
in_node
,
broadcast_or_immutable
},
{
{
in_node
,
broadcast_or_immutable
}
},
[](
OperatorNodeBase
*
opr
)
{
auto
concat_pad
=
opr
->
try_cast_final
<
opr
::
Concat
>
();
return
concat_pad
->
axis
()
==
1
;
}};
},
"broadcast_concat"
};
SGM
::
Node
nchwx_reshape
{
opr
::
Reshape
::
typeinfo
(),
{
broadcast_concat
,
SGM
::
Node
(
nullptr
)
},
{
{
broadcast_concat
,
SGM
::
Node
(
nullptr
)}
},
[](
OperatorNodeBase
*
opr
)
{
auto
inp0
=
opr
->
input
()[
0
];
return
is_shape_nchw
(
inp0
->
shape
());
}};
SGM
::
Node
shuffle_root
{
opr
::
Dimshuffle
::
typeinfo
(),
{
nchwx_reshape
},
{
{
nchwx_reshape
}
},
[](
OperatorNodeBase
*
opr
)
{
auto
&
shuffle_opr
=
opr
->
cast_final
<
opr
::
Dimshuffle
>
();
auto
&
input_vec
=
shuffle_opr
.
input
();
...
...
@@ -263,13 +287,55 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
}};
return
shuffle_root
;
};
auto
gen_u8_cvt2_q8
=
[](
OperatorNodeBase
*&
src_node
,
OperatorNodeBase
*&
neg_128_immu_node
)
{
SGM
::
Node
input_data_u8
{
nullptr
,
[
&
](
OperatorNodeBase
*
opr
)
{
auto
src_dtype
=
opr
->
output
()[
0
]
->
dtype
();
if
(
src_dtype
.
enumv
()
==
DTypeEnum
::
Uint8
)
{
src_node
=
opr
;
return
true
;
}
else
{
return
false
;
}
}};
SGM
::
Node
cvt_fp32
{
opr
::
TypeCvt
::
typeinfo
(),
{{
input_data_u8
}},
[](
OperatorNodeBase
*
opr
)
{
auto
cvt_op
=
opr
->
try_cast_final
<
opr
::
TypeCvt
>
();
bool
is_fp32
=
cvt_op
->
param
().
enumv
()
==
DTypeEnum
::
Float32
;
return
is_fp32
;
}};
SGM
::
Node
sub_128
{
opr
::
Elemwise
::
typeinfo
(),
{{
cvt_fp32
,
nullptr
},
{
nullptr
,
cvt_fp32
}},
[
&
](
OperatorNodeBase
*
opr
)
{
auto
elem_op
=
opr
->
try_cast_final
<
opr
::
Elemwise
>
();
bool
is_add_op
=
elem_op
->
param
().
mode
==
opr
::
Elemwise
::
Param
::
Mode
::
ADD
;
auto
neg_128_op
=
elem_op
->
input
()[
1
]
->
owner_opr
();
bool
is_neg_128
=
is_immutable_equal
(
neg_128_op
,
-
128.
f
,
DTypeEnum
::
Float32
);
neg_128_op
=
elem_op
->
input
()[
0
]
->
owner_opr
();
is_neg_128
=
is_neg_128
||
is_immutable_equal
(
neg_128_op
,
-
128.
f
,
DTypeEnum
::
Float32
);
neg_128_immu_node
=
is_neg_128
?
neg_128_op
:
nullptr
;
return
is_add_op
&&
is_neg_128
;
},
"sub_128"
};
return
sub_128
;
};
auto
replace_shuffle_opr
=
[
&
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
,
SubGraph
::
Rewriter
&
rewriter
,
ReaderType
&
reader
)
{
SGM
matcher
;
OperatorNodeBase
*
src_node
=
nullptr
;
SGM
::
Node
input_data_cp
{
OperatorNodeBase
*
neg_128_immu_node
=
nullptr
;
auto
u8_q8_input
=
gen_u8_cvt2_q8
(
src_node
,
neg_128_immu_node
);
SGM
::
Node
input_data_qu8
{
nullptr
,
[
&
](
OperatorNodeBase
*
opr
)
{
auto
src_dtype
=
opr
->
output
()[
0
]
->
dtype
();
if
(
src_dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
...
...
@@ -279,7 +345,18 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
return
false
;
}
}};
SGM
::
Node
type_cvt
{
opr
::
TypeCvt
::
typeinfo
(),
{
input_data_cp
}};
SGM
::
Node
type_cvt
{
opr
::
TypeCvt
::
typeinfo
(),
{{
input_data_qu8
},
{
u8_q8_input
}},
[](
OperatorNodeBase
*
opr
)
{
auto
cvt_op
=
opr
->
try_cast_final
<
opr
::
TypeCvt
>
();
if
(
cvt_op
)
{
return
cvt_op
->
param
().
enumv
()
==
DTypeEnum
::
QuantizedS8
;
}
else
{
return
false
;
}
}};
SGM
::
Node
::
CallBack
const_pad_cbk
=
[
&
](
OperatorNodeBase
*
opr
)
{
bool
is_fp32_pad
=
is_immutable_all_equal
<
dtype
::
Float32
>
(
opr
,
0
);
bool
is_i32_pad
=
is_immutable_all_equal
<
dtype
::
Int32
>
(
opr
,
0
);
...
...
@@ -321,37 +398,7 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
OperatorNodeBase
*
neg_128_immu_node
=
nullptr
;
OperatorNodeBase
*
pad0_immu_node
=
nullptr
;
OperatorNodeBase
*
const_reshape_last_dim_node
=
nullptr
;
SGM
::
Node
input_data_cp
{
nullptr
,
[
&
](
OperatorNodeBase
*
opr
)
{
auto
src_dtype
=
opr
->
output
()[
0
]
->
dtype
();
if
(
src_dtype
.
enumv
()
==
DTypeEnum
::
Uint8
)
{
src_node
=
opr
;
return
true
;
}
else
{
return
false
;
}
}};
SGM
::
Node
cvt_fp32
{
opr
::
TypeCvt
::
typeinfo
(),
{
input_data_cp
},
[](
OperatorNodeBase
*
opr
)
{
auto
cvt_op
=
opr
->
try_cast_final
<
opr
::
TypeCvt
>
();
bool
is_fp32
=
cvt_op
->
param
().
enumv
()
==
DTypeEnum
::
Float32
;
return
is_fp32
;
}};
SGM
::
Node
sub_128
{
opr
::
Elemwise
::
typeinfo
(),
{
cvt_fp32
},
[
&
](
OperatorNodeBase
*
opr
)
{
auto
elem_op
=
opr
->
try_cast_final
<
opr
::
Elemwise
>
();
bool
is_add_op
=
elem_op
->
param
().
mode
==
opr
::
Elemwise
::
Param
::
Mode
::
ADD
;
auto
neg_128_op
=
elem_op
->
input
()[
1
]
->
owner_opr
();
bool
is_neg_128
=
is_immutable_equal
(
neg_128_op
,
-
128.
f
,
DTypeEnum
::
Float32
);
neg_128_immu_node
=
is_neg_128
?
neg_128_op
:
nullptr
;
return
is_add_op
&&
is_neg_128
;
}};
auto
sub_128
=
gen_u8_cvt2_q8
(
src_node
,
neg_128_immu_node
);
SGM
::
Node
::
CallBack
const_pad_cbk
=
[
&
](
OperatorNodeBase
*
opr
)
{
pad0_immu_node
=
opr
;
bool
is_fp32_pad
=
is_immutable_all_equal
<
dtype
::
Float32
>
(
opr
,
0
);
...
...
@@ -364,8 +411,16 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
};
auto
&&
shuffle_root
=
gen_pad_dimshuffle_graph
(
sub_128
,
const_pad_cbk
,
const_reshape_cbk
);
SGM
::
Node
astype_root
{
opr
::
TypeCvt
::
typeinfo
(),
{
shuffle_root
}};
SGM
::
Node
::
CallBack
cvt_q8_cbk
=
[](
OperatorNodeBase
*
opr
)
{
auto
cvt_op
=
opr
->
try_cast_final
<
opr
::
TypeCvt
>
();
if
(
cvt_op
)
{
return
cvt_op
->
param
().
enumv
()
==
DTypeEnum
::
QuantizedS8
;
}
else
{
return
false
;
}
};
SGM
::
Node
astype_root
{
opr
::
TypeCvt
::
typeinfo
(),
{{
shuffle_root
}},
cvt_q8_cbk
};
bool
match
=
matcher
.
match
(
astype_root
,
opr
);
bool
check_ok
=
false
;
if
(
match
)
{
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
0fb9cc41
此差异已折叠。
点击以展开。
src/gopt/test/inference.cpp
浏览文件 @
0fb9cc41
...
...
@@ -3815,7 +3815,7 @@ TEST(TestGoptInference, PreProcessCase1) {
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-5
);
...
...
@@ -3882,6 +3882,68 @@ TEST(TestGoptInference, WarpAndPreProcessCase0) {
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-5
);
}
TEST
(
TestGoptInference
,
PreProcessCaseAutopadNCHW64
)
{
REQUIRE_GPU
(
1
);
HostTensorGenerator
<
dtype
::
Uint8
,
RandomDistribution
::
UNIFORM
>
gen
(
0
,
255
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
auto
&&
prop
=
CompNodeEnv
::
from_comp_node
(
cn
).
cuda_env
().
device_prop
;
auto
sm_ver
=
prop
.
major
*
10
+
prop
.
minor
;
if
(
sm_ver
<
75
)
{
printf
(
"This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)
\n
"
,
sm_ver
,
75
);
return
;
}
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
),
dtype
);
};
size_t
n
=
2
;
size_t
c
=
3
;
size_t
h
=
32
;
size_t
w
=
32
;
auto
host_x1
=
gen
({
n
,
c
,
h
,
w
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x1
);
auto
x_u8_fp32
=
opr
::
TypeCvt
::
make
(
x
,
dtype
::
Float32
(),
cn
);
auto
x_s8_fp32
=
x_u8_fp32
-
128
;
auto
x_s8
=
opr
::
TypeCvt
::
make
(
x_s8_fp32
,
dtype
::
QuantizedS8
(
2.5
f
),
cn
);
auto
weight
=
mkcvar
(
"weight"
,
{
16
,
3
,
3
,
3
},
dtype
::
QuantizedS8
(
2.5
f
)),
bias
=
mkcvar
(
"bias"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
opr
::
ConvBias
::
Param
param
;
param
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW
;
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
result
=
opr
::
ConvBias
::
make
(
x_s8
,
weight
,
bias
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
(
2.5
f
)});
auto
y
=
result
;
SymbolVar
y_opt
;
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nchw64
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.PreProcessCaseAutopadNCHW64.json"
));
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-5
);
ASSERT_TRUE
(
find_opr
<
opr
::
RelayoutFormat
>
(
y_opt
).
param
().
mode
==
opr
::
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4
);
}
TEST
(
TestGoptInference
,
WarpAndPreProcessCase1
)
{
REQUIRE_GPU
(
1
);
HostTensorGenerator
<
dtype
::
Uint8
,
RandomDistribution
::
UNIFORM
>
gen
(
0
,
255
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录