提交 5b1383e0 编写于 作者: M Megvii Engine Team

fix(mgb/plugin): fix io dump for qint4, quint4 and bool type tensors

GitOrigin-RevId: bad1e8869069ae6cceb9d0ef2963e4815ec09018
上级 5e07e1e0
......@@ -34,6 +34,14 @@ double as_double(megdnn::dt_qint8& a) {
return static_cast<double>(a.as_int8());
}
template <>
double as_double(megdnn::dt_quint4& a) {
return static_cast<double>(a.as_uint8());
}
template <>
double as_double(megdnn::dt_qint4& a) {
return static_cast<double>(a.as_int8());
}
template <>
double as_double(megdnn::dt_qint32& a) {
return static_cast<double>(a.as_int32());
}
......@@ -69,7 +77,10 @@ void do_print_host_val(
sum2 += as_double(i) * as_double(i);
};
size_t nr = val.layout.total_nr_elems();
if (val.layout.is_contiguous()) {
bool normal_contig = !val.layout.dtype.is_low_bit() && val.layout.is_contiguous();
bool lowbit_contig =
val.layout.dtype.is_low_bit() && val.layout.is_physical_contiguous();
if (normal_contig || lowbit_contig) {
ctype* ptr = val.ptr<ctype>();
for (size_t i = 0; i < nr; ++i) {
update(ptr[i]);
......@@ -99,13 +110,14 @@ void print_host_val(
fout, max_nr_print, val, print_stat);
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(dtype::Bool)
#undef cb
default:
mgb_throw(
MegBrainError,
"can not handle dtype %s in "
"print_host_val",
val.layout.dtype.name());
default : mgb_throw(
MegBrainError,
"can not handle dtype %s in "
"print_host_val",
val.layout.dtype.name());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册