From 858261af1ff66ccab0a829eee506ceeb6035faef Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 Mar 2021 17:51:21 +0800 Subject: [PATCH] fix(python_module): fix conversion between numpy-ndarray and mgb tensor for qint4 and quint4 GitOrigin-RevId: 7450c4f25e52fc334344b7853c0bd897a6212847 --- dnn/include/megdnn/oprs/base.h | 1 + dnn/src/fallback/convolution/opr_impl.cpp | 2 ++ src/core/impl/dtype.cpp | 31 +++++++++++++++++------ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index 56d3297f7..f62b8927f 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -90,6 +90,7 @@ enum class AlgoDataType : uint32_t { INT8X8X16 = 1 << 4, INT16X16X32 = 1 << 5, INT4X4X16 = 1 << 6, + QINT4x4x32 = 1 << 7, }; /*! diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 803de16a4..8e1d14829 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -434,6 +434,8 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { } } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { return ConvolutionImpl::AlgoDataType::QUINT8X8X32; + } else if (src_type.enumv() == DTypeEnum::QuantizedS4) { + return ConvolutionImpl::AlgoDataType::QINT4x4x32; } else { megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n", src_type.name(), filter_type.name(), diff --git a/src/core/impl/dtype.cpp b/src/core/impl/dtype.cpp index 85b8a0e57..91c8f23b4 100644 --- a/src/core/impl/dtype.cpp +++ b/src/core/impl/dtype.cpp @@ -14,7 +14,6 @@ #include "megbrain/exception.h" #include "megbrain/utils/metahelper.h" #include "megbrain/utils/arith_helper.h" -#include "megdnn/dtype.h" #include #include @@ -383,24 +382,40 @@ struct QuantizedLowbitMemcpy { // cast with bits that 8 % bits == 0 static constexpr uint16_t bits = DTypeTrait
::low_bit; static constexpr uint8_t MASK = (1 << bits) - 1; - using Trait = QuantizedLowbitTrait
; + static constexpr bool signedness = + std::is_same::value; static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { auto dest = static_cast(dest_raw); auto src = static_cast(src_raw); memset(dest, 0, divup(n * bits, 8)); for (size_t i = 0; i < n; ++i) { - int8_t val = src[i] + Trait::SHIFT; - mgb_assert(val >= 0 && val < (1 << bits)); - dest[i * bits / 8] |= val << (i * bits % 8); + int8_t val = src[i]; + static const auto min_val = DTypeTrait
::min(); + static const auto max_val = DTypeTrait
::max(); + MGB_MARK_USED_VAR(min_val); + MGB_MARK_USED_VAR(max_val); + mgb_assert(val >= static_cast(min_val) && + val <= static_cast(max_val), + "data exceeds range(%d,%d) of data type", min_val, + max_val); + dest[i * bits / 8] |= (val & MASK) << (i * bits % 8); } } static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { - auto dest = static_cast(dest_raw); + auto dest = reinterpret_cast(dest_raw); auto src = static_cast(src_raw); for (size_t i = 0; i < n; ++i) { - int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); - dest[i] = val - Trait::SHIFT; + uint8_t intermediate = + ((src[i * bits / 8] >> (i * bits % 8)) & MASK); + if (signedness) { + int val = (intermediate & uint8_t(1 << (bits - 1))) + ? ((int)(intermediate) | ~(int)(MASK)) + : (int)(intermediate); + dest[i] = static_cast(val); + } else { + dest[i] = static_cast(intermediate); + } } } }; -- GitLab