From 1adb262ad4d7771e4bae1ca30b91f84068fbc633 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 4 Jul 2020 13:44:35 +0800 Subject: [PATCH] fix(dnn/naive): fix midout for pooling GitOrigin-RevId: 4edd99f3ecb3f96275fcd139076d63181487e8ed --- dnn/src/naive/pooling/opr_impl.cpp | 116 ++++++++++++++--------------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/dnn/src/naive/pooling/opr_impl.cpp b/dnn/src/naive/pooling/opr_impl.cpp index b0a5222d3..c2d7d8e75 100644 --- a/dnn/src/naive/pooling/opr_impl.cpp +++ b/dnn/src/naive/pooling/opr_impl.cpp @@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, } } -} // anonymous namespace +} // namespace namespace megdnn { namespace naive { void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { - MIDOUT_BEGIN(megdnn_naive_pooling) { - check_exec(src.layout, dst.layout, workspace.size); - size_t c_pos, spatial_pos, batch_pos = 0; - if (param().format == Param::Format::NCHW || - param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW88 || - param().format == Param::Format::NCHW44 || - param().format == Param::Format::NCHW32) { - c_pos = 1; - spatial_pos = 2; - } else if (param().format == Param::Format::NHWC) { - c_pos = 3; - spatial_pos = 1; - } else if (param().format == Param::Format::CHWN4) { - c_pos = 0; - spatial_pos = 1; - batch_pos = 3; - } else { - megdnn_assert(param().format == Param::Format::NHWCD4); - c_pos = 2; - spatial_pos = 1; - } - size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], - IH = src.layout.shape[spatial_pos + 0], - IW = src.layout.shape[spatial_pos + 1]; - size_t OH = dst.layout.shape[spatial_pos + 0], - OW = dst.layout.shape[spatial_pos + 1]; - if (param().format == Param::Format::NHWCD4) { - C *= 4; - IW = src.layout.shape[spatial_pos + 2]; - OW = dst.layout.shape[spatial_pos + 2]; - } - if (param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW44 || - param().format == Param::Format::CHWN4) { - C *= 4; - } - if (param().format == Param::Format::NCHW88) { - C *= 8; - } - if (param().format == Param::Format::NCHW32) { - C *= 32; - } - size_t PH = param().pad_h, PW = param().pad_w; - size_t FH = param().window_h, FW = param().window_w; - size_t SH = param().stride_h, SW = param().stride_w; -#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(handle()), \ - pooling_forward_impl( \ - sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ - PW, SH, SW, FH, FW)); + check_exec(src.layout, dst.layout, workspace.size); + size_t c_pos, spatial_pos, batch_pos = 0; + if (param().format == Param::Format::NCHW || + param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW88 || + param().format == Param::Format::NCHW44 || + param().format == Param::Format::NCHW32) { + c_pos = 1; + spatial_pos = 2; + } else if (param().format == Param::Format::NHWC) { + c_pos = 3; + spatial_pos = 1; + } else if (param().format == Param::Format::CHWN4) { + c_pos = 0; + spatial_pos = 1; + batch_pos = 3; + } else { + megdnn_assert(param().format == Param::Format::NHWCD4); + c_pos = 2; + spatial_pos = 1; + } + size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], + IH = src.layout.shape[spatial_pos + 0], + IW = src.layout.shape[spatial_pos + 1]; + size_t OH = dst.layout.shape[spatial_pos + 0], + OW = dst.layout.shape[spatial_pos + 1]; + if (param().format == Param::Format::NHWCD4) { + C *= 4; + IW = src.layout.shape[spatial_pos + 2]; + OW = dst.layout.shape[spatial_pos + 2]; + } + if (param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW44 || + param().format == Param::Format::CHWN4) { + C *= 4; + } + if (param().format == Param::Format::NCHW88) { + C *= 8; + } + if (param().format == Param::Format::NCHW32) { + C *= 32; + } + size_t PH = param().pad_h, PW = param().pad_w; + size_t FH = param().window_h, FW = param().window_w; + size_t SH = param().stride_h, SW = param().stride_w; +#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ + MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(handle()), \ + pooling_forward_impl( \ + sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \ + PH, PW, SH, SW, FH, FW)); \ + } \ + MIDOUT_END(); #define DISPATCH_WITH_POOLER(Pooler) \ switch (param().format) { \ @@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, } \ } \ } - MEGDNN_FOREACH_COMPUTING_DTYPE(cb) - MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER #undef DISPATCH_WITH_POOLER - megdnn_assert_internal(0); - } - MIDOUT_END(); + megdnn_assert_internal(0); } WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( -- GitLab