提交 b90c1540 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/naive): fix midout for pooling

GitOrigin-RevId: 4edd99f3ecb3f96275fcd139076d63181487e8ed
上级 df47637d
......@@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src,
}
}
} // anonymous namespace
} // namespace
namespace megdnn {
namespace naive {
void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_pooling) {
check_exec(src.layout, dst.layout, workspace.size);
size_t c_pos, spatial_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW32) {
c_pos = 1;
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
c_pos = 3;
spatial_pos = 1;
} else if (param().format == Param::Format::CHWN4) {
c_pos = 0;
spatial_pos = 1;
batch_pos = 3;
} else {
megdnn_assert(param().format == Param::Format::NHWCD4);
c_pos = 2;
spatial_pos = 1;
}
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos],
IH = src.layout.shape[spatial_pos + 0],
IW = src.layout.shape[spatial_pos + 1];
size_t OH = dst.layout.shape[spatial_pos + 0],
OW = dst.layout.shape[spatial_pos + 1];
if (param().format == Param::Format::NHWCD4) {
C *= 4;
IW = src.layout.shape[spatial_pos + 2];
OW = dst.layout.shape[spatial_pos + 2];
}
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) {
C *= 4;
}
if (param().format == Param::Format::NCHW88) {
C *= 8;
}
if (param().format == Param::Format::NCHW32) {
C *= 32;
}
size_t PH = param().pad_h, PW = param().pad_w;
size_t FH = param().window_h, FW = param().window_w;
size_t SH = param().stride_h, SW = param().stride_w;
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \
PW, SH, SW, FH, FW));
check_exec(src.layout, dst.layout, workspace.size);
size_t c_pos, spatial_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW32) {
c_pos = 1;
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
c_pos = 3;
spatial_pos = 1;
} else if (param().format == Param::Format::CHWN4) {
c_pos = 0;
spatial_pos = 1;
batch_pos = 3;
} else {
megdnn_assert(param().format == Param::Format::NHWCD4);
c_pos = 2;
spatial_pos = 1;
}
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos],
IH = src.layout.shape[spatial_pos + 0],
IW = src.layout.shape[spatial_pos + 1];
size_t OH = dst.layout.shape[spatial_pos + 0],
OW = dst.layout.shape[spatial_pos + 1];
if (param().format == Param::Format::NHWCD4) {
C *= 4;
IW = src.layout.shape[spatial_pos + 2];
OW = dst.layout.shape[spatial_pos + 2];
}
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) {
C *= 4;
}
if (param().format == Param::Format::NCHW88) {
C *= 8;
}
if (param().format == Param::Format::NCHW32) {
C *= 32;
}
size_t PH = param().pad_h, PW = param().pad_w;
size_t FH = param().window_h, FW = param().window_w;
size_t SH = param().stride_h, SW = param().stride_w;
#define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \
MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \
PH, PW, SH, SW, FH, FW)); \
} \
MIDOUT_END();
#define DISPATCH_WITH_POOLER(Pooler) \
switch (param().format) { \
......@@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
} \
} \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
#undef DISPATCH_WITH_POOLER_AND_IDX_GETTER
#undef DISPATCH_WITH_POOLER
megdnn_assert_internal(0);
}
MIDOUT_END();
megdnn_assert_internal(0);
}
WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册