Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
af576e9a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
af576e9a
编写于
8月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/gopt): fix auto padding for nhwc layout
GitOrigin-RevId: 038e372cbecb6f14a408c87cf2d30eba020b4605
上级
af828ca9
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
203 addition
and
72 deletion
+203
-72
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
+32
-4
src/gopt/impl/layout_transform_pass.cpp
src/gopt/impl/layout_transform_pass.cpp
+18
-13
src/gopt/impl/profiler_impl.cpp
src/gopt/impl/profiler_impl.cpp
+50
-33
src/gopt/impl/reformat_manager.cpp
src/gopt/impl/reformat_manager.cpp
+74
-12
src/gopt/include/megbrain/gopt/global_layout_transform.h
src/gopt/include/megbrain/gopt/global_layout_transform.h
+1
-0
src/gopt/include/megbrain/gopt/reformat_manager.h
src/gopt/include/megbrain/gopt/reformat_manager.h
+28
-10
未找到文件。
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
浏览文件 @
af576e9a
...
...
@@ -200,6 +200,15 @@ static inline bool is_nchw_nchw4_shuffle_vec(
param
.
pattern
[
4
]
==
2
;
}
static
inline
bool
is_shape_before_nhwc
(
const
TensorShape
&
shape
)
{
return
shape
.
ndim
==
4
&&
shape
[
1
]
==
4
;
}
static
inline
bool
is_nchw_nhwc_shuffle
(
const
opr
::
Dimshuffle
::
Param
param
)
{
return
param
.
ndim
==
4
&&
param
.
pattern
[
0
]
==
0
&&
param
.
pattern
[
1
]
==
2
&&
param
.
pattern
[
2
]
==
3
&&
param
.
pattern
[
3
]
==
1
;
}
template
<
typename
T
>
static
inline
bool
is_immutable_equal
(
OperatorNodeBase
*
opr
,
T
val
,
DTypeEnum
dtype_enum
)
{
...
...
@@ -276,14 +285,20 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
auto
inp0
=
opr
->
input
()[
0
];
return
is_shape_nchw
(
inp0
->
shape
());
}};
SGM
::
Node
shuffle_root
{
opr
::
Dimshuffle
::
typeinfo
(),
{{
nchwx_reshape
}},
{{
nchwx_reshape
}
,
{
broadcast_concat
}
},
[](
OperatorNodeBase
*
opr
)
{
auto
&
shuffle_opr
=
opr
->
cast_final
<
opr
::
Dimshuffle
>
();
auto
&
input_vec
=
shuffle_opr
.
input
();
return
is_shape_before_nchw4
(
input_vec
[
0
]
->
shape
())
&&
is_nchw_nchw4_shuffle_vec
(
shuffle_opr
.
param
());
bool
nchw_nchw4_ok
=
is_shape_before_nchw4
(
input_vec
[
0
]
->
shape
())
&&
is_nchw_nchw4_shuffle_vec
(
shuffle_opr
.
param
());
bool
nchw_nhwc_ok
=
is_shape_before_nhwc
(
input_vec
[
0
]
->
shape
())
&&
is_nchw_nhwc_shuffle
(
shuffle_opr
.
param
());
return
nchw_nchw4_ok
||
nchw_nhwc_ok
;
}};
return
shuffle_root
;
};
...
...
@@ -382,6 +397,19 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
auto
out_node
=
opr
::
RelayoutFormat
::
make
(
rewriter
.
get_var
(
src_node
->
output
()[
0
]),
param
.
mode
,
config
);
const
auto
&
outshp
=
opr
->
output
(
0
)
->
shape
();
if
(
outshp
.
ndim
==
4
)
{
auto
shpvar
=
opr
::
GetVarShape
::
make
(
out_node
);
auto
cv
=
[
&
out_node
](
int
v
)
{
return
out_node
.
make_scalar
(
v
);
};
auto
sub
=
[
&
shpvar
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
shpvar
,
{{
0
,
cv
(
idx
)}});
};
auto
nhwc_shp
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
2
),
sub
(
3
),
sub
(
4
)},
0
);
out_node
=
opr
::
Reshape
::
make
(
out_node
,
nhwc_shp
);
}
return
out_node
.
node
()
->
owner_opr
();
}
else
{
return
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
...
...
@@ -740,4 +768,4 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const {
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
}
\ No newline at end of file
}
src/gopt/impl/layout_transform_pass.cpp
浏览文件 @
af576e9a
...
...
@@ -92,19 +92,24 @@ void LayoutTransformPass::apply(OptState& opt) const {
bool
is_parameter
=
fmtcfg
.
valid
()
&&
fmtcfg
.
val
().
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
;
ReformatManager
::
ReformatImpl
reformat
;
ReformatManager
::
ReformatKey
key
{
from
,
to
,
reformat_attribute
,
var
->
dtype
().
enumv
(),
var
->
dtype
().
enumv
()};
if
(
is_parameter
)
{
auto
aligned_desc
=
make_aligned_desc
(
base_fmt
,
out_fmt
);
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_weight
(
var
,
key
,
aligned_desc
);
}
else
{
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_featrue
(
var
,
base_fmt
,
key
);
// need relayout
if
(
from
!=
to
&&
!
new_var
->
shape
().
is_scalar
())
{
ReformatManager
::
ReformatImpl
reformat
;
ReformatManager
::
ReformatKey
key
{
from
,
to
,
reformat_attribute
,
var
->
dtype
().
enumv
(),
var
->
dtype
().
enumv
()};
if
(
is_parameter
)
{
auto
aligned_desc
=
ReformatManager
::
make_aligned_desc
(
base_fmt
,
out_fmt
);
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_weight
(
var
,
key
,
aligned_desc
);
}
else
{
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_featrue
(
var
,
base_fmt
,
key
);
}
new_var
=
reformat
({
new_var
});
}
if
(
from
!=
to
&&
!
new_var
->
shape
().
is_scalar
())
new_var
=
reformat
({
new_var
});
...
...
src/gopt/impl/profiler_impl.cpp
浏览文件 @
af576e9a
...
...
@@ -165,6 +165,7 @@ public:
private:
static
constexpr
float
PROFILE_TIME_OUT
=
1e7
;
using
ReformatAttribute
=
ReformatKey
::
Attribute
;
/*!
* \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.)
*
...
...
@@ -175,40 +176,48 @@ private:
*/
OperatorNodeRecord
profile_operator
(
const
OperatorNodeBase
*
opr
,
TensorFormats
base_format
,
const
SmallVector
<
TensorFormats
>&
available_tensor_formats
)
const
;
const
SmallVector
<
TensorFormats
>&
available_tensor_formats
,
ReformatAttribute
extra_attribute
=
ReformatAttribute
::
DEFAULT
)
const
;
float
profile_operator
(
const
OperatorNodeBase
*
opr
,
TensorFormats
base_format
,
TensorFormats
tensor_format
)
const
;
TensorFormats
tensor_format
,
ReformatAttribute
extra_attribute
=
ReformatAttribute
::
DEFAULT
)
const
;
/*!
* \brief profile opr format aware operators (like conv, deconv, conv_bias, etc.)
* \brief profile opr format aware operators (like conv, deconv, conv_bias,
* etc.)
*
* \param opr pointer to the operator node to be profiled
* \param base_config the tensor formats configuration of base opr format
* \param config all the available configuration
* \param config all the available configuration
* \return the operator node record
*/
OperatorNodeRecord
profile_operator
(
const
OperatorNodeBase
*
opr
,
const
OprTensorFormatsConfiguration
&
base_config
,
const
SmallVector
<
OprTensorFormatsConfiguration
>&
available_configs
)
const
;
const
SmallVector
<
OprTensorFormatsConfiguration
>&
available_configs
,
ReformatAttribute
extra_attribute
=
ReformatAttribute
::
DEFAULT
)
const
;
float
profile_operator
(
const
OperatorNodeBase
*
opr
,
const
OprTensorFormatsConfiguration
&
base_config
,
const
OprTensorFormatsConfiguration
&
config
)
const
;
const
OprTensorFormatsConfiguration
&
config
,
ReformatAttribute
extra_attribute
=
ReformatAttribute
::
DEFAULT
)
const
;
/*!
* \brief profile layout transform of the var node
*
* \param var pointer to the var node to be profiled
* \param base_format the original tensor formats in which the var node is
stored
* \param available_tensor_formats the available tensor formats
* \param base_format the original tensor formats in which the var node is
*
stored
\param available_tensor_formats the available tensor formats
* \param extra_attribute the extra attributes (options) of the problem
* \return the var node record
*/
VarNodeRecord
profile_var_node
(
const
VarNode
*
var
,
TensorFormats
base_format
,
const
SmallVector
<
TensorFormats
>&
available_tensor_formats
,
Reformat
Key
::
Attribute
extra_attribute
=
Reformat
Key
::
Attribute
::
DEFAULT
)
const
;
ReformatAttribute
extra_attribute
=
ReformatAttribute
::
DEFAULT
)
const
;
float
profile_var_node
(
const
VarNode
*
var
,
TensorFormats
base_format
,
const
ReformatKey
&
key
)
const
;
int
m_runs
;
/// sample times of the profiler
...
...
@@ -216,20 +225,23 @@ private:
ProfilerImpl
::
OperatorNodeRecord
ProfilerImpl
::
profile_operator
(
const
OperatorNodeBase
*
opr
,
TensorFormats
base_format
,
const
SmallVector
<
TensorFormats
>&
available_tensor_formats
)
const
{
const
SmallVector
<
TensorFormats
>&
available_tensor_formats
,
ReformatAttribute
extra_attribute
)
const
{
OperatorNodeRecord
record
;
record
.
opr
=
opr
;
auto
&
costs
=
record
.
costs
;
for
(
auto
&&
f
:
available_tensor_formats
)
{
auto
opr_format
=
tensor_formats_to_opr_format
(
f
);
costs
[
opr_format
]
=
profile_operator
(
opr
,
base_format
,
f
);
costs
[
opr_format
]
=
profile_operator
(
opr
,
base_format
,
f
,
extra_attribute
);
}
return
record
;
}
float
ProfilerImpl
::
profile_operator
(
const
OperatorNodeBase
*
opr
,
TensorFormats
base_format
,
TensorFormats
tensor_format
)
const
{
TensorFormats
tensor_format
,
ReformatAttribute
extra_attribute
)
const
{
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
graph
->
options
().
var_sanity_check_first_run
=
false
;
...
...
@@ -239,8 +251,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
auto
&&
cn
=
var
->
comp_node
();
auto
&&
dtype
=
var
->
dtype
();
auto
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
dtype
);
auto
aligned_tensor_shape
=
make_aligned_tensor_shape
(
var
,
base_format
,
tensor_format
);
auto
aligned_tensor_shape
=
ReformatManager
::
make_aligned_tensor_shape
(
var
,
base_format
,
tensor_format
,
extra_attribute
);
dval
->
resize
(
aligned_tensor_shape
);
auto
aligned_var
=
opr
::
VolatileSharedDeviceTensor
::
make
(
*
graph
,
dval
);
new_inps
[
i
]
=
aligned_var
.
node
();
...
...
@@ -263,8 +275,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
ProfilerImpl
::
OperatorNodeRecord
ProfilerImpl
::
profile_operator
(
const
OperatorNodeBase
*
opr
,
const
OprTensorFormatsConfiguration
&
base_config
,
const
SmallVector
<
OprTensorFormatsConfiguration
>&
available_configs
)
const
{
const
SmallVector
<
OprTensorFormatsConfiguration
>&
available_configs
,
ReformatAttribute
extra_attribute
)
const
{
OperatorNodeRecord
record
;
record
.
opr
=
opr
;
auto
&
costs
=
record
.
costs
;
...
...
@@ -273,7 +285,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
if
(
i
.
opr_format
==
OprFormat
::
NCHW
&&
opr
->
input
(
0
)
->
dtype
().
enumv
()
!=
DTypeEnum
::
Float32
)
continue
;
costs
[
i
.
opr_format
]
=
profile_operator
(
opr
,
base_config
,
i
);
costs
[
i
.
opr_format
]
=
profile_operator
(
opr
,
base_config
,
i
,
extra_attribute
);
}
return
record
;
}
...
...
@@ -281,7 +294,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
float
ProfilerImpl
::
profile_operator
(
const
OperatorNodeBase
*
opr
,
const
OprTensorFormatsConfiguration
&
base_config
,
const
OprTensorFormatsConfiguration
&
config
)
const
{
const
OprTensorFormatsConfiguration
&
config
,
ReformatAttribute
extra_attribute
)
const
{
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
graph
->
options
().
var_sanity_check_first_run
=
false
;
...
...
@@ -297,18 +311,18 @@ float ProfilerImpl::profile_operator(
TensorShape
aligned_shape
;
if
(
config
.
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
)
{
mgb_assert
(
base_config
.
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
);
aligned_shape
=
make_aligned_weight_shape
(
aligned_shape
=
ReformatManager
::
make_aligned_weight_shape
(
var
,
base_config
.
input_tensor_formats
[
i
],
config
.
input_tensor_formats
[
i
],
config
.
output_tensor_formats
[
0
]);
config
.
output_tensor_formats
[
0
]
,
extra_attribute
);
}
else
{
mgb_assert
(
base_config
.
input_tensor_types
[
i
]
==
config
.
input_tensor_types
[
i
]);
mgb_assert
(
base_config
.
input_tensor_types
[
i
]
==
TensorType
::
FEATURE
);
aligned_shape
=
make_aligned_tensor_shape
(
aligned_shape
=
ReformatManager
::
make_aligned_tensor_shape
(
var
,
base_config
.
input_tensor_formats
[
i
],
config
.
input_tensor_formats
[
i
]);
config
.
input_tensor_formats
[
i
]
,
extra_attribute
);
}
dval
->
resize
(
aligned_shape
);
auto
aligned_var
=
opr
::
VolatileSharedDeviceTensor
::
make
(
*
graph
,
dval
);
...
...
@@ -357,7 +371,7 @@ float ProfilerImpl::profile_operator(
ProfilerImpl
::
VarNodeRecord
ProfilerImpl
::
profile_var_node
(
const
VarNode
*
var
,
TensorFormats
base_format
,
const
SmallVector
<
TensorFormats
>&
available_tensor_formats
,
Reformat
Key
::
Attribute
attribute
)
const
{
ReformatAttribute
attribute
)
const
{
VarNodeRecord
record
;
record
.
var
=
var
;
auto
&
costs
=
record
.
costs
;
...
...
@@ -379,8 +393,8 @@ float ProfilerImpl::profile_var_node(const VarNode* var,
auto
&&
cn
=
var
->
comp_node
();
auto
&&
dtype
=
var
->
dtype
();
auto
dval
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
dtype
);
auto
aligned_tensor_shape
=
make_aligned_tensor_shape
(
var
,
base_format
,
key
.
input_format
);
auto
aligned_tensor_shape
=
ReformatManager
::
make_aligned_tensor_shape
(
var
,
base_format
,
key
.
input_format
,
key
.
attribute
);
dval
->
resize
(
aligned_tensor_shape
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
...
...
@@ -468,13 +482,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
auto
base_format
=
problem
.
base_format
();
auto
&&
available_tensor_formats
=
problem
.
available_tensor_formats
();
auto
&&
reformat_attribute
=
problem
.
attribute
().
reformat_attribute
;
ProfilingResult
profiling_result
;
auto
&
opr_record
=
profiling_result
.
opr_record
;
auto
&
var_record
=
profiling_result
.
var_record
;
for
(
auto
&&
var
:
vars
)
{
var_record
[
var
]
=
profile_var_node
(
var
,
base_format
,
available_tensor_formats
);
var_record
[
var
]
=
profile_var_node
(
var
,
base_format
,
available_tensor_formats
,
reformat_attribute
);
}
for
(
auto
&&
opr
:
oprs
)
{
auto
&&
opr_configs
=
problem
.
opr_configs
();
...
...
@@ -482,11 +497,12 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
if
(
find
==
opr_configs
.
end
())
{
if
(
skip_oprs
.
count
(
opr
)
>
0
)
{
SmallVector
<
TensorFormats
>
tensor_formats
=
{
base_format
};
opr_record
[
opr
]
=
profile_operator
(
opr
,
base_format
,
tensor_formats
);
opr_record
[
opr
]
=
profile_operator
(
opr
,
base_format
,
tensor_formats
,
reformat_attribute
);
}
else
{
opr_record
[
opr
]
=
profile_operator
(
opr
,
base_format
,
available_tensor_formats
);
available_tensor_formats
,
reformat_attribute
);
}
}
else
{
auto
&&
dispatchers
=
find
->
second
;
...
...
@@ -498,7 +514,8 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
}
}
auto
base_config
=
problem
.
base_config
(
opr
);
opr_record
[
opr
]
=
profile_operator
(
opr
,
base_config
,
configs
);
opr_record
[
opr
]
=
profile_operator
(
opr
,
base_config
,
configs
,
reformat_attribute
);
}
}
for
(
auto
&&
rpair
:
opr_record
)
{
...
...
src/gopt/impl/reformat_manager.cpp
浏览文件 @
af576e9a
...
...
@@ -21,7 +21,7 @@ using NamedTensorShape = megdnn::NamedTensorShape;
using
Dimension
=
megdnn
::
Dimension
;
namespace
{
int
gcd
(
const
int
&
p
,
const
int
&
q
)
{
static
inline
int
gcd
(
const
int
&
p
,
const
int
&
q
)
{
int
x
=
p
,
y
=
q
;
while
(
y
!=
0
)
{
if
(
x
<
y
)
{
...
...
@@ -33,6 +33,47 @@ int gcd(const int& p, const int& q) {
}
return
x
;
}
static
inline
size_t
extra_alignment
(
ReformatManager
::
ReformatKey
::
Attribute
attr
,
TensorFormats
target_formats
,
DType
dt
,
size_t
channel_alignment
)
{
using
Attribute
=
ReformatManager
::
ReformatKey
::
Attribute
;
if
(
attr
&
Attribute
::
AUTO_PADDING_NHWC
)
{
constexpr
size_t
alignment_in_bits
=
32
;
size_t
dtype_bits
=
dt
.
is_low_bit
()
?
dt
.
low_bit
()
:
dt
.
size
(
1
)
*
8
;
size_t
extra_alignment
=
alignment_in_bits
>=
dtype_bits
?
alignment_in_bits
/
dtype_bits
:
1
;
if
(
target_formats
==
TensorFormats
::
NHWC
)
channel_alignment
=
extra_alignment
*
channel_alignment
/
gcd
(
channel_alignment
,
extra_alignment
);
return
channel_alignment
;
}
return
channel_alignment
;
}
static
inline
std
::
tuple
<
size_t
,
size_t
>
extra_alignment
(
const
ReformatManager
::
ReformatKey
&
key
,
DType
dt
,
size_t
input_channel_alignment
,
size_t
output_channel_alignment
)
{
using
Attribute
=
ReformatManager
::
ReformatKey
::
Attribute
;
if
(
key
.
attribute
&
Attribute
::
AUTO_PADDING_NHWC
)
{
constexpr
size_t
alignment_in_bits
=
32
;
size_t
dtype_bits
=
dt
.
is_low_bit
()
?
dt
.
low_bit
()
:
dt
.
size
(
1
)
*
8
;
size_t
extra_alignment
=
alignment_in_bits
>=
dtype_bits
?
alignment_in_bits
/
dtype_bits
:
1
;
if
(
key
.
input_format
==
TensorFormats
::
NHWC
)
input_channel_alignment
=
input_channel_alignment
*
extra_alignment
/
gcd
(
input_channel_alignment
,
extra_alignment
);
if
(
key
.
output_format
==
TensorFormats
::
NHWC
)
output_channel_alignment
=
output_channel_alignment
*
extra_alignment
/
gcd
(
output_channel_alignment
,
extra_alignment
);
return
{
input_channel_alignment
,
output_channel_alignment
};
}
return
{
input_channel_alignment
,
output_channel_alignment
};
}
};
// namespace
// =================== ReformatManager::ReformatKey ====================*/
...
...
@@ -293,7 +334,8 @@ ReformatManager::ReformatImpl ReformatManager::get(
auto
rst
=
find
->
second
;
return
rst
;
}
mgb_assert
(
key
.
attribute
==
Attribute
::
DEFAULT
);
mgb_assert
(
!
(
key
.
attribute
&
Attribute
::
IMAGE2D
)
&&
!
(
key
.
attribute
&
Attribute
::
IC_SMALL
));
auto
&&
i
=
key
.
input_format
;
auto
&&
o
=
key
.
output_format
;
auto
ishp
=
tensor_formats_to_named_tensor_shape
(
i
);
...
...
@@ -346,6 +388,8 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue(
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)"
,
input_alignment
,
output_alignment
,
input_shape
.
to_string
().
c_str
());
std
::
tie
(
input_alignment
,
output_alignment
)
=
extra_alignment
(
key
,
orig_var
->
dtype
(),
input_alignment
,
output_alignment
);
NamedTensorShape
orig_shape
=
tensor_formats_to_named_tensor_shape
(
orig_format
);
size_t
orig_channel
=
0
;
...
...
@@ -451,6 +495,12 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)"
,
in_channel_alignment
,
out_channel_alignment
,
output_shape
.
to_string
().
c_str
());
in_channel_alignment
=
::
extra_alignment
(
key
.
attribute
,
key
.
output_format
,
orig_var
->
dtype
(),
in_channel_alignment
);
out_channel_alignment
=
::
extra_alignment
(
key
.
attribute
,
key
.
output_format
,
orig_var
->
dtype
(),
out_channel_alignment
);
size_t
aligned_in_channel
=
divup
(
in_channels
,
in_channel_alignment
)
*
in_channel_alignment
;
if
(
extra_alignment
.
name
==
out_channel_name
)
{
...
...
@@ -506,9 +556,9 @@ const ReformatManager& ReformatManager::instance() {
return
inst
;
}
TensorShape
mgb
::
gopt
::
make_aligned_tensor_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
)
{
TensorShape
ReformatManager
::
make_aligned_tensor_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
,
ReformatKey
::
Attribute
extra_attribute
)
{
using
Dimension
=
megdnn
::
Dimension
;
static
constexpr
uint32_t
UNDETERMINED_EXTENT
=
Dimension
::
UNDETERMINED_EXTENT
;
...
...
@@ -545,6 +595,15 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var,
tshp
[
i
]
=
oshp
[
idx
]
*
factor
;
else
tshp
[
i
]
=
divup
(
oshp
[
idx
],
factor
);
if
(
name
==
Dimension
::
Name
::
C
)
{
size_t
channel_alignment
=
target_shape
[
i
].
stride
();
size_t
channels
=
tshp
[
i
]
*
channel_alignment
;
size_t
new_channel_alignment
=
extra_alignment
(
extra_attribute
,
target_formats
,
var
->
dtype
(),
channel_alignment
);
tshp
[
i
]
=
divup
(
channels
,
new_channel_alignment
)
*
new_channel_alignment
/
channel_alignment
;
}
}
else
{
tshp
[
i
]
=
target_shape
[
i
].
extent
();
}
...
...
@@ -552,11 +611,12 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var,
return
tshp
;
}
TensorShape
mgb
::
gopt
::
make_aligned_weight_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
,
TensorFormats
extra_formats
)
{
auto
tshp
=
make_aligned_tensor_shape
(
var
,
orig_formats
,
target_formats
);
TensorShape
ReformatManager
::
make_aligned_weight_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
,
TensorFormats
extra_formats
,
ReformatKey
::
Attribute
extra_attribute
)
{
auto
tshp
=
make_aligned_tensor_shape
(
var
,
orig_formats
,
target_formats
,
extra_attribute
);
auto
extra_shape
=
tensor_formats_to_named_tensor_shape
(
extra_formats
);
using
Dimension
=
megdnn
::
Dimension
;
static
constexpr
uint32_t
UNDETERMINED_EXTENT
=
...
...
@@ -567,6 +627,9 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var,
if
(
name
==
Dimension
::
Name
::
C
&&
extra_shape
[
i
].
extent
()
==
UNDETERMINED_EXTENT
)
{
out_channel_alignment
=
extra_shape
[
i
].
stride
();
out_channel_alignment
=
extra_alignment
(
extra_attribute
,
target_formats
,
var
->
dtype
(),
out_channel_alignment
);
}
}
...
...
@@ -583,9 +646,8 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var,
return
tshp
;
}
ReformatManager
::
AlignmentDesc
mgb
::
gopt
::
make_aligned_desc
(
ReformatManager
::
AlignmentDesc
ReformatManager
::
make_aligned_desc
(
TensorFormats
weight_format
,
TensorFormats
out_feature_format
)
{
using
AlignmentDesc
=
ReformatManager
::
AlignmentDesc
;
using
Name
=
Dimension
::
Name
;
auto
weight_shape
=
tensor_formats_to_named_tensor_shape
(
weight_format
);
auto
out_shape
=
tensor_formats_to_named_tensor_shape
(
out_feature_format
);
...
...
src/gopt/include/megbrain/gopt/global_layout_transform.h
浏览文件 @
af576e9a
...
...
@@ -143,6 +143,7 @@ public:
TensorFormats
base_format
()
const
{
return
m_ctx
.
attribute
().
base_tensor_formats
;
}
Attribute
attribute
()
const
{
return
m_ctx
.
attribute
();
}
/*!
* \brief return the tensor formats configuration of an operator in the
* default op format
...
...
src/gopt/include/megbrain/gopt/reformat_manager.h
浏览文件 @
af576e9a
...
...
@@ -74,6 +74,7 @@ public:
DEFAULT
=
0
,
IMAGE2D
=
1
<<
0
,
IC_SMALL
=
1
<<
1
,
AUTO_PADDING_NHWC
=
1
<<
2
,
};
TensorFormats
input_format
,
output_format
;
DTypeEnum
input_dtype
,
output_dtype
;
...
...
@@ -124,23 +125,40 @@ public:
ReformatImpl
auto_aligned_reformat_weight
(
const
VarNode
*
orig_var
,
const
ReformatKey
&
key
,
const
AlignmentDesc
&
extra_alignment
=
{})
const
;
static
TensorShape
make_aligned_tensor_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
,
ReformatKey
::
Attribute
extra_attribute
=
ReformatKey
::
Attribute
::
DEFAULT
);
static
TensorShape
make_aligned_weight_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
,
TensorFormats
extra_formats
,
ReformatKey
::
Attribute
extra_attribute
=
ReformatKey
::
Attribute
::
DEFAULT
);
static
AlignmentDesc
make_aligned_desc
(
TensorFormats
weight_format
,
TensorFormats
out_feature_format
);
static
const
ReformatManager
&
instance
();
private:
ReformatCache
m_cache
;
};
TensorShape
make_aligned_tensor_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
);
TensorShape
make_aligned_weight_shape
(
const
VarNode
*
var
,
TensorFormats
orig_formats
,
TensorFormats
target_formats
,
TensorFormats
extra_formats
);
MGB_DEF_ENUM_CLASS_BIT_OPR
(
ReformatManager
::
ReformatKey
::
Attribute
);
//
//TensorShape make_aligned_tensor_shape(
// const VarNode* var, TensorFormats orig_formats,
// TensorFormats target_formats,
// ReformatManager::ReformatKey::Attribute extra_attribute =
// ReformatManager::ReformatKey::Attribute::DEFAULT);
//
//TensorShape make_aligned_weight_shape(
// const VarNode* var, TensorFormats orig_formats,
// TensorFormats target_formats, TensorFormats extra_formats,
// ReformatManager::ReformatKey::Attribute extra_attribute =
// ReformatManager::ReformatKey::Attribute::DEFAULT);
ReformatManager
::
AlignmentDesc
make_aligned_desc
(
TensorFormats
weight_format
,
TensorFormats
out_feature_format
);
}
// namespace gopt
}
// namespace mgb
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录