Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fc6aa12e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fc6aa12e
编写于
3月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge):add qint4/quint4 to python
GitOrigin-RevId: f94609db00fcaaa9ca249eb61639eb0482705f79
上级
a94fb7b1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
207 addition
and
4 deletion
+207
-4
dnn/src/fallback/type_cvt/opr_impl.cpp
dnn/src/fallback/type_cvt/opr_impl.cpp
+6
-1
python_module/megengine/_internal/dtype.py
python_module/megengine/_internal/dtype.py
+101
-1
python_module/src/cpp/python_helper.cpp
python_module/src/cpp/python_helper.cpp
+43
-2
src/core/impl/dtype.cpp
src/core/impl/dtype.cpp
+57
-0
未找到文件。
dnn/src/fallback/type_cvt/opr_impl.cpp
浏览文件 @
fc6aa12e
...
...
@@ -451,7 +451,12 @@ namespace fallback {
void
TypeCvtImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
)
{
check_exec
(
src
.
layout
,
dst
.
layout
);
if
(
src
.
layout
.
is_contiguous
()
&&
dst
.
layout
.
is_contiguous
())
{
auto
is_quantize_lowbit
=
[](
const
DType
&
dt
)
{
return
dt
.
category
()
==
DTypeCategory
::
QUANTIZED
&&
dt
.
is_low_bit
();
};
if
(
src
.
layout
.
is_contiguous
()
&&
dst
.
layout
.
is_contiguous
()
&&
!
is_quantize_lowbit
(
src
.
layout
.
dtype
)
&&
!
is_quantize_lowbit
(
dst
.
layout
.
dtype
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
run_contiguous
(
src
,
dst
));
}
else
{
naive
::
TypeCvtImpl
::
exec
(
src
,
dst
);
...
...
python_module/megengine/_internal/dtype.py
浏览文件 @
fc6aa12e
...
...
@@ -32,7 +32,7 @@ def get_scale(dtype):
def
get_zero_point
(
dtype
):
assert
is_quantize
(
dtype
)
metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
assert
metadata
[
"name"
]
==
"Quantized8Asymm"
assert
metadata
[
"name"
]
in
(
"Quantized8Asymm"
,
"Quantized4Asymm"
)
return
metadata
[
"zero_point"
]
...
...
@@ -79,6 +79,38 @@ def qint32(scale):
)
def
quint4
(
scale
,
zero_point
):
"""
Consturct a quantized unsigned int4 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
int_zp
=
int
(
zero_point
)
assert
int_zp
==
zero_point
,
"zero_point should be an integer"
if
int_zp
<
0
or
int_zp
>
15
:
raise
ValueError
(
"zero_point should be within [0, 15] for quint4"
)
return
np
.
dtype
(
np
.
uint8
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"Quantized4Asymm"
,
"scale"
:
float
(
scale
),
"zero_point"
:
int
(
zero_point
),
}
},
)
def
qint4
(
scale
):
"""
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return
np
.
dtype
(
np
.
int8
,
metadata
=
{
"mgb_dtype"
:
{
"name"
:
"QuantizedS4"
,
"scale"
:
float
(
scale
)}}
)
def
convert_to_quint8
(
arr
,
q
):
"""
Quantize a float NumPy ndarray into a quint8 one with specified params.
...
...
@@ -177,3 +209,71 @@ def convert_from_qint32(arr):
),
"arr should be a ndarray with qint8 dtype"
scale
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
arr
.
astype
(
np
.
float32
)
*
scale
def
convert_to_quint4
(
arr
,
q
):
"""
Quantize a float NumPy ndarray into a quint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint4.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"Quantized4Asymm"
),
"q should be a quint4 dtype"
scale
,
zp
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
],
q
.
metadata
[
"mgb_dtype"
][
"zero_point"
]
return
(
np
.
round
(
arr
/
scale
)
+
zp
).
clip
(
0
,
15
).
astype
(
q
)
def
convert_from_quint4
(
arr
):
"""
Dequantize a quint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"Quantized4Asymm"
),
"arr should be a ndarray with quint4 dtype"
scale
,
zp
=
(
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
],
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"zero_point"
],
)
return
(
arr
.
astype
(
np
.
float32
)
-
zp
)
*
scale
def
convert_to_qint4
(
arr
,
q
):
"""
Quantize a float NumPy ndarray into a qint4 one with specified params.
:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint4.
:type q: :class:`np.dtype`
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
q
.
metadata
and
q
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS4"
),
"q should be a qint4 dtype"
scale
=
q
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
(
np
.
round
(
arr
/
scale
)).
clip
(
-
8
,
7
).
astype
(
q
)
def
convert_from_qint4
(
arr
):
"""
Dequantize a qint4 NumPy ndarray into a float one.
:param arr: Input ndarray.
"""
assert
isinstance
(
arr
,
np
.
ndarray
)
assert
(
"mgb_dtype"
in
arr
.
dtype
.
metadata
and
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"name"
]
==
"QuantizedS4"
),
"arr should be a ndarray with qint4 dtype"
scale
=
arr
.
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
]
return
arr
.
astype
(
np
.
float32
)
*
scale
python_module/src/cpp/python_helper.cpp
浏览文件 @
fc6aa12e
...
...
@@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
{{
"scale"
,
PyFloat_FromDouble
(
param
.
scale
)}});
break
;
}
case
DTypeEnum
::
Quantized4Asymm
:
{
auto
&
param
=
dtype
.
param
<
dtype
::
Quantized4Asymm
>
();
type_descr
=
PyArray_DescrNewFromType
(
NPY_UINT8
);
type_descr
->
metadata
=
build_mgb_dtype_dict
(
DTypeTrait
<
dtype
::
Quantized4Asymm
>::
name
,
{{
"scale"
,
PyFloat_FromDouble
(
param
.
scale
)},
{
"zero_point"
,
PyLong_FromLong
(
param
.
zero_point
)}});
break
;
}
case
DTypeEnum
::
QuantizedS4
:
{
auto
&
param
=
dtype
.
param
<
dtype
::
QuantizedS4
>
();
type_descr
=
PyArray_DescrNewFromType
(
NPY_INT8
);
type_descr
->
metadata
=
build_mgb_dtype_dict
(
DTypeTrait
<
dtype
::
QuantizedS4
>::
name
,
{{
"scale"
,
PyFloat_FromDouble
(
param
.
scale
)}});
break
;
}
case
DTypeEnum
::
QuantizedS32
:
{
auto
&
param
=
dtype
.
param
<
dtype
::
QuantizedS32
>
();
type_descr
=
PyArray_DescrNewFromType
(
NPY_INT32
);
...
...
@@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
static_cast
<
float
>
(
PyFloat_AS_DOUBLE
(
scale_py
)),
static_cast
<
uint8_t
>
(
zero_point
));
}
if
(
dtype_name
==
"QuantizedS32"
||
dtype_name
==
"QuantizedS8"
)
{
if
(
dtype_name
==
"Quantized4Asymm"
)
{
PyObject
*
scale_py
=
PyDict_GetItemString
(
metadata
,
"scale"
);
PyObject
*
zero_point_py
=
PyDict_GetItemString
(
metadata
,
"zero_point"
);
mgb_assert
(
scale_py
&&
zero_point_py
,
"Invalid Quantized4Asymm metadata: missing scale or "
"zero_point."
);
mgb_assert
(
PyFloat_Check
(
scale_py
),
"Invalid Quantized4Asymm metadata: scale should be float"
);
mgb_assert
(
PyLong_Check
(
zero_point_py
),
"Invalid Quantized4Asymm metadata: zero_point should be "
"integer"
);
auto
zero_point
=
PyLong_AS_LONG
(
zero_point_py
);
mgb_assert
(
zero_point
>=
0
&&
zero_point
<
15
,
"Invalid Quantized4Asymm metadata: zero_point should be "
"in [0, 15)"
);
return
dtype
::
Quantized4Asymm
(
static_cast
<
float
>
(
PyFloat_AS_DOUBLE
(
scale_py
)),
static_cast
<
uint8_t
>
(
zero_point
));
}
if
(
dtype_name
==
"QuantizedS32"
||
dtype_name
==
"QuantizedS8"
||
dtype_name
==
"QuantizedS4"
)
{
PyObject
*
scale_py
=
PyDict_GetItemString
(
metadata
,
"scale"
);
mgb_assert
(
scale_py
,
"Invalid metadata: missing scale"
);
mgb_assert
(
PyFloat_Check
(
scale_py
),
...
...
@@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
float
scale
=
static_cast
<
float
>
(
PyFloat_AS_DOUBLE
(
scale_py
));
if
(
dtype_name
==
"QuantizedS32"
)
{
return
dtype
::
QuantizedS32
(
scale
);
}
else
{
}
else
if
(
dtype_name
==
"QuantizedS8"
)
{
return
dtype
::
QuantizedS8
(
scale
);
}
else
{
return
dtype
::
QuantizedS4
(
scale
);
}
}
throw
ConversionError
(
...
...
src/core/impl/dtype.cpp
浏览文件 @
fc6aa12e
...
...
@@ -14,6 +14,7 @@
#include "megbrain/exception.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/arith_helper.h"
#include "megdnn/dtype.h"
#include <cmath>
#include <cstring>
...
...
@@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> {
}
}
};
template
<
typename
DT
>
struct
QuantizedLowbitTrait
;
template
<
>
struct
QuantizedLowbitTrait
<
dtype
::
Quantized4Asymm
>
{
static
constexpr
int8_t
SHIFT
=
0
;
};
template
<
>
struct
QuantizedLowbitTrait
<
dtype
::
QuantizedS4
>
{
static
constexpr
int8_t
SHIFT
=
8
;
};
template
<
typename
DT
,
bool
div_byte
=
(
DTypeTrait
<
DT
>
::
category
==
DTypeCategory
::
QUANTIZED
)
&&
(
8
%
DTypeTrait
<
DT
>::
low_bit
==
0
)
>
struct
QuantizedLowbitMemcpy
;
template
<
typename
DT
>
struct
QuantizedLowbitMemcpy
<
DT
,
true
>
{
// cast with bits that 8 % bits == 0
static
constexpr
uint16_t
bits
=
DTypeTrait
<
DT
>::
low_bit
;
static
constexpr
uint8_t
MASK
=
(
1
<<
bits
)
-
1
;
using
Trait
=
QuantizedLowbitTrait
<
DT
>
;
static
void
byte2compact
(
void
*
dest_raw
,
const
void
*
src_raw
,
size_t
n
)
{
auto
dest
=
static_cast
<
uint8_t
*>
(
dest_raw
);
auto
src
=
static_cast
<
const
int8_t
*>
(
src_raw
);
memset
(
dest
,
0
,
divup
<
size_t
>
(
n
*
bits
,
8
));
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
int8_t
val
=
src
[
i
]
+
Trait
::
SHIFT
;
mgb_assert
(
val
>=
0
&&
val
<
(
1
<<
bits
));
dest
[
i
*
bits
/
8
]
|=
val
<<
(
i
*
bits
%
8
);
}
}
static
void
compact2byte
(
void
*
dest_raw
,
const
void
*
src_raw
,
size_t
n
)
{
auto
dest
=
static_cast
<
int8_t
*>
(
dest_raw
);
auto
src
=
static_cast
<
const
uint8_t
*>
(
src_raw
);
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
int8_t
val
=
((
src
[
i
*
bits
/
8
]
>>
(
i
*
bits
%
8
))
&
MASK
);
dest
[
i
]
=
val
-
Trait
::
SHIFT
;
}
}
};
}
// anonymous namespace
void
mgb
::
lowbit_memcpy_byte2compact
(
...
...
@@ -365,6 +412,11 @@ void mgb::lowbit_memcpy_byte2compact(
if (dtype == mgb::dtype::name##bits()) \
return LowbitMemcpy<bits>::byte2compact(dest, src, n);
MEGDNN_FOREACH_LOWBIT_DTYPE
(
cb
)
#undef cb
#define cb(dt) \
if (dtype.enumv() == DTypeTrait<dt>::enumv) \
return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n);
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
mgb_throw
(
MegBrainError
,
"bad dtype for lowbit: %s"
,
dtype
.
name
());
}
...
...
@@ -375,6 +427,11 @@ void mgb::lowbit_memcpy_compact2byte(
if (dtype == mgb::dtype::name##bits()) \
return LowbitMemcpy<bits>::compact2byte(dest, src, n);
MEGDNN_FOREACH_LOWBIT_DTYPE
(
cb
)
#undef cb
#define cb(dt) \
if (dtype.enumv() == DTypeTrait<dt>::enumv) \
return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n);
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
mgb_throw
(
MegBrainError
,
"bad dtype for lowbit: %s"
,
dtype
.
name
());
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录