提交 5a442553 编写于 作者: M Megvii Engine Team

fix(mgb/plugin): fix opr footprint for conv with NCHW32_NCHW4 format

GitOrigin-RevId: 9881b7971c8de57f11f7a4c61b58f74be292f97a
上级 adc49de8
......@@ -127,7 +127,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 4;
}
if (param.format == Param::Format::NCHW32) {
if (param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) {
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 /
group * 2;
}
......@@ -157,11 +158,12 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW32) {
param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) {
return eval_conv_computation_nchwx();
}
if (param.format == Param::Format::CHWN4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册