提交 8ac73333 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn): fix data type check for quantized convolution

GitOrigin-RevId: e0f97052ffedeb4ba5607dc32ac929ed19184970
上级 c59be192
......@@ -132,6 +132,15 @@ namespace megdnn {
cb(::megdnn::dtype::Quantized4Asymm) \
cb(::megdnn::dtype::QuantizedS4)
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \
cb(::megdnn::dtype::QuantizedS32) \
cb(::megdnn::dtype::QuantizedS8) \
cb(::megdnn::dtype::QuantizedS4)
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \
cb(::megdnn::dtype::Quantized8Asymm) \
cb(::megdnn::dtype::Quantized4Asymm)
/*!
* \brief a POD representation of a single byte
*
......
......@@ -604,9 +604,16 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
if (!dst.valid()) {
dst = supported_dst_dtype.at(0);
} else {
megdnn_assert(vec_contains(supported_dst_dtype, dst),
"unsupported Conv(%s, %s) -> %s", src.name(),
filter.name(), dst.name());
bool dst_supported = false;
for (auto&& dt : supported_dst_dtype) {
if (dtype_almost_equal(dt, dst)) {
dst_supported = true;
break;
}
}
MEGDNN_MARK_USED_VAR(dst_supported);
megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s",
src.name(), filter.name(), dst.name());
}
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16
......
......@@ -245,6 +245,25 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
}
// clang-format on
bool megdnn::dtype_almost_equal(DType lhs, DType rhs) {
if (lhs.enumv() != rhs.enumv())
return false;
if (lhs.category() != DTypeCategory::QUANTIZED)
return true;
#define cb(dt) \
if (lhs.enumv() == DTypeTrait<dt>::enumv) \
return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale);
MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb)
#undef cb
#define cb(dt) \
if (lhs.enumv() == DTypeTrait<dt>::enumv) \
return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale) && \
lhs.param<dt>().zero_point == rhs.param<dt>().zero_point;
MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb)
#undef cb
megdnn_assert_internal(false);
}
template <>
uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst,
size_t offset) {
......
......@@ -434,6 +434,20 @@ int8_t convert<dt_qint4, int8_t>(dt_qint4 src, int8_t dst, size_t offset);
template <>
dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset);
/*!
* \brief check float equal within given ULP(unit in the last place)
*/
template <class T>
static inline
typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type
almost_equal(T x, T y, int unit_last_place = 1) {
return std::abs(x - y) < (std::numeric_limits<T>::epsilon() *
std::abs(x + y) * unit_last_place) ||
std::abs(x - y) < std::numeric_limits<T>::min();
}
bool dtype_almost_equal(DType lhs, DType rhs);
/**
* \brief N-dimensional index space
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册