提交 df47637d 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/naive): fix midout for relayout_format

GitOrigin-RevId: 6ff9e2280ebb5c9ba388192f57cf1aa737ca27a1
上级 0d8b9136
......@@ -14,6 +14,10 @@
#include "megdnn/tensor_iter.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_relayout_format)
using namespace megdnn;
using namespace naive;
......@@ -224,11 +228,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \
case (DTypeEnum::name): { \
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::NCHW_NHWCD4I)) { \
ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \
m_handle, \
padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \
m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \
IC, IH, IW);); \
} \
MIDOUT_END(); \
break; \
}
cb(Float32, dt_float32);
......@@ -250,11 +258,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \
case (DTypeEnum::name): { \
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \
ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \
m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \
IC, FH, FW);); \
MEGDNN_DISPATCH_CPU_KERN(m_handle, \
padding_filter_to_workspace<ctype>( \
dptr, sptr, OC, IC, FH, FW);); \
} \
MIDOUT_END(); \
break; \
}
cb(Quantized8Asymm, dt_uint8);
......@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
} else if (param().mode == Param::Mode::NCHW_NCHW88) {
#define cb(_idx, _pack_size) \
#define cb(_idx, _pack_size, _mode) \
MIDOUT_BEGIN(megdnn_naive_relayout_format, \
midout_iv(Param::Mode::_mode)) { \
size_t val = src.layout[_idx]; \
if (val % _pack_size != 0) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
_pack_size); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \
}
cb(1, 8);
} \
} \
MIDOUT_END();
cb(1, 8, NCHW_NCHW88);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) {
megdnn_assert(src.layout[0] % 8 == 0);
cb(1, 8);
cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) {
cb(0, 8);
cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) {
megdnn_assert(src.layout[1] % 8 == 0);
cb(2, 8);
cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) {
cb(1, 4);
cb(1, 4, NCHW_NCHW4_IC_SMALL);
} else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4);
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
}
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
#undef cb
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册