提交 df47637d 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/naive): fix midout for relayout_format

GitOrigin-RevId: 6ff9e2280ebb5c9ba388192f57cf1aa737ca27a1
上级 0d8b9136
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "megdnn/tensor_iter.h" #include "megdnn/tensor_iter.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_relayout_format)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
...@@ -224,11 +228,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -224,11 +228,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \ #define cb(name, ctype) \
case (DTypeEnum::name): { \ case (DTypeEnum::name): { \
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::NCHW_NHWCD4I)) { \
ctype* sptr = src.compatible_ptr<ctype>(); \ ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \ ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
m_handle, \ m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \
padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \ IC, IH, IW);); \
} \
MIDOUT_END(); \
break; \ break; \
} }
cb(Float32, dt_float32); cb(Float32, dt_float32);
...@@ -250,11 +258,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -250,11 +258,15 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \ #define cb(name, ctype) \
case (DTypeEnum::name): { \ case (DTypeEnum::name): { \
MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \
ctype* sptr = src.compatible_ptr<ctype>(); \ ctype* sptr = src.compatible_ptr<ctype>(); \
ctype* dptr = workspace.ptr<ctype>(); \ ctype* dptr = workspace.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN(m_handle, \
m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \ padding_filter_to_workspace<ctype>( \
IC, FH, FW);); \ dptr, sptr, OC, IC, FH, FW);); \
} \
MIDOUT_END(); \
break; \ break; \
} }
cb(Quantized8Asymm, dt_uint8); cb(Quantized8Asymm, dt_uint8);
...@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
exec_src_nd.raw_ptr = workspace.raw_ptr; exec_src_nd.raw_ptr = workspace.raw_ptr;
} }
} else if (param().mode == Param::Mode::NCHW_NCHW88) { } else if (param().mode == Param::Mode::NCHW_NCHW88) {
#define cb(_idx, _pack_size) \ #define cb(_idx, _pack_size, _mode) \
MIDOUT_BEGIN(megdnn_naive_relayout_format, \
midout_iv(Param::Mode::_mode)) { \
size_t val = src.layout[_idx]; \ size_t val = src.layout[_idx]; \
if (val % _pack_size != 0) { \ if (val % _pack_size != 0) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
_pack_size); \ _pack_size); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \ exec_src_nd.raw_ptr = workspace.raw_ptr; \
} } \
cb(1, 8); } \
MIDOUT_END();
cb(1, 8, NCHW_NCHW88);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) {
megdnn_assert(src.layout[0] % 8 == 0); megdnn_assert(src.layout[0] % 8 == 0);
cb(1, 8); cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) {
cb(0, 8); cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) {
megdnn_assert(src.layout[1] % 8 == 0); megdnn_assert(src.layout[1] % 8 == 0);
cb(2, 8); cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) {
cb(1, 4); cb(1, 4, NCHW_NCHW4_IC_SMALL);
} else if (param().mode == } else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4); cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
} }
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
#undef cb
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册