Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
777f3ea9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
777f3ea9
编写于
8月 20, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(gopt): format code
GitOrigin-RevId: 9d5c87000fdfa291d91306365f7401c8af443dc1
上级
b44e0549
变更
2
展开全部
显示空白变更内容
内联
并排
Showing
2 changed file
with
353 addition
and
346 deletion
+353
-346
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+151
-144
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+202
-202
未找到文件。
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
777f3ea9
...
...
@@ -10,23 +10,23 @@
* implied.
*/
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/utils/shared_set.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"
...
...
@@ -68,8 +68,8 @@ using namespace gopt;
MGB_DEFINE_OPR_CLASS
(
TensorReformatPass
::
RelayoutPlaceholder
,
cg
::
SingleCNOperatorNodeBase
)
// {
public
:
//! relayout type of this opr
enum
class
LayoutType
{
//! relayout type of this opr
enum
class
LayoutType
{
NCHW4_TO_NCHW32
,
//!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4
,
//!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4
,
//!< from nchw4 layout to chwn4 layout
...
...
@@ -112,25 +112,28 @@ public:
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
,
//!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
};
};
RelayoutPlaceholder
(
VarNode
*
src_var
,
LayoutType
layout_type
);
RelayoutPlaceholder
(
VarNode
*
src_var
,
LayoutType
layout_type
);
/*!
/*!
* \param src_var the input var
* \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType
*/
static
SymbolVar
make
(
VarNode
*
src_var
,
LayoutType
layout_type
);
static
SymbolVar
make
(
VarNode
*
src_var
,
LayoutType
layout_type
);
LayoutType
layout_type
()
const
{
return
m_layout_type
;
}
LayoutType
layout_type
()
const
{
return
m_layout_type
;
}
private
:
void
init_output_static_infer_desc
()
override
;
void
scn_do_execute
()
override
;
void
init_output_comp_node
()
override
;
const
LayoutType
m_layout_type
;
}
;
void
init_output_static_infer_desc
()
override
;
void
scn_do_execute
()
override
;
void
init_output_comp_node
()
override
;
const
LayoutType
m_layout_type
;
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
TensorReformatPass
::
RelayoutPlaceholder
);
...
...
@@ -211,7 +214,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
){
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
...
...
@@ -249,7 +252,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
inp_shape
[
4
];
dst
[
5
]
=
4
;
}
else
if
(
layout_type
()
==
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NCHW88
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
1
]
%
8
==
0
);
dst
.
ndim
=
5
;
...
...
@@ -1033,7 +1036,6 @@ EnableTensorCorePass::make_tensorcore_converter() {
"can not be changed in this opt "
"pass"
);
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
};
auto
replace_warp_affine_opr
=
[
replace_inps_to_nchw4
,
replace_non_nchw4_opr
](
...
...
@@ -1247,7 +1249,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
}
if
(
nr_shape_changed
)
{
auto
inps
=
new_inp
;
if
(
nr_shape_changed
>=
nr_inps
/
2
)
{
// CHWN4 > NCHW4 -> use CHWN4
if
(
nr_shape_changed
>=
nr_inps
/
2
)
{
// CHWN4 > NCHW4 -> use CHWN4
for
(
size_t
i
=
0
;
i
<
nr_inps
;
++
i
)
{
if
(
varshape_changed
.
count
(
new_inp
[
i
])
==
0
)
{
auto
symvar
=
RelayoutPlaceholder
::
make
(
...
...
@@ -1309,7 +1312,6 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
"can not be changed in this opt "
"pass"
);
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
};
// capture by copy to avoid use after return
auto
replace_warp_affine_opr
=
...
...
@@ -1410,7 +1412,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
return
new_var
;
}
std
::
unique_ptr
<
EnableNCHW4Pass
>
EnableNCHW4Pass
::
make_nchw4_converter
(){
std
::
unique_ptr
<
EnableNCHW4Pass
>
EnableNCHW4Pass
::
make_nchw4_converter
()
{
MIDOUT_B
(
"EnableNCHW4Pass::make"
)
auto
ret
=
std
::
make_unique
<
EnableNCHW4Pass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
...
...
@@ -1469,14 +1471,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
());
}
auto
conv_mode
=
trans_nchw4
(
conv_opr
.
param
().
sparse
,
new_inp
[
1
]);
auto
conv_mode
=
trans_nchw4
(
conv_opr
.
param
().
sparse
,
new_inp
[
1
]);
VarNode
*
conv_src
=
new_inp
[
0
],
*
conv_filter
=
new_inp
[
1
];
// src: NCHW --> NCWH4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
conv_mode
.
src
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
conv_mode
.
src
);
conv_src
=
new_src
.
node
();
}
// weight: NCHW --> NCHW4
...
...
@@ -1488,8 +1488,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
new_param
.
format
=
conv_format
;
// dst
auto
new_conv_opr
=
opr
::
Convolution
::
make
(
conv_src
,
conv_filter
,
new_param
,
conv_opr
.
execution_policy
(),
conv_opr
.
config
());
conv_src
,
conv_filter
,
new_param
,
conv_opr
.
execution_policy
(),
conv_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchw4"
);
...
...
@@ -1515,10 +1515,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight
VarNode
*
src
=
new_inp
[
0
],
*
filter
=
new_inp
[
1
];
// src: NCHW --> NCHW4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
src_to_nchw4_mode
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
src_to_nchw4_mode
);
src
=
new_src
.
node
();
}
// weight: BNCHW --> BNCHW4
...
...
@@ -1531,7 +1531,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto
new_param
=
batch_conv_bias_opr
.
param
();
new_param
.
format
=
batch_conv_bias_format
;
if
(
new_inp
.
size
()
==
2
)
{
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
new_param
,
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
batch_conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
dst
.
node
()
->
owner_opr
();
...
...
@@ -1542,12 +1543,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// bias: NCHW --> NCHW4
VarNode
*
bias
=
new_inp
[
2
];
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchw4_mode
);
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchw4_mode
);
bias
=
new_bias
.
node
();
}
if
(
new_inp
.
size
()
==
3
)
{
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
bias
,
new_param
,
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
bias
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
batch_conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
dst
.
node
()
->
owner_opr
();
...
...
@@ -1558,12 +1560,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// z_inp: NCHW --> NCHW4
VarNode
*
z_inp
=
new_inp
[
3
];
if
(
new_inp
[
3
]
->
shape
().
ndim
==
4
)
{
auto
new_z
=
RelayoutPlaceholder
::
make
(
new_inp
[
3
],
src_to_nchw4_mode
);
auto
new_z
=
RelayoutPlaceholder
::
make
(
new_inp
[
3
],
src_to_nchw4_mode
);
z_inp
=
new_z
.
node
();
}
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
bias
,
z_inp
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
auto
dst
=
opr
::
BatchConvBias
::
make
(
src
,
filter
,
bias
,
z_inp
,
new_param
,
batch_conv_bias_opr
.
execution_policy
(),
batch_conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
dst
.
node
()
->
owner_opr
();
mgb_assert
(
dst
.
shape
().
ndim
==
5
,
...
...
@@ -1584,13 +1587,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
];
auto
conv_mode
=
trans_nchw4
(
conv_bias_opr
.
param
().
sparse
,
new_inp
[
1
]);
auto
conv_mode
=
trans_nchw4
(
conv_bias_opr
.
param
().
sparse
,
new_inp
[
1
]);
// src: NCHW --> NCHW4
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
conv_mode
.
src
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
conv_mode
.
src
);
conv_bias_src
=
new_src
.
node
();
}
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
...
...
@@ -1632,9 +1633,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
RelayoutPlaceholder
::
make
(
new_inp
[
3
],
src_to_nchw4_mode
);
z_inp
=
new_z
.
node
();
}
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
z_inp
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
z_inp
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchw4"
);
...
...
@@ -1654,8 +1656,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto
temp_inp
=
new_inp
;
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
new_inp
[
i
]
->
shape
().
ndim
==
4
)
{
auto
new_var
=
RelayoutPlaceholder
::
make
(
new_inp
[
i
],
src_to_nchw4_mode
);
auto
new_var
=
RelayoutPlaceholder
::
make
(
new_inp
[
i
],
src_to_nchw4_mode
);
temp_inp
[
i
]
=
new_var
.
node
();
}
else
{
mgb_assert
((
new_inp
[
i
]
->
shape
().
ndim
==
5
)
||
...
...
@@ -1697,8 +1699,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
mgb_assert
(
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
);
auto
new_param
=
pooling
.
param
();
new_param
.
format
=
Format
::
NCHW4
;
auto
new_pooling
=
opr
::
PoolingForward
::
make
(
new_inp
[
0
],
new_param
,
opr
->
config
());
auto
new_pooling
=
opr
::
PoolingForward
::
make
(
new_inp
[
0
],
new_param
,
opr
->
config
());
mgb_assert
(
new_pooling
.
shape
().
ndim
==
5
,
"out var of Pooling opr after transform must be 5 (got: "
"%zu)."
,
...
...
@@ -1767,8 +1769,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
//! supportted nchw4
replace_func
[
opr
::
Convolution
::
typeinfo
()]
=
replace_conv_opr
;
replace_func
[
opr
::
ConvBias
::
typeinfo
()]
=
replace_conv_bias_opr
;
replace_func
[
opr
::
BatchConvBias
::
typeinfo
()]
=
replace_batch_conv_bias_opr
;
replace_func
[
opr
::
BatchConvBias
::
typeinfo
()]
=
replace_batch_conv_bias_opr
;
replace_func
[
opr
::
PoolingForward
::
typeinfo
()]
=
replace_pooling_opr
;
replace_func
[
opr
::
ResizeForward
::
typeinfo
()]
=
replace_resize_opr
;
replace_func
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
...
...
@@ -1811,7 +1812,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return
new_var
;
}
void
EnableNchwxxPass
::
fill_opr_convert_fun
(
size_t
pack_c_size
){
void
EnableNchwxxPass
::
fill_opr_convert_fun
(
size_t
pack_c_size
)
{
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
using
TestFilterResult
=
std
::
pair
<
TransType
,
RelayoutMode
>
;
RelayoutMode
weight_to_nchwxx_mode_dense
=
...
...
@@ -2205,7 +2206,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
size_t
OC
=
filter
->
shape
()[
0
];
if
((
IC
%
pack_c_size
==
0
)
&&
(
OC
%
pack_c_size
==
0
))
{
ret
.
trans_type
=
TransType
::
TRANS_PURE_NCHWXX
;
ret
.
relayout_mod
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE
;
ret
.
relayout_mod
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE
;
ret
.
conv_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
}
else
if
(
IC
<
pack_c_size
&&
OC
%
pack_c_size
==
0
)
{
ret
.
trans_type
=
TransType
::
TRANS_HYBIRD_NCHWXX
;
...
...
@@ -2223,7 +2225,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
ret
.
conv_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW44
;
}
else
if
((
icpg
%
pack_c_size
==
0
)
&&
(
ocpg
%
pack_c_size
==
0
))
{
ret
.
trans_type
=
TransType
::
TRANS_PURE_NCHWXX
;
ret
.
relayout_mod
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
;
ret
.
relayout_mod
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
;
ret
.
conv_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
}
}
...
...
@@ -2538,7 +2541,6 @@ public:
param
.
mode
=
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW4_CHWN4
;
auto
reformat
=
opr
::
RelayoutFormat
::
make
(
inp
,
param
);
return
reformat
.
node
();
};
m_reformat
[
std
::
make_pair
(
TensorFormat
::
CHWN4
,
TensorFormat
::
NCHW4
)]
=
...
...
@@ -2563,7 +2565,6 @@ public:
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
1
,
3
,
4
,
0
,
2
});
return
y1
.
node
();
};
m_reformat
[
std
::
make_pair
(
TensorFormat
::
CHWN4
,
TensorFormat
::
NCHW
)]
=
...
...
@@ -2593,22 +2594,27 @@ public:
MGB_DEFINE_OPR_CLASS
(
ShuffleShuffleRemovePass
::
Impl
::
AbstractShuffleOpr
,
cg
::
SingleCNOperatorNodeBase
)
// {
public
:
AbstractShuffleOpr
(
VarNode
*
inpvar
,
TensorFormat
inp_format
,
AbstractShuffleOpr
(
VarNode
*
inpvar
,
TensorFormat
inp_format
,
TensorFormat
out_format
);
static
SymbolVar
make
(
VarNode
*
inpvar
,
TensorFormat
inp_format
,
static
SymbolVar
make
(
VarNode
*
inpvar
,
TensorFormat
inp_format
,
TensorFormat
out_format
);
TensorFormat
inp_format
()
const
{
return
m_inp_format
;
}
TensorFormat
inp_format
()
const
{
return
m_inp_format
;
}
TensorFormat
out_format
()
const
{
return
m_out_format
;
}
TensorFormat
out_format
()
const
{
return
m_out_format
;
}
private
:
void
init_output_static_infer_desc
()
override
;
void
scn_do_execute
()
override
;
const
TensorFormat
m_inp_format
;
const
TensorFormat
m_out_format
;
}
;
void
init_output_static_infer_desc
()
override
;
void
scn_do_execute
()
override
;
const
TensorFormat
m_inp_format
;
const
TensorFormat
m_out_format
;
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ShuffleShuffleRemovePass
::
Impl
::
AbstractShuffleOpr
);
...
...
@@ -2914,7 +2920,8 @@ void ShuffleShuffleRemovePass::Impl::do_replace() {
bool
force_folding_typecvt
=
false
;
bool
first_shuffle
=
false
;
// initialize inp_format and out_format
TensorFormat
out_format
=
TensorFormat
::
NCHW
,
inp_format
=
out_format
;
TensorFormat
out_format
=
TensorFormat
::
NCHW
,
inp_format
=
out_format
;
megdnn
::
DType
inp_dtype
=
cur
->
input
(
0
)
->
dtype
(),
out_dtype
=
cur
->
output
(
0
)
->
dtype
();
SmallVector
<
megdnn
::
DType
>
out_dtype_vec
;
...
...
src/gopt/test/inference.cpp
浏览文件 @
777f3ea9
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录