提交 f5597d9a 编写于 作者: M Megvii Engine Team

fix(mgb): make error infomation of input channel mismatch more readable

GitOrigin-RevId: 6f95260070bf826dc78cd23a3e62548c0d1cb9a8
上级 38bd5999
......@@ -573,8 +573,15 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
filter.param<dtype::QuantizedS8>().scale));
} else {
megdnn_throw(ssprintf(
"unsupported input / filter DType: %s x %s", src.name(),
filter.name()));
"runtime does not support input / filter DType: %s x %s"
"now support case list: FLOAT x FLOAT\n"
" Int8 x Int8\n"
" QuantizedS8 x QuantizedS8\n"
" Quantized8Asymm x Quantized8Asymm\n"
" QuantizedS4 x QuantizedS4\n"
" Quantized4Asymm x Quantized4Asymm\n"
" QuantizedS1 x QuantizedS1\n",
src.name(), filter.name()));
}
if (!dst.valid()) {
dst = supported_dst_dtype.at(0);
......@@ -588,8 +595,21 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
}
MEGDNN_MARK_USED_VAR(dst_supported);
megdnn_assert(
dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
filter.name(), dst.name());
dst_supported,
"runtime does not support Conv(%s, %s) -> %s"
"now support case list: Conv(FLOAT x FLOAT) -> FLOAT\n"
" Conv(Int8 x Int8) -> Int32\n"
" Conv(QuantizedS8 x QuantizedS8) -> "
"QuantizedS32\n"
" Conv(Quantized8Asymm x Quantized8Asymm) -> "
"Quantized32Asymm\n"
" Conv(QuantizedS4 x QuantizedS4) -> "
"QuantizedS32\n"
" Conv(Quantized4Asymm x Quantized4Asymm) -> "
"Quantized32Asymm\n"
" Conv(QuantizedS1 x QuantizedS1) -> "
"QuantizedS32\n",
src.name(), filter.name(), dst.name());
}
megdnn_assert(
(param().compute_mode == Param::ComputeMode::FLOAT32 ||
......@@ -1098,15 +1118,26 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad
}
} else {
megdnn_throw(ssprintf(
"unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
"runtime does not support input / diff DType: %s x %s"
"now support case list: FLOAT x FLOAT\n"
" Int8 x Int8\n"
" QuantizedS8 x QuantizedS8\n"
" Quantized8Asymm x Quantized8Asymm\n",
filter.name(), diff.name()));
}
if (!grad.valid()) {
grad = supported_dst_dtype.at(0);
} else {
megdnn_assert(
vec_contains(supported_dst_dtype, grad),
"unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
grad.name());
"runtime does not support ConvBwd(%s, %s) -> %s"
"now support case list: ConvBwd(FLOAT x FLOAT) -> FLOAT\n"
" ConvBwd(Int8 x Int8) -> Int32\n"
" ConvBwd(QuantizedS8 x QuantizedS8) -> "
"QuantizedS32\n"
" ConvBwd(Quantized8Asymm x Quantized8Asymm) -> "
"Quantized32Asymm\n",
filter.name(), diff.name(), grad.name());
}
megdnn_assert(
param().compute_mode != Param::ComputeMode::FLOAT32
......
......@@ -95,8 +95,11 @@ TensorLayout do_shape_infer(
dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
}
mgb_assert(icpg * group == src[src_or_dst_c_pos], "group conv invalid");
mgb_assert(
icpg * group == src[src_or_dst_c_pos],
"group conv invalid: input channel of Conv expect %zu, but got %zu\n"
"hint: weight may be changed by mistake\n",
icpg * group, src[src_or_dst_c_pos]);
TensorLayout dst{src.dtype};
dst.ndim = src_ndim;
dst[0] = src[0];
......@@ -310,8 +313,11 @@ TensorLayout convbwd_do_shape_infer(
dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
}
mgb_assert(ocpg * group == diff[src_or_dst_c_pos], "group conv invalid");
mgb_assert(
ocpg * group == diff[src_or_dst_c_pos],
"group conv invalid: input channel of Conv expect %zu, but got %zu\n"
"hint: weight may be changed by mistake\n",
ocpg * group, diff[src_or_dst_c_pos]);
auto deduce = [](size_t out, size_t filter, size_t stride, size_t pad) {
auto i = (out - 1) * stride + filter;
mgb_assert(i > pad * 2);
......@@ -479,8 +485,11 @@ TensorLayout do_shape_infer(
dilated_spatial[i] =
(filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1;
}
mgb_assert(icpg * group == src[src_or_dst_c_pos], "group conv invalid");
mgb_assert(
icpg * group == src[src_or_dst_c_pos],
"group conv invalid: input channel of Conv expect %zu, but got %zu\n"
"hint: weight may be changed by mistake\n",
icpg * group, src[src_or_dst_c_pos]);
TensorLayout dst{src.dtype};
dst.ndim = src_ndim;
dst[0] = src[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册