提交 b90c1540 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/naive): fix midout for pooling

GitOrigin-RevId: 4edd99f3ecb3f96275fcd139076d63181487e8ed
上级 df47637d
......@@ -370,14 +370,13 @@ void pooling_backward_max_impl(const ctype* __restrict src,
}
}
} // anonymous namespace
} // namespace
namespace megdnn {
namespace naive {
void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_pooling) {
check_exec(src.layout, dst.layout, workspace.size);
size_t c_pos, spatial_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW ||
......@@ -424,11 +423,14 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
size_t FH = param().window_h, FW = param().window_w;
size_t SH = param().stride_h, SW = param().stride_w;
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \
PW, SH, SW, FH, FW));
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \
PH, PW, SH, SW, FH, FW)); \
} \
MIDOUT_END();
#define DISPATCH_WITH_POOLER(Pooler) \
switch (param().format) { \
......@@ -490,8 +492,6 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
#undef DISPATCH_WITH_POOLER_AND_IDX_GETTER
#undef DISPATCH_WITH_POOLER
megdnn_assert_internal(0);
}
MIDOUT_END();
}
WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册