提交 3bd40887 编写于 作者: M Megvii Engine Team

feat(mgb/opr): add NHWC support for AdaptivePooling

GitOrigin-RevId: b23e37ac23764085568d4c303619eec10a0b4867
上级 e393d1cf
...@@ -6,8 +6,21 @@ namespace megdnn { ...@@ -6,8 +6,21 @@ namespace megdnn {
param::Pooling AdaptivePoolingBase::deduce_pooling_param( param::Pooling AdaptivePoolingBase::deduce_pooling_param(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
megdnn_assert(param().format == param::AdaptivePooling::Format::NCHW); auto param_format = param().format;
size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], OW = dst.shape[3]; size_t IH, IW, OH, OW;
if (param_format == param::AdaptivePooling::Format::NCHW) {
IH = src.shape[2];
IW = src.shape[3];
OH = dst.shape[2];
OW = dst.shape[3];
} else if (param_format == param::AdaptivePooling::Format::NHWC) {
IH = src.shape[1];
IW = src.shape[2];
OH = dst.shape[1];
OW = dst.shape[2];
} else {
megdnn_throw("AdaptivePooling only support NCHW or NHWC format");
}
param::Pooling ret; param::Pooling ret;
ret.mode = param().mode; ret.mode = param().mode;
......
...@@ -43,13 +43,22 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( ...@@ -43,13 +43,22 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
"shape mismatch for AdaptivePooling: src=%s, out2d=%s", "shape mismatch for AdaptivePooling: src=%s, out2d=%s",
src.to_string().c_str(), oshp2d.to_string().c_str()); src.to_string().c_str(), oshp2d.to_string().c_str());
mgb_assert( auto param_format = param().format;
param().format == Param::Format::NCHW, "AdaptivePooling only support NCHW"); if (param_format == Param::Format::NCHW) {
dest.ndim = 4; dest.ndim = 4;
dest.shape[0] = src.shape[0]; dest.shape[0] = src.shape[0];
dest.shape[1] = src.shape[1]; dest.shape[1] = src.shape[1];
dest.shape[2] = oshp2d.shape[0]; dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = oshp2d.shape[1]; dest.shape[3] = oshp2d.shape[1];
} else if (param_format == Param::Format::NHWC) {
dest.ndim = 4;
dest.shape[0] = src.shape[0];
dest.shape[1] = oshp2d.shape[0];
dest.shape[2] = oshp2d.shape[1];
dest.shape[3] = src.shape[3];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
}
} }
size_t AdaptivePoolingForward::get_workspace_size_bytes( size_t AdaptivePoolingForward::get_workspace_size_bytes(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册