From 8ac7333378171dea62a5a9f793294a6f6f0efbc5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 6 Mar 2020 18:43:19 +0800 Subject: [PATCH] fix(dnn): fix data type check for quantized convolution GitOrigin-RevId: e0f97052ffedeb4ba5607dc32ac929ed19184970 --- dnn/include/megdnn/dtype.h | 9 +++++++++ dnn/src/common/convolution.cpp | 13 ++++++++++--- dnn/src/common/utils.cpp | 19 +++++++++++++++++++ dnn/src/common/utils.h | 14 ++++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index aae14fec1..df3319f0b 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -132,6 +132,15 @@ namespace megdnn { cb(::megdnn::dtype::Quantized4Asymm) \ cb(::megdnn::dtype::QuantizedS4) +#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ + cb(::megdnn::dtype::QuantizedS32) \ + cb(::megdnn::dtype::QuantizedS8) \ + cb(::megdnn::dtype::QuantizedS4) + +#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ + cb(::megdnn::dtype::Quantized8Asymm) \ + cb(::megdnn::dtype::Quantized4Asymm) + /*! * \brief a POD representation of a single byte * diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index 41b1249ab..0a89010fb 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -604,9 +604,16 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, if (!dst.valid()) { dst = supported_dst_dtype.at(0); } else { - megdnn_assert(vec_contains(supported_dst_dtype, dst), - "unsupported Conv(%s, %s) -> %s", src.name(), - filter.name(), dst.name()); + bool dst_supported = false; + for (auto&& dt : supported_dst_dtype) { + if (dtype_almost_equal(dt, dst)) { + dst_supported = true; + break; + } + } + MEGDNN_MARK_USED_VAR(dst_supported); + megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s", + src.name(), filter.name(), dst.name()); } megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 #if !MEGDNN_DISABLE_FLOAT16 diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 371afbb55..2dd2ca703 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -245,6 +245,25 @@ float megdnn::mul_scale(DType lhs, DType rhs) { } // clang-format on +bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { + if (lhs.enumv() != rhs.enumv()) + return false; + if (lhs.category() != DTypeCategory::QUANTIZED) + return true; +#define cb(dt) \ + if (lhs.enumv() == DTypeTrait
::enumv) \ + return almost_equal(lhs.param
().scale, rhs.param
().scale); + MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) +#undef cb +#define cb(dt) \ + if (lhs.enumv() == DTypeTrait
::enumv) \ + return almost_equal(lhs.param
().scale, rhs.param
().scale) && \ + lhs.param
().zero_point == rhs.param
().zero_point; + MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) +#undef cb + megdnn_assert_internal(false); +} + template <> uint8_t megdnn::convert(dt_quint4 src, uint8_t dst, size_t offset) { diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index a12f2a537..1061ab479 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -434,6 +434,20 @@ int8_t convert(dt_qint4 src, int8_t dst, size_t offset); template <> dt_qint4 convert(int8_t src, dt_qint4 dst, size_t offset); +/*! + * \brief check float equal within given ULP(unit in the last place) + */ +template +static inline + typename std::enable_if::is_integer, bool>::type + almost_equal(T x, T y, int unit_last_place = 1) { + return std::abs(x - y) < (std::numeric_limits::epsilon() * + std::abs(x + y) * unit_last_place) || + std::abs(x - y) < std::numeric_limits::min(); +} + +bool dtype_almost_equal(DType lhs, DType rhs); + /** * \brief N-dimensional index space */ -- GitLab