Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a6230ba9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a6230ba9
编写于
9月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): global layout transform support arm
GitOrigin-RevId: db50b33c112b99ab6f34cd81d9cf62790fc87c6e
上级
0be6ca88
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
314 addition
and
31 deletion
+314
-31
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+9
-6
src/gopt/impl/global_layout_transform/layout_transform_context.cpp
...impl/global_layout_transform/layout_transform_context.cpp
+41
-0
src/gopt/impl/global_layout_transform/layout_transform_pass.cpp
...pt/impl/global_layout_transform/layout_transform_pass.cpp
+10
-2
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
...mpl/global_layout_transform/opr_tensor_formats_config.cpp
+202
-1
src/gopt/impl/global_layout_transform/profiler_impl.cpp
src/gopt/impl/global_layout_transform/profiler_impl.cpp
+4
-4
src/gopt/impl/global_layout_transform/reformat_manager.cpp
src/gopt/impl/global_layout_transform/reformat_manager.cpp
+22
-5
src/gopt/impl/global_layout_transform/subgraph_extractor.cpp
src/gopt/impl/global_layout_transform/subgraph_extractor.cpp
+19
-13
src/gopt/impl/global_layout_transform/utils.h
src/gopt/impl/global_layout_transform/utils.h
+7
-0
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
a6230ba9
...
...
@@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
_passes need_param_fuse = true; \
}
using
Target
=
GraphTuningOptions
::
Target
;
cb
(
layout_transform
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
if
(
options
.
target
==
Target
::
CUDA
)
add_pass
<
FuseConvBiasZPass
>
();
add_pass
(
LayoutTransformPass
::
make
(
options
.
target
));
add_pass
<
ShuffleShuffleRemovePass
>
();
if
(
options
.
target
==
Target
::
CUDA
)
{
add_pass
(
FuseNCHW4Int8Preprocess
::
make
());
add_pass
<
FuseWarpPerspectiveDimshufflePass
>
();
#if CUDA_VERSION >= 10020
add_pass
<
FoldingConvBiasDimshufflePass
>
();
add_pass
<
FoldingConvBiasTypecvtPass
>
();
#endif
}
});
#undef cb
if
(
need_param_fuse
)
{
add_pass
<
ParamFusePass
>
();
add_pass
<
ParamMergePass
>
();
}
return
*
this
;
}
...
...
src/gopt/impl/global_layout_transform/layout_transform_context.cpp
浏览文件 @
a6230ba9
...
...
@@ -15,6 +15,7 @@
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
using
namespace
mgb
;
using
namespace
gopt
;
...
...
@@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
{
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW64
});
return
ctx
;
}
std
::
unique_ptr
<
LayoutTransformContext
>
make_arm_ctx
(
OprFormat
base_opr_format
,
TensorFormats
base_tensor_format
)
{
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ConvolutionForward
::
typeinfo
(),
opr
::
ElemwiseMultiType
::
typeinfo
(),
opr
::
Elemwise
::
typeinfo
(),
opr
::
TypeCvt
::
typeinfo
(),
opr
::
PoolingForward
::
typeinfo
(),
opr
::
Resize
::
typeinfo
(),
opr
::
PowC
::
typeinfo
(),
opr
::
Concat
::
typeinfo
(),
};
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NCHWc4
,
DNN_INC_FLOAT16
(
TensorFormats
::
NCHWc8
)};
Attribute
attribute
=
{
base_opr_format
,
base_tensor_format
,
Target
::
ARM
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
),
OprFormat
::
NCHW44_DOT
})
.
add_opr_config
(
opr
::
ConvolutionForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
),
OprFormat
::
NCHW44_DOT
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
)})
.
add_opr_config
(
opr
::
ResizeForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
)});
return
ctx
;
}
}
// namespace
/* ================= LayoutTransformContext ==================*/
...
...
@@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
switch
(
target
)
{
case
Target
::
CUDA
:
return
make_cuda_ctx
(
base_opr_format
,
base_tensor_format
);
case
Target
::
ARM
:
return
make_arm_ctx
(
base_opr_format
,
base_tensor_format
);
default:
mgb_assert
(
false
,
"unsupported target %s
\n
"
,
target_to_string
(
target
));
}
...
...
src/gopt/impl/global_layout_transform/layout_transform_pass.cpp
浏览文件 @
a6230ba9
...
...
@@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto
&&
opr_configs
=
m_ctx
->
opr_configs
();
auto
&&
base_fmt
=
m_ctx
->
attribute
().
base_tensor_formats
;
auto
&&
base_opr_fmt
=
m_ctx
->
attribute
().
base_opr_format
;
auto
&&
reformat_attribute
=
m_ctx
->
attribute
().
reformat_attribute
;
ThinHashMap
<
VarNode
*
,
TensorFormats
>
var2fmts
;
static
ThinHashSet
<
Typeinfo
*>
format_aware_oprs
=
{
...
...
@@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const {
#undef cb
};
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
opr_configs
,
&
base_fmt
,
&
reformat_attribute
,
&
rewriter
,
&
solution
,
&
var2fmts
,
&
endpoint_vars
](
OperatorNodeBase
*
opr
)
{
auto
on_opr
=
[
&
opr_configs
,
&
base_fmt
,
&
base_opr_fmt
,
&
reformat_attribute
,
&
rewriter
,
&
solution
,
&
var2fmts
,
&
endpoint_vars
](
OperatorNodeBase
*
opr
)
{
auto
it
=
solution
.
find
(
opr
);
if
(
it
!=
solution
.
end
())
{
auto
opr_fmt
=
it
->
second
;
auto
find
=
opr_configs
.
find
(
opr
->
dyn_typeinfo
());
Maybe
<
OprTensorFormatsConfiguration
>
fmtcfg
=
None
;
Maybe
<
OprTensorFormatsConfiguration
>
basecfg
=
None
;
if
(
find
!=
opr_configs
.
end
())
{
fmtcfg
=
(
*
find
->
second
.
at
(
opr_fmt
))(
opr
);
basecfg
=
(
*
find
->
second
.
at
(
base_opr_fmt
))(
opr
);
}
VarNodeArray
new_inp
;
size_t
nr_inps
=
opr
->
input
().
size
();
...
...
@@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const {
bool
is_parameter
=
fmtcfg
.
valid
()
&&
fmtcfg
.
val
().
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
;
if
(
is_parameter
)
{
mgb_assert
(
basecfg
.
valid
());
from
=
basecfg
.
val
().
input_tensor_formats
[
i
];
}
// need relayout
if
(
from
!=
to
&&
!
new_var
->
shape
().
is_scalar
())
{
ReformatManager
::
ReformatImpl
reformat
;
...
...
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
浏览文件 @
a6230ba9
...
...
@@ -78,6 +78,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> {
}
};
template
<
>
struct
OprSingleInOutTensorFormatsDispatcherImpl
<
OprFormat
::
NCHW44
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NCHW44
;
bool
available
=
true
;
available
&=
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
config
.
input_dtypes
=
{
opr
->
input
(
0
)
->
dtype
().
enumv
()};
config
.
input_tensor_types
=
{
TensorType
::
FEATURE
};
config
.
output_dtypes
=
{
opr
->
output
(
0
)
->
dtype
().
enumv
()};
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc4
};
config
.
output_tensor_formats
=
{
TensorFormats
::
NCHWc4
};
if
(
!
available
)
return
None
;
return
config
;
}
};
#if !MEGDNN_DISABLE_FLOAT16
template
<
>
struct
OprSingleInOutTensorFormatsDispatcherImpl
<
OprFormat
::
NCHW88
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NCHW88
;
bool
available
=
true
;
available
&=
opr
->
input
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float16
;
config
.
input_dtypes
=
{
opr
->
input
(
0
)
->
dtype
().
enumv
()};
config
.
input_tensor_types
=
{
TensorType
::
FEATURE
};
config
.
output_dtypes
=
{
opr
->
output
(
0
)
->
dtype
().
enumv
()};
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc8
};
config
.
output_tensor_formats
=
{
TensorFormats
::
NCHWc8
};
if
(
!
available
)
return
None
;
return
config
;
}
};
#endif
template
<
>
struct
OprSingleInOutTensorFormatsDispatcherImpl
<
OprFormat
::
NCHW4
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
...
...
@@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> {
// setup tensor formats
if
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
DENSE
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NCHW
,
TensorFormats
::
NCHW
,
TensorFormats
::
NCHW
,
TensorFormats
::
KCRS
,
TensorFormats
::
NCHW
,
TensorFormats
::
NCHW
};
}
else
{
mgb_assert
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
GROUP
);
...
...
@@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> {
}
};
template
<
typename
Opr
>
struct
ConvTensorFormatsDispatcherImpl
<
Opr
,
OprFormat
::
NCHW44
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
const
auto
&
conv
=
opr
->
cast_final_safe
<
Opr
>
();
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NCHW44
;
bool
available
=
true
;
// setup dtypes
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
++
i
)
{
available
&=
opr
->
input
(
i
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
config
.
input_dtypes
.
emplace_back
(
opr
->
input
(
i
)
->
dtype
().
enumv
());
TensorType
tensor_type
=
i
==
1
?
TensorType
::
WEIGHT
:
TensorType
::
FEATURE
;
config
.
input_tensor_types
.
emplace_back
(
tensor_type
);
}
available
&=
opr
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
config
.
output_dtypes
.
emplace_back
(
opr
->
output
(
0
)
->
dtype
().
enumv
());
// setup tensor formats
if
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
DENSE
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc4
,
TensorFormats
::
KCRSc4k4
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc4
};
}
else
{
mgb_assert
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
GROUP
);
if
(
is_channel_wise_conv
<
Opr
>
(
opr
))
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc4
,
TensorFormats
::
C11RSc4
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc4
};
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc4
,
TensorFormats
::
GKCRSc4k4
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc4
};
}
}
config
.
output_tensor_formats
=
{
TensorFormats
::
NCHWc4
};
if
(
!
available
)
return
None
;
return
config
;
}
};
#if !MEGDNN_DISABLE_FLOAT16
template
<
typename
Opr
>
struct
ConvTensorFormatsDispatcherImpl
<
Opr
,
OprFormat
::
NCHW88
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
const
auto
&
conv
=
opr
->
cast_final_safe
<
Opr
>
();
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NCHW88
;
bool
available
=
true
;
// setup dtypes
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
++
i
)
{
available
&=
opr
->
input
(
i
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float16
;
config
.
input_dtypes
.
emplace_back
(
opr
->
input
(
i
)
->
dtype
().
enumv
());
TensorType
tensor_type
=
i
==
1
?
TensorType
::
WEIGHT
:
TensorType
::
FEATURE
;
config
.
input_tensor_types
.
emplace_back
(
tensor_type
);
}
available
&=
opr
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float16
;
config
.
output_dtypes
.
emplace_back
(
opr
->
output
(
0
)
->
dtype
().
enumv
());
// setup tensor formats
if
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
DENSE
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc8
,
TensorFormats
::
KCRSc8k8
,
TensorFormats
::
NCHWc8
,
TensorFormats
::
NCHWc8
};
}
else
{
mgb_assert
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
GROUP
);
if
(
is_channel_wise_conv
<
Opr
>
(
opr
))
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc8
,
TensorFormats
::
C11RSc8
,
TensorFormats
::
NCHWc8
,
TensorFormats
::
NCHWc8
};
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc8
,
TensorFormats
::
GKCRSc8k8
,
TensorFormats
::
NCHWc8
,
TensorFormats
::
NCHWc8
};
}
}
config
.
output_tensor_formats
=
{
TensorFormats
::
NCHWc8
};
if
(
!
available
)
return
None
;
return
config
;
}
};
#endif
template
<
typename
Opr
>
struct
ConvTensorFormatsDispatcherImpl
<
Opr
,
OprFormat
::
NCHW44_DOT
>
{
static
Maybe
<
OprTensorFormatsConfiguration
>
dispatch
(
const
OperatorNodeBase
*
opr
)
{
const
auto
&
conv
=
opr
->
cast_final_safe
<
Opr
>
();
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
OprFormat
::
NCHW44_DOT
;
bool
available
=
true
;
// setup dtypes
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
++
i
)
{
if
(
i
==
2
)
{
available
&=
opr
->
input
(
i
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS32
;
}
else
{
available
&=
opr
->
input
(
i
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
opr
->
input
(
i
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized8Asymm
;
}
config
.
input_dtypes
.
emplace_back
(
opr
->
input
(
i
)
->
dtype
().
enumv
());
TensorType
tensor_type
=
i
==
1
?
TensorType
::
WEIGHT
:
TensorType
::
FEATURE
;
config
.
input_tensor_types
.
emplace_back
(
tensor_type
);
}
available
&=
opr
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
||
opr
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Quantized8Asymm
;
config
.
output_dtypes
.
emplace_back
(
opr
->
output
(
0
)
->
dtype
().
enumv
());
// setup tensor formats
if
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
DENSE
)
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc4
,
TensorFormats
::
KCRSk4c4
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc4
};
}
else
{
mgb_assert
(
conv
.
param
().
sparse
==
Opr
::
Param
::
Sparse
::
GROUP
);
if
(
is_channel_wise_conv
<
Opr
>
(
opr
))
{
available
=
false
;
}
else
{
config
.
input_tensor_formats
=
{
TensorFormats
::
NCHWc4
,
TensorFormats
::
GKCRSk4c4
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc4
};
}
}
config
.
output_tensor_formats
=
{
TensorFormats
::
NCHWc4
};
if
(
!
available
)
return
None
;
return
config
;
}
};
template
<
>
struct
ConvTensorFormatsDispatcherImpl
<
opr
::
ConvolutionBackwardData
,
OprFormat
::
NCHW
>
{
using
Opr
=
opr
::
ConvolutionBackwardData
;
...
...
@@ -530,9 +711,19 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
CHWN4
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW32
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW64
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW44
);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW88
);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvBias
,
NCHW44_DOT
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW4
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW44
);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW88
);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionForward
,
NCHW44_DOT
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionBackwardData
,
NCHW
);
OPR_TENSOR_FORMATS_CONFIG_REG
(
ConvolutionBackwardData
,
NHWC
);
...
...
@@ -549,6 +740,16 @@ StaticData::StaticData() {
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
CHWN4
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW32
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW64
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW44
);
#if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
PoolingForward
,
NCHW88
);
#endif
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NCHW
);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NCHW44
);
#if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
(
ResizeForward
,
NCHW88
);
#endif
#undef OPR_TENSOR_FORMATS_CONFIG_REG
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
...
...
src/gopt/impl/global_layout_transform/profiler_impl.cpp
浏览文件 @
a6230ba9
...
...
@@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
case
TensorFormats
::
NCHW
:
return
OprFormat
::
NCHW
;
case
TensorFormats
::
NCHWc4
:
return
OprFormat
::
NCHW4
;
return
OprFormat
::
NCHW4
4
;
case
TensorFormats
::
NCHWc8
:
return
OprFormat
::
NCHW8
;
return
OprFormat
::
NCHW8
8
;
case
TensorFormats
::
NCHWc32
:
return
OprFormat
::
NCHW32
;
case
TensorFormats
::
NCHWc64
:
...
...
@@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons
skip
&=
problem
.
graph_partition
().
input
().
count
(
i
)
>
0
||
skip_oprs
.
count
(
i
->
owner_opr
())
>
0
;
}
skip
&=
skip_opr_types
.
count
(
opr
->
dyn_typeinfo
());
auto
find
=
format_aware_input_tensors
.
find
(
opr
->
dyn_typeinfo
());
skip
&=
find
==
format_aware_input_tensors
.
end
();
if
(
skip
)
skip_oprs
.
insert
(
opr
);
oprs
.
insert
(
opr
);
auto
find
=
format_aware_input_tensors
.
find
(
opr
->
dyn_typeinfo
());
if
(
find
==
format_aware_input_tensors
.
end
())
{
for
(
auto
&&
i
:
opr
->
input
())
{
if
(
!
cvprop
.
is_const
(
i
))
{
...
...
src/gopt/impl/global_layout_transform/reformat_manager.cpp
浏览文件 @
a6230ba9
...
...
@@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
input_shape
[
i
].
extent
()
==
Dimension
::
UNDETERMINED_EXTENT
)
{
in_channels
=
orig_var
->
shape
()[
i
]
*
input_shape
[
i
].
stride
();
input_channel_idx
=
i
;
// mgb_assert(input_shape[i].stride() == 1,
//
"unsupport weight format(got:%s)",
//
input_shape.to_string().c_str());
mgb_assert
(
input_shape
[
i
].
stride
()
==
1
,
"unsupport weight format(got:%s)"
,
input_shape
.
to_string
().
c_str
());
}
else
if
(
(
input_shape
[
i
].
name
()
==
Dimension
::
Name
::
K
||
input_shape
[
i
].
name
()
==
Dimension
::
Name
::
N
)
&&
...
...
@@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
input_shape
.
to_string
().
c_str
());
}
}
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise
* convolution does not have output channel dimension, so we mannually modify the
* out_channel_name, out_channel_idx to bypass the following assertion statements. */
bool
is_channelwise
=
key
.
input_format
==
TensorFormats
::
C11RS
;
if
(
is_channelwise
)
{
out_channel_name
=
Dimension
::
Name
::
K
;
out_channels
=
in_channels
;
output_channel_idx
=
input_channel_idx
;
}
mgb_assert
(
out_channel_name
==
Dimension
::
Name
::
K
||
out_channel_name
==
Dimension
::
Name
::
N
,
"invalid out channel(shp:%s)"
,
input_shape
.
to_string
().
c_str
());
mgb_assert
(
input_channel_idx
<
input_shape
.
ndim
&&
output_channel_idx
<
input_shape
.
ndim
,
(
input_channel_idx
<
input_shape
.
ndim
&&
output_channel_idx
<
input_shape
.
ndim
)
||
(
is_channelwise
&&
output_channel_idx
==
input_channel_idx
),
"invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)"
,
input_channel_idx
,
output_channel_idx
,
input_shape
.
to_string
().
c_str
());
size_t
in_channel_alignment
=
0
,
out_channel_alignment
=
0
;
...
...
@@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
out_channel_alignment
=
output_shape
[
i
].
stride
();
}
}
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise
* convolution does not have output channel dimension, so we mannually modify the
* out_channel_alignment to bypass the following assertion statements. */
if
(
is_channelwise
)
{
mgb_assert
(
out_channel_alignment
==
0
);
out_channel_alignment
=
1
;
}
mgb_assert
(
in_channel_alignment
>
0
&&
out_channel_alignment
>
0
,
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)"
,
...
...
src/gopt/impl/global_layout_transform/subgraph_extractor.cpp
浏览文件 @
a6230ba9
...
...
@@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract(
std
::
vector
<
GraphPartition
>
partitions
;
partitions
.
reserve
(
topo
.
size
());
ThinHashMap
<
OperatorNodeBase
*
,
GraphPartition
*>
roots
;
/// backward pass
for
(
const
auto
&
opr
:
reverse_adaptor
(
topo
))
{
if
(
m_opr_list
.
count
(
opr
->
dyn_typeinfo
())
==
0
)
{
for
(
const
auto
&
i
:
opr
->
input
())
{
if
(
m_opr_list
.
count
(
i
->
owner_opr
()
->
dyn_typeinfo
()))
{
auto
root
=
union_find
(
i
->
owner_opr
());
GraphPartition
*
partition
;
auto
find
=
roots
.
find
(
root
);
if
(
find
!=
roots
.
end
())
{
partition
=
find
->
second
;
partition
->
output
().
insert
(
i
);
}
}
}
}
else
{
if
(
m_opr_list
.
count
(
opr
->
dyn_typeinfo
())
>
0
)
{
auto
root
=
union_find
(
opr
);
auto
find
=
roots
.
find
(
root
);
GraphPartition
*
partition
=
nullptr
;
...
...
@@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract(
partition
->
input
().
insert
(
i
);
}
}
/// forward pass
for
(
auto
&&
opr
:
topo
)
{
if
(
m_opr_list
.
count
(
opr
->
dyn_typeinfo
())
==
0
)
{
for
(
const
auto
&
i
:
opr
->
input
())
{
if
(
m_opr_list
.
count
(
i
->
owner_opr
()
->
dyn_typeinfo
()))
{
auto
root
=
union_find
(
i
->
owner_opr
());
GraphPartition
*
partition
;
auto
find
=
roots
.
find
(
root
);
if
(
find
!=
roots
.
end
())
{
partition
=
find
->
second
;
partition
->
output
().
insert
(
i
);
}
}
}
}
}
for
(
auto
&&
partition
:
partitions
)
{
auto
&
all_oprs
=
partition
.
all_oprs
();
std
::
reverse
(
all_oprs
.
begin
(),
all_oprs
.
end
());
...
...
src/gopt/impl/global_layout_transform/utils.h
浏览文件 @
a6230ba9
...
...
@@ -29,6 +29,9 @@ static inline const char* opr_format_to_string(
cb
(
NCHW32
);
cb
(
NCHW64
);
cb
(
CHWN4
);
cb
(
NCHW44
);
cb
(
NCHW88
);
cb
(
NCHW44_DOT
);
default:
mgb_assert
(
false
,
"Invalid opr format(got:%u)"
,
...
...
@@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats(
return
TensorFormats
::
NCHWc64
;
case
OprFormat
::
CHWN4
:
return
TensorFormats
::
CHWNc4
;
case
OprFormat
::
NCHW88
:
return
TensorFormats
::
NCHWc8
;
case
OprFormat
::
NCHW44
:
return
TensorFormats
::
NCHWc4
;
default:
mgb_throw
(
AssertionError
,
"format(%s) is not supported"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录