提交 f0291883 编写于 作者: M Megvii Engine Team

fix(mgb): make error infomation of group conv input channel mismatch more readable

GitOrigin-RevId: d249408c26dfee0eaacbba62f362439cc8e0cb93
上级 31218a18
......@@ -777,8 +777,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src_or_dst_spatial_start = 1;
}
megdnn_assert(
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
errmsg().c_str());
cflt.icpg * cflt.group == src[src_or_dst_c_pos],
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
......@@ -792,8 +794,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -809,8 +813,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 8,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 8, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -826,8 +832,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 32,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -856,7 +864,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 8 ||
(cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : "
"\n%s",
src.ndim == 5 ? src[1] * 8 : src[1], cflt.icpg * cflt.group,
errmsg().c_str());
}
} else if (
......@@ -879,15 +891,21 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4 ||
(cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : "
"\n%s",
src.ndim == 5 ? src[1] * 4 : src[1], cflt.icpg * cflt.group,
errmsg().c_str());
}
} else if (param().format == Param::Format::CHWN4) {
megdnn_assert(
src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[0] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[0] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[3] = src[3];
auto oc = cflt.ocpg * cflt.group;
......@@ -903,8 +921,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = 4;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -918,8 +938,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = 4;
dst[0] = src[0];
dst[1] = infer_conv_shape(
......@@ -933,8 +955,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -950,8 +974,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 32,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -967,8 +993,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[1] * 64,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 64, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......@@ -985,8 +1013,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
src.ndim);
megdnn_assert(
cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
errmsg().c_str(), cflt.icpg, cflt.group);
cflt.icpg * cflt.group == src[2] * 4,
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[2] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
......
......@@ -148,7 +148,10 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd(
src_or_dst_spatial_start = 1;
}
megdnn_assert(
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", errmsg().c_str());
cflt.icpg * cflt.group == src[src_or_dst_c_pos],
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : \n%s",
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim;
dst[0] = src[0];
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册