提交 f0291883 编写于 作者: M Megvii Engine Team

fix(mgb): make error infomation of group conv input channel mismatch more readable

GitOrigin-RevId: d249408c26dfee0eaacbba62f362439cc8e0cb93
上级 31218a18
...@@ -777,8 +777,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -777,8 +777,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src_or_dst_spatial_start = 1; src_or_dst_spatial_start = 1;
} }
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", cflt.icpg * cflt.group == src[src_or_dst_c_pos],
errmsg().c_str()); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group; dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
...@@ -792,8 +794,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -792,8 +794,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 4,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -809,8 +813,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -809,8 +813,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 8,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 8, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -826,8 +832,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -826,8 +832,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 32,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -856,7 +864,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -856,7 +864,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 8 || cflt.icpg * cflt.group == src[1] * 8 ||
(cflt.icpg * cflt.group == src[1]), (cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : "
"\n%s",
src.ndim == 5 ? src[1] * 8 : src[1], cflt.icpg * cflt.group,
errmsg().c_str());
} }
} else if ( } else if (
...@@ -879,15 +891,21 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -879,15 +891,21 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4 || cflt.icpg * cflt.group == src[1] * 4 ||
(cflt.icpg * cflt.group == src[1]), (cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : "
"\n%s",
src.ndim == 5 ? src[1] * 4 : src[1], cflt.icpg * cflt.group,
errmsg().c_str());
} }
} else if (param().format == Param::Format::CHWN4) { } else if (param().format == Param::Format::CHWN4) {
megdnn_assert( megdnn_assert(
src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[0] * 4,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[0] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[3] = src[3]; dst[3] = src[3];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -903,8 +921,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -903,8 +921,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 4,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = 4; dst.ndim = 4;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -918,8 +938,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -918,8 +938,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 4,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = 4; dst.ndim = 4;
dst[0] = src[0]; dst[0] = src[0];
dst[1] = infer_conv_shape( dst[1] = infer_conv_shape(
...@@ -933,8 +955,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -933,8 +955,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 4,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -950,8 +974,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -950,8 +974,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 32,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -967,8 +993,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -967,8 +993,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[1] * 64,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[1] * 64, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -985,8 +1013,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -985,8 +1013,10 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu", src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
src.ndim); src.ndim);
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u", cflt.icpg * cflt.group == src[2] * 4,
errmsg().c_str(), cflt.icpg, cflt.group); "group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details for src, filter and dst : \n%s",
src[2] * 4, cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
......
...@@ -148,7 +148,10 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd( ...@@ -148,7 +148,10 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd(
src_or_dst_spatial_start = 1; src_or_dst_spatial_start = 1;
} }
megdnn_assert( megdnn_assert(
cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", errmsg().c_str()); cflt.icpg * cflt.group == src[src_or_dst_c_pos],
"group conv channel mismatch : input channel got %zu, and "
"filter channel got %u. More details about src, filter and dst : \n%s",
src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
dst.ndim = src.ndim; dst.ndim = src.ndim;
dst[0] = src[0]; dst[0] = src[0];
dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group; dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册