From 5a442553240881504c60b2e4e15afa9f72e6f36e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 25 Jan 2021 12:36:16 +0800 Subject: [PATCH] fix(mgb/plugin): fix opr footprint for conv with NCHW32_NCHW4 format GitOrigin-RevId: 9881b7971c8de57f11f7a4c61b58f74be292f97a --- src/plugin/impl/opr_footprint.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 2b691d875..cabb61c96 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -127,7 +127,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, src_shape[1] / group * 2; return hybird_nchwx ? computation : computation * 4; } - if (param.format == Param::Format::NCHW32) { + if (param.format == Param::Format::NCHW32 || + param.format == Param::Format::NCHW32_NCHW4) { return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 / group * 2; } @@ -157,11 +158,12 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, }; if (param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW4_NCHW || - param.format == Param::Format::NCHW4_NCHW32 || + param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44_DOT || - param.format == Param::Format::NCHW32) { + param.format == Param::Format::NCHW32 || + param.format == Param::Format::NCHW32_NCHW4) { return eval_conv_computation_nchwx(); } if (param.format == Param::Format::CHWN4) { -- GitLab