Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c33126ab
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c33126ab
编写于
7月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add reformat manager
GitOrigin-RevId: b9791b131a996f4f7c55173294abf0715fac97fc
上级
65d554ed
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
515 addition
and
13 deletion
+515
-13
dnn/include/megdnn/named_tensor.h
dnn/include/megdnn/named_tensor.h
+1
-0
dnn/src/common/named_tensor.cpp
dnn/src/common/named_tensor.cpp
+6
-2
dnn/test/common/named_tensor.cpp
dnn/test/common/named_tensor.cpp
+1
-1
src/gopt/impl/reformat_emitter.cpp
src/gopt/impl/reformat_emitter.cpp
+5
-10
src/gopt/impl/reformat_manager.cpp
src/gopt/impl/reformat_manager.cpp
+399
-0
src/gopt/include/megbrain/gopt/reformat_manager.h
src/gopt/include/megbrain/gopt/reformat_manager.h
+103
-0
未找到文件。
dnn/include/megdnn/named_tensor.h
浏览文件 @
c33126ab
...
...
@@ -30,6 +30,7 @@ public:
C
=
'C'
,
// input channel
H
=
'H'
,
// input height
W
=
'W'
,
// input width
G
=
'G'
,
// group
K
=
'K'
,
// output channel
R
=
'R'
,
// filter height
S
=
'S'
,
// filter width
...
...
dnn/src/common/named_tensor.cpp
浏览文件 @
c33126ab
...
...
@@ -18,8 +18,9 @@ using namespace megdnn;
/* ===================== Dimension ============================ */
const
Dimension
::
Name
Dimension
::
NAME_ALL
[]
=
{
Dimension
::
Name
::
N
,
Dimension
::
Name
::
C
,
Dimension
::
Name
::
H
,
Dimension
::
Name
::
W
,
Dimension
::
Name
::
K
,
Dimension
::
Name
::
R
,
Dimension
::
Name
::
S
,
Dimension
::
Name
::
P
,
Dimension
::
Name
::
Q
,
Dimension
::
Name
::
W
,
Dimension
::
Name
::
G
,
Dimension
::
Name
::
K
,
Dimension
::
Name
::
R
,
Dimension
::
Name
::
S
,
Dimension
::
Name
::
P
,
Dimension
::
Name
::
Q
,
};
const
int
Dimension
::
NR_NAMES
=
sizeof
(
Dimension
::
NAME_ALL
);
Dimension
::
Dimension
(
const
std
::
string
&
expr
)
{
...
...
@@ -92,6 +93,9 @@ bool Dimension::operator<(const Dimension& rhs) const {
if
(
m_name
!=
rhs
.
m_name
)
{
return
static_cast
<
char
>
(
m_name
)
<
static_cast
<
char
>
(
rhs
.
m_name
);
}
if
(
m_stride
==
rhs
.
m_stride
)
{
return
m_extent
>
rhs
.
m_extent
;
}
return
m_stride
>
rhs
.
m_stride
;
}
...
...
dnn/test/common/named_tensor.cpp
浏览文件 @
c33126ab
...
...
@@ -21,7 +21,7 @@ using namespace megdnn;
using
megdnn
::
test
::
MegDNNError
;
TEST
(
NAMED_TENSOR
,
NAMED_TENSOR_SHAPE_BASIC
)
{
ASSERT_EQ
(
Dimension
::
NR_NAMES
,
9
);
ASSERT_EQ
(
Dimension
::
NR_NAMES
,
10
);
Dimension
dim0
=
{
"C"
},
dim1
=
{
"C//32"
},
dim2
=
{
"C//4"
},
dim3
=
{
"C%32"
},
dim4
=
{
"C%4"
},
dim5
=
{
"C//4%8"
};
ASSERT_TRUE
(
dim0
==
dim1
*
dim3
);
...
...
src/gopt/impl/reformat_emitter.cpp
浏览文件 @
c33126ab
...
...
@@ -182,7 +182,6 @@ ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const {
};
std
::
sort
(
src_dims
.
begin
(),
src_dims
.
end
(),
compare
);
std
::
sort
(
dest_dims
.
begin
(),
dest_dims
.
end
(),
compare
);
auto
src_iter
=
src_dims
.
begin
();
auto
dest_iter
=
dest_dims
.
begin
();
for
(;
src_iter
!=
src_dims
.
end
()
&&
dest_iter
!=
dest_dims
.
end
();)
{
...
...
@@ -233,18 +232,14 @@ ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const {
}
UnderlyingBuilders
builders
;
if
(
!
m_src
.
eq_shape
(
i1
))
{
builders
.
make_shape1
=
std
::
move
(
std
::
get
<
0
>
(
MakeShapeEmitter
(
m_src
,
i1
).
emit
()));
builders
.
reshape1
=
std
::
move
(
std
::
get
<
0
>
(
ReshapeEmitter
(
m_src
,
i1
).
emit
()));
builders
.
make_shape1
=
std
::
get
<
0
>
(
MakeShapeEmitter
(
m_src
,
i1
).
emit
());
builders
.
reshape1
=
std
::
get
<
0
>
(
ReshapeEmitter
(
m_src
,
i1
).
emit
());
}
builders
.
dimshuffle
=
std
::
move
(
std
::
get
<
0
>
(
DimshuffleEmitter
(
permute
).
emit
()));
builders
.
dimshuffle
=
std
::
get
<
0
>
(
DimshuffleEmitter
(
permute
).
emit
());
if
(
!
m_dest
.
eq_shape
(
i2
))
{
builders
.
make_shape2
=
std
::
move
(
std
::
get
<
0
>
(
MakeShapeEmitter
(
m_src
,
m_dest
).
emit
()));
builders
.
reshape2
=
std
::
move
(
std
::
get
<
0
>
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
()));
std
::
get
<
0
>
(
MakeShapeEmitter
(
m_src
,
m_dest
).
emit
());
builders
.
reshape2
=
std
::
get
<
0
>
(
ReshapeEmitter
(
i2
,
m_dest
).
emit
());
}
return
builders
;
}
...
...
src/gopt/impl/reformat_manager.cpp
0 → 100644
浏览文件 @
c33126ab
/**
* \file src/gopt/impl/reformat_manager.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/reformat_manager.h"
#include <numeric>
#include "megbrain/opr/tensor_manip.h"
using
namespace
mgb
;
using
namespace
gopt
;
using
NamedTensorShape
=
megdnn
::
NamedTensorShape
;
namespace
{
NamedTensorShape
tensor_formats_to_named_tensor_shape
(
TensorFormats
format
)
{
switch
(
format
)
{
case
TensorFormats
::
NCHW
:
return
{{
"N"
},
{
"C"
},
{
"H"
},
{
"W"
}};
case
TensorFormats
::
NHWC
:
return
{{
"N"
},
{
"H"
},
{
"W"
},
{
"C"
}};
case
TensorFormats
::
NCHWc4
:
return
{{
"N"
},
{
"C//4"
},
{
"H"
},
{
"W"
},
{
"C%4"
}};
case
TensorFormats
::
NCHWc8
:
return
{{
"N"
},
{
"C//8"
},
{
"H"
},
{
"W"
},
{
"C%8"
}};
case
TensorFormats
::
NCHWc32
:
return
{{
"N"
},
{
"C//32"
},
{
"H"
},
{
"W"
},
{
"C%32"
}};
case
TensorFormats
::
NCHWc64
:
return
{{
"N"
},
{
"C//64"
},
{
"H"
},
{
"W"
},
{
"C%64"
}};
case
TensorFormats
::
CHWNc4
:
return
{{
"C//4"
},
{
"H"
},
{
"W"
},
{
"N"
},
{
"C%4"
}};
case
TensorFormats
::
NHCWc4
:
return
{{
"N"
},
{
"H"
},
{
"C//4"
},
{
"W"
},
{
"C%4"
}};
case
TensorFormats
::
KRSCk4
:
return
{{
"K//4"
},
{
"R"
},
{
"S"
},
{
"C"
},
{
"K%4"
}};
case
TensorFormats
::
GKRSCk4
:
return
{{
"G"
},
{
"K//4"
},
{
"R"
},
{
"S"
},
{
"C"
},
{
"K%4"
}};
case
TensorFormats
::
C1RSc4
:
return
{{
"C//4"
},
{
"C%1"
},
{
"R"
},
{
"S"
},
{
"C%4"
}};
case
TensorFormats
::
KRSCk4c4
:
return
{{
"K//4"
},
{
"R"
},
{
"S"
},
{
"C//4"
},
{
"K%4"
},
{
"C%4"
}};
case
TensorFormats
::
GKRSCk4c4
:
return
{{
"G"
},
{
"K//4"
},
{
"R"
},
{
"S"
},
{
"C//4"
},
{
"K%4"
},
{
"C%4"
}};
case
TensorFormats
::
KCRSk4c4
:
return
{{
"K//4"
},
{
"C//4"
},
{
"R"
},
{
"S"
},
{
"K%4"
},
{
"C%4"
}};
case
TensorFormats
::
GKCRSk4c4
:
return
{{
"G"
},
{
"K//4"
},
{
"C//4"
},
{
"R"
},
{
"S"
},
{
"K%4"
},
{
"C%4"
}};
case
TensorFormats
::
KCRSc4k4
:
return
{{
"K//4"
},
{
"C//4"
},
{
"R"
},
{
"S"
},
{
"C%4"
},
{
"K%4"
}};
case
TensorFormats
::
GKCRSc4k4
:
return
{{
"G"
},
{
"K//4"
},
{
"C//4"
},
{
"R"
},
{
"S"
},
{
"C%4"
},
{
"K%4"
}};
case
TensorFormats
::
C11RSc4
:
return
{{
"C//4"
},
{
"C%1"
},
{
"C%1"
},
{
"R"
},
{
"S"
},
{
"C%4"
}};
case
TensorFormats
::
KCRSc8k8
:
return
{{
"K//8"
},
{
"C//8"
},
{
"R"
},
{
"S"
},
{
"C%8"
},
{
"K%8"
}};
case
TensorFormats
::
GKCRSc8k8
:
return
{{
"G"
},
{
"K//8"
},
{
"C//8"
},
{
"R"
},
{
"S"
},
{
"C%8"
},
{
"K%8"
}};
case
TensorFormats
::
C11RSc8
:
return
{{
"C//8"
},
{
"C%1"
},
{
"C%1"
},
{
"R"
},
{
"S"
},
{
"C%8"
}};
case
TensorFormats
::
KRSCk8
:
return
{{
"K//8"
},
{
"R"
},
{
"S"
},
{
"C"
},
{
"K%8"
}};
case
TensorFormats
::
KCRS
:
return
{{
"K"
},
{
"C"
},
{
"R"
},
{
"S"
}};
case
TensorFormats
::
GKCRS
:
return
{{
"G"
},
{
"K"
},
{
"C"
},
{
"R"
},
{
"S"
}};
case
TensorFormats
::
C11RS
:
return
{{
"C"
},
{
"C%1"
},
{
"C%1"
},
{
"R"
},
{
"S"
}};
default:
mgb_throw
(
AssertionError
,
"invalid tensor formats(%u)"
,
static_cast
<
uint32_t
>
(
format
));
}
}
};
// namespace
// =================== ReformatManager::ReformatKey ====================*/
std
::
string
ReformatManager
::
ReformatKey
::
to_string
()
const
{
auto
&&
i
=
tensor_formats_to_named_tensor_shape
(
input_format
);
auto
&&
o
=
tensor_formats_to_named_tensor_shape
(
output_format
);
std
::
string
input_name
,
output_name
;
#define cb(_name) \
if (input_dtype == DTypeEnum::_name) { \
input_name = #_name; \
} else
MEGDNN_FOREACH_DTYPE_NAME
(
cb
)
MEGDNN_FOREACH_PARAMETERIZED_DTYPE
(
cb
)
{
mgb_throw
(
MegBrainError
,
"invalid input dtype enum(%u)"
,
static_cast
<
uint32_t
>
(
input_dtype
));
}
#undef cb
#define cb(_name) \
if (output_dtype == DTypeEnum::_name) { \
output_name = #_name; \
} else
MEGDNN_FOREACH_DTYPE_NAME
(
cb
)
MEGDNN_FOREACH_PARAMETERIZED_DTYPE
(
cb
)
{
mgb_throw
(
MegBrainError
,
"invalid output dtype enum(%u)"
,
static_cast
<
uint32_t
>
(
output_dtype
));
}
#undef cb
return
ssprintf
(
"%s;%s;%s;%s;%s"
,
i
.
to_string
().
c_str
(),
o
.
to_string
().
c_str
(),
std
::
to_string
(
static_cast
<
uint32_t
>
(
attribute
)).
c_str
(),
input_name
.
c_str
(),
output_name
.
c_str
());
}
size_t
ReformatManager
::
ReformatKey
::
Hash
::
operator
()(
const
ReformatKey
&
key
)
const
{
auto
enumhash
=
mgb
::
enumhash
();
size_t
h
=
enumhash
(
key
.
input_format
);
h
=
mgb
::
hash_pair_combine
(
h
,
enumhash
(
key
.
output_format
));
h
=
mgb
::
hash_pair_combine
(
h
,
enumhash
(
key
.
attribute
));
h
=
mgb
::
hash_pair_combine
(
h
,
enumhash
(
key
.
input_dtype
));
h
=
mgb
::
hash_pair_combine
(
h
,
enumhash
(
key
.
output_dtype
));
return
h
;
}
bool
ReformatManager
::
ReformatKey
::
Equal
::
operator
()(
const
ReformatKey
&
lhs
,
const
ReformatKey
&
rhs
)
const
{
return
lhs
.
input_format
==
rhs
.
input_format
&&
lhs
.
output_format
==
rhs
.
output_format
&&
lhs
.
input_dtype
==
rhs
.
input_dtype
&&
lhs
.
output_dtype
==
rhs
.
output_dtype
&&
lhs
.
attribute
==
rhs
.
attribute
;
}
// =================== ReformatManager ====================*/
#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \
cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \
cb(NHCWc4)
#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \
cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \
cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \
cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8)
ReformatManager
::
ReformatManager
()
{
static
constexpr
TensorFormats
feature_tensor_formats
[]
=
{
#define cb(_fmt) TensorFormats::_fmt,
FOREACH_FEATURE_TENSOR_FORMATS
(
cb
)
#undef cb
};
static
constexpr
int
nr_feature_tensor_formats
=
sizeof
(
feature_tensor_formats
)
/
sizeof
(
TensorFormats
);
for
(
int
i
=
0
;
i
<
nr_feature_tensor_formats
;
++
i
)
{
for
(
int
o
=
0
;
o
<
nr_feature_tensor_formats
;
++
o
)
{
if
(
i
==
o
)
continue
;
NamedTensorShape
input_shape
=
tensor_formats_to_named_tensor_shape
(
feature_tensor_formats
[
i
]);
NamedTensorShape
output_shape
=
tensor_formats_to_named_tensor_shape
(
feature_tensor_formats
[
o
]);
auto
impl
=
std
::
get
<
0
>
(
ReformatEmitter
{
input_shape
,
output_shape
}.
emit
());
m_cache
.
emplace
(
ReformatKey
{
feature_tensor_formats
[
i
],
feature_tensor_formats
[
o
]},
impl
);
}
}
static
constexpr
TensorFormats
default_weight_tensor_formats
=
TensorFormats
::
KCRS
;
static
constexpr
TensorFormats
default_group_conv_weight_tensor_formats
=
TensorFormats
::
GKCRS
;
static
constexpr
TensorFormats
default_chan_conv_weight_tensor_formats
=
TensorFormats
::
C11RS
;
static
constexpr
TensorFormats
weight_tensor_formats
[]
=
{
#define cb(_fmt) TensorFormats::_fmt,
FOREACH_WEIGHT_TENSOR_FORMATS
(
cb
)
#undef cb
};
static
constexpr
int
nr_weight_tensor_formats
=
sizeof
(
weight_tensor_formats
)
/
sizeof
(
TensorFormats
);
using
Name
=
megdnn
::
Dimension
::
Name
;
for
(
int
o
=
0
;
o
<
nr_weight_tensor_formats
;
++
o
)
{
NamedTensorShape
output_shape
=
tensor_formats_to_named_tensor_shape
(
weight_tensor_formats
[
o
]);
TensorFormats
input_format
;
if
(
output_shape
[
0
].
name
()
==
Name
::
G
)
{
input_format
=
default_group_conv_weight_tensor_formats
;
}
else
if
(
output_shape
[
0
].
name
()
==
Name
::
C
)
{
input_format
=
default_chan_conv_weight_tensor_formats
;
}
else
{
mgb_assert
(
output_shape
[
0
].
name
()
==
Name
::
K
);
input_format
=
default_weight_tensor_formats
;
}
NamedTensorShape
input_shape
=
tensor_formats_to_named_tensor_shape
(
input_format
);
auto
impl
=
std
::
get
<
0
>
(
ReformatEmitter
{
input_shape
,
output_shape
}.
emit
());
m_cache
.
emplace
(
ReformatKey
{
input_format
,
weight_tensor_formats
[
o
]},
impl
);
}
{
auto
i
=
TensorFormats
::
NCHW
,
o
=
TensorFormats
::
NCHWc4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NCHW4_IC_SMALL
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IC_SMALL
},
impl
);
}
{
auto
i
=
TensorFormats
::
KCRS
,
o
=
TensorFormats
::
KCRSc4k4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IC_SMALL
},
impl
);
}
{
auto
i
=
TensorFormats
::
NCHW
,
o
=
TensorFormats
::
NCHWc64
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NCHW64
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
QuantizedS4
,
DTypeEnum
::
QuantizedS4
},
impl
);
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
Quantized4Asymm
,
DTypeEnum
::
Quantized4Asymm
},
impl
);
}
{
auto
i
=
TensorFormats
::
NCHWc64
,
o
=
TensorFormats
::
NCHW
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NCHW64
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
QuantizedS4
,
DTypeEnum
::
QuantizedS4
},
impl
);
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
Quantized4Asymm
,
DTypeEnum
::
Quantized4Asymm
},
impl
);
}
{
auto
i
=
TensorFormats
::
NCHW
,
o
=
TensorFormats
::
NHWC
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWC
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
QuantizedS4
,
DTypeEnum
::
QuantizedS4
},
impl
);
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
Quantized4Asymm
,
DTypeEnum
::
Quantized4Asymm
},
impl
);
}
{
auto
i
=
TensorFormats
::
NHWC
,
o
=
TensorFormats
::
NCHW
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWC
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
QuantizedS4
,
DTypeEnum
::
QuantizedS4
},
impl
);
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
DEFAULT
,
DTypeEnum
::
Quantized4Asymm
,
DTypeEnum
::
Quantized4Asymm
},
impl
);
}
// nhcw4
{
auto
i
=
TensorFormats
::
KCRS
,
o
=
TensorFormats
::
KRSCk4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
INTER_WEIGHT_DENSEI
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
KCRS
,
o
=
TensorFormats
::
GKRSCk4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
INTER_WEIGHT_GROUPI
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
KCRS
,
o
=
TensorFormats
::
C1RSc4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
INTER_WEIGHT_CHANI
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
NCHW
,
o
=
TensorFormats
::
NHCWc4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWCD4I
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
{
auto
i
=
TensorFormats
::
NHCWc4
,
o
=
TensorFormats
::
NCHW
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWCD4I
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
},
impl
);
}
// nhcw4-dot
{
auto
i
=
TensorFormats
::
KCRS
,
o
=
TensorFormats
::
KRSCk4c4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
INTER_WEIGHT_DENSEI_DOT
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
,
DTypeEnum
::
QuantizedS8
,
DTypeEnum
::
QuantizedS8
},
impl
);
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
,
DTypeEnum
::
Quantized8Asymm
,
DTypeEnum
::
Quantized8Asymm
},
impl
);
}
{
auto
i
=
TensorFormats
::
GKCRS
,
o
=
TensorFormats
::
GKRSCk4c4
;
auto
&&
impl
=
[](
const
VarNodeArray
&
vars
)
{
return
opr
::
RelayoutFormat
::
make
(
vars
[
0
],
megdnn
::
param
::
RelayoutFormat
::
Mode
::
INTER_WEIGHT_GROUPI_DOT
)
.
node
();
};
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
,
DTypeEnum
::
QuantizedS8
,
DTypeEnum
::
QuantizedS8
},
impl
);
m_cache
.
emplace
(
ReformatKey
{
i
,
o
,
Attribute
::
IMAGE2D
,
DTypeEnum
::
Quantized8Asymm
,
DTypeEnum
::
Quantized8Asymm
},
impl
);
}
}
#undef FOREACH_FEATURE_TENSOR_FORMATS
#undef FOREACH_WEIGHT_TENSOR_FORMATS
const
ReformatManager
::
ReformatImpl
&
ReformatManager
::
get
(
const
ReformatKey
&
key
)
const
{
MGB_TRY
{
auto
&&
impl
=
m_cache
.
at
(
key
);
return
impl
;
}
MGB_CATCH
(
std
::
exception
&
exc
,
{
mgb_log_error
(
"cannot find ReformatImpl for ReformatKey(%s), extra "
"message(%s)"
,
key
.
to_string
().
c_str
(),
exc
.
what
());
throw
;
})
}
const
ReformatManager
&
ReformatManager
::
instance
()
{
static
ReformatManager
*
inst
=
nullptr
;
if
(
inst
==
nullptr
)
{
inst
=
new
ReformatManager
();
}
return
*
inst
;
}
// vim: syntax=cpp.doxygen
src/gopt/include/megbrain/gopt/reformat_manager.h
0 → 100644
浏览文件 @
c33126ab
/**
* \file src/gopt/include/megbrain/gopt/reformat_manager.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/gopt/reformat_emitter.h"
#include "megbrain/graph.h"
namespace
mgb
{
namespace
gopt
{
enum
class
TensorFormats
:
uint32_t
{
// input tensor formats
NCHW
=
0
,
///< [N, C, H, W]
NHWC
=
1
,
///< [N, H, W, C]
NCHWc4
=
2
,
///< [N, C/4, H, W, C%4]
NCHWc8
=
3
,
///< [N, C/8, H, W, C%8]
NCHWc32
=
4
,
///< [N, C/32, H, W, C%32]
NCHWc64
=
5
,
///< [N, C/64, H, W, C%64]
CHWNc4
=
6
,
///< [C/4, H, W, N, C%4]
NHCWc4
=
7
,
///< [N, H, C/4, W, C%4]
// weight tensor formats
// NHWCD4
KRSCk4
=
8
,
///< [K/4, R, S, C, K%4] [dense conv]
GKRSCk4
=
9
,
///< [G, K/4, R, S, C, K%4] [group conv]
C1RSc4
=
10
,
///< [C/4, 1, R, S, C%4] [channel wise conv]
// NHWCD4-dot
KRSCk4c4
=
11
,
///< [K/4, R, S, C/4, K%4, C%4] [dense conv]
GKRSCk4c4
=
12
,
///< [G, K/4, R, S, C/4, K%4, C%4] [group conv]
// NCHW44-dot
KCRSk4c4
=
13
,
///< [K/4, C/4, R, S, K%4, C%4] [dense conv]
GKCRSk4c4
=
14
,
///< [G, K/4, C/4, R, S, K%4, C%4] [group conv]
// NCHW44
KCRSc4k4
=
15
,
///< [K/4, C/4, R, S, C%4, K%4] [dense conv]
GKCRSc4k4
=
16
,
///< [G, K/4, C/4, R, S, C%4, K%4] [group conv]
C11RSc4
=
17
,
///< [C/4, 1, 1, R, S, C%4] [channel wise conv]
// NCHW88
KCRSc8k8
=
18
,
///< [K/8, C/8, R, S, C%8, K%8] [dense conv]
GKCRSc8k8
=
19
,
///< [G, K/8, C/8, R, S, C%8, K%8] [group conv]
C11RSc8
=
20
,
///< [C/8, 1, 1, R, S, C%8] [channel wise conv]
KRSCk8
=
21
,
///< [K/8, R, S, C, K%8]
// default weight format
KCRS
=
22
,
///< [K, C, R, S]
GKCRS
=
23
,
///< [G, K, C, R, S]
C11RS
=
24
,
///< [C, 1, 1, R, S]
};
class
ReformatManager
:
public
NonCopyableObj
{
ReformatManager
();
public:
using
ReformatImpl
=
thin_function
<
VarNode
*
(
const
VarNodeArray
&
)
>
;
enum
class
Attribute
:
uint32_t
{
DEFAULT
=
0
,
IMAGE2D
=
1
<<
0
,
IC_SMALL
=
1
<<
1
,
};
struct
ReformatKey
{
TensorFormats
input_format
,
output_format
;
DTypeEnum
input_dtype
,
output_dtype
;
Attribute
attribute
;
std
::
string
to_string
()
const
;
ReformatKey
(
TensorFormats
input_format_
,
TensorFormats
output_format_
,
Attribute
attribute_
=
Attribute
::
DEFAULT
,
DTypeEnum
input_dtype_
=
DTypeEnum
::
Float32
,
DTypeEnum
output_dtype_
=
DTypeEnum
::
Float32
)
:
input_format
{
input_format_
},
output_format
{
output_format_
},
input_dtype
{
input_dtype_
},
output_dtype
{
output_dtype_
},
attribute
{
attribute_
}
{}
struct
Hash
{
size_t
operator
()(
const
ReformatKey
&
key
)
const
;
};
struct
Equal
{
bool
operator
()(
const
ReformatKey
&
lhs
,
const
ReformatKey
&
rhs
)
const
;
};
};
using
ReformatCache
=
std
::
unordered_map
<
ReformatKey
,
ReformatImpl
,
ReformatKey
::
Hash
,
ReformatKey
::
Equal
>
;
const
ReformatImpl
&
get
(
const
ReformatKey
&
key
)
const
;
static
const
ReformatManager
&
instance
();
private:
ReformatCache
m_cache
;
};
}
// namespace gopt
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录