Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
df47637d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
df47637d
编写于
7月 04, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
7月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/naive): fix midout for relayout_format
GitOrigin-RevId: 6ff9e2280ebb5c9ba388192f57cf1aa737ca27a1
上级
0d8b9136
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
46 addition
and
29 deletion
+46
-29
dnn/src/naive/relayout_format/opr_impl.cpp
dnn/src/naive/relayout_format/opr_impl.cpp
+46
-29
未找到文件。
dnn/src/naive/relayout_format/opr_impl.cpp
浏览文件 @
df47637d
...
...
@@ -14,6 +14,10 @@
#include "megdnn/tensor_iter.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_naive_relayout_format
)
using
namespace
megdnn
;
using
namespace
naive
;
...
...
@@ -222,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
//! ic % 4 != 0
if
((
IC
&
0x3
))
{
switch
(
src
.
layout
.
dtype
.
enumv
())
{
#define cb(name, ctype) \
case (DTypeEnum::name): { \
ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \
m_handle, \
padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \
break; \
#define cb(name, ctype) \
case (DTypeEnum::name): { \
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::NCHW_NHWCD4I)) { \
ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \
m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \
IC, IH, IW);); \
} \
MIDOUT_END(); \
break; \
}
cb
(
Float32
,
dt_float32
);
MEGDNN_INC_FLOAT16
(
cb
(
Float16
,
dt_float16
));
...
...
@@ -248,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
size_t
FW
=
src
.
layout
[
3
];
if
((
IC
&
0x3
))
{
switch
(
src
.
layout
.
dtype
.
enumv
())
{
#define cb(name, ctype) \
case (DTypeEnum::name): { \
ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \
m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \
IC, FH, FW);); \
break; \
#define cb(name, ctype) \
case (DTypeEnum::name): { \
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \
ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN(m_handle, \
padding_filter_to_workspace<ctype>( \
dptr, sptr, OC, IC, FH, FW);); \
} \
MIDOUT_END(); \
break; \
}
cb
(
Quantized8Asymm
,
dt_uint8
);
cb
(
QuantizedS8
,
dt_int8
);
...
...
@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
exec_src_nd
.
raw_ptr
=
workspace
.
raw_ptr
;
}
}
else
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW88
)
{
#define cb(_idx, _pack_size) \
size_t val = src.layout[_idx]; \
if (val % _pack_size != 0) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
_pack_size); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \
}
cb
(
1
,
8
);
#define cb(_idx, _pack_size, _mode) \
MIDOUT_BEGIN(megdnn_naive_relayout_format, \
midout_iv(Param::Mode::_mode)) { \
size_t val = src.layout[_idx]; \
if (val % _pack_size != 0) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
_pack_size); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \
} \
} \
MIDOUT_END();
cb
(
1
,
8
,
NCHW_NCHW88
);
}
else
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW88_CONV_DENSE_WEIGHT
)
{
megdnn_assert
(
src
.
layout
[
0
]
%
8
==
0
);
cb
(
1
,
8
);
cb
(
1
,
8
,
NCHW_NCHW88_CONV_DENSE_WEIGHT
);
}
else
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW88_CONV_CHAN_WEIGHT
)
{
cb
(
0
,
8
);
cb
(
0
,
8
,
NCHW_NCHW88_CONV_CHAN_WEIGHT
);
}
else
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW88_CONV_GROUP_WEIGHT
)
{
megdnn_assert
(
src
.
layout
[
1
]
%
8
==
0
);
cb
(
2
,
8
);
cb
(
2
,
8
,
NCHW_NCHW88_CONV_GROUP_WEIGHT
);
}
else
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4_IC_SMALL
)
{
cb
(
1
,
4
);
cb
(
1
,
4
,
NCHW_NCHW4_IC_SMALL
);
}
else
if
(
param
().
mode
==
Param
::
Mode
::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT
)
{
cb
(
1
,
4
);
cb
(
1
,
4
,
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT
);
}
m_handle
->
relayout_opr
()
->
exec
(
exec_src_nd
,
exec_dst_nd
,
handle
());
#undef cb
}
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录