提交 858261af 编写于 作者: M Megvii Engine Team

fix(python_module): fix conversion between numpy-ndarray and mgb tensor for qint4 and quint4

GitOrigin-RevId: 7450c4f25e52fc334344b7853c0bd897a6212847
上级 e250afb0
......@@ -90,6 +90,7 @@ enum class AlgoDataType : uint32_t {
INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5,
INT4X4X16 = 1 << 6,
QINT4x4x32 = 1 << 7,
};
/*!
......
......@@ -434,6 +434,8 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
}
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
} else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
return ConvolutionImpl::AlgoDataType::QINT4x4x32;
} else {
megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n",
src_type.name(), filter_type.name(),
......
......@@ -14,7 +14,6 @@
#include "megbrain/exception.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/arith_helper.h"
#include "megdnn/dtype.h"
#include <cmath>
#include <cstring>
......@@ -383,24 +382,40 @@ struct QuantizedLowbitMemcpy<DT, true> {
// cast with bits that 8 % bits == 0
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit;
static constexpr uint8_t MASK = (1 << bits) - 1;
using Trait = QuantizedLowbitTrait<DT>;
static constexpr bool signedness =
std::is_same<DT, dtype::QuantizedS4>::value;
static void byte2compact(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<uint8_t*>(dest_raw);
auto src = static_cast<const int8_t*>(src_raw);
memset(dest, 0, divup<size_t>(n * bits, 8));
for (size_t i = 0; i < n; ++i) {
int8_t val = src[i] + Trait::SHIFT;
mgb_assert(val >= 0 && val < (1 << bits));
dest[i * bits / 8] |= val << (i * bits % 8);
int8_t val = src[i];
static const auto min_val = DTypeTrait<DT>::min();
static const auto max_val = DTypeTrait<DT>::max();
MGB_MARK_USED_VAR(min_val);
MGB_MARK_USED_VAR(max_val);
mgb_assert(val >= static_cast<int8_t>(min_val) &&
val <= static_cast<int8_t>(max_val),
"data exceeds range(%d,%d) of data type", min_val,
max_val);
dest[i * bits / 8] |= (val & MASK) << (i * bits % 8);
}
}
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<int8_t*>(dest_raw);
auto dest = reinterpret_cast<int8_t*>(dest_raw);
auto src = static_cast<const uint8_t*>(src_raw);
for (size_t i = 0; i < n; ++i) {
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK);
dest[i] = val - Trait::SHIFT;
uint8_t intermediate =
((src[i * bits / 8] >> (i * bits % 8)) & MASK);
if (signedness) {
int val = (intermediate & uint8_t(1 << (bits - 1)))
? ((int)(intermediate) | ~(int)(MASK))
: (int)(intermediate);
dest[i] = static_cast<int8_t>(val);
} else {
dest[i] = static_cast<int8_t>(intermediate);
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册