Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
554ce352
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
554ce352
编写于
4月 17, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add nchw44 optpass
GitOrigin-RevId: dc38724558b0c6635ea9a3137e1c0d0acc665a0f
上级
7d1e1f9a
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
364 addition
and
20 deletion
+364
-20
python_module/megengine/_internal/__init__.py
python_module/megengine/_internal/__init__.py
+6
-2
python_module/src/swig/misc.i
python_module/src/swig/misc.i
+1
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+7
-0
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+3
-0
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+248
-15
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+7
-2
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+81
-0
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+11
-1
未找到文件。
python_module/megengine/_internal/__init__.py
浏览文件 @
554ce352
...
...
@@ -541,7 +541,8 @@ def optimize_for_inference(
fuse_conv_bias_nonlinearity
=
False
,
use_tensor_core
=
False
,
fuse_conv_bias_with_z
=
False
,
use_nchw88
=
False
use_nchw88
=
False
,
use_nchw44
=
False
):
"""optimize computing graph for inference
...
...
@@ -559,7 +560,9 @@ def optimize_for_inference(
OpenCL devices
:param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. This is supported only in NHWCD4 format.
:param use_nchw88: whether to use NCHW4 tensor format. This maybe faster some
:param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some
times.
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
times.
...
...
@@ -577,6 +580,7 @@ def optimize_for_inference(
"use_tensor_core"
,
"fuse_conv_bias_with_z"
,
"use_nchw88"
,
"use_nchw44"
,
]:
if
settings
[
i
]:
getattr
(
opt
,
"enable_{}"
.
format
(
i
))()
...
...
python_module/src/swig/misc.i
浏览文件 @
554ce352
...
...
@@ -79,6 +79,7 @@ struct _OptimizeForInferenceOptions {
SET
(
use_tensor_core
)
;
SET
(
fuse_conv_bias_with_z
)
;
SET
(
use_nchw88
)
;
SET
(
use_nchw44
)
;
#
undef
SET
}
;
...
...
sdk/load-and-run/dump_with_testcase_mge.py
浏览文件 @
554ce352
...
...
@@ -253,6 +253,7 @@ def optimize_for_inference(args, outputs):
'enable_ioc16'
:
'f16_io_comp'
,
'enable_hwcd4'
:
'use_nhwcd4'
,
'enable_nchw88'
:
'use_nchw88'
,
'enable_nchw44'
:
'use_nchw44'
,
'enable_fuse_conv_bias_nonlinearity'
:
'fuse_conv_bias_nonlinearity'
,
'enable_tensorcore'
:
'use_tensor_core'
,
'enable_fuse_conv_bias_with_z'
:
'fuse_conv_bias_with_z'
,
...
...
@@ -385,6 +386,12 @@ def main():
help
=
'transform the model format from NCHW to NCHW88 '
'for inference'
)
parser
.
add_argument
(
'--enable-nchw44'
,
action
=
'store_true'
,
help
=
'transform the model format from NCHW to NCHW44 '
'for inference'
)
parser
.
add_argument
(
'--enable-tensorcore'
,
action
=
'store_true'
,
...
...
src/gopt/impl/framework.cpp
浏览文件 @
554ce352
...
...
@@ -700,6 +700,9 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
if
(
inference_opt
->
use_nchw88
)
{
add_pass
(
EnableNchwxxPass
::
make_nchwxx_converter
(
8
));
}
if
(
inference_opt
->
use_nchw44
)
{
add_pass
(
EnableNchwxxPass
::
make_nchwxx_converter
(
4
));
}
if
(
inference_opt
->
use_tensor_core
)
{
mgb_assert
(
inference_opt
->
fuse_conv_bias_nonlinearity
,
"enable tensor core should fuse conv bias activation "
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
554ce352
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/inference.h"
...
...
@@ -63,7 +64,10 @@ public:
NCHW4_TO_CHWN4
,
//!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4
,
//!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW88
,
//!< from nchw layout to nchw88 layout
NCHW_TO_NCHW44
,
//!< from nchw layout to nchw44 layout
NCHW88_TO_NCHW
,
//!< from nchw88 layout to nchw layout
NCHW44_TO_NCHW
,
//!< from nchw44 layout to nchw layout
WEIGHT_NCHW_TO_NCHW88_DENSE
,
//!< weight from nchw layout to nchw88
//!< layout
WEIGHT_NCHW_TO_NCHW88_GROUP
,
//!< group weight from nchw layout to
...
...
@@ -73,6 +77,16 @@ public:
//!< the weight layout of input is nchw output is nchw88, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
WEIGHT_HYBIRD_NCHW_NCHW88
,
WEIGHT_NCHW_TO_NCHW44_DENSE
,
//!< weight from nchw layout to nchw44
//!< layout
WEIGHT_NCHW_TO_NCHW44_GROUP
,
//!< group weight from nchw layout to
//!< nchw44 layout
WEIGHT_NCHW_TO_NCHW44_CHAN
,
//!< channel wise weight from nchw layout
//!< to nchw44 layout
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44
,
};
RelayoutPlaceholder
(
VarNode
*
src_var
,
LayoutType
layout_type
);
...
...
@@ -203,10 +217,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
inp_shape
[
4
];
dst
[
5
]
=
8
;
}
else
{
mgb_assert
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_HYBIRD_NCHW_NCHW88
);
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_HYBIRD_NCHW_NCHW88
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
0
]
%
8
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
]
/
8
;
...
...
@@ -214,6 +226,68 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
];
dst
[
4
]
=
8
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NCHW44
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
1
]
%
4
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
/
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW44_TO_NCHW
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
*
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DENSE
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
0
]
%
4
==
0
&&
inp_shape
[
1
]
%
4
==
0
);
dst
.
ndim
=
6
;
dst
[
0
]
=
inp_shape
[
0
]
/
4
;
dst
[
1
]
=
inp_shape
[
1
]
/
4
;
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
4
;
dst
[
5
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_GROUP
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
1
]
%
4
==
0
&&
inp_shape
[
2
]
%
4
==
0
);
dst
.
ndim
=
7
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
1
]
/
4
;
dst
[
2
]
=
inp_shape
[
2
]
/
4
;
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
inp_shape
[
4
];
dst
[
5
]
=
4
;
dst
[
6
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_CHAN
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
1
]
==
1
&&
inp_shape
[
2
]
==
1
&&
inp_shape
[
0
]
%
4
==
0
);
dst
.
ndim
=
6
;
dst
[
0
]
=
inp_shape
[
0
]
/
4
;
dst
[
1
]
=
inp_shape
[
1
];
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
inp_shape
[
4
];
dst
[
5
]
=
4
;
}
else
{
mgb_assert
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_HYBIRD_NCHW_NCHW44
);
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
0
]
%
4
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
]
/
4
;
dst
[
1
]
=
inp_shape
[
2
];
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
];
dst
[
4
]
=
4
;
}
return
true
;
};
...
...
@@ -418,6 +492,104 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW_TO_NCHW44
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
4
,
2
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW44_TO_NCHW
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
*
4
,
sub
(
2
),
sub
(
3
)},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
1
,
4
,
2
,
3
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DENSE
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
cv
(
4
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
2
,
4
,
5
,
3
,
1
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_GROUP
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
)
/
4
,
cv
(
4
),
sub
(
3
),
sub
(
4
)},
0
),
tshp1
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
)
/
4
,
sub
(
3
),
sub
(
4
),
cv
(
4
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
5
,
6
,
4
,
2
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_CHAN
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
cv
(
4
),
sub
(
1
),
sub
(
2
),
sub
(
3
),
sub
(
4
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
sub
(
1
),
sub
(
2
),
sub
(
3
),
sub
(
4
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
2
,
3
,
4
,
5
,
1
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_HYBIRD_NCHW_NCHW44
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
cv
(
4
),
sub
(
1
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
sub
(
2
),
sub
(
3
),
sub
(
1
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
3
,
4
,
2
,
1
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
reformat
,
&
rewriter
](
OperatorNodeBase
*
opr
)
{
...
...
@@ -1071,16 +1243,24 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
VarNode
*
EnableNchwxxPass
::
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
{
if
(
!
orig_var
->
shape
().
eq_shape
(
new_var
->
shape
()))
{
if
(
m_pack_c_size
==
8
)
{
return
RelayoutPlaceholder
::
make
(
new_var
,
RelayoutPlaceholder
::
LayoutType
::
NCHW88_TO_NCHW
)
.
node
();
}
else
if
(
m_pack_c_size
==
4
)
{
return
RelayoutPlaceholder
::
make
(
new_var
,
RelayoutPlaceholder
::
LayoutType
::
NCHW88_TO_NCHW
)
new_var
,
RelayoutPlaceholder
::
LayoutType
::
NCHW44_TO_NCHW
)
.
node
();
}
}
return
new_var
;
}
std
::
unique_ptr
<
EnableNchwxxPass
>
EnableNchwxxPass
::
make_nchwxx_converter
(
size_t
pack_c_size
)
{
auto
ret
=
std
::
make_unique
<
EnableNchwxxPass
>
();
auto
ret
=
std
::
make_unique
<
EnableNchwxxPass
>
(
pack_c_size
);
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
...
...
@@ -1102,8 +1282,18 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
megdnn
::
param
::
Pooling
::
Format
pooling_format
=
megdnn
::
param
::
Pooling
::
Format
::
NCHW88
;
std
::
string
convter_pass_name
=
"conv_format_nchw88"
;
mgb_assert
(
pack_c_size
==
static_cast
<
size_t
>
(
8
),
"The ConvertFormatPass to nchwxx only support NCHW88 now !"
);
if
(
pack_c_size
==
4
)
{
weight_to_nchwxx_mode_dense
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DENSE
;
weight_to_nchwxx_mode_group
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_GROUP
;
weight_to_nchwxx_mode_chan
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_CHAN
;
hybrid_nchw_nchwxx
=
RelayoutMode
::
WEIGHT_HYBIRD_NCHW_NCHW44
;
src_to_nchwxx_mode
=
RelayoutMode
::
NCHW_TO_NCHW44
;
src_to_nchw_mode
=
RelayoutMode
::
NCHW44_TO_NCHW
;
conv_bias_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW44
;
conv_format
=
megdnn
::
param
::
ConvolutionV0
::
Format
::
NCHW44
;
pooling_format
=
megdnn
::
param
::
Pooling
::
Format
::
NCHW44
;
convter_pass_name
=
"conv_format_nchw44"
;
}
auto
test_trans_nchwxx
=
[
pack_c_size
,
weight_to_nchwxx_mode_dense
,
weight_to_nchwxx_mode_group
,
weight_to_nchwxx_mode_chan
,
...
...
@@ -1297,7 +1487,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
conv_bias_format
;
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
new_param
,
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
...
...
@@ -1330,6 +1520,51 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
}
};
auto
replace_concat_opr
=
[
=
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
bool
has_inp_changed
=
false
;
bool
can_exec_ncwxx
=
true
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
new_inp
[
i
]
->
shape
().
ndim
==
5
)
{
has_inp_changed
=
true
;
break
;
}
else
if
(
new_inp
[
i
]
->
shape
().
ndim
==
4
)
{
if
(
new_inp
[
i
]
->
shape
()[
1
]
%
pack_c_size
!=
0
)
{
can_exec_ncwxx
=
false
;
}
}
}
if
(
has_inp_changed
)
{
auto
temp_inp
=
new_inp
;
if
(
can_exec_ncwxx
)
{
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
new_inp
[
i
]
->
shape
().
ndim
==
4
)
{
auto
new_var
=
RelayoutPlaceholder
::
make
(
new_inp
[
i
],
src_to_nchwxx_mode
);
temp_inp
[
i
]
=
new_var
.
node
();
}
else
{
mgb_assert
((
new_inp
[
i
]
->
shape
().
ndim
==
5
)
||
new_inp
[
i
]
->
shape
().
is_scalar
());
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
new_inp
[
i
]
->
shape
().
ndim
==
5
)
{
auto
new_var
=
RelayoutPlaceholder
::
make
(
new_inp
[
i
],
src_to_nchw_mode
);
temp_inp
[
i
]
=
new_var
.
node
();
}
}
}
return
serialization
::
copy_opr_shallow
(
*
opr
,
temp_inp
,
opr
->
config
());
}
else
{
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
};
auto
replace_elemwise_opr
=
[
=
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
...
...
@@ -1382,6 +1617,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
replace_func
[
opr
::
Convolution
::
typeinfo
()]
=
replace_conv_opr
;
replace_func
[
opr
::
ConvBias
::
typeinfo
()]
=
replace_conv_bias_opr
;
replace_func
[
opr
::
PoolingForward
::
typeinfo
()]
=
replace_pooling_opr
;
replace_func
[
opr
::
Concat
::
typeinfo
()]
=
replace_concat_opr
;
replace_func
[
opr
::
Elemwise
::
typeinfo
()]
=
replace_elemwise_opr
;
replace_func
[
opr
::
TypeCvt
::
typeinfo
()]
=
replace_elemwise_opr
;
replace_func
[
opr
::
ElemwiseMultiType
::
typeinfo
()]
=
replace_elemwise_opr
;
...
...
@@ -1390,13 +1626,10 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
replace_func
[
opr
::
ConvolutionBackwardData
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Subtensor
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Concat
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Reshape
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
GetVarShape
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Dimshuffle
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Reduce
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
AssertEqual
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Broadcast
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
IncrSubtensor
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
ResizeForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
...
...
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
554ce352
...
...
@@ -234,16 +234,18 @@ namespace gopt {
*/
class
EnableNchwxxPass
final
:
public
TensorReformatPass
{
std
::
string
m_name
=
"tensor_format_nchwxx"
;
size_t
m_pack_c_size
;
VarNode
*
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
override
;
//! the flag for conv to transform to nchwxx
enum
class
TransType
{
TRANS_PURE_NCHWXX
,
//!< weight and src all trans to nchw
88
TRANS_HYBIRD_NCHWXX
,
//!< input is nchw, output is nchw
88
TRANS_PURE_NCHWXX
,
//!< weight and src all trans to nchw
xx
TRANS_HYBIRD_NCHWXX
,
//!< input is nchw, output is nchw
xx
TRANS_NONE
,
//!< no need trans
};
public:
EnableNchwxxPass
(
size_t
pack_c_size
)
:
m_pack_c_size
(
pack_c_size
)
{}
const
char
*
name
()
const
override
{
return
mgb_cstr_log
(
m_name
.
c_str
());
}
...
...
@@ -265,6 +267,8 @@ namespace gopt {
bool
use_nhwcd4
=
false
;
//! whether to compute using NCHW88 tensor format
bool
use_nchw88
=
false
;
//! whether to compute using NCHW44 tensor format
bool
use_nchw44
=
false
;
//! whether to enable tensor core
bool
use_tensor_core
=
false
;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
...
...
@@ -283,6 +287,7 @@ namespace gopt {
SET
(
use_tensor_core
);
SET
(
fuse_conv_bias_with_z
);
SET
(
use_nchw88
);
SET
(
use_nchw44
);
#undef SET
};
...
...
src/gopt/test/inference.cpp
浏览文件 @
554ce352
...
...
@@ -2325,5 +2325,86 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-1
);
}
TEST
(
TestGoptInference
,
ConvertFormatNCHW44
)
{
HostTensorGenerator
<>
gen
;
auto
cn
=
CompNode
::
load
(
"cpu0"
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
,
cn
)).
rename
(
name
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
);
};
auto
host_x
=
gen
({
2
,
3
,
16
,
16
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
);
//!Hybrid nchw88 mode
opr
::
Convolution
::
Param
param_conv
;
param_conv
.
pad_h
=
param_conv
.
pad_w
=
1
;
auto
w1
=
mkcvar
(
"w1"
,
{
8
,
3
,
3
,
3
}),
conv1
=
opr
::
Convolution
::
make
(
x
,
w1
,
param_conv
);
//!channel wise
opr
::
ConvBias
::
Param
param_conv_bias
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
w2
=
mkcvar
(
"w2"
,
{
8
,
1
,
1
,
3
,
3
}),
b2
=
mkcvar
(
"b2"
,
{
1
,
8
,
1
,
1
}),
conv2
=
opr
::
ConvBias
::
make
(
conv1
,
w2
,
b2
,
param_conv_bias
);
//! group
auto
w3
=
mkcvar
(
"w3"
,
{
2
,
4
,
4
,
3
,
3
}),
b3
=
mkcvar
(
"b3"
,
{
1
,
8
,
1
,
1
}),
conv3
=
opr
::
ConvBias
::
make
(
conv2
,
w3
,
b3
,
param_conv_bias
);
auto
shape_of
=
opr
::
GetVarShape
::
make
(
conv3
);
auto
subtensor
=
opr
::
Subtensor
::
make
(
shape_of
,
{
opr
::
Subtensor
::
AxisIndexer
::
make_interval
(
0
,
x
.
make_scalar
(
2
),
None
,
x
.
make_scalar
(
1
))});
opr
::
Resize
::
Param
param_resize
;
param_resize
.
format
=
opr
::
Resize
::
Param
::
Format
::
NCHW
;
auto
resize
=
opr
::
ResizeForward
::
make
(
conv3
,
subtensor
*
2
,
param_resize
);
auto
mat
=
mkcvar
(
"mat"
,
{
2
,
3
,
3
}),
warp
=
opr
::
WarpPerspectiveForward
::
make
(
resize
,
mat
,
nullptr
,
cg
::
var_from_tensor_shape
(
x
,
{
4
,
4
}));
auto
b
=
mkvar
(
"b"
,
{
1
,
8
,
1
,
1
}),
elem
=
opr
::
Elemwise
::
make
({
warp
+
b
},
opr
::
Elemwise
::
Param
::
Mode
::
RELU
);
//! Dense
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
DENSE
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
auto
w4
=
mkcvar
(
"w4"
,
{
4
,
8
,
3
,
3
}),
b4
=
mkcvar
(
"b4"
,
{
1
,
4
,
1
,
1
}),
conv4
=
opr
::
ConvBias
::
make
(
elem
,
w4
,
b4
,
param_conv_bias
);
auto
w5
=
mkcvar
(
"w5"
,
{
6
,
4
,
3
,
3
}),
b5
=
mkcvar
(
"b5"
,
{
1
,
6
,
1
,
1
}),
conv5
=
opr
::
ConvBias
::
make
(
conv4
,
w5
,
b5
,
param_conv_bias
);
auto
w6
=
mkcvar
(
"w6"
,
{
4
,
6
,
3
,
3
}),
b6
=
mkcvar
(
"b6"
,
{
1
,
4
,
1
,
1
}),
y
=
opr
::
ConvBias
::
make
(
conv5
,
w6
,
b6
,
param_conv_bias
);
SymbolVar
y_opt
;
unpack_vector
(
gopt
::
optimize_for_inference
(
{
y
},
gopt
::
OptimizeForInferenceOptions
{}.
enable_use_nchw44
()),
y_opt
);
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
NCHW44
,
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.ConvertFormatNCHW44.json"
));
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-1
);
*
host_x
=
*
gen
({
2
,
3
,
32
,
32
},
cn
);
func
->
execute
();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-1
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/plugin/impl/opr_footprint.cpp
浏览文件 @
554ce352
...
...
@@ -99,7 +99,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
group
=
filter_shape
[
0
];
}
if
(
param
.
format
==
Param
::
Format
::
NCHW88
)
{
//! if channel wise weight layout is {group/8,
1, 1, FH, FW
, 8}
//! if channel wise weight layout is {group/8,
FH, FW, 1, 1
, 8}
if
(
filter_shape
[
1
]
==
1
&&
filter_shape
[
2
]
==
1
)
{
group
*=
8
;
}
...
...
@@ -107,6 +107,15 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape
[
1
]
/
group
*
2
;
return
hybird_nchwx
?
computation
:
computation
*
8
;
}
if
(
param
.
format
==
Param
::
Format
::
NCHW44
)
{
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if
(
filter_shape
[
1
]
==
1
&&
filter_shape
[
2
]
==
1
)
{
group
*=
4
;
}
size_t
computation
=
dst_shape
.
total_nr_elems
()
*
fh
*
fw
*
src_shape
[
1
]
/
group
*
2
;
return
hybird_nchwx
?
computation
:
computation
*
4
;
}
if
(
param
.
format
==
Param
::
Format
::
NCHW32
)
{
return
dst_shape
.
total_nr_elems
()
*
fh
*
fw
*
src_shape
[
1
]
*
32
/
group
*
2
;
...
...
@@ -135,6 +144,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if
(
param
.
format
==
Param
::
Format
::
NCHW4
||
param
.
format
==
Param
::
Format
::
NCHW88
||
param
.
format
==
Param
::
Format
::
NCHW44
||
param
.
format
==
Param
::
Format
::
NCHW32
)
{
return
eval_conv_computation_nchwx
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录