提交 b90c1540 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/naive): fix midout for pooling

GitOrigin-RevId: 4edd99f3ecb3f96275fcd139076d63181487e8ed
上级 df47637d
...@@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, ...@@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src,
} }
} }
} // anonymous namespace } // namespace
namespace megdnn { namespace megdnn {
namespace naive { namespace naive {
void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_pooling) { check_exec(src.layout, dst.layout, workspace.size);
check_exec(src.layout, dst.layout, workspace.size); size_t c_pos, spatial_pos, batch_pos = 0;
size_t c_pos, spatial_pos, batch_pos = 0; if (param().format == Param::Format::NCHW ||
if (param().format == Param::Format::NCHW || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW32) {
param().format == Param::Format::NCHW32) { c_pos = 1;
c_pos = 1; spatial_pos = 2;
spatial_pos = 2; } else if (param().format == Param::Format::NHWC) {
} else if (param().format == Param::Format::NHWC) { c_pos = 3;
c_pos = 3; spatial_pos = 1;
spatial_pos = 1; } else if (param().format == Param::Format::CHWN4) {
} else if (param().format == Param::Format::CHWN4) { c_pos = 0;
c_pos = 0; spatial_pos = 1;
spatial_pos = 1; batch_pos = 3;
batch_pos = 3; } else {
} else { megdnn_assert(param().format == Param::Format::NHWCD4);
megdnn_assert(param().format == Param::Format::NHWCD4); c_pos = 2;
c_pos = 2; spatial_pos = 1;
spatial_pos = 1; }
} size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos],
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], IH = src.layout.shape[spatial_pos + 0],
IH = src.layout.shape[spatial_pos + 0], IW = src.layout.shape[spatial_pos + 1];
IW = src.layout.shape[spatial_pos + 1]; size_t OH = dst.layout.shape[spatial_pos + 0],
size_t OH = dst.layout.shape[spatial_pos + 0], OW = dst.layout.shape[spatial_pos + 1];
OW = dst.layout.shape[spatial_pos + 1]; if (param().format == Param::Format::NHWCD4) {
if (param().format == Param::Format::NHWCD4) { C *= 4;
C *= 4; IW = src.layout.shape[spatial_pos + 2];
IW = src.layout.shape[spatial_pos + 2]; OW = dst.layout.shape[spatial_pos + 2];
OW = dst.layout.shape[spatial_pos + 2]; }
} if (param().format == Param::Format::NCHW4 ||
if (param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44 || param().format == Param::Format::CHWN4) {
param().format == Param::Format::CHWN4) { C *= 4;
C *= 4; }
} if (param().format == Param::Format::NCHW88) {
if (param().format == Param::Format::NCHW88) { C *= 8;
C *= 8; }
} if (param().format == Param::Format::NCHW32) {
if (param().format == Param::Format::NCHW32) { C *= 32;
C *= 32; }
} size_t PH = param().pad_h, PW = param().pad_w;
size_t PH = param().pad_h, PW = param().pad_w; size_t FH = param().window_h, FW = param().window_w;
size_t FH = param().window_h, FW = param().window_w; size_t SH = param().stride_h, SW = param().stride_w;
size_t SH = param().stride_h, SW = param().stride_w; #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \ static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \
PW, SH, SW, FH, FW)); PH, PW, SH, SW, FH, FW)); \
} \
MIDOUT_END();
#define DISPATCH_WITH_POOLER(Pooler) \ #define DISPATCH_WITH_POOLER(Pooler) \
switch (param().format) { \ switch (param().format) { \
...@@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
} \ } \
} \ } \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb #undef cb
#undef DISPATCH_WITH_POOLER_AND_IDX_GETTER #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER
#undef DISPATCH_WITH_POOLER #undef DISPATCH_WITH_POOLER
megdnn_assert_internal(0); megdnn_assert_internal(0);
}
MIDOUT_END();
} }
WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册