From 32d7f25b194ef94357e1f7a9d0e2a529e99754f8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 4 Jul 2020 09:44:24 +0800 Subject: [PATCH] fix(dnn/naive): fix midout for relayout_format GitOrigin-RevId: 6ff9e2280ebb5c9ba388192f57cf1aa737ca27a1 --- dnn/src/naive/relayout_format/opr_impl.cpp | 75 +++++++++++++--------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/dnn/src/naive/relayout_format/opr_impl.cpp b/dnn/src/naive/relayout_format/opr_impl.cpp index 21d4e8112..cc4e71ace 100644 --- a/dnn/src/naive/relayout_format/opr_impl.cpp +++ b/dnn/src/naive/relayout_format/opr_impl.cpp @@ -14,6 +14,10 @@ #include "megdnn/tensor_iter.h" +#include "midout.h" + +MIDOUT_DECL(megdnn_naive_relayout_format) + using namespace megdnn; using namespace naive; @@ -222,14 +226,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, //! ic % 4 != 0 if ((IC & 0x3)) { switch (src.layout.dtype.enumv()) { -#define cb(name, ctype) \ - case (DTypeEnum::name): { \ - ctype* sptr = src.compatible_ptr(); \ - ctype* dptr = workspace.ptr(); \ - MEGDNN_DISPATCH_CPU_KERN( \ - m_handle, \ - padding_src_to_workspace(dptr, sptr, N, IC, IH, IW);); \ - break; \ +#define cb(name, ctype) \ + case (DTypeEnum::name): { \ + MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ + midout_iv(Param::Mode::NCHW_NHWCD4I)) { \ + ctype* sptr = src.compatible_ptr(); \ + ctype* dptr = workspace.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN( \ + m_handle, padding_src_to_workspace(dptr, sptr, N, \ + IC, IH, IW);); \ + } \ + MIDOUT_END(); \ + break; \ } cb(Float32, dt_float32); MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); @@ -248,14 +256,18 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, size_t FW = src.layout[3]; if ((IC & 0x3)) { switch (src.layout.dtype.enumv()) { -#define cb(name, ctype) \ - case (DTypeEnum::name): { \ - ctype* sptr = src.compatible_ptr(); \ - ctype* dptr = workspace.ptr(); \ - MEGDNN_DISPATCH_CPU_KERN( \ - m_handle, padding_filter_to_workspace(dptr, sptr, OC, \ - IC, FH, FW);); \ - break; \ +#define cb(name, ctype) \ + case (DTypeEnum::name): { \ + MIDOUT_BEGIN(megdnn_naive_relayout_format, ctype, \ + midout_iv(Param::Mode::INTER_WEIGHT_DENSEI_DOT)) { \ + ctype* sptr = src.compatible_ptr(); \ + ctype* dptr = workspace.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN(m_handle, \ + padding_filter_to_workspace( \ + dptr, sptr, OC, IC, FH, FW);); \ + } \ + MIDOUT_END(); \ + break; \ } cb(Quantized8Asymm, dt_uint8); cb(QuantizedS8, dt_int8); @@ -266,30 +278,35 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, exec_src_nd.raw_ptr = workspace.raw_ptr; } } else if (param().mode == Param::Mode::NCHW_NCHW88) { -#define cb(_idx, _pack_size) \ - size_t val = src.layout[_idx]; \ - if (val % _pack_size != 0) { \ - padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ - _pack_size); \ - exec_src_nd.raw_ptr = workspace.raw_ptr; \ - } - cb(1, 8); +#define cb(_idx, _pack_size, _mode) \ + MIDOUT_BEGIN(megdnn_naive_relayout_format, \ + midout_iv(Param::Mode::_mode)) { \ + size_t val = src.layout[_idx]; \ + if (val % _pack_size != 0) { \ + padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ + _pack_size); \ + exec_src_nd.raw_ptr = workspace.raw_ptr; \ + } \ + } \ + MIDOUT_END(); + cb(1, 8, NCHW_NCHW88); } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { megdnn_assert(src.layout[0] % 8 == 0); - cb(1, 8); + cb(1, 8, NCHW_NCHW88_CONV_DENSE_WEIGHT); } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { - cb(0, 8); + cb(0, 8, NCHW_NCHW88_CONV_CHAN_WEIGHT); } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { megdnn_assert(src.layout[1] % 8 == 0); - cb(2, 8); + cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT); } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { - cb(1, 4); + cb(1, 4, NCHW_NCHW4_IC_SMALL); } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { - cb(1, 4); + cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); } m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); +#undef cb } // vim: syntax=cpp.doxygen -- GitLab