Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fd6f8e58
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fd6f8e58
编写于
3月 15, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/dtype): add dtype qint1
GitOrigin-RevId: abe9fb68b15aa2a73ffa3e2c44175639d5307a57
上级
616352b0
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
214 addition
and
17 deletion
+214
-17
dnn/include/megdnn/dtype.h
dnn/include/megdnn/dtype.h
+40
-3
dnn/src/common/dtype.cpp
dnn/src/common/dtype.cpp
+13
-0
dnn/src/common/utils.cpp
dnn/src/common/utils.cpp
+3
-1
dnn/src/cuda/elemwise_helper.cpp
dnn/src/cuda/elemwise_helper.cpp
+7
-0
dnn/src/cuda/elemwise_helper.cuh
dnn/src/cuda/elemwise_helper.cuh
+4
-0
dnn/src/cuda/type_cvt/kern.cu
dnn/src/cuda/type_cvt/kern.cu
+7
-2
dnn/src/cuda/type_cvt/opr_impl.cpp
dnn/src/cuda/type_cvt/opr_impl.cpp
+4
-2
dnn/src/cuda/utils.cuh
dnn/src/cuda/utils.cuh
+17
-0
dnn/src/fallback/type_cvt/opr_impl.cpp
dnn/src/fallback/type_cvt/opr_impl.cpp
+3
-1
dnn/src/naive/type_cvt/opr_impl.cpp
dnn/src/naive/type_cvt/opr_impl.cpp
+4
-2
dnn/test/common/checker.cpp
dnn/test/common/checker.cpp
+3
-2
dnn/test/common/dtype.cpp
dnn/test/common/dtype.cpp
+26
-0
dnn/test/common/rng.cpp
dnn/test/common/rng.cpp
+1
-1
dnn/test/common/utils.h
dnn/test/common/utils.h
+13
-0
dnn/test/cuda/type_cvt.cpp
dnn/test/cuda/type_cvt.cpp
+4
-0
imperative/python/megengine/core/tensor/dtype.py
imperative/python/megengine/core/tensor/dtype.py
+27
-0
imperative/python/src/helper.cpp
imperative/python/src/helper.cpp
+12
-2
imperative/python/test/unit/core/test_dtype_quant.py
imperative/python/test/unit/core/test_dtype_quant.py
+17
-0
src/plugin/impl/opr_io_dump.cpp
src/plugin/impl/opr_io_dump.cpp
+5
-1
src/serialization/impl/dtype.fbs
src/serialization/impl/dtype.fbs
+1
-0
src/serialization/impl/flatbuffers_helper.cpp
src/serialization/impl/flatbuffers_helper.cpp
+3
-0
未找到文件。
dnn/include/megdnn/dtype.h
浏览文件 @
fd6f8e58
...
...
@@ -62,7 +62,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \
cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \
cb(QuantizedS16)
cb(QuantizedS16)
cb(QuantizedS1)
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \
...
...
@@ -112,7 +112,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \
cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \
cb(::megdnn::dtype::QuantizedS4)
cb(::megdnn::dtype::QuantizedS4)
cb(::megdnn::dtype::QuantizedS1)
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \
cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm)
...
...
@@ -292,10 +292,27 @@ public:
};
using
dt_qint4
=
dt_qlowbit
<
4
>
;
class
dt_qint1
{
int8_t
_
;
public:
MEGDNN_DEVICE
int8_t
as_int8
()
const
{
return
_
;
}
MEGDNN_HOST
MEGDNN_DEVICE
explicit
dt_qint1
(
int8_t
val
)
:
_
(
val
)
{}
#ifdef MEGDNN_CC_HOST
explicit
operator
int8_t
()
{
return
_
;
}
#endif
bool
operator
<
(
const
dt_qint1
&
b
)
const
{
return
_
<
b
.
_
;
}
bool
operator
>
(
const
dt_qint1
&
b
)
const
{
return
_
>
b
.
_
;
}
bool
operator
==
(
const
dt_qint1
&
b
)
const
{
return
_
==
b
.
_
;
}
bool
operator
!=
(
const
dt_qint1
&
b
)
const
{
return
_
!=
b
.
_
;
}
}
MEGDNN_PACKED
;
#ifdef __clang__
#pragma clang diagnostic pop
#endif
MEGDNN_STATIC_ASSERT
(
sizeof
(
dt_byte
)
==
1
,
"bad dt_byte size"
);
MEGDNN_STATIC_ASSERT
(
sizeof
(
dt_qint1
)
==
1
,
"bad dt_qint1 size"
);
MEGDNN_STATIC_ASSERT
(
sizeof
(
dt_quint8
)
==
1
,
"bad dt_quint8 size"
);
MEGDNN_STATIC_ASSERT
(
sizeof
(
dt_qint16
)
==
2
,
"bad dt_qint16 size"
);
MEGDNN_STATIC_ASSERT
(
sizeof
(
dt_qint32
)
==
4
,
"bad dt_qint32 size"
);
...
...
@@ -677,7 +694,7 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT)
return static_cast<_itype>(_maxval); \
} \
};
MEGDNN_DEF_PARAMETERIZED_DT
(
QuantizedS1
,
dt_qint1
,
int8_t
,
QUANTIZED
,
SIGNED
,
0
,
1
,
0
);
MEGDNN_DEF_PARAMETERIZED_DT
(
Quantized4Asymm
,
dt_quint4
,
uint8_t
,
QUANTIZED
,
SIGNED
,
0
,
15
,
4
);
MEGDNN_DEF_PARAMETERIZED_DT
(
QuantizedS4
,
dt_qint4
,
int8_t
,
QUANTIZED
,
SIGNED
,
-
8
,
7
,
4
);
...
...
@@ -876,6 +893,26 @@ struct DTypeParamImpl<dt_quint4> {
}
};
template
<
>
struct
DTypeParamImpl
<
dt_qint1
>
{
float
scale
;
DTypeParamImpl
<
dt_qint1
>
()
=
default
;
MGE_WIN_DECLSPEC_FUC
DTypeParamImpl
<
dt_qint1
>
(
float
scale
);
#ifdef MEGDNN_CC_HOST
std
::
size_t
hash
()
const
;
#endif
bool
operator
==
(
const
DTypeParam
<
dt_qint1
>&
rhs
)
const
;
MEGDNN_DEVICE
dt_qint1
quantize
(
float
in
)
const
{
float
v
=
in
/
scale
;
v
=
roundf
(
v
);
v
=
fmin
(
fmax
(
0.
f
,
v
),
1.
f
);
return
static_cast
<
dt_qint1
>
(
v
);
}
MEGDNN_DEVICE
float
dequantize
(
int8_t
in
)
const
{
return
in
*
scale
;
}
MEGDNN_DEVICE
float
dequantize
(
dt_qint1
in
)
const
{
return
in
.
as_int8
()
*
scale
;
}
};
template
<
>
struct
DTypeParamImpl
<
dt_qint4
>
{
float
scale
;
...
...
dnn/src/common/dtype.cpp
浏览文件 @
fd6f8e58
...
...
@@ -142,6 +142,19 @@ inline bool DTypeParam<dt_qint32>::operator==(const DTypeParam<dt_qint32>& rhs)
return
scale
==
rhs
.
scale
;
}
DTypeParam
<
dt_qint1
>::
DTypeParamImpl
(
float
scale
)
:
scale
{
scale
}
{
//! As the nan is not equal to any value
megdnn_assert
(
!
std
::
isnan
(
scale
),
"nan number compare is not support"
);
}
inline
std
::
size_t
DTypeParam
<
dt_qint1
>::
hash
()
const
{
return
std
::
hash
<
float
>
()(
scale
);
}
inline
bool
DTypeParam
<
dt_qint1
>::
operator
==
(
const
DTypeParam
<
dt_qint1
>&
rhs
)
const
{
return
scale
==
rhs
.
scale
;
}
DTypeParam
<
dt_quint4
>::
DTypeParamImpl
(
float
scale
,
uint8_t
zero_point
)
:
scale
{
scale
},
zero_point
{
zero_point
}
{
//! As the nan is not equal to any value
...
...
dnn/src/common/utils.cpp
浏览文件 @
fd6f8e58
...
...
@@ -241,6 +241,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
return lhs.param<dt>().scale * rhs.param<dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
megdnn_assert_internal
(
0
);
}
...
...
@@ -253,8 +254,9 @@ float megdnn::get_scale(DType dt) {
return dt.param<_dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
megdnn_assert_internal
(
0
);
megdnn_assert_internal
(
0
);
}
bool
megdnn
::
dtype_almost_equal
(
DType
lhs
,
DType
rhs
)
{
...
...
dnn/src/cuda/elemwise_helper.cpp
浏览文件 @
fd6f8e58
...
...
@@ -160,6 +160,9 @@ INST_FOR_CTYPE
#define ct dt_bool
INST_FOR_CTYPE
#undef ct
#define ct dt_qint1
INST_FOR_CTYPE
#undef ct
#undef INST_FOR_CTYPE
#undef INST
...
...
@@ -210,6 +213,9 @@ INST_FOR_CTYPE
#define ct dt_bool
INST_FOR_CTYPE
#undef ct
#define ct dt_qint1
INST_FOR_CTYPE
#undef ct
#undef ndim_cb
...
...
@@ -221,6 +227,7 @@ INST(dt_int8);
INST
(
dt_uint8
);
INST
(
dt_bool
);
INST
(
dt_qint8
);
INST
(
dt_qint1
);
INST
(
dt_quint8
);
#undef dt_ibyte
...
...
dnn/src/cuda/elemwise_helper.cuh
浏览文件 @
fd6f8e58
...
...
@@ -96,6 +96,7 @@ INST(dt_bool, uchar4);
#undef as_raw
#define as_raw(x) x.as_int8()
INST
(
dt_qint8
,
char4
);
INST
(
dt_qint1
,
char4
);
#undef as_raw
#define as_raw(x) x.as_uint8()
INST
(
dt_quint8
,
uchar4
);
...
...
@@ -466,6 +467,7 @@ INST_PARAM_VECT_VISITOR;
INST_DT_IBYTE
(
dt_int8
);
INST_DT_IBYTE
(
dt_uint8
);
INST_DT_IBYTE
(
dt_qint8
);
INST_DT_IBYTE
(
dt_qint1
);
INST_DT_IBYTE
(
dt_quint8
);
INST_DT_IBYTE
(
dt_bool
);
#undef INST_DT_IBYTE
...
...
@@ -1299,6 +1301,7 @@ private:
INST_DT_IBYTE
(
dt_int8
);
INST_DT_IBYTE
(
dt_uint8
);
INST_DT_IBYTE
(
dt_qint8
);
INST_DT_IBYTE
(
dt_qint1
);
INST_DT_IBYTE
(
dt_quint8
);
INST_DT_IBYTE
(
dt_bool
);
#undef INST_DT_IBYTE
...
...
@@ -1649,6 +1652,7 @@ public:
INST_DT_IBYTE
(
dt_int8
);
INST_DT_IBYTE
(
dt_uint8
);
INST_DT_IBYTE
(
dt_qint8
);
INST_DT_IBYTE
(
dt_qint1
);
INST_DT_IBYTE
(
dt_quint8
);
INST_DT_IBYTE
(
dt_bool
);
#undef INST_DT_IBYTE
...
...
dnn/src/cuda/type_cvt/kern.cu
浏览文件 @
fd6f8e58
...
...
@@ -88,6 +88,7 @@ struct TypeCvtOpToQuantized<
typename
std
::
enable_if
<
std
::
is_same
<
ctype_src
,
dt_int8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_uint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_qint1
>::
value
||
std
::
is_same
<
ctype_src
,
dt_bool
>::
value
>::
type
>
{
ctype_dest
*
dest
;
CudaDTypeParam
<
ctype_dest
>
param
;
...
...
@@ -111,6 +112,7 @@ struct TypeCvtOpFromQuantized<
ctype_dest
,
ctype_src
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype_src
,
dt_qint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_qint1
>::
value
||
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
>::
type
>
{
ctype_dest
*
dest
;
CudaDTypeParam
<
ctype_src
>
param
;
...
...
@@ -134,7 +136,8 @@ struct TypeCvtOpBetweenQuantized<
ctype_dest
,
ctype_src
,
typename
std
::
enable_if
<
(
std
::
is_same
<
ctype_src
,
dt_qint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
)
&&
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_qint1
>::
value
)
&&
IsNotTypeQ4
<
ctype_dest
>::
value
>::
type
>
{
ctype_dest
*
dest
;
CudaDTypeParam
<
ctype_src
>
src_param
;
...
...
@@ -306,6 +309,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dtype_src, dt_quint8) \
cb(dtype_src, dt_qint32) \
cb(dtype_src, dt_qint8) \
cb(dtype_src, dt_qint1) \
#define INST_SRC_QUANTIZED(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \
...
...
@@ -330,7 +334,8 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dt_qint32) \
cb(dt_qint8) \
cb(dt_qint4) \
cb(dt_quint4)
cb(dt_quint4) \
cb(dt_qint1)
MEGDNN_FOREACH_QUANTIZED_CTYPE
(
INST_SRC_QUANTIZED
)
MEGDNN_FOREACH_COMPUTING_CTYPE
(
INST_SRC_NORMAL
)
...
...
dnn/src/cuda/type_cvt/opr_impl.cpp
浏览文件 @
fd6f8e58
...
...
@@ -50,6 +50,7 @@ void exec_src_quantized(
return; \
}
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
);
cb
(
::
megdnn
::
dtype
::
QuantizedS1
);
default:
megdnn_assert_internal
(
0
);
#undef cb
...
...
@@ -101,6 +102,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre
return; \
}
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
);
cb
(
::
megdnn
::
dtype
::
QuantizedS1
);
#undef cb
default:
megdnn_assert_internal
(
0
);
...
...
@@ -150,9 +152,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
default:
megdnn_assert_internal
(
0
);
default
:
megdnn_assert_internal
(
0
);
}
}
}
...
...
dnn/src/cuda/utils.cuh
浏览文件 @
fd6f8e58
...
...
@@ -241,6 +241,23 @@ struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> {
}
};
template
<
>
struct
CudaDTypeParamImpl
<
dt_qint1
>
:
DTypeParamImpl
<
dt_qint1
>
{
float
inv_scale
;
CudaDTypeParamImpl
()
=
default
;
CudaDTypeParamImpl
(
float
scale
)
:
DTypeParamImpl
<
dt_qint1
>
(
scale
),
inv_scale
(
1.0
f
/
scale
)
{}
CudaDTypeParamImpl
(
const
DTypeParamImpl
<
dt_qint1
>&
param
)
:
CudaDTypeParamImpl
(
param
.
scale
)
{}
__device__
dt_qint1
quantize
(
float
in
)
const
{
float
v
=
in
*
inv_scale
;
v
=
roundf
(
v
);
v
=
fmin
(
fmax
(
0.
f
,
v
),
1.
f
);
return
static_cast
<
dt_qint1
>
(
v
);
}
};
#if MEGDNN_CC_CUDA
static
inline
MEGDNN_DEVICE
void
dot_prod
(
int
a
,
int
b
,
int
c
,
int
&
d
)
{
#if __CUDA_ARCH__ >= 610
...
...
dnn/src/fallback/type_cvt/opr_impl.cpp
浏览文件 @
fd6f8e58
...
...
@@ -510,7 +510,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
};
if
(
src
.
layout
.
is_contiguous
()
&&
dst
.
layout
.
is_contiguous
()
&&
!
is_quantize_lowbit
(
src
.
layout
.
dtype
)
&&
!
is_quantize_lowbit
(
dst
.
layout
.
dtype
))
{
!
is_quantize_lowbit
(
dst
.
layout
.
dtype
)
&&
dst
.
layout
.
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS1
&&
src
.
layout
.
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS1
)
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
run_contiguous
(
src
,
dst
));
}
else
{
naive
::
TypeCvtImpl
::
exec
(
src
,
dst
);
...
...
dnn/src/naive/type_cvt/opr_impl.cpp
浏览文件 @
fd6f8e58
...
...
@@ -79,8 +79,9 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
default
:
megdnn_throw
(
"bad dtype"
);
default
:
megdnn_throw
(
"bad dtype"
);
}
}
...
...
@@ -100,8 +101,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
default
:
megdnn_throw
(
"bad dtype"
);
default
:
megdnn_throw
(
"bad dtype"
);
}
}
...
...
dnn/test/common/checker.cpp
浏览文件 @
fd6f8e58
...
...
@@ -79,7 +79,8 @@ template <typename ctype>
const
char
*
expr0
,
const
char
*
expr1
,
const
TensorND
&
v0
,
const
TensorND
&
v1
,
float
maxerr
,
float
maxerr_avg
,
float
maxerr_avg_biased
,
bool
allow_invalid
)
{
if
(
!
std
::
is_same
<
ctype
,
dt_qint4
>::
value
&&
!
std
::
is_same
<
ctype
,
dt_quint4
>::
value
)
{
!
std
::
is_same
<
ctype
,
dt_quint4
>::
value
&&
!
std
::
is_same
<
ctype
,
dt_qint1
>::
value
)
{
if
(
v0
.
layout
.
is_physical_contiguous
()
&&
v1
.
layout
.
is_physical_contiguous
())
{
return
assert_tensor_eq_with_iter
<
ctype
>
(
expr0
,
expr1
,
v0
.
ptr
<
ctype
>
(),
v1
.
ptr
<
ctype
>
(),
v0
.
layout
,
maxerr
,
...
...
@@ -158,7 +159,7 @@ void copy_tensors(
//! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb
(
::
megdnn
::
dtype
::
QuantizedS16
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
default
:
megdnn_trap
();
}
...
...
dnn/test/common/dtype.cpp
浏览文件 @
fd6f8e58
...
...
@@ -71,6 +71,32 @@ TEST(TestDType, TestQuantized8Asymm) {
EXPECT_ANY_THROW
(
DType
::
from_enum
(
DTypeEnum
::
Quantized8Asymm
));
}
TEST
(
TestDType
,
QuantizedS1
)
{
using
namespace
megdnn
;
dtype
::
QuantizedS1
qint1
(
0.1
f
);
EXPECT_EQ
(
qint1
.
size
(
1
),
1u
);
EXPECT_FLOAT_EQ
(
qint1
.
param
().
scale
,
0.1
f
);
dtype
::
QuantizedS1
qint1_copy
=
qint1
;
EXPECT_NO_THROW
(
qint1_copy
.
assert_is
(
qint1
));
EXPECT_FLOAT_EQ
(
qint1_copy
.
param
().
scale
,
0.1
f
);
dtype
::
QuantizedS1
qint1_reconstruct_with_same_param
(
0.1
f
);
EXPECT_NO_THROW
(
qint1_reconstruct_with_same_param
.
assert_is
(
qint1
));
dtype
::
QuantizedS1
qint1_diff
(
0.2
f
);
EXPECT_ANY_THROW
(
qint1_diff
.
assert_is
(
qint1
));
DType
parent
=
qint1
;
ASSERT_NO_THROW
(
dtype
::
QuantizedS1
::
downcast_from
(
parent
));
auto
param
=
dtype
::
QuantizedS1
::
downcast_from
(
parent
).
param
();
EXPECT_FLOAT_EQ
(
param
.
scale
,
0.1
f
);
EXPECT_ANY_THROW
(
dtype
::
QuantizedS1
::
downcast_from
(
dtype
::
IntB1
()));
EXPECT_ANY_THROW
(
DType
::
from_enum
(
DTypeEnum
::
QuantizedS1
));
}
TEST
(
TestDType
,
TestQuantizedS4
)
{
using
namespace
megdnn
;
...
...
dnn/test/common/rng.cpp
浏览文件 @
fd6f8e58
...
...
@@ -149,7 +149,7 @@ void IIDRNG::gen(const TensorND& tensor) {
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
//! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb
(
::
megdnn
::
dtype
::
QuantizedS16
)
cb
(
::
megdnn
::
dtype
::
QuantizedS16
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
if
(
tensor
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
auto
ptr
=
static_cast
<
uint8_t
*>
(
tensor
.
raw_ptr
());
...
...
dnn/test/common/utils.h
浏览文件 @
fd6f8e58
...
...
@@ -226,6 +226,10 @@ static inline int diff(dt_qint4 x, dt_qint4 y) {
return
x
.
as_int8
()
-
y
.
as_int8
();
}
static
inline
int
diff
(
dt_qint1
x
,
dt_qint1
y
)
{
return
x
.
as_int8
()
-
y
.
as_int8
();
}
static
inline
int
diff
(
dt_quint4
x
,
dt_quint4
y
)
{
return
x
.
as_uint8
()
-
y
.
as_uint8
();
}
...
...
@@ -339,6 +343,10 @@ static inline bool good_float(dt_qint4) {
return
true
;
}
static
inline
bool
good_float
(
dt_qint1
)
{
return
true
;
}
static
inline
bool
good_float
(
dt_quint4
)
{
return
true
;
}
...
...
@@ -373,6 +381,11 @@ static inline int operator+(dt_qint4 lhs, int rhs) {
megdnn_assert
(
rhs
==
0
,
"unexpected rhs"
);
return
lhs
.
as_int8
();
}
static
inline
int
operator
+
(
dt_qint1
lhs
,
int
rhs
)
{
megdnn_assert
(
rhs
==
0
,
"unexpected rhs"
);
return
lhs
.
as_int8
();
}
}
// namespace test
static
inline
bool
operator
==
(
const
TensorLayout
&
a
,
const
TensorLayout
&
b
)
{
...
...
dnn/test/cuda/type_cvt.cpp
浏览文件 @
fd6f8e58
...
...
@@ -77,16 +77,19 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) {
};
run
(
dtype
::
Float32
(),
dtype
::
QuantizedS8
(
3.0
f
));
run
(
dtype
::
Float32
(),
dtype
::
QuantizedS1
(
3.0
f
));
run
(
dtype
::
Float16
(),
dtype
::
QuantizedS8
(
3.0
f
));
run
(
dtype
::
Int32
(),
dtype
::
QuantizedS32
(
5.0
f
));
run
(
dtype
::
Int8
(),
dtype
::
QuantizedS32
(
10.0
f
));
run
(
dtype
::
Float32
(),
dtype
::
QuantizedS8
(
2e-3
f
));
run
(
dtype
::
Float32
(),
dtype
::
QuantizedS1
(
2e-3
f
));
run
(
dtype
::
Float16
(),
dtype
::
QuantizedS8
(
1e-3
f
));
run
(
dtype
::
Int32
(),
dtype
::
QuantizedS32
(
1e-3
f
));
run
(
dtype
::
Int8
(),
dtype
::
QuantizedS32
(
7e-4
f
));
run
(
dtype
::
QuantizedS8
(
3.0
f
),
dtype
::
QuantizedS8
(
10.0
f
));
run
(
dtype
::
QuantizedS1
(
3.0
f
),
dtype
::
QuantizedS1
(
10.0
f
));
run
(
dtype
::
QuantizedS32
(
3.0
f
),
dtype
::
QuantizedS8
(
10.0
f
));
run
(
dtype
::
QuantizedS8
(
3.0
f
),
dtype
::
QuantizedS32
(
10.0
f
));
run
(
dtype
::
QuantizedS32
(
3.0
f
),
dtype
::
QuantizedS32
(
10.0
f
));
...
...
@@ -95,6 +98,7 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) {
run
(
dtype
::
QuantizedS32
(
2e-3
f
),
dtype
::
QuantizedS8
(
9e-4
f
));
run
(
dtype
::
QuantizedS8
(
9e-4
f
),
dtype
::
QuantizedS32
(
7e-4
f
));
run
(
dtype
::
QuantizedS32
(
5e-3
f
),
dtype
::
QuantizedS32
(
1e-3
f
));
run
(
dtype
::
QuantizedS1
(
1e-3
f
),
dtype
::
Float32
());
run
(
dtype
::
Quantized8Asymm
(
5.0
f
,
(
uint8_t
)
128
),
dtype
::
Float32
());
run
(
dtype
::
Quantized8Asymm
(
5.0
f
,
(
uint8_t
)
124
),
dtype
::
Float16
());
...
...
imperative/python/megengine/core/tensor/dtype.py
浏览文件 @
fd6f8e58
...
...
@@ -94,6 +94,7 @@ _builtin_quant_dtypes = {
"qint8_narrow"
:
QuantDtypeMeta
(
"qint8_narrow"
,
"QuantizedS8"
,
"int8"
,
-
127
,
127
),
"quint4"
:
QuantDtypeMeta
(
"quint4"
,
"Quantized4Asymm"
,
"uint8"
,
0
,
15
),
"qint4"
:
QuantDtypeMeta
(
"qint4"
,
"QuantizedS4"
,
"int8"
,
-
8
,
7
),
"qint1"
:
QuantDtypeMeta
(
"qint1"
,
"QuantizedS1"
,
"int8"
,
0
,
1
),
"qint32"
:
QuantDtypeMeta
(
"qint32"
,
"QuantizedS32"
,
"int32"
,
-
(
2
**
31
),
2
**
31
-
1
,
),
...
...
@@ -192,6 +193,13 @@ def qint4(scale):
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"qint4"
],
scale
,
None
)
def
qint1
(
scale
):
r
"""Construct a quantized int1 data type with ``scale`` (float). The real value
represented by a qint1 data type is float_val = scale * int1_val
"""
return
create_quantized_dtype
(
_builtin_quant_dtypes
[
"qint1"
],
scale
,
None
)
def
_convert_to_quantized_dtype
(
arr
:
np
.
ndarray
,
dtype
:
np
.
dtype
,
dtype_meta
:
QuantDtypeMeta
):
...
...
@@ -335,3 +343,22 @@ def convert_from_qint4(arr: np.ndarray):
arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"qint4"
])
def
convert_to_qint1
(
arr
:
np
.
ndarray
,
q
:
np
.
dtype
):
r
"""Quantize a float NumPy ndarray into a qint1 one with specified params.
Args:
arr: Input ndarray.
q: Target data type, should be a qint1.
"""
return
_convert_to_quantized_dtype
(
arr
,
q
,
_builtin_quant_dtypes
[
"qint1"
])
def
convert_from_qint1
(
arr
:
np
.
ndarray
):
r
"""Dequantize a qint1 NumPy ndarray into a float one.
Args:
arr: Input ndarray.
"""
return
_convert_from_quantized_dtype
(
arr
,
_builtin_quant_dtypes
[
"qint1"
])
imperative/python/src/helper.cpp
浏览文件 @
fd6f8e58
...
...
@@ -214,6 +214,14 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty
if
(
dtype
.
has_param
())
{
PyArray_Descr
*
type_descr
;
switch
(
dtype
.
enumv
())
{
case
DTypeEnum
::
QuantizedS1
:
{
auto
&
param
=
dtype
.
param
<
dtype
::
QuantizedS1
>
();
type_descr
=
PyArray_DescrNewFromType
(
NPY_INT8
);
type_descr
->
metadata
=
build_mgb_dtype_dict
(
DTypeTrait
<
dtype
::
QuantizedS1
>::
name
,
{{
"scale"
,
PyFloat_FromDouble
(
param
.
scale
)}});
break
;
}
case
DTypeEnum
::
Quantized4Asymm
:
{
auto
&
param
=
dtype
.
param
<
dtype
::
Quantized4Asymm
>
();
type_descr
=
PyArray_DescrNewFromType
(
NPY_UINT8
);
...
...
@@ -354,7 +362,7 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
static_cast
<
uint8_t
>
(
zero_point
));
}
if
(
dtype_name
==
"QuantizedS32"
||
dtype_name
==
"QuantizedS8"
||
dtype_name
==
"QuantizedS4"
)
{
dtype_name
==
"QuantizedS4"
||
dtype_name
==
"QuantizedS1"
)
{
PyObject
*
scale_py
=
PyDict_GetItemString
(
metadata
,
"scale"
);
mgb_assert
(
scale_py
,
"Invalid metadata: missing scale"
);
mgb_assert
(
...
...
@@ -364,8 +372,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
return
dtype
::
QuantizedS32
(
scale
);
}
else
if
(
dtype_name
==
"QuantizedS8"
)
{
return
dtype
::
QuantizedS8
(
scale
);
}
else
{
}
else
if
(
dtype_name
==
"QuantizedS4"
)
{
return
dtype
::
QuantizedS4
(
scale
);
}
else
if
(
dtype_name
==
"QuantizedS1"
)
{
return
dtype
::
QuantizedS1
(
scale
);
}
}
throw
ConversionError
(
...
...
imperative/python/test/unit/core/test_dtype_quant.py
浏览文件 @
fd6f8e58
...
...
@@ -15,10 +15,12 @@ import megengine.core.tensor.megbrain_graph as G
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.tensor.dtype
import
(
_builtin_quant_dtypes
,
convert_from_qint1
,
convert_from_qint4
,
convert_from_qint8
,
convert_from_quint4
,
convert_from_quint8
,
convert_to_qint1
,
convert_to_qint4
,
convert_to_qint8
,
convert_to_quint4
,
...
...
@@ -26,6 +28,7 @@ from megengine.core.tensor.dtype import (
get_scale
,
get_zero_point
,
is_quantize
,
qint1
,
qint4
,
qint8
,
quint4
,
...
...
@@ -113,9 +116,20 @@ def test_dtype_qint4():
np
.
testing
.
assert_allclose
(
get_scale
(
dt
),
0.01
)
def
test_dtype_qint1
():
dt
=
qint1
(
0.01
)
assert
isinstance
(
dt
,
np
.
dtype
)
assert
"mgb_dtype"
in
dt
.
metadata
np
.
testing
.
assert_allclose
(
dt
.
metadata
[
"mgb_dtype"
][
"scale"
],
0.01
)
assert
is_quantize
(
dt
)
np
.
testing
.
assert_allclose
(
get_scale
(
dt
),
0.01
)
@
pytest
.
mark
.
parametrize
(
"dtype, dtype_name"
,
[
(
qint1
(
0.01
),
"qint1"
),
(
quint4
(
0.01
,
5
),
"quint4"
),
(
qint4
(
0.01
),
"qint4"
),
(
quint8
(
0.01
,
135
),
"quint8"
),
...
...
@@ -141,6 +155,7 @@ def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name):
@
pytest
.
mark
.
parametrize
(
"dtype, dtype_name"
,
[
(
qint1
(
0.01
),
"qint1"
),
(
quint4
(
0.01
,
5
),
"quint4"
),
(
qint4
(
0.01
),
"qint4"
),
(
quint8
(
0.01
,
135
),
"quint8"
),
...
...
@@ -178,6 +193,7 @@ def test_qint_typecvt(dtype, dtype_name):
@
pytest
.
mark
.
parametrize
(
"dtype, dtype_name"
,
[
(
qint1
(
0.01
),
"qint1"
),
(
quint4
(
0.01
,
5
),
"quint4"
),
(
qint4
(
0.01
),
"qint4"
),
(
quint8
(
0.01
,
135
),
"quint8"
),
...
...
@@ -207,6 +223,7 @@ def test_qint_astype(dtype, dtype_name):
@
pytest
.
mark
.
parametrize
(
"dtype, dtype_name"
,
[
(
qint1
(
0.01
),
"qint1"
),
(
quint4
(
0.01
,
5
),
"quint4"
),
(
qint4
(
0.01
),
"qint4"
),
(
quint8
(
0.01
,
135
),
"quint8"
),
...
...
src/plugin/impl/opr_io_dump.cpp
浏览文件 @
fd6f8e58
...
...
@@ -42,6 +42,10 @@ double as_double(megdnn::dt_qint4& a) {
return
static_cast
<
double
>
(
a
.
as_int8
());
}
template
<
>
double
as_double
(
megdnn
::
dt_qint1
&
a
)
{
return
static_cast
<
double
>
(
a
.
as_int8
());
}
template
<
>
double
as_double
(
megdnn
::
dt_qint32
&
a
)
{
return
static_cast
<
double
>
(
a
.
as_int32
());
}
...
...
@@ -111,7 +115,7 @@ void print_host_val(
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
dtype
::
Bool
)
cb
(
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
QuantizedS1
)
#undef cb
default
:
mgb_throw
(
MegBrainError
,
...
...
src/serialization/impl/dtype.fbs
浏览文件 @
fd6f8e58
...
...
@@ -23,6 +23,7 @@ enum DTypeEnum : byte {
BFloat16,
Bool,
Uint16,
QuantizedS1,
}
table LinearQuantizationParam {
...
...
src/serialization/impl/flatbuffers_helper.cpp
浏览文件 @
fd6f8e58
...
...
@@ -55,6 +55,8 @@ megdnn::DType load_dtype(const fbs::DType* dtype) {
return dtype::_dt{};
MEGDNN_FOREACH_DTYPE_NAME
(
cb
)
#undef cb
case
DTypeEnum_QuantizedS1
:
return
dtype
::
QuantizedS1
{
param
->
scale
()};
case
DTypeEnum_QuantizedS4
:
return
dtype
::
QuantizedS4
{
param
->
scale
()};
case
DTypeEnum_QuantizedS8
:
...
...
@@ -113,6 +115,7 @@ flatbuffers::Offset<fbs::DType> build_dtype(
break;
CASE_ASYMMETRIC
(
Quantized4Asymm
)
CASE_ASYMMETRIC
(
Quantized8Asymm
)
CASE_SYMMETRIC
(
QuantizedS1
)
CASE_SYMMETRIC
(
QuantizedS4
)
CASE_SYMMETRIC
(
QuantizedS8
)
CASE_SYMMETRIC
(
QuantizedS16
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录