提交 32d7f25b 编写于 作者: M Megvii Engine Team

fix(dnn/naive): fix midout for relayout_format

GitOrigin-RevId: 6ff9e2280ebb5c9ba388192f57cf1aa737ca27a1
上级 f856b170
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "megdnn/tensor_iter.h" #include "megdnn/tensor_iter.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_relayout_format)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
...@@ -222,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -222,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
//! ic % 4 != 0 //! ic % 4 != 0
if ((IC & 0x3)) { if ((IC & 0x3)) {
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \ #define cb(name, ctype) \
case (DTypeEnum::name): { \ case (DTypeEnum::name): { \
ctype* sptr = src.compatible_ptr<ctype>(); \ MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
ctype* dptr = workspace.ptr<ctype>(); \ midout_iv(Param::Mode::NCHW_NHWCD4I)) { \
MEGDNN_DISPATCH_CPU_KERN( \ ctype* sptr = src.compatible_ptr<ctype>(); \
m_handle, \ ctype* dptr = workspace.ptr<ctype>(); \
padding_src_to_workspace<ctype>(dptr, sptr, N, IC, IH, IW);); \ MEGDNN_DISPATCH_CPU_KERN( \
break; \ m_handle, padding_src_to_workspace<ctype>(dptr, sptr, N, \
IC, IH, IW);); \
} \
MIDOUT_END(); \
break; \
} }
cb(Float32, dt_float32); cb(Float32, dt_float32);
MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); MEGDNN_INC_FLOAT16(cb(Float16, dt_float16));
...@@ -248,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -248,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
size_t FW = src.layout[3]; size_t FW = src.layout[3];
if ((IC & 0x3)) { if ((IC & 0x3)) {
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \ #define cb(name, ctype) \
case (DTypeEnum::name): { \ case (DTypeEnum::name): { \
ctype* sptr = src.compatible_ptr<ctype>(); \ MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \
ctype* dptr = workspace.ptr<ctype>(); \ midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \
MEGDNN_DISPATCH_CPU_KERN( \ ctype* sptr = src.compatible_ptr<ctype>(); \
m_handle, padding_filter_to_workspace<ctype>(dptr, sptr, OC, \ ctype* dptr = workspace.ptr<ctype>(); \
IC, FH, FW);); \ MEGDNN_DISPATCH_CPU_KERN(m_handle, \
break; \ padding_filter_to_workspace<ctype>( \
dptr, sptr, OC, IC, FH, FW);); \
} \
MIDOUT_END(); \
break; \
} }
cb(Quantized8Asymm, dt_uint8); cb(Quantized8Asymm, dt_uint8);
cb(QuantizedS8, dt_int8); cb(QuantizedS8, dt_int8);
...@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
exec_src_nd.raw_ptr = workspace.raw_ptr; exec_src_nd.raw_ptr = workspace.raw_ptr;
} }
} else if (param().mode == Param::Mode::NCHW_NCHW88) { } else if (param().mode == Param::Mode::NCHW_NCHW88) {
#define cb(_idx, _pack_size) \ #define cb(_idx, _pack_size, _mode) \
size_t val = src.layout[_idx]; \ MIDOUT_BEGIN(megdnn_naive_relayout_format, \
if (val % _pack_size != 0) { \ midout_iv(Param::Mode::_mode)) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ size_t val = src.layout[_idx]; \
_pack_size); \ if (val % _pack_size != 0) { \
exec_src_nd.raw_ptr = workspace.raw_ptr; \ padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
} _pack_size); \
cb(1, 8); exec_src_nd.raw_ptr = workspace.raw_ptr; \
} \
} \
MIDOUT_END();
cb(1, 8, NCHW_NCHW88);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) {
megdnn_assert(src.layout[0] % 8 == 0); megdnn_assert(src.layout[0] % 8 == 0);
cb(1, 8); cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) {
cb(0, 8); cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) {
megdnn_assert(src.layout[1] % 8 == 0); megdnn_assert(src.layout[1] % 8 == 0);
cb(2, 8); cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT);
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) {
cb(1, 4); cb(1, 4, NCHW_NCHW4_IC_SMALL);
} else if (param().mode == } else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4); cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
} }
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
#undef cb
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册