Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fe93013a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fe93013a
编写于
9月 22, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): global layout transform support nchw_nchwxx hybrid mode
GitOrigin-RevId: 6d5b55d7fc67b536b25c2fe49457f6a74f9c62b5
上级
3d45d352
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
866 addition
and
272 deletion
+866
-272
dnn/src/common/convolution.cpp
dnn/src/common/convolution.cpp
+8
-9
src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp
...pl/global_layout_transform/dynamic_programming_solver.cpp
+33
-25
src/gopt/impl/global_layout_transform/layout_transform_context.cpp
...impl/global_layout_transform/layout_transform_context.cpp
+43
-31
src/gopt/impl/global_layout_transform/layout_transform_pass.cpp
...pt/impl/global_layout_transform/layout_transform_pass.cpp
+19
-12
src/gopt/impl/global_layout_transform/opr_format_modifier.h
src/gopt/impl/global_layout_transform/opr_format_modifier.h
+1
-1
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
...mpl/global_layout_transform/opr_tensor_formats_config.cpp
+281
-62
src/gopt/impl/global_layout_transform/profiler_cache.cpp
src/gopt/impl/global_layout_transform/profiler_cache.cpp
+1
-1
src/gopt/impl/global_layout_transform/profiler_impl.cpp
src/gopt/impl/global_layout_transform/profiler_impl.cpp
+30
-32
src/gopt/impl/global_layout_transform/profiling_based_solver.cpp
...t/impl/global_layout_transform/profiling_based_solver.cpp
+2
-1
src/gopt/impl/global_layout_transform/utils.h
src/gopt/impl/global_layout_transform/utils.h
+42
-1
src/gopt/include/megbrain/gopt/layout_transform_context.h
src/gopt/include/megbrain/gopt/layout_transform_context.h
+48
-14
src/gopt/include/megbrain/gopt/profiler.h
src/gopt/include/megbrain/gopt/profiler.h
+8
-5
src/gopt/include/megbrain/gopt/solver.h
src/gopt/include/megbrain/gopt/solver.h
+2
-1
src/gopt/test/embed_cache.py
src/gopt/test/embed_cache.py
+2
-1
src/gopt/test/layout_transform_pass.cpp
src/gopt/test/layout_transform_pass.cpp
+229
-66
src/gopt/test/network.cpp
src/gopt/test/network.cpp
+90
-0
src/gopt/test/network.h
src/gopt/test/network.h
+15
-1
src/gopt/test/profiler.cpp
src/gopt/test/profiler.cpp
+12
-9
未找到文件。
dnn/src/common/convolution.cpp
浏览文件 @
fe93013a
...
...
@@ -830,9 +830,9 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src
[
3
],
cflt
.
dilated_spatial
[
1
],
cflt
.
stride
[
1
],
cflt
.
padding
[
1
]);
dst
[
4
]
=
32
;
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW88
)
{
megdnn_assert
(
src
.
ndim
==
5
||
(
src
.
ndim
==
4
&&
src
[
1
]
<=
8
)
,
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu"
,
src
.
ndim
);
megdnn_assert
(
src
.
ndim
==
5
||
src
.
ndim
==
4
,
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu"
,
src
.
ndim
);
dst
.
ndim
=
5
;
dst
[
0
]
=
src
[
0
];
auto
oc
=
cflt
.
ocpg
*
cflt
.
group
;
...
...
@@ -850,12 +850,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
"%s icpg=%u group=%u"
,
errmsg
().
c_str
(),
cflt
.
icpg
,
cflt
.
group
);
}
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44_DOT
)
{
megdnn_assert
(
src
.
ndim
==
5
||
(
src
.
ndim
==
4
&&
src
[
1
]
<=
4
),
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu"
,
src
.
ndim
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44_DOT
)
{
megdnn_assert
(
src
.
ndim
==
5
||
src
.
ndim
==
4
,
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu"
,
src
.
ndim
);
dst
.
ndim
=
5
;
dst
[
0
]
=
src
[
0
];
auto
oc
=
cflt
.
ocpg
*
cflt
.
group
;
...
...
src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp
浏览文件 @
fe93013a
...
...
@@ -47,7 +47,7 @@ private:
struct
Value
{
OperatorNodeBase
*
opr
;
const
State
*
prev
;
OprFormat
opr_fmt
;
OprFormat
ConfigID
cfg_id
;
float
time
;
///! index in the topo order of the correspoding operator
size_t
opr_idx
;
...
...
@@ -87,14 +87,15 @@ private:
};
/*!
* \brief get the tensor formats configuration for the operator with
* particular op format \param[out] var2fmts hashmap that maps varnode to
* actual tensor formats of the op format configuration \param[in] opr given
* operator \param[in] opr_fmt given op format, an enum type argument which
* indicates the op format configuration. \param[in] ctx context
* particular op format
* \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op
* format configuration \param[in] opr given operator \param[in] opr_fmt given op
* format, an enum type argument which indicates the op format configuration.
* \param[in] ctx context
*/
TensorFormats
get_io_formats
(
ThinHashMap
<
VarNode
*
,
TensorFormats
>&
var2fmts
,
const
OperatorNodeBase
*
opr
,
OprFormat
opr_fmt
,
const
Context
&
ctx
);
OprFormat
ConfigID
config_id
,
const
Context
&
ctx
);
/*!
* \brief compute the distace of two states of the given varnode
* \param[in] from the source state
...
...
@@ -140,28 +141,35 @@ private:
TensorFormats
DynamicProgrammingSolver
::
Impl
::
get_io_formats
(
ThinHashMap
<
VarNode
*
,
TensorFormats
>&
var2fmts
,
const
OperatorNodeBase
*
opr
,
OprFormat
opr_fmt
,
const
Context
&
ctx
)
{
OprFormat
ConfigID
config_id
,
const
Context
&
ctx
)
{
auto
&&
rst
=
ctx
.
rst
;
auto
&&
opr_configs
=
ctx
.
opr_configs
;
auto
iter
=
opr_configs
.
find
(
opr
->
dyn_typeinfo
());
Maybe
<
OprTensorFormatsConfiguration
>
fmtcfg
=
None
;
Maybe
<
OprFormat
>
opr_fmt
=
None
;
if
(
iter
!=
opr_configs
.
end
())
{
fmtcfg
=
(
*
iter
->
second
.
at
(
opr_fmt
))(
opr
);
fmtcfg
=
(
*
iter
->
second
.
at
(
config_id
))(
opr
);
}
else
{
opr_fmt
=
OprTensorFormatsConfiguration
::
safe_cast_to_opr_format
(
config_id
);
}
TensorFormats
out_fmt
;
if
(
fmtcfg
.
valid
())
out_fmt
=
fmtcfg
.
val
().
output_tensor_formats
[
0
];
else
out_fmt
=
opr_format_to_tensor_formats
(
opr_fmt
);
else
{
mgb_assert
(
opr_fmt
.
valid
());
out_fmt
=
opr_format_to_tensor_formats
(
opr_fmt
.
val
());
}
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
++
i
)
{
auto
&&
var
=
opr
->
input
(
i
);
auto
iter
=
rst
.
var_record
.
find
(
var
);
if
(
iter
!=
rst
.
var_record
.
end
())
{
if
(
fmtcfg
.
valid
())
var2fmts
[
var
]
=
fmtcfg
.
val
().
input_tensor_formats
[
i
];
else
var2fmts
[
var
]
=
opr_format_to_tensor_formats
(
opr_fmt
);
else
{
mgb_assert
(
opr_fmt
.
valid
());
var2fmts
[
var
]
=
opr_format_to_tensor_formats
(
opr_fmt
.
val
());
}
}
}
return
out_fmt
;
...
...
@@ -342,13 +350,13 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
cuts
.
emplace_back
(
Cut
{});
auto
&
states
=
cuts
.
back
().
states
;
for
(
const
auto
&
record
:
records
)
{
auto
opr_fmt
=
record
.
first
;
auto
cfg_id
=
record
.
first
;
float
opr_time
=
record
.
second
;
ThinHashMap
<
VarNode
*
,
TensorFormats
>
ivar2fmts
;
auto
out_fmt
=
get_io_formats
(
ivar2fmts
,
opr
,
opr_fmt
,
ctx
);
auto
out_fmt
=
get_io_formats
(
ivar2fmts
,
opr
,
cfg_id
,
ctx
);
const
auto
&
edge
=
edges
[
cur
];
State
state
(
edge
.
size
(),
0
);
Value
value
{
opr
,
nullptr
,
opr_fmt
,
0.
f
,
cur
};
Value
value
{
opr
,
nullptr
,
cfg_id
,
0.
f
,
cur
};
float
ovar_time
=
0.
f
;
for
(
size_t
i
=
0
;
i
<
edge
.
size
();
++
i
)
{
auto
&&
var
=
edge
[
i
];
...
...
@@ -396,16 +404,16 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
const
auto
&
records
=
it
->
second
.
costs
;
StateTable
states
;
for
(
const
auto
&
record
:
records
)
{
auto
opr_fmt
=
record
.
first
;
auto
cfg_id
=
record
.
first
;
float
opr_time
=
record
.
second
;
ThinHashMap
<
VarNode
*
,
TensorFormats
>
ivar2fmts
;
auto
out_fmt
=
get_io_formats
(
ivar2fmts
,
opr
,
opr_fmt
,
ctx
);
auto
out_fmt
=
get_io_formats
(
ivar2fmts
,
opr
,
cfg_id
,
ctx
);
for
(
const
auto
&
kv
:
cuts
.
back
().
states
)
{
auto
&&
prev_state
=
kv
.
first
;
float
prev_time
=
kv
.
second
.
time
;
const
auto
&
edge
=
edges
[
cur
];
State
state
(
edge
.
size
(),
0
);
Value
value
{
opr
,
&
prev_state
,
opr_fmt
,
0.
f
,
cur
};
Value
value
{
opr
,
&
prev_state
,
cfg_id
,
0.
f
,
cur
};
float
ovar_time
=
0.
f
;
for
(
size_t
i
=
0
;
i
<
edge
.
size
();
++
i
)
{
auto
&&
var
=
edge
[
i
];
...
...
@@ -482,7 +490,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
/// backward pass to generate the solution
float
min_time
=
std
::
numeric_limits
<
float
>::
max
();
OperatorNodeBase
*
cur_opr
=
nullptr
;
OprFormat
min_fmt
=
OprFormat
::
NCHW
;
OprFormat
ConfigID
min_cfg
=
OprFormatConfigID
::
NCHW
;
const
State
*
pstate
=
nullptr
;
for
(
auto
&&
kv
:
cuts
.
back
().
states
)
{
auto
&&
v
=
kv
.
second
;
...
...
@@ -490,7 +498,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
cur_opr
=
v
.
opr
;
pstate
=
v
.
prev
;
min_time
=
v
.
time
;
min_
fmt
=
v
.
opr_fmt
;
min_
cfg
=
v
.
cfg_id
;
///! just to check the tensor formats of the output varnode
auto
&&
k
=
kv
.
first
;
size_t
opr_idx
=
v
.
opr_idx
;
...
...
@@ -505,10 +513,10 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
}
mgb_assert
(
cur_opr
!=
nullptr
);
mgb_log_debug
(
"opr:%s;
format:%s;time:%f"
,
cur_opr
->
cname
(),
opr_format_to_string
(
min_fmt
),
"opr:%s;
config:%s;time:%f"
,
cur_opr
->
cname
(),
config_id_to_string
(
min_cfg
),
min_time
);
solution
.
insert
({
cur_opr
,
min_
fmt
});
solution
.
insert
({
cur_opr
,
min_
cfg
});
cur
=
cuts
.
size
()
-
2
;
while
(
pstate
)
{
auto
val
=
cuts
[
cur
].
states
[
*
pstate
];
...
...
@@ -522,9 +530,9 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
}
}
mgb_log_debug
(
"opr:%s;
format
:%s;time:%f"
,
val
.
opr
->
cname
(),
opr_format_to_string
(
val
.
opr_fmt
),
val
.
time
);
solution
.
insert
({
val
.
opr
,
val
.
opr_fmt
});
"opr:%s;
cofig
:%s;time:%f"
,
val
.
opr
->
cname
(),
config_id_to_string
(
val
.
cfg_id
),
val
.
time
);
solution
.
insert
({
val
.
opr
,
val
.
cfg_id
});
pstate
=
val
.
prev
;
cur
--
;
}
...
...
src/gopt/impl/global_layout_transform/layout_transform_context.cpp
浏览文件 @
fe93013a
...
...
@@ -22,6 +22,7 @@ using namespace gopt;
namespace
{
using
OprFormat
=
LayoutTransformContext
::
OprFormat
;
using
OprFormatConfigID
=
LayoutTransformContext
::
OprFormatConfigID
;
using
OprList
=
LayoutTransformContext
::
OprList
;
using
Attribute
=
LayoutTransformContext
::
Attribute
;
using
Target
=
LayoutTransformContext
::
Target
;
...
...
@@ -43,7 +44,7 @@ const char* target_to_string(Target target) {
}
std
::
unique_ptr
<
LayoutTransformContext
>
make_cuda_ctx
(
OprFormat
base_opr_format
,
TensorFormats
base_tensor_format
)
{
OprFormat
ConfigID
base_config_id
,
TensorFormats
base_tensor_format
)
{
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ConvolutionForward
::
typeinfo
(),
...
...
@@ -58,34 +59,38 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NHWC
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc32
,
TensorFormats
::
NCHWc64
,
TensorFormats
::
CHWNc4
};
Attribute
attribute
=
{
base_
opr_format
,
base_tensor_format
,
Target
::
CUDA
,
base_
config_id
,
base_tensor_format
,
Target
::
CUDA
,
LayoutTransformContext
::
ReformatAttribute
::
AUTO_PADDING_NHWC
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
{
OprFormatConfigID
::
NCHW
,
OprFormatConfigID
::
NHWC
,
OprFormatConfigID
::
NCHW4_NCHW32
,
OprFormatConfigID
::
NCHW32_NCHW4
,
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NCHW32
,
OprFormatConfigID
::
NCHW64
,
OprFormatConfigID
::
CHWN4
})
.
add_opr_config
(
opr
::
ConvolutionForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
{
OprFormat
ConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW4
})
.
add_opr_config
(
opr
::
ConvolutionBackwardData
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
,
OprFormat
::
NHWC
})
{
OprFormatConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NHWC
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
{
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NCHW32
,
OprFormatConfigID
::
NHWC
,
OprFormatConfigID
::
NCHW64
,
OprFormatConfigID
::
CHWN4
})
.
add_opr_config
(
opr
::
WarpPerspectiveForward
::
typeinfo
(),
{
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW64
});
{
OprFormatConfigID
::
NHWC
,
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NCHW64
});
return
ctx
;
}
std
::
unique_ptr
<
LayoutTransformContext
>
make_arm_ctx
(
OprFormat
base_opr_format
,
TensorFormats
base_tensor_format
)
{
OprFormat
ConfigID
base_config_id
,
TensorFormats
base_tensor_format
)
{
OprList
opr_list
=
{
opr
::
ConvBiasForward
::
typeinfo
(),
opr
::
ConvolutionForward
::
typeinfo
(),
...
...
@@ -101,57 +106,64 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx(
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NCHWc4
,
DNN_INC_FLOAT16
(
TensorFormats
::
NCHWc8
)};
Attribute
attribute
=
{
base_
opr_format
,
base_tensor_format
,
Target
::
ARM
};
Attribute
attribute
=
{
base_
config_id
,
base_tensor_format
,
Target
::
ARM
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
),
OprFormat
::
NCHW44_DOT
})
{
OprFormatConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW44
,
OprFormatConfigID
::
NCHW44_HYBRID
,
DNN_INC_FLOAT16
(
OprFormatConfigID
::
NCHW88
),
DNN_INC_FLOAT16
(
OprFormatConfigID
::
NCHW88_HYBRID
),
OprFormatConfigID
::
NCHW44_DOT
,
OprFormatConfigID
::
NCHW44_DOT_HYBRID
})
.
add_opr_config
(
opr
::
ConvolutionForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
),
OprFormat
::
NCHW44_DOT
})
{
OprFormatConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW44
,
OprFormatConfigID
::
NCHW44_HYBRID
,
DNN_INC_FLOAT16
(
OprFormatConfigID
::
NCHW88
),
DNN_INC_FLOAT16
(
OprFormatConfigID
::
NCHW88_HYBRID
),
OprFormatConfigID
::
NCHW44_DOT
,
OprFormatConfigID
::
NCHW44_DOT_HYBRID
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
)})
{
OprFormat
ConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
ConfigID
::
NCHW88
)})
.
add_opr_config
(
opr
::
ResizeForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
::
NCHW88
)});
{
OprFormat
ConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW44
,
DNN_INC_FLOAT16
(
OprFormat
ConfigID
::
NCHW88
)});
return
ctx
;
}
}
// namespace
/* ================= LayoutTransformContext ==================*/
LayoutTransformContext
&
LayoutTransformContext
::
add_opr_config
(
Typeinfo
*
opr
,
OprFormat
opr_format
)
{
Typeinfo
*
opr
,
OprFormat
ConfigID
config_id
)
{
auto
&
dispatchers
=
m_opr_configs
[
opr
];
dispatchers
[
opr_format
]
=
dispatchers
[
config_id
]
=
OprTensorFormatsConfiguration
::
find_dispatcher_by_type_format
(
opr
,
opr_format
);
opr
,
config_id
);
return
*
this
;
}
LayoutTransformContext
&
LayoutTransformContext
::
add_opr_config
(
Typeinfo
*
opr
,
SmallVector
<
OprFormat
>
opr_format
s
)
{
Typeinfo
*
opr
,
SmallVector
<
OprFormat
ConfigID
>
config_id
s
)
{
auto
&
dispatchers
=
m_opr_configs
[
opr
];
for
(
auto
opr_fmt
:
opr_formats
)
{
dispatchers
[
opr_fmt
]
=
OprTensorFormatsConfiguration
::
find_dispatcher_by_type_format
(
opr
,
opr_fmt
);
for
(
auto
cfg
:
config_ids
)
{
dispatchers
[
cfg
]
=
OprTensorFormatsConfiguration
::
find_dispatcher_by_type_format
(
opr
,
cfg
);
}
return
*
this
;
}
std
::
unique_ptr
<
LayoutTransformContext
>
LayoutTransformContext
::
make
(
Target
target
,
OprFormat
base_opr_format
,
TensorFormats
base_tensor_format
)
{
Target
target
,
OprFormatConfigID
base_config_id
,
TensorFormats
base_tensor_format
)
{
switch
(
target
)
{
case
Target
::
CUDA
:
return
make_cuda_ctx
(
base_
opr_format
,
base_tensor_format
);
return
make_cuda_ctx
(
base_
config_id
,
base_tensor_format
);
case
Target
::
ARM
:
return
make_arm_ctx
(
base_
opr_format
,
base_tensor_format
);
return
make_arm_ctx
(
base_
config_id
,
base_tensor_format
);
default:
mgb_assert
(
false
,
"unsupported target %s
\n
"
,
target_to_string
(
target
));
}
...
...
src/gopt/impl/global_layout_transform/layout_transform_pass.cpp
浏览文件 @
fe93013a
...
...
@@ -43,6 +43,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto
partitions
=
extractor
.
extract
(
opt
.
graph
().
endpoint_vars
());
using
Solution
=
SolverBase
::
Solution
;
using
OprFormat
=
SolverBase
::
OprFormat
;
Solution
solution
;
ThinHashSet
<
VarNode
*>
endpoint_vars
;
for
(
auto
&&
partition
:
partitions
)
{
...
...
@@ -60,7 +61,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto
&&
opr_configs
=
m_ctx
->
opr_configs
();
auto
&&
base_fmt
=
m_ctx
->
attribute
().
base_tensor_formats
;
auto
&&
base_
opr_fmt
=
m_ctx
->
attribute
().
base_opr_format
;
auto
&&
base_
cfg_id
=
m_ctx
->
attribute
().
base_config_id
;
auto
&&
reformat_attribute
=
m_ctx
->
attribute
().
reformat_attribute
;
ThinHashMap
<
VarNode
*
,
TensorFormats
>
var2fmts
;
static
ThinHashSet
<
Typeinfo
*>
format_aware_oprs
=
{
...
...
@@ -69,18 +70,25 @@ void LayoutTransformPass::apply(OptState& opt) const {
#undef cb
};
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
opr_configs
,
&
base_fmt
,
&
base_
opr_fmt
,
&
reformat_attribute
,
auto
on_opr
=
[
&
opr_configs
,
&
base_fmt
,
&
base_
cfg_id
,
&
reformat_attribute
,
&
rewriter
,
&
solution
,
&
var2fmts
,
&
endpoint_vars
](
OperatorNodeBase
*
opr
)
{
auto
it
=
solution
.
find
(
opr
);
if
(
it
!=
solution
.
end
())
{
auto
opr_fmt
=
it
->
second
;
auto
cfg_id
=
it
->
second
;
auto
find
=
opr_configs
.
find
(
opr
->
dyn_typeinfo
());
Maybe
<
OprTensorFormatsConfiguration
>
fmtcfg
=
None
;
Maybe
<
OprTensorFormatsConfiguration
>
basecfg
=
None
;
Maybe
<
OprFormat
>
opr_fmt
=
None
;
if
(
find
!=
opr_configs
.
end
())
{
fmtcfg
=
(
*
find
->
second
.
at
(
opr_fmt
))(
opr
);
basecfg
=
(
*
find
->
second
.
at
(
base_opr_fmt
))(
opr
);
fmtcfg
=
(
*
find
->
second
.
at
(
cfg_id
))(
opr
);
auto
_
=
OprTensorFormatsConfiguration
::
find_dispatcher_by_type_format
(
opr
->
dyn_typeinfo
(),
base_cfg_id
);
basecfg
=
(
*
_
)(
opr
);
opr_fmt
=
fmtcfg
.
val
().
opr_format
;
}
else
{
opr_fmt
=
OprTensorFormatsConfiguration
::
safe_cast_to_opr_format
(
cfg_id
);
}
VarNodeArray
new_inp
;
size_t
nr_inps
=
opr
->
input
().
size
();
...
...
@@ -89,7 +97,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
nr_inps
=
std
::
min
(
fmtcfg
.
val
().
input_tensor_formats
.
size
(),
nr_inps
);
out_fmt
=
fmtcfg
.
val
().
output_tensor_formats
[
0
];
}
else
{
out_fmt
=
opr_format_to_tensor_formats
(
opr_fmt
);
out_fmt
=
opr_format_to_tensor_formats
(
opr_fmt
.
val
()
);
}
new_inp
.
resize
(
nr_inps
);
for
(
size_t
i
=
0
;
i
<
nr_inps
;
++
i
)
{
...
...
@@ -103,7 +111,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
from
=
find
->
second
;
}
auto
to
=
fmtcfg
.
valid
()
?
fmtcfg
.
val
().
input_tensor_formats
[
i
]
:
opr_format_to_tensor_formats
(
opr_fmt
);
:
opr_format_to_tensor_formats
(
opr_fmt
.
val
()
);
bool
is_parameter
=
fmtcfg
.
valid
()
&&
fmtcfg
.
val
().
input_tensor_types
[
i
]
==
TensorType
::
WEIGHT
;
...
...
@@ -119,7 +127,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
var
->
dtype
().
enumv
()};
if
(
is_parameter
)
{
auto
aligned_desc
=
ReformatManager
::
make_aligned_desc
(
base_fmt
,
out_fmt
);
ReformatManager
::
make_aligned_desc
(
from
,
out_fmt
);
reformat
=
ReformatManager
::
instance
()
.
auto_aligned_reformat_weight
(
var
,
key
,
aligned_desc
);
...
...
@@ -134,7 +142,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
}
VarNode
*
new_out
;
if
(
format_aware_oprs
.
count
(
opr
->
dyn_typeinfo
())
>
0
)
{
new_out
=
intl
::
modify_opr_format
(
opr_fmt
,
new_inp
,
opr
);
new_out
=
intl
::
modify_opr_format
(
opr_fmt
.
val
()
,
new_inp
,
opr
);
}
else
{
new_out
=
serialization
::
copy_opr_shallow
(
*
opr
,
new_inp
,
opr
->
config
())
->
output
(
0
);
...
...
@@ -170,9 +178,8 @@ void LayoutTransformPass::apply(OptState& opt) const {
ovar
,
new_ovar
,
mgb_cstr_log
(
ssprintf
(
"replace opr(%s) to new opr "
"format(%s)"
,
opr
->
cname
(),
opr_format_to_string
(
opr_fmt
))
"format config(%s)"
,
opr
->
cname
(),
config_id_to_string
(
cfg_id
))
.
c_str
()));
}
}
else
{
...
...
src/gopt/impl/global_layout_transform/opr_format_modifier.h
浏览文件 @
fe93013a
...
...
@@ -24,7 +24,7 @@ namespace intl {
bool
has_available_algo
(
const
VarNodeArray
&
i
,
const
cg
::
OperatorNodeBase
*
opr
);
VarNode
*
modify_opr_format
(
opr
::
Conv
Bias
::
Param
::
Format
opr_format
,
const
VarNodeArray
&
i
,
opr
::
Conv
olution
::
Param
::
Format
opr_format
,
const
VarNodeArray
&
i
,
const
cg
::
OperatorNodeBase
*
opr
);
}
// namespace intl
...
...
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
浏览文件 @
fe93013a
此差异已折叠。
点击以展开。
src/gopt/impl/global_layout_transform/profiler_cache.cpp
浏览文件 @
fe93013a
...
...
@@ -64,7 +64,7 @@ void ProfilerCache::Key::build_blob_from_opr() {
// serialize opr_format
m_blob_storage
.
append
(
std
::
to_string
(
static_cast
<
uint32_t
>
(
m_key_impl
.
opr_key
.
opr_format
)));
std
::
to_string
(
static_cast
<
uint32_t
>
(
m_key_impl
.
opr_key
.
config_id
)));
// serialize extra_attribute
m_blob_storage
.
append
(
...
...
src/gopt/impl/global_layout_transform/profiler_impl.cpp
浏览文件 @
fe93013a
...
...
@@ -29,30 +29,6 @@ using namespace gopt;
using
ReformatKey
=
ReformatManager
::
ReformatKey
;
namespace
{
using
OprFormat
=
Problem
::
OprFormat
;
OprFormat
tensor_formats_to_opr_format
(
TensorFormats
tensor_format
)
{
switch
(
tensor_format
)
{
case
TensorFormats
::
NCHW
:
return
OprFormat
::
NCHW
;
case
TensorFormats
::
NCHWc4
:
return
OprFormat
::
NCHW44
;
case
TensorFormats
::
NCHWc8
:
return
OprFormat
::
NCHW88
;
case
TensorFormats
::
NCHWc32
:
return
OprFormat
::
NCHW32
;
case
TensorFormats
::
NCHWc64
:
return
OprFormat
::
NCHW64
;
case
TensorFormats
::
NHWC
:
return
OprFormat
::
NHWC
;
case
TensorFormats
::
CHWNc4
:
return
OprFormat
::
CHWN4
;
default:
mgb_throw
(
MegBrainError
,
"tensor format(%u) is not supported"
,
static_cast
<
uint32_t
>
(
tensor_format
));
}
}
class
GraphPartitionProfiler
final
:
public
PluginBase
{
using
CompNodeEventPtr
=
std
::
unique_ptr
<
CompNode
::
Event
>
;
...
...
@@ -214,8 +190,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
record
.
opr
=
opr
;
auto
&
costs
=
record
.
costs
;
for
(
auto
&&
f
:
available_tensor_formats
)
{
auto
opr_format
=
tensor_formats_to_opr_format
(
f
);
costs
[
opr_format
]
=
profile_operator
(
opr
,
base_format
,
f
,
extra_attribute
);
auto
config_id
=
tensor_formats_to_config_id
(
f
);
costs
[
config_id
]
=
profile_operator
(
opr
,
base_format
,
f
,
extra_attribute
);
}
return
record
;
}
...
...
@@ -261,7 +237,7 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
record
.
opr
=
opr
;
auto
&
costs
=
record
.
costs
;
for
(
auto
&&
i
:
available_configs
)
{
costs
[
i
.
opr_format
]
=
profile_operator
(
opr
,
base_config
,
i
,
extra_attribute
);
costs
[
i
.
config_id
]
=
profile_operator
(
opr
,
base_config
,
i
,
extra_attribute
);
}
return
record
;
}
...
...
@@ -316,7 +292,6 @@ float ProfilerImpl::profile_operator(
new_inps
[
i
]
=
imm
.
node
();
}
VarNode
*
y
=
mgb
::
gopt
::
intl
::
modify_opr_format
(
config
.
opr_format
,
new_inps
,
opr
);
#if 0
static
const
ThinHashSet
<
Typeinfo
*>
multi_algo_oprs
=
{
opr
::
Convolution
::
typeinfo
(),
opr
::
ConvBiasForward
::
typeinfo
(),
...
...
@@ -326,7 +301,6 @@ float ProfilerImpl::profile_operator(
if
(
multi_algo_oprs
.
count
(
opr
->
dyn_typeinfo
())
&&
!
mgb
::
gopt
::
intl
::
has_available_algo
(
new_inps
,
y
->
owner_opr
()))
return
PROFILE_TIME_OUT
;
#endif
if
(
!
m_opr_filter
(
opr
,
y
->
owner_opr
()))
return
PROFILE_TIME_OUT
;
auto
mark
=
MarkInputContiguous
::
make
(
SymbolVar
(
y
));
...
...
@@ -494,6 +468,30 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons
return
profiling_result
;
}
ProfilerImpl
::
OprFormatConfigID
ProfilerImpl
::
tensor_formats_to_config_id
(
TensorFormats
tensor_format
)
const
{
switch
(
tensor_format
)
{
case
TensorFormats
::
NCHW
:
return
OprFormatConfigID
::
NCHW
;
case
TensorFormats
::
NCHWc4
:
return
OprFormatConfigID
::
NCHW4
;
case
TensorFormats
::
NCHWc8
:
return
OprFormatConfigID
::
NCHW8
;
case
TensorFormats
::
NCHWc32
:
return
OprFormatConfigID
::
NCHW32
;
case
TensorFormats
::
NCHWc64
:
return
OprFormatConfigID
::
NCHW64
;
case
TensorFormats
::
NHWC
:
return
OprFormatConfigID
::
NHWC
;
case
TensorFormats
::
CHWNc4
:
return
OprFormatConfigID
::
CHWN4
;
default:
mgb_throw
(
MegBrainError
,
"tensor format(%u) is not supported"
,
static_cast
<
uint32_t
>
(
tensor_format
));
}
}
/* ================== ProfilerBase =================*/
std
::
string
ProfilerBase
::
OperatorNodeRecord
::
to_string
()
const
{
auto
str
=
ssprintf
(
...
...
@@ -508,7 +506,7 @@ std::string ProfilerBase::OperatorNodeRecord::to_string() const {
opr
->
output
(
0
)
->
shape
().
to_string
().
c_str
());
for
(
auto
&&
cpair
:
costs
)
{
str
+=
ssprintf
(
"
\t
format: %s; cost:%f"
,
opr_format
_to_string
(
cpair
.
first
),
"
\t
config: %s; cost:%f"
,
config_id
_to_string
(
cpair
.
first
),
cpair
.
second
);
}
return
str
;
...
...
@@ -557,7 +555,7 @@ float CachedProfiler::profile_operator(
const
OperatorNodeBase
*
opr
,
TensorFormats
base_format
,
TensorFormats
tensor_format
,
ReformatAttribute
extra_attribute
)
const
{
ProfilerCache
::
Key
key
{
opr
,
tensor_formats_to_
opr_format
(
tensor_format
),
extra_attribute
};
opr
,
tensor_formats_to_
config_id
(
tensor_format
),
extra_attribute
};
auto
ret
=
ProfilerCache
::
inst
().
get
(
key
);
if
(
ret
.
valid
())
return
ret
.
val
();
...
...
@@ -571,7 +569,7 @@ float CachedProfiler::profile_operator(
const
OperatorNodeBase
*
opr
,
const
OprTensorFormatsConfiguration
&
base_config
,
const
OprTensorFormatsConfiguration
&
config
,
ReformatAttribute
extra_attribute
)
const
{
ProfilerCache
::
Key
key
{
opr
,
config
.
opr_format
,
extra_attribute
};
ProfilerCache
::
Key
key
{
opr
,
config
.
config_id
,
extra_attribute
};
auto
ret
=
ProfilerCache
::
inst
().
get
(
key
);
if
(
ret
.
valid
())
return
ret
.
val
();
...
...
src/gopt/impl/global_layout_transform/profiling_based_solver.cpp
浏览文件 @
fe93013a
...
...
@@ -48,7 +48,8 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile
};
m_problem_filter
=
[](
const
Problem
&
problem
)
{
auto
&&
base_opr_format
=
problem
.
attribute
().
base_opr_format
;
auto
&&
base_opr_format
=
OprTensorFormatsConfiguration
::
safe_cast_to_opr_format
(
problem
.
attribute
().
base_config_id
);
bool
has_format_aware_opr
=
false
;
for
(
auto
&&
opr
:
problem
.
graph_partition
().
all_oprs
())
{
auto
iter
=
format_aware_opr_validators
.
find
(
opr
->
dyn_typeinfo
());
...
...
src/gopt/impl/global_layout_transform/utils.h
浏览文件 @
fe93013a
...
...
@@ -40,6 +40,37 @@ static inline const char* opr_format_to_string(
#undef cb
}
static
inline
const
char
*
config_id_to_string
(
OprTensorFormatsConfiguration
::
OprFormatConfigID
config_id
)
{
using
OprFormatConfigID
=
OprTensorFormatsConfiguration
::
OprFormatConfigID
;
#define cb(_fmt) \
case OprFormatConfigID::_fmt: \
return #_fmt
switch
(
config_id
)
{
cb
(
NCHW
);
cb
(
NHWC
);
cb
(
NCHW4
);
cb
(
NCHW8
);
cb
(
NCHW4_NCHW32
);
cb
(
NCHW4_NCHW
);
cb
(
NCHW32
);
cb
(
NCHW32_NCHW4
);
cb
(
NCHW64
);
cb
(
CHWN4
);
cb
(
NCHW44
);
cb
(
NCHW44_HYBRID
);
cb
(
NCHW88
);
cb
(
NCHW88_HYBRID
);
cb
(
NCHW44_DOT
);
cb
(
NCHW44_DOT_HYBRID
);
default:
mgb_assert
(
false
,
"Invalid config id(got:%u)"
,
static_cast
<
uint32_t
>
(
config_id
));
}
#undef cb
}
static
inline
TensorFormats
opr_format_to_tensor_formats
(
OprTensorFormatsConfiguration
::
OprFormat
opr_format
)
{
using
OprFormat
=
OprTensorFormatsConfiguration
::
OprFormat
;
...
...
@@ -60,6 +91,8 @@ static inline TensorFormats opr_format_to_tensor_formats(
return
TensorFormats
::
NCHWc8
;
case
OprFormat
::
NCHW44
:
return
TensorFormats
::
NCHWc4
;
case
OprFormat
::
NCHW8
:
return
TensorFormats
::
NCHWc8
;
default:
mgb_throw
(
AssertionError
,
"format(%s) is not supported"
,
...
...
@@ -124,9 +157,17 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape(
return
{{
"G"
},
{
"K"
},
{
"C"
},
{
"R"
},
{
"S"
}};
case
TensorFormats
::
C11RS
:
return
{{
"C"
},
{
"C%1"
},
{
"C%1"
},
{
"R"
},
{
"S"
}};
case
TensorFormats
::
KRSC
:
return
{{
"K"
},
{
"R"
},
{
"S"
},
{
"C"
}};
case
TensorFormats
::
KCRSc32
:
return
{{
"K"
},
{
"C//32"
},
{
"R"
},
{
"S"
},
{
"C%32"
}};
case
TensorFormats
::
KCRSc64
:
return
{{
"K"
},
{
"C//64"
},
{
"R"
},
{
"S"
},
{
"C%64"
}};
case
TensorFormats
::
CRSKc4
:
return
{{
"C//4"
},
{
"R"
},
{
"S"
},
{
"K"
},
{
"C%4"
}};
default:
mgb_throw
(
Assertio
nError
,
"invalid tensor formats(%u)"
,
MegBrai
nError
,
"invalid tensor formats(%u)"
,
static_cast
<
uint32_t
>
(
format
));
}
}
...
...
src/gopt/include/megbrain/gopt/layout_transform_context.h
浏览文件 @
fe93013a
...
...
@@ -26,19 +26,48 @@ namespace gopt {
* configuration of the opr format
*/
struct
OprTensorFormatsConfiguration
{
using
OprFormat
=
opr
::
ConvBias
::
Param
::
Format
;
using
OprFormat
=
opr
::
Convolution
::
Param
::
Format
;
static
constexpr
uint32_t
FORMAT_NR_MEMBER
=
opr
::
Convolution
::
Param
::
FORMAT_NR_MEMBER
;
enum
class
OprFormatConfigID
:
uint32_t
{
#define cb(fmt_) fmt_ = static_cast<uint32_t>(OprFormat::fmt_)
cb
(
NCHW
),
cb
(
NHWC
),
cb
(
NHWCD4
),
cb
(
NCHW4
),
cb
(
NCHW8
),
cb
(
NCHW32
),
cb
(
NCHW88
),
cb
(
NCHW44
),
cb
(
NCHW44_DOT
),
cb
(
NCHW4_NCHW32
),
cb
(
NCHW32_NCHW4
),
cb
(
NCHW4_NCHW
),
cb
(
NCHW4_NHWC
),
cb
(
CHWN4
),
cb
(
NCHW64
),
NCHW44_HYBRID
=
FORMAT_NR_MEMBER
,
NCHW88_HYBRID
=
FORMAT_NR_MEMBER
+
1
,
NCHW44_DOT_HYBRID
=
FORMAT_NR_MEMBER
+
2
,
};
#undef cb
using
OprTensorFormatsDispatcher
=
thin_function
<
Maybe
<
OprTensorFormatsConfiguration
>
(
const
cg
::
OperatorNodeBase
*
)
>
;
Typeinfo
*
typeinfo
;
OprFormat
opr_format
;
OprFormatConfigID
config_id
;
SmallVector
<
DTypeEnum
>
input_dtypes
;
SmallVector
<
DTypeEnum
>
output_dtypes
;
SmallVector
<
TensorFormats
>
input_tensor_formats
;
SmallVector
<
TensorType
>
input_tensor_types
;
SmallVector
<
TensorFormats
>
output_tensor_formats
;
static
OprTensorFormatsDispatcher
*
find_dispatcher_by_type_format
(
Typeinfo
*
type
,
OprFormat
opr_format
);
Typeinfo
*
type
,
OprFormatConfigID
config_id
);
static
OprFormat
safe_cast_to_opr_format
(
OprFormatConfigID
config_id
)
{
mgb_assert
(
static_cast
<
uint32_t
>
(
config_id
)
<
FORMAT_NR_MEMBER
);
return
static_cast
<
OprFormat
>
(
static_cast
<
uint32_t
>
(
config_id
));
}
};
/*!
...
...
@@ -48,14 +77,15 @@ class LayoutTransformContext {
public:
using
OprList
=
SubGraphExtractor
::
OprList
;
using
OprFormat
=
OprTensorFormatsConfiguration
::
OprFormat
;
using
OprFormatConfigID
=
OprTensorFormatsConfiguration
::
OprFormatConfigID
;
using
OprTensorFormatsDispatcher
=
OprTensorFormatsConfiguration
::
OprTensorFormatsDispatcher
;
using
OprConfigTrait
=
T
hinHashMap
<
Typeinfo
*
,
ThinHashMap
<
OprFormat
,
OprTensorFormatsDispatcher
*>>
;
using
OprConfigTrait
=
ThinHashMap
<
T
ypeinfo
*
,
ThinHashMap
<
OprFormatConfigID
,
OprTensorFormatsDispatcher
*>>
;
using
Target
=
GraphTuningOptions
::
Target
;
using
ReformatAttribute
=
ReformatManager
::
ReformatKey
::
Attribute
;
struct
Attribute
{
OprFormat
base_opr_format
;
/// the base opr format indicates that the
OprFormat
ConfigID
base_config_id
;
/// the base opr format indicates that the
/// network to be optimized is constructed
/// in the base opr format, i.e. all the
/// format aware operators (conv, conv_bias,
...
...
@@ -97,21 +127,22 @@ public:
/*!
* \brief add an op format configuration for a particular operator type
* \param opr runtime typeinfo of operator
* \param
opr_format op format configuration which to be enabled in the
* layout transform problem
* \param
config_id op format configuration id which is going to be enabled
*
in the
layout transform problem
*/
LayoutTransformContext
&
add_opr_config
(
Typeinfo
*
opr
,
OprFormat
opr_format
);
LayoutTransformContext
&
add_opr_config
(
Typeinfo
*
opr
,
OprFormat
ConfigID
config_id
);
/*!
* \brief add a vector of op format configurations for a particular operator
* type
* \param opr runtime typeinfo of operator
* \param
opr_format op format configuration which to be enabled in the
* layout transform problem
* \param
config_ids ids of op format configurations which are enabled in
*
the
layout transform problem
*/
LayoutTransformContext
&
add_opr_config
(
Typeinfo
*
opr
,
SmallVector
<
OprFormat
>
opr_format
s
);
Typeinfo
*
opr
,
SmallVector
<
OprFormat
ConfigID
>
config_id
s
);
static
std
::
unique_ptr
<
LayoutTransformContext
>
make
(
Target
target
=
Target
::
UNSPEC
,
OprFormat
base_opr_format
=
OprFormat
::
NCHW
,
Target
target
=
Target
::
UNSPEC
,
OprFormatConfigID
base_config_id
=
OprFormatConfigID
::
NCHW
,
TensorFormats
base_tensor_format
=
TensorFormats
::
NCHW
);
private:
...
...
@@ -130,6 +161,7 @@ private:
class
Problem
{
public:
using
OprFormat
=
OprTensorFormatsConfiguration
::
OprFormat
;
using
OprFormatConfigID
=
OprTensorFormatsConfiguration
::
OprFormatConfigID
;
using
OprTensorFormatsDispatcher
=
OprTensorFormatsConfiguration
::
OprTensorFormatsDispatcher
;
using
OprConfigTrait
=
LayoutTransformContext
::
OprConfigTrait
;
...
...
@@ -152,13 +184,15 @@ public:
*/
OprTensorFormatsConfiguration
base_config
(
const
cg
::
OperatorNodeBase
*
opr
)
const
{
auto
_
=
OprTensorFormatsConfiguration
::
find_dispatcher_by_type_format
(
opr
->
dyn_typeinfo
(),
m_ctx
.
attribute
().
base_
opr_format
);
opr
->
dyn_typeinfo
(),
m_ctx
.
attribute
().
base_
config_id
);
auto
rst
=
(
*
_
)(
opr
);
if
(
rst
.
valid
())
return
rst
.
val
();
OprTensorFormatsConfiguration
config
;
config
.
typeinfo
=
opr
->
dyn_typeinfo
();
config
.
opr_format
=
m_ctx
.
attribute
().
base_opr_format
;
config
.
config_id
=
m_ctx
.
attribute
().
base_config_id
;
config
.
opr_format
=
OprTensorFormatsConfiguration
::
safe_cast_to_opr_format
(
config
.
config_id
);
for
(
const
auto
&
i
:
opr
->
input
())
{
config
.
input_dtypes
.
emplace_back
(
i
->
dtype
().
enumv
());
config
.
input_tensor_formats
.
emplace_back
(
base_format
());
...
...
src/gopt/include/megbrain/gopt/profiler.h
浏览文件 @
fe93013a
...
...
@@ -33,9 +33,10 @@ class CachedProfiler;
class
ProfilerBase
{
public:
using
OprFormat
=
Problem
::
OprFormat
;
using
OprFormatConfigID
=
Problem
::
OprFormatConfigID
;
struct
OperatorNodeRecord
{
const
cg
::
OperatorNodeBase
*
opr
;
///< pointer to operator node
ThinHashMap
<
OprFormat
,
float
>
ThinHashMap
<
OprFormat
ConfigID
,
float
>
costs
;
///< costs of operator node, i.e. the elapsed device
///< time of the operator node on different opr format
///< (layout configuration).
...
...
@@ -199,6 +200,8 @@ protected:
virtual
float
profile_var_node
(
const
VarNode
*
var
,
TensorFormats
base_format
,
const
ReformatKey
&
key
)
const
;
OprFormatConfigID
tensor_formats_to_config_id
(
TensorFormats
tensor_format
)
const
;
OprFootprint
m_opr_footprint
;
float
m_opr_threshold
;
/// a threshold, when the computation of the newly
/// created operator that is built in some opr
...
...
@@ -224,14 +227,14 @@ class ProfilerCache : public NonCopyableObj {
public:
using
ReformatKey
=
ReformatManager
::
ReformatKey
;
using
ReformatAttribute
=
ReformatKey
::
Attribute
;
using
OprFormat
=
ProfilerBase
::
OprFormat
;
using
OprFormat
ConfigID
=
ProfilerBase
::
OprFormatConfigID
;
class
Key
final
:
public
NonCopyableObj
{
std
::
string
m_blob_storage
;
std
::
string
m_category
;
struct
OprKey
{
const
OperatorNodeBase
*
opr
;
OprFormat
opr_format
;
OprFormat
ConfigID
config_id
;
ReformatAttribute
extra_attribute
;
};
...
...
@@ -254,9 +257,9 @@ public:
void
build_category
(
CompNode
cn
);
public:
Key
(
const
OperatorNodeBase
*
opr
,
OprFormat
opr_format
,
Key
(
const
OperatorNodeBase
*
opr
,
OprFormat
ConfigID
config_id
,
ReformatAttribute
extra_attribute
=
ReformatAttribute
::
DEFAULT
)
{
m_key_impl
.
opr_key
=
{
opr
,
opr_format
,
extra_attribute
};
m_key_impl
.
opr_key
=
{
opr
,
config_id
,
extra_attribute
};
build_blob_from_opr
();
mgb_assert
(
opr
->
node_prop
().
contain
(
...
...
src/gopt/include/megbrain/gopt/solver.h
浏览文件 @
fe93013a
...
...
@@ -28,7 +28,8 @@ class ProfilerBase;
class
SolverBase
{
public:
using
OprFormat
=
Problem
::
OprFormat
;
using
Solution
=
ThinHashMap
<
cg
::
OperatorNodeBase
*
,
OprFormat
>
;
using
OprFormatConfigID
=
Problem
::
OprFormatConfigID
;
using
Solution
=
ThinHashMap
<
cg
::
OperatorNodeBase
*
,
OprFormatConfigID
>
;
SolverBase
()
=
default
;
virtual
~
SolverBase
()
=
default
;
/*!
...
...
src/gopt/test/embed_cache.py
浏览文件 @
fe93013a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...
...
@@ -95,7 +96,7 @@ static const std::vector<uint8_t> {} = {{
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'embed c
ache into cache header
file'
,
description
=
'embed c
ubin into cpp source
file'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'-o'
,
'--output'
,
help
=
'output source file'
,
required
=
True
)
...
...
src/gopt/test/layout_transform_pass.cpp
浏览文件 @
fe93013a
此差异已折叠。
点击以展开。
src/gopt/test/network.cpp
浏览文件 @
fe93013a
...
...
@@ -45,6 +45,36 @@ SymbolVar Network::add_conv(
return
conv
;
}
SymbolVar
Network
::
add_group_conv
(
SymbolVar
f
,
size_t
output_channels
,
size_t
groups
,
KernSize
kern_size
,
DType
out_dtype
,
bool
has_relu
,
Stride
stride
,
Padding
padding
)
{
static
int
weight_idx
=
0
;
static
int
bias_idx
=
0
;
size_t
input_channels
=
f
.
node
()
->
shape
()[
1
];
auto
weight
=
add_cvar
(
ssprintf
(
"w%d"
,
weight_idx
).
c_str
(),
{
groups
,
output_channels
/
groups
,
input_channels
/
groups
,
kern_size
[
0
],
kern_size
[
1
]});
auto
bias
=
add_cvar
(
ssprintf
(
"b%d"
,
bias_idx
).
c_str
(),
{
1
,
output_channels
,
1
,
1
});
mgb_assert
(
out_dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
opr
::
ConvBias
::
Param
param
;
param
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
GROUP
;
param
.
stride_h
=
stride
[
0
],
param
.
stride_w
=
stride
[
1
];
param
.
pad_h
=
padding
[
0
],
param
.
pad_w
=
padding
[
1
];
if
(
has_relu
)
{
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
}
else
{
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
IDENTITY
;
}
auto
conv
=
opr
::
ConvBias
::
make
(
f
,
weight
,
bias
,
param
,
{},
OperatorNodeConfig
{
out_dtype
});
weight_idx
++
;
bias_idx
++
;
return
conv
;
}
SymbolVar
Network
::
add_deconv
(
SymbolVar
f
,
size_t
ratio
,
size_t
output_channels
,
DType
out_dtype
)
{
static
int
weight_idx
=
0
;
...
...
@@ -208,6 +238,7 @@ SymbolVarArray fusion_pyramids_feature(
false
,
{
1
,
1
},
{
0
,
0
});
if
(
!
touch
)
{
x
=
f
;
touch
=
true
;
}
else
{
x
=
network
.
add_deconv
(
x
,
2
,
16
,
dtype
::
QuantizedS8
{
1.
f
});
x
=
network
.
add_elemwise
(
...
...
@@ -236,4 +267,63 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) {
return
outputs
;
}
SymbolVar
mgb
::
bottleneck
(
Network
&
network
,
SymbolVar
f
,
size_t
input_channels
,
size_t
channels
,
size_t
t
,
size_t
stride
)
{
size_t
in_channels
=
f
.
node
()
->
shape
()[
1
];
SymbolVar
x
=
f
;
if
(
t
!=
1
)
{
x
=
network
.
add_conv
(
f
,
input_channels
*
t
,
{
1
,
1
},
dtype
::
Float32
(),
true
,
{
1
,
1
},
{
0
,
0
});
}
x
=
network
.
add_group_conv
(
x
,
input_channels
*
t
,
input_channels
*
t
,
{
3
,
3
},
dtype
::
Float32
(),
true
,
{
stride
,
stride
},
{
1
,
1
});
x
=
network
.
add_conv
(
x
,
channels
,
{
1
,
1
},
dtype
::
Float32
(),
false
,
{
1
,
1
},
{
0
,
0
});
if
(
stride
==
1
&&
in_channels
==
channels
)
x
=
f
+
x
;
return
x
;
}
SymbolVar
mgb
::
bottleneck_group
(
Network
&
network
,
SymbolVar
f
,
size_t
input_channels
,
size_t
channels
,
size_t
stages
,
size_t
s
,
size_t
t
)
{
SymbolVar
x
=
f
;
for
(
size_t
i
=
0
;
i
<
stages
;
++
i
)
{
size_t
stride
=
i
==
0
?
s
:
1
;
x
=
bottleneck
(
network
,
x
,
input_channels
,
channels
,
t
,
stride
);
input_channels
=
channels
;
}
return
x
;
}
namespace
{
size_t
make_divisible
(
size_t
v
,
size_t
divisor
)
{
size_t
min_value
=
divisor
;
size_t
new_v
=
std
::
max
(
min_value
,
(
v
+
divisor
/
2
)
/
divisor
*
divisor
);
if
(
new_v
<
0.9
*
v
)
new_v
+=
divisor
;
return
new_v
;
}
}
// namespace
SymbolVar
mgb
::
make_mobilenet_v2
(
Network
&
network
,
size_t
batch
)
{
auto
data
=
network
.
add_var
(
"data"
,
{
batch
,
3
,
224
,
224
});
constexpr
size_t
round_nearest
=
8
;
auto
x
=
network
.
add_conv
(
data
,
make_divisible
(
32
,
round_nearest
),
{
3
,
3
},
dtype
::
Float32
(),
true
,
{
2
,
2
},
{
1
,
1
});
x
=
bottleneck
(
network
,
x
,
32
,
make_divisible
(
16
,
round_nearest
),
1
,
1
);
x
=
bottleneck_group
(
network
,
x
,
16
,
make_divisible
(
24
,
round_nearest
),
2
,
2
,
6
);
x
=
bottleneck_group
(
network
,
x
,
24
,
make_divisible
(
32
,
round_nearest
),
3
,
2
,
6
);
x
=
bottleneck_group
(
network
,
x
,
32
,
make_divisible
(
64
,
round_nearest
),
4
,
2
,
6
);
x
=
bottleneck_group
(
network
,
x
,
64
,
make_divisible
(
96
,
round_nearest
),
3
,
1
,
6
);
x
=
bottleneck_group
(
network
,
x
,
96
,
make_divisible
(
160
,
round_nearest
),
3
,
2
,
6
);
x
=
bottleneck_group
(
network
,
x
,
160
,
make_divisible
(
320
,
round_nearest
),
1
,
1
,
6
);
x
=
network
.
add_conv
(
x
,
make_divisible
(
1280
,
round_nearest
),
{
1
,
1
},
dtype
::
Float32
(),
true
,
{
1
,
1
},
{
0
,
0
});
return
x
;
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/test/network.h
浏览文件 @
fe93013a
...
...
@@ -28,7 +28,7 @@
namespace
mgb
{
class
Network
{
private:
HostTensorGenerator
<
>
gen
;
HostTensorGenerator
<
dtype
::
Float32
,
RandomDistribution
::
UNIFORM
>
gen
{
-
0.01
,
0.01
}
;
CompNode
cn
;
public:
...
...
@@ -49,6 +49,10 @@ public:
SymbolVar
f
,
size_t
output_channels
,
KernSize
kern_size
,
DType
out_dtype
=
dtype
::
Float32
(),
bool
has_relu
=
true
,
Stride
stride
=
{
1
,
1
},
Padding
padding
=
{
0
,
0
});
SymbolVar
add_group_conv
(
SymbolVar
f
,
size_t
output_channels
,
size_t
groups
,
KernSize
kern_size
,
DType
out_dtype
=
dtype
::
Float32
(),
bool
has_relu
=
true
,
Stride
stride
=
{
1
,
1
},
Padding
padding
=
{
0
,
0
});
SymbolVar
add_deconv
(
SymbolVar
f
,
size_t
ratio
,
size_t
output_channels
,
DType
out_dtype
);
SymbolVar
add_elemwise
(
...
...
@@ -73,6 +77,16 @@ SymbolVar make_resnet18(
SymbolVarArray
make_det
(
Network
&
network
,
size_t
batch
=
16
,
DType
out_dtype
=
dtype
::
Float32
());
SymbolVar
bottleneck
(
Network
&
network
,
SymbolVar
f
,
size_t
input_channels
,
size_t
channels
,
size_t
t
,
size_t
stride
);
SymbolVar
bottleneck_group
(
Network
&
network
,
SymbolVar
f
,
size_t
input_channels
,
size_t
channels
,
size_t
stages
,
size_t
s
,
size_t
t
);
SymbolVar
make_mobilenet_v2
(
Network
&
network
,
size_t
batch
=
1
);
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/test/profiler.cpp
浏览文件 @
fe93013a
...
...
@@ -26,7 +26,7 @@ using namespace serialization;
#if MGB_CUDA
namespace
{
std
::
unique_ptr
<
LayoutTransformContext
>
make_ctx
()
{
using
OprFormat
=
LayoutTransformContext
::
OprFormat
;
using
OprFormat
ConfigID
=
LayoutTransformContext
::
OprFormatConfigID
;
using
OprList
=
LayoutTransformContext
::
OprList
;
using
Attribute
=
LayoutTransformContext
::
Attribute
;
using
Target
=
LayoutTransformContext
::
Target
;
...
...
@@ -44,26 +44,29 @@ std::unique_ptr<LayoutTransformContext> make_ctx() {
SmallVector
<
TensorFormats
>
available_tensor_formats
=
{
TensorFormats
::
NCHW
,
TensorFormats
::
NHWC
,
TensorFormats
::
NCHWc4
,
TensorFormats
::
NCHWc32
,
TensorFormats
::
NCHWc64
,
TensorFormats
::
CHWNc4
};
Attribute
attribute
=
{
OprFormat
::
NCHW
,
TensorFormats
::
NCHW
,
Target
::
CUDA
};
Attribute
attribute
=
{
OprFormat
ConfigID
::
NCHW
,
TensorFormats
::
NCHW
,
Target
::
CUDA
};
auto
ctx
=
std
::
make_unique
<
LayoutTransformContext
>
(
std
::
move
(
opr_list
),
std
::
move
(
available_tensor_formats
),
attribute
);
ctx
->
add_opr_config
(
opr
::
ConvBiasForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
{
OprFormatConfigID
::
NCHW
,
OprFormatConfigID
::
NHWC
,
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NCHW32
,
OprFormatConfigID
::
NCHW64
,
OprFormatConfigID
::
CHWN4
})
.
add_opr_config
(
opr
::
ConvolutionForward
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
{
OprFormat
ConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW4
})
.
add_opr_config
(
opr
::
ConvolutionBackwardData
::
typeinfo
(),
{
OprFormat
::
NCHW
,
OprFormat
::
NCHW4
})
{
OprFormat
ConfigID
::
NCHW
,
OprFormatConfigID
::
NCHW4
})
.
add_opr_config
(
opr
::
PoolingForward
::
typeinfo
(),
{
OprFormat
::
NCHW4
,
OprFormat
::
NCHW32
,
OprFormat
::
NHWC
,
OprFormat
::
NCHW64
,
OprFormat
::
CHWN4
})
{
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NCHW32
,
OprFormatConfigID
::
NHWC
,
OprFormatConfigID
::
NCHW64
,
OprFormatConfigID
::
CHWN4
})
.
add_opr_config
(
opr
::
WarpPerspectiveForward
::
typeinfo
(),
{
OprFormat
::
NHWC
,
OprFormat
::
NCHW4
,
OprFormat
::
NCHW64
});
{
OprFormatConfigID
::
NHWC
,
OprFormatConfigID
::
NCHW4
,
OprFormatConfigID
::
NCHW64
});
return
ctx
;
}
}
// namespace
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录