Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
87f4b46e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
87f4b46e
编写于
12月 21, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge/imperative): move convert_inputs from python to C++
GitOrigin-RevId: baef3d348c590d477432c2c45df54835557e7c8d
上级
b310f261
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
456 addition
and
170 deletion
+456
-170
imperative/python/megengine/core/ops/special.py
imperative/python/megengine/core/ops/special.py
+8
-3
imperative/python/megengine/core/tensor/dtype.py
imperative/python/megengine/core/tensor/dtype.py
+6
-35
imperative/python/megengine/core/tensor/indexing.py
imperative/python/megengine/core/tensor/indexing.py
+5
-8
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+9
-92
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+1
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+2
-2
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+2
-4
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+4
-2
imperative/python/src/common.cpp
imperative/python/src/common.cpp
+1
-0
imperative/python/src/helper.cpp
imperative/python/src/helper.cpp
+9
-15
imperative/python/src/helper.h
imperative/python/src/helper.h
+15
-1
imperative/python/src/numpy_dtypes.cpp
imperative/python/src/numpy_dtypes.cpp
+179
-0
imperative/python/src/numpy_dtypes.h
imperative/python/src/numpy_dtypes.h
+1
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+212
-5
imperative/python/test/unit/core/test_dtype_quant.py
imperative/python/test/unit/core/test_dtype_quant.py
+2
-2
未找到文件。
imperative/python/megengine/core/ops/special.py
浏览文件 @
87f4b46e
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
numpy
as
np
from
.._imperative_rt.core2
import
Tensor
#
from .._imperative_rt.core2 import Tensor
from
..tensor.core
import
OpBase
,
TensorBase
,
apply
from
..tensor.core
import
OpBase
,
TensorBase
,
apply
...
@@ -19,5 +19,10 @@ class Const:
...
@@ -19,5 +19,10 @@ class Const:
self
.
device
=
device
self
.
device
=
device
def
__call__
(
self
,
*
reference
):
def
__call__
(
self
,
*
reference
):
Wrapper
=
type
(
reference
[
0
])
from
...tensor
import
Tensor
return
(
Wrapper
(
self
.
value
,
self
.
dtype
,
self
.
device
,
True
),)
device
=
self
.
device
if
device
is
None
:
device
=
reference
[
0
].
device
return
(
Tensor
(
self
.
value
,
self
.
dtype
,
self
.
device
,
True
),)
imperative/python/megengine/core/tensor/dtype.py
浏览文件 @
87f4b46e
...
@@ -13,6 +13,12 @@ import numpy as np
...
@@ -13,6 +13,12 @@ import numpy as np
# normal dtype related
# normal dtype related
from
.._imperative_rt
import
bfloat16
,
intb1
,
intb2
,
intb4
from
.._imperative_rt
import
bfloat16
,
intb1
,
intb2
,
intb4
from
.._imperative_rt.common
import
(
get_scale
,
get_zero_point
,
is_dtype_equal
,
is_quantize
,
)
def
is_lowbit
(
dtype
):
def
is_lowbit
(
dtype
):
...
@@ -42,41 +48,6 @@ _metadata_dict = {
...
@@ -42,41 +48,6 @@ _metadata_dict = {
}
}
def
is_quantize
(
dtype
):
return
(
hasattr
(
dtype
,
"metadata"
)
and
dtype
.
metadata
is
not
None
and
"mgb_dtype"
in
dtype
.
metadata
)
def
get_scale
(
dtype
):
assert
is_quantize
(
dtype
)
return
dtype
.
metadata
[
"mgb_dtype"
][
"scale"
]
def
get_zero_point
(
dtype
):
assert
is_quantize
(
dtype
)
metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
assert
metadata
[
"name"
]
in
(
"Quantized8Asymm"
,
"Quantized4Asymm"
)
return
metadata
[
"zero_point"
]
def
is_equal
(
dt0
,
dt1
):
def
_get_zero_point
(
dtype
):
assert
is_quantize
(
dtype
)
metadata
=
dtype
.
metadata
[
"mgb_dtype"
]
return
metadata
.
get
(
"zero_point"
)
if
is_quantize
(
dt0
)
and
is_quantize
(
dt1
):
return
get_scale
(
dt0
)
==
get_scale
(
dt1
)
and
_get_zero_point
(
dt0
)
==
_get_zero_point
(
dt1
)
if
not
(
is_quantize
(
dt0
)
or
is_quantize
(
dt1
)):
return
dt0
==
dt1
return
False
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
def
_check_zero_point
(
zp
:
int
,
dtype_str
:
str
):
qmin
=
_metadata_dict
[
dtype_str
].
qmin
qmin
=
_metadata_dict
[
dtype_str
].
qmin
qmax
=
_metadata_dict
[
dtype_str
].
qmax
qmax
=
_metadata_dict
[
dtype_str
].
qmax
...
...
imperative/python/megengine/core/tensor/indexing.py
浏览文件 @
87f4b46e
...
@@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
...
@@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
def
get_index
(
i
):
def
get_index
(
i
):
if
not
isinstance
(
i
,
(
Tensor
)):
if
not
isinstance
(
i
,
(
Tensor
)):
if
is_bool_list
(
i
)
or
isinstance
(
i
,
np
.
ndarray
)
and
i
.
dtype
==
np
.
bool_
:
if
is_bool_list
(
i
)
or
isinstance
(
i
,
np
.
ndarray
)
and
i
.
dtype
==
np
.
bool_
:
(
i
,)
=
Const
(
i
,
dtype
=
np
.
bool_
,
device
=
inp
.
device
)(
inp
)
(
i
,)
=
Const
(
i
,
dtype
=
np
.
bool_
,
device
=
inp
.
device
)()
else
:
else
:
(
i
,)
=
Const
(
i
,
dtype
=
np
.
int32
,
device
=
inp
.
device
)(
inp
)
(
i
,)
=
Const
(
i
,
dtype
=
np
.
int32
,
device
=
inp
.
device
)()
return
i
return
i
assert
isinstance
(
i
,
Tensor
)
assert
isinstance
(
i
,
Tensor
)
if
i
.
dtype
!=
np
.
bool_
:
if
i
.
dtype
!=
np
.
bool_
:
...
@@ -197,7 +197,7 @@ def try_condtake(tensor, index):
...
@@ -197,7 +197,7 @@ def try_condtake(tensor, index):
):
):
return
[]
return
[]
if
isinstance
(
index
,
np
.
ndarray
):
if
isinstance
(
index
,
np
.
ndarray
):
(
index
,)
=
Const
(
index
,
dtype
=
np
.
bool_
,
device
=
tensor
.
device
)(
tensor
)
(
index
,)
=
Const
(
index
,
dtype
=
np
.
bool_
,
device
=
tensor
.
device
)()
assert
isinstance
(
index
,
Tensor
)
assert
isinstance
(
index
,
Tensor
)
if
not
isinstance
(
tensor
,
Tensor
):
if
not
isinstance
(
tensor
,
Tensor
):
raise
TypeError
(
"input must be a tensor"
)
raise
TypeError
(
"input must be a tensor"
)
...
@@ -217,9 +217,7 @@ def getitem(tensor, index):
...
@@ -217,9 +217,7 @@ def getitem(tensor, index):
if
isinstance
(
v
.
shape
,
v
.
__class__
):
if
isinstance
(
v
.
shape
,
v
.
__class__
):
break
break
if
len
(
v
.
shape
)
>
0
and
v
.
shape
[
0
]
==
0
:
if
len
(
v
.
shape
)
>
0
and
v
.
shape
[
0
]
==
0
:
(
empty_tensor
,)
=
Const
([],
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)(
(
empty_tensor
,)
=
Const
([],
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)()
tensor
)
return
empty_tensor
return
empty_tensor
if
use_subtensor
:
if
use_subtensor
:
op
=
builtin
.
Subtensor
(
items
=
items
)
op
=
builtin
.
Subtensor
(
items
=
items
)
...
@@ -240,8 +238,7 @@ def setitem(tensor, index, value):
...
@@ -240,8 +238,7 @@ def setitem(tensor, index, value):
return
tensor
return
tensor
tensor
=
tensor
.
reshape
(
-
1
)
tensor
=
tensor
.
reshape
(
-
1
)
if
not
isinstance
(
value
,
Tensor
):
if
not
isinstance
(
value
,
Tensor
):
op
=
Const
(
value
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
(
value
,)
=
Const
(
value
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)()
(
value
,)
=
op
(
tensor
)
tensor
,
tensors
,
items
,
use_subtensor
,
_
=
unpack_getitem
(
tensor
,
index
)
tensor
,
tensors
,
items
,
use_subtensor
,
_
=
unpack_getitem
(
tensor
,
index
)
for
v
in
tensors
:
for
v
in
tensors
:
if
len
(
v
.
shape
)
>
0
and
v
.
shape
[
0
]
==
0
:
if
len
(
v
.
shape
)
>
0
and
v
.
shape
[
0
]
==
0
:
...
...
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
87f4b46e
...
@@ -11,10 +11,10 @@ from typing import Iterable, Union
...
@@ -11,10 +11,10 @@ from typing import Iterable, Union
import
numpy
as
np
import
numpy
as
np
from
.._imperative_rt.core2
import
Tensor
,
apply
from
.._imperative_rt.core2
import
Tensor
,
apply
,
dtype_promotion
,
get_device
from
..ops
import
builtin
from
..ops
import
builtin
from
..ops.special
import
Const
from
..ops.special
import
Const
from
.dtype
import
is_equal
,
is_quantize
from
.dtype
import
is_
dtype_
equal
,
is_quantize
from
.megbrain_graph
import
VarNode
from
.megbrain_graph
import
VarNode
_enable_convert_inputs
=
True
_enable_convert_inputs
=
True
...
@@ -37,94 +37,12 @@ def set_convert_inputs(flag):
...
@@ -37,94 +37,12 @@ def set_convert_inputs(flag):
return
backup
return
backup
def
dtype_promotion
(
inputs
):
"""
Returns the dtype that would result from performing an arithmetic
operation on the provided input tensors and scalars.
"""
# map numpy.dtype.kind to priority
category_priority
=
{
"f"
:
3
,
# floating-point
"i"
:
2
,
# signed integer
"u"
:
2
,
# unsigned integer
"b"
:
1
,
# boolean
}
def
scalar2dtype
(
x
):
"""
For scalar `x`, returns its corresponding type. A floating point scalar
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'.
A boolean scalar has dtype 'bool'.
"""
if
isinstance
(
x
,
bool
):
return
np
.
bool_
if
isinstance
(
x
,
int
):
return
np
.
int32
if
isinstance
(
x
,
float
):
return
np
.
float32
def
promote_types
(
types
,
cat
):
"""
Returns the data type with sufficient size to hold all types of
category `cat` in the list `types`.
"""
used_types
=
[
i
for
i
in
types
if
category_priority
.
get
(
np
.
dtype
(
i
).
kind
,
0
)
==
cat
]
assert
len
(
used_types
)
>
0
res
=
used_types
[
0
]
for
i
in
used_types
:
res
=
np
.
promote_types
(
res
,
i
)
return
res
def
max_priority
(
types
):
"""
Returns the maximum value of the priority of each type in the list
`types`.
"""
if
not
types
:
return
0
else
:
return
max
([
category_priority
.
get
(
np
.
dtype
(
i
).
kind
,
0
)
for
i
in
types
])
scalars
=
[]
tensors
=
[]
for
data
in
inputs
:
if
hasattr
(
data
,
"dtype"
):
tensors
.
append
(
data
.
dtype
)
elif
isinstance
(
data
,
(
float
,
int
,
bool
)):
scalars
.
append
(
scalar2dtype
(
data
))
max_pri_scalars
=
max_priority
(
scalars
)
max_pri_tensors
=
max_priority
(
tensors
)
assert
max_pri_scalars
>
0
or
max_pri_tensors
>
0
if
max_pri_scalars
>
max_pri_tensors
:
return
promote_types
(
scalars
,
max_pri_scalars
)
else
:
return
promote_types
(
tensors
,
max_pri_tensors
)
def
get_device
(
inputs
):
device
=
None
for
i
in
inputs
:
if
isinstance
(
i
,
(
Tensor
,
VarNode
)):
if
device
is
None
:
device
=
i
.
device
elif
device
!=
i
.
device
:
raise
ValueError
(
"ambiguous device: {} vs {}"
.
format
(
device
,
i
.
device
))
assert
device
is
not
None
return
device
def
concatenate
(
inputs
,
axis
=
0
,
*
,
device
=
None
):
def
concatenate
(
inputs
,
axis
=
0
,
*
,
device
=
None
):
dtype
=
dtype_promotion
(
inputs
)
dtype
=
dtype_promotion
(
inputs
)
device
=
get_device
(
inputs
)
device
=
get_device
(
inputs
)
def
convert
(
x
):
def
convert
(
x
):
return
convert_single_value
(
x
,
inputs
,
dtype
=
dtyp
e
)
return
convert_single_value
(
x
,
dtype
=
dtype
,
device
=
devic
e
)
inputs
=
tuple
(
map
(
convert
,
inputs
))
inputs
=
tuple
(
map
(
convert
,
inputs
))
(
result
,)
=
apply
(
builtin
.
Concat
(
axis
=
axis
,
comp_node
=
device
),
*
inputs
)
(
result
,)
=
apply
(
builtin
.
Concat
(
axis
=
axis
,
comp_node
=
device
),
*
inputs
)
...
@@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None):
...
@@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None):
def
astype
(
x
,
dtype
):
def
astype
(
x
,
dtype
):
dtype
=
np
.
dtype
(
dtype
)
dtype
=
np
.
dtype
(
dtype
)
if
not
is_equal
(
x
.
dtype
,
dtype
):
if
not
is_
dtype_
equal
(
x
.
dtype
,
dtype
):
isscalar
=
x
.
isscalar
()
isscalar
=
x
.
isscalar
()
(
x
,)
=
apply
(
builtin
.
TypeCvt
(
dtype
=
dtype
),
x
)
(
x
,)
=
apply
(
builtin
.
TypeCvt
(
dtype
=
dtype
),
x
)
if
isscalar
:
if
isscalar
:
...
@@ -141,13 +59,12 @@ def astype(x, dtype):
...
@@ -141,13 +59,12 @@ def astype(x, dtype):
return
x
return
x
def
convert_single_value
(
v
,
inputs
,
*
,
dtype
=
None
,
device
=
None
):
def
convert_single_value
(
v
,
*
,
dtype
=
None
,
device
=
None
):
tensors
=
[
i
for
i
in
inputs
if
isinstance
(
i
,
(
Tensor
,
VarNode
))]
assert
len
(
tensors
)
>
0
if
isinstance
(
v
,
(
Tensor
,
VarNode
)):
if
isinstance
(
v
,
(
Tensor
,
VarNode
)):
v
=
astype
(
v
,
v
.
dtype
if
is_quantize
(
v
.
dtype
)
else
dtype
)
if
not
is_quantize
(
v
.
dtype
):
v
=
astype
(
v
,
dtype
)
else
:
else
:
(
v
,)
=
Const
(
v
,
dtype
=
dtype
,
device
=
device
)(
*
tensors
)
(
v
,)
=
Const
(
v
,
dtype
=
dtype
,
device
=
device
)()
return
v
return
v
...
@@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor):
...
@@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor):
def
convert
(
value
):
def
convert
(
value
):
if
value
is
None
:
if
value
is
None
:
return
value
return
value
return
convert_single_value
(
value
,
args
,
dtype
=
dtype
,
device
=
device
)
return
convert_single_value
(
value
,
dtype
=
dtype
,
device
=
device
)
return
tuple
(
map
(
convert
,
args
))
return
tuple
(
map
(
convert
,
args
))
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
87f4b46e
...
@@ -703,7 +703,7 @@ def topk(
...
@@ -703,7 +703,7 @@ def topk(
op
=
builtin
.
TopK
(
mode
=
mode
)
op
=
builtin
.
TopK
(
mode
=
mode
)
if
not
isinstance
(
k
,
Tensor
):
if
not
isinstance
(
k
,
Tensor
):
(
k
,)
=
Const
(
k
,
dtype
=
"int32"
,
device
=
inp
.
device
)(
inp
)
(
k
,)
=
Const
(
k
,
dtype
=
"int32"
,
device
=
inp
.
device
)()
if
len
(
inp
.
shape
)
==
1
:
if
len
(
inp
.
shape
)
==
1
:
inp
=
inp
.
reshape
(
1
,
-
1
)
inp
=
inp
.
reshape
(
1
,
-
1
)
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
87f4b46e
...
@@ -658,7 +658,7 @@ def batch_norm(
...
@@ -658,7 +658,7 @@ def batch_norm(
def
make_full_if_none
(
x
,
value
):
def
make_full_if_none
(
x
,
value
):
if
x
is
None
:
if
x
is
None
:
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)(
inp
)
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)()
shape
=
utils
.
astensor1d
(
shape
=
utils
.
astensor1d
(
(
1
,
C
,
1
,
1
),
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
(
1
,
C
,
1
,
1
),
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
)
...
@@ -1567,7 +1567,7 @@ def indexing_one_hot(
...
@@ -1567,7 +1567,7 @@ def indexing_one_hot(
"""
"""
assert
isinstance
(
src
,
Tensor
),
"src must be of Tensor type"
assert
isinstance
(
src
,
Tensor
),
"src must be of Tensor type"
op
=
builtin
.
IndexingOneHot
(
axis
=
axis
)
op
=
builtin
.
IndexingOneHot
(
axis
=
axis
)
index
=
utils
.
convert_single_value
(
index
,
(
src
,),
dtype
=
"int32"
,
device
=
src
.
device
)
index
=
utils
.
convert_single_value
(
index
,
dtype
=
"int32"
,
device
=
src
.
device
)
(
result
,)
=
apply
(
op
,
src
,
index
)
(
result
,)
=
apply
(
op
,
src
,
index
)
if
not
keepdims
:
if
not
keepdims
:
result
=
squeeze
(
result
,
axis
)
result
=
squeeze
(
result
,
axis
)
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
87f4b46e
...
@@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None):
...
@@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None):
shape
=
(
shape
,)
shape
=
(
shape
,)
if
device
is
None
:
if
device
is
None
:
device
=
get_default_device
()
device
=
get_default_device
()
(
x
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)(
(
x
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)()
Tensor
(
value
,
dtype
=
dtype
,
device
=
device
)
)
if
len
(
shape
)
==
0
:
# scalar
if
len
(
shape
)
==
0
:
# scalar
return
x
return
x
return
broadcast_to
(
x
,
shape
)
return
broadcast_to
(
x
,
shape
)
...
@@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
...
@@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
device
=
as_device
(
device
)
device
=
as_device
(
device
)
def
convert
(
x
):
def
convert
(
x
):
return
convert_single_value
(
x
,
inps
,
dtype
=
dtyp
e
)
return
convert_single_value
(
x
,
dtype
=
dtype
,
device
=
devic
e
)
inps
=
tuple
(
map
(
convert
,
inps
))
inps
=
tuple
(
map
(
convert
,
inps
))
(
result
,)
=
apply
(
builtin
.
Concat
(
axis
=
axis
,
comp_node
=
device
.
to_c
()),
*
inps
)
(
result
,)
=
apply
(
builtin
.
Concat
(
axis
=
axis
,
comp_node
=
device
.
to_c
()),
*
inps
)
...
...
imperative/python/megengine/tensor.py
浏览文件 @
87f4b46e
...
@@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
...
@@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
else
:
else
:
cn
=
CompNode
(
device
)
cn
=
CompNode
(
device
)
else
:
else
:
assert
isinstance
(
device
,
CompNode
)
if
isinstance
(
device
,
CompNode
):
cn
=
device
cn
=
device
else
:
cn
=
device
.
_cn
# import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
if
isinstance
(
data
,
_Tensor
):
if
isinstance
(
data
,
_Tensor
):
...
...
imperative/python/src/common.cpp
浏览文件 @
87f4b46e
...
@@ -179,4 +179,5 @@ void init_common(py::module m) {
...
@@ -179,4 +179,5 @@ void init_common(py::module m) {
init_npy_num_bfloat16
(
m
);
init_npy_num_bfloat16
(
m
);
init_npy_num_intbx
(
m
);
init_npy_num_intbx
(
m
);
init_dtypes
(
m
);
}
}
imperative/python/src/helper.cpp
浏览文件 @
87f4b46e
...
@@ -158,7 +158,7 @@ void PyExceptionForward::throw_() {
...
@@ -158,7 +158,7 @@ void PyExceptionForward::throw_() {
/* ============== namespace npy ============== */
/* ============== namespace npy ============== */
namespace
{
namespace
npy
{
int
to_mgb_supported_dtype_raw
(
int
dtype
)
{
int
to_mgb_supported_dtype_raw
(
int
dtype
)
{
if
(
dtype
==
NPY_INT64
)
if
(
dtype
==
NPY_INT64
)
...
@@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) {
...
@@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) {
"can not convert dtype %s to numpy dtype"
,
dtype
.
name
()));
"can not convert dtype %s to numpy dtype"
,
dtype
.
name
()));
}
}
struct
PyArrayDescrDeleter
{
void
operator
()(
PyArray_Descr
*
obj
)
{
Py_XDECREF
(
obj
);
}
};
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
//! reference to the descriptor.
//! reference to the descriptor.
std
::
unique_ptr
<
PyArray_Descr
,
PyArrayDescrDeleter
>
dtype_mgb2np_descr
(
std
::
unique_ptr
<
PyArray_Descr
,
PyArrayDescrDeleter
>
dtype_mgb2np_descr
(
...
@@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) {
...
@@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) {
HostTensorNDRefHolder
::
free
(
static_cast
<
HostTensorNDRefHolder
*>
(
ptr
));
HostTensorNDRefHolder
::
free
(
static_cast
<
HostTensorNDRefHolder
*>
(
ptr
));
}
}
}
// anonymous namespace
PyObject
*
ndarray_from_tensor
(
PyObject
*
npy
::
ndarray_from_tensor
(
const
HostTensorND
&
val
,
ShareType
share_type
)
{
const
HostTensorND
&
val
,
ShareType
share_type
)
{
if
(
!
val
.
layout
().
is_contiguous
()
&&
!
val
.
shape
().
is_empty
())
{
if
(
!
val
.
layout
().
is_contiguous
()
&&
!
val
.
shape
().
is_empty
())
{
mgb_assert
(
share_type
!=
ShareType
::
MUST_SHARE
);
mgb_assert
(
share_type
!=
ShareType
::
MUST_SHARE
);
...
@@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor(
...
@@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor(
return
ret
;
return
ret
;
}
}
HostTensorND
np
y
::
np
2tensor
(
PyObject
*
obj
,
const
Meth
&
meth
,
DType
dtype
)
{
HostTensorND
np2tensor
(
PyObject
*
obj
,
const
Meth
&
meth
,
DType
dtype
)
{
auto
ret_full
=
np2tensor_try_borrow
(
obj
,
meth
,
dtype
);
auto
ret_full
=
np2tensor_try_borrow
(
obj
,
meth
,
dtype
);
if
(
meth
.
must_borrow_
)
{
if
(
meth
.
must_borrow_
)
{
mgb_assert
(
ret_full
.
second
,
mgb_assert
(
ret_full
.
second
,
...
@@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
...
@@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
return
ret_full
.
first
;
return
ret_full
.
first
;
}
}
PyObject
*
npy
::
dtype_mgb2np
(
mgb
::
DType
dtype
)
{
PyObject
*
dtype_mgb2np
(
mgb
::
DType
dtype
)
{
PYTHON_GIL
;
PYTHON_GIL
;
// According to
// According to
// https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
// https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
...
@@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
...
@@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
return
typeobj
;
return
typeobj
;
}
}
mgb
::
DType
npy
::
dtype_np2mgb
(
PyObject
*
obj
)
{
mgb
::
DType
dtype_np2mgb
(
PyObject
*
obj
)
{
mgb_assert
(
obj
&&
obj
!=
Py_None
,
mgb_assert
(
obj
&&
obj
!=
Py_None
,
"can not convert null PyObject to numpy dtype"
);
"can not convert null PyObject to numpy dtype"
);
// see
// see
...
@@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) {
...
@@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) {
return
result
;
return
result
;
}
}
PyObject
*
npy
::
to_mgb_supported_dtype
(
PyObject
*
dtype
)
{
PyObject
*
to_mgb_supported_dtype
(
PyObject
*
dtype
)
{
PYTHON_GIL
;
PYTHON_GIL
;
PyArray_Descr
*
descr
;
PyArray_Descr
*
descr
;
...
@@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) {
...
@@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) {
return
PyArray_TypeObjectFromType
(
type_num
);
return
PyArray_TypeObjectFromType
(
type_num
);
}
}
TensorShape
npy
::
vec2shape
(
const
std
::
vector
<
size_t
>
&
vec
)
{
TensorShape
vec2shape
(
const
std
::
vector
<
size_t
>
&
vec
)
{
TensorShape
shape
;
TensorShape
shape
;
mgb_assert
(
vec
.
size
()
<=
TensorShape
::
MAX_NDIM
,
mgb_assert
(
vec
.
size
()
<=
TensorShape
::
MAX_NDIM
,
"dim too large: %zd (max %zd)"
,
"dim too large: %zd (max %zd)"
,
...
@@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) {
...
@@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) {
mgb_assert
(
shape
.
ndim
,
"shape should not be empty"
);
mgb_assert
(
shape
.
ndim
,
"shape should not be empty"
);
return
shape
;
return
shape
;
}
}
}
// namespace npy
imperative/python/src/helper.h
浏览文件 @
87f4b46e
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#pragma once
#pragma once
#include "megbrain/
graph
.h"
#include "megbrain/
common
.h"
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/op_def.h"
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <pybind11/numpy.h>
#include <pybind11/numpy.h>
#include <pybind11/functional.h>
#include <pybind11/functional.h>
#include "./numpy_dtypes.h"
pybind11
::
module
submodule
(
pybind11
::
module
parent
,
const
char
*
name
,
const
char
*
doc
=
nullptr
);
pybind11
::
module
submodule
(
pybind11
::
module
parent
,
const
char
*
name
,
const
char
*
doc
=
nullptr
);
pybind11
::
module
rel_import
(
pybind11
::
str
name
,
pybind11
::
module
m
,
int
level
);
pybind11
::
module
rel_import
(
pybind11
::
str
name
,
pybind11
::
module
m
,
int
level
);
...
@@ -182,6 +184,18 @@ namespace npy {
...
@@ -182,6 +184,18 @@ namespace npy {
//! convert raw vector to tensor shape
//! convert raw vector to tensor shape
mgb
::
TensorShape
vec2shape
(
const
std
::
vector
<
size_t
>
&
vec
);
mgb
::
TensorShape
vec2shape
(
const
std
::
vector
<
size_t
>
&
vec
);
struct
PyArrayDescrDeleter
{
void
operator
()(
PyArray_Descr
*
obj
)
{
Py_XDECREF
(
obj
);
}
};
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
//! reference to the descriptor.
std
::
unique_ptr
<
PyArray_Descr
,
PyArrayDescrDeleter
>
dtype_mgb2np_descr
(
mgb
::
DType
dtype
);
mgb
::
DType
dtype_np2mgb_descr
(
PyArray_Descr
*
descr
);
//! convert megbrain dtype to numpy dtype object; return new reference
//! convert megbrain dtype to numpy dtype object; return new reference
PyObject
*
dtype_mgb2np
(
mgb
::
DType
dtype
);
PyObject
*
dtype_mgb2np
(
mgb
::
DType
dtype
);
...
...
imperative/python/src/numpy_dtypes.cpp
0 → 100644
浏览文件 @
87f4b46e
/**
* \file imperative/python/src/numpy_dtypes.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./numpy_dtypes.h"
#include "./helper.h"
#include "./pyext17.h"
#include "pybind11/pybind11.h"
#include <cstring>
namespace
py
=
pybind11
;
namespace
mgb
{
namespace
{
inline
bool
_is_quantize
(
PyArray_Descr
*
dtype
)
{
static
PyObject
*
PY_MGB_DTYPE_KEY
=
PyUnicode_FromString
(
"mgb_dtype"
);
return
dtype
->
metadata
&&
PyDict_CheckExact
(
dtype
->
metadata
)
&&
PyDict_Contains
(
dtype
->
metadata
,
PY_MGB_DTYPE_KEY
)
==
1
;
}
PyObject
*
_get_mgb_dtype
(
PyArray_Descr
*
dtype
)
{
// Return value: New reference.
if
(
!
_is_quantize
(
dtype
))
{
throw
py
::
type_error
(
"expact quantize dtype"
);
}
PyObject
*
ob
=
PyDict_GetItemString
(
dtype
->
metadata
,
"mgb_dtype"
);
if
(
!
PyDict_CheckExact
(
ob
))
{
throw
py
::
type_error
(
"mgb_dtype is not dict"
);
}
Py_INCREF
(
ob
);
return
ob
;
}
double
_get_scale
(
PyArray_Descr
*
dtype
)
{
PyObject
*
ob
=
_get_mgb_dtype
(
dtype
);
PyObject
*
scale
=
PyDict_GetItemString
(
ob
,
"scale"
);
if
(
!
scale
)
{
Py_DECREF
(
ob
);
throw
py
::
key_error
(
"scale"
);
}
if
(
!
PyFloat_Check
(
scale
))
{
Py_DECREF
(
ob
);
throw
py
::
type_error
(
"scale is not float"
);
}
double
ret
=
PyFloat_AsDouble
(
scale
);
Py_DECREF
(
ob
);
return
ret
;
}
long
_get_zero_point
(
PyArray_Descr
*
dtype
)
{
PyObject
*
ob
=
_get_mgb_dtype
(
dtype
);
PyObject
*
name
=
PyDict_GetItemString
(
ob
,
"name"
);
if
(
!
name
)
{
Py_DECREF
(
ob
);
throw
py
::
key_error
(
"name"
);
}
const
char
*
s
=
PyUnicode_AsUTF8
(
name
);
if
(
strcmp
(
s
,
"Quantized8Asymm"
)
!=
0
&&
strcmp
(
s
,
"Quantized4Asymm"
)
!=
0
)
{
Py_DECREF
(
ob
);
throw
py
::
value_error
(
ssprintf
(
"expect name to be
\"
Quantized8Asymm
\"
or
\"
Quantized4Asymm
\"
, got %s"
,
s
));
}
PyObject
*
zp
=
PyDict_GetItemString
(
ob
,
"zero_point"
);
if
(
!
zp
)
{
Py_DECREF
(
ob
);
throw
py
::
key_error
(
"zero_point"
);
}
long
ret
=
PyLong_AsLong
(
zp
);
Py_DECREF
(
ob
);
return
ret
;
}
bool
_is_dtype_equal
(
PyArray_Descr
*
dt1
,
PyArray_Descr
*
dt2
)
{
bool
q1
=
_is_quantize
(
dt1
),
q2
=
_is_quantize
(
dt2
);
if
(
q1
&&
q2
)
{
if
(
_get_scale
(
dt1
)
!=
_get_scale
(
dt2
))
{
return
false
;
}
PyObject
*
zp1
=
PyDict_GetItemString
(
PyDict_GetItemString
(
dt1
->
metadata
,
"mgb_dtype"
),
"zero_point"
);
PyObject
*
zp2
=
PyDict_GetItemString
(
PyDict_GetItemString
(
dt2
->
metadata
,
"mgb_dtype"
),
"zero_point"
);
if
(
!
zp1
||
!
zp2
)
{
throw
py
::
key_error
(
"zero_point"
);
}
return
PyLong_AsLong
(
zp1
)
==
PyLong_AsLong
(
zp2
);
}
if
(
!
q1
&&
!
q2
)
{
return
dt1
->
type_num
==
dt2
->
type_num
;
}
return
false
;
}
template
<
auto
f
>
struct
_wrap
{
static
constexpr
size_t
n_args
=
[]()
{
using
F
=
decltype
(
f
);
using
T
=
PyArray_Descr
*
;
static_assert
(
std
::
is_pointer
<
F
>::
value
);
if
constexpr
(
std
::
is_invocable
<
F
,
T
>::
value
)
{
return
1
;
}
else
if
constexpr
(
std
::
is_invocable
<
F
,
T
,
T
>::
value
)
{
return
2
;
}
else
{
static_assert
(
!
std
::
is_same_v
<
F
,
F
>
,
"unreachable"
);
}
}();
static
PyObject
*
impl
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
nargs
!=
n_args
)
{
PyErr_Format
(
PyExc_ValueError
,
"expected %lu arguments"
,
n_args
);
return
nullptr
;
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
(
args
[
i
]
==
Py_None
)
{
PyErr_SetString
(
PyExc_ValueError
,
"can not convert null PyObject to numpy dtype"
);
return
nullptr
;
}
}
try
{
PyArray_Descr
*
dt1
;
if
(
!
PyArray_DescrConverter
(
args
[
0
],
&
dt1
))
{
throw
ConversionError
(
ssprintf
(
"can not convert to numpy.dtype from %s"
,
args
[
0
]
->
ob_type
->
tp_name
));
}
if
constexpr
(
n_args
==
1
)
{
auto
res
=
(
*
f
)(
dt1
);
Py_DECREF
(
dt1
);
return
py
::
cast
(
res
).
release
().
ptr
();
}
else
{
PyArray_Descr
*
dt2
;
if
(
!
PyArray_DescrConverter
(
args
[
1
],
&
dt2
))
{
Py_DECREF
(
dt1
);
throw
ConversionError
(
ssprintf
(
"can not convert to numpy.dtype from %s"
,
args
[
1
]
->
ob_type
->
tp_name
));
}
auto
&&
res
=
(
*
f
)(
dt1
,
dt2
);
Py_DECREF
(
dt1
);
Py_DECREF
(
dt2
);
return
py
::
cast
(
res
).
release
().
ptr
();
}
}
catch
(
std
::
exception
&
e
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
e
.
what
());
return
nullptr
;
}
}
};
}
// anonymous namespace
void
init_dtypes
(
py
::
module
m
)
{
static
PyMethodDef
method_defs
[]
=
{
{
"is_quantize"
,
(
PyCFunction
)
_wrap
<&
_is_quantize
>::
impl
,
METH_FASTCALL
,
nullptr
},
{
"get_scale"
,
(
PyCFunction
)
_wrap
<&
_get_scale
>::
impl
,
METH_FASTCALL
,
nullptr
},
{
"get_zero_point"
,
(
PyCFunction
)
_wrap
<&
_get_zero_point
>::
impl
,
METH_FASTCALL
,
nullptr
},
{
"is_dtype_equal"
,
(
PyCFunction
)
_wrap
<&
_is_dtype_equal
>::
impl
,
METH_FASTCALL
,
nullptr
},
{
nullptr
,
nullptr
,
0
,
nullptr
}
};
for
(
auto
&&
def
:
method_defs
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
auto
*
func
=
PyCFunction_NewEx
(
&
def
,
nullptr
,
nullptr
);
if
(
!
func
)
throw
py
::
error_already_set
();
py
::
setattr
(
m
,
def
.
ml_name
,
func
);
}
}
}
}
// namespace mgb
imperative/python/src/numpy_dtypes.h
浏览文件 @
87f4b46e
...
@@ -36,6 +36,7 @@ namespace mgb {
...
@@ -36,6 +36,7 @@ namespace mgb {
int npy_num_intb##n();
int npy_num_intb##n();
FOREACH_MGB_LOW_BIT
(
DEFINE_NPY_INTBX
)
FOREACH_MGB_LOW_BIT
(
DEFINE_NPY_INTBX
)
#undef DEFINE_NPY_INTBX
#undef DEFINE_NPY_INTBX
void
init_dtypes
(
pybind11
::
module
m
);
void
init_npy_num_intbx
(
pybind11
::
module
m
);
void
init_npy_num_intbx
(
pybind11
::
module
m
);
//! numpy type num for bfloat16 type
//! numpy type num for bfloat16 type
...
...
imperative/python/src/tensor.cpp
浏览文件 @
87f4b46e
...
@@ -9,16 +9,22 @@
...
@@ -9,16 +9,22 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "megbrain/dtype.h"
#include "megbrain/common.h"
#include "./tensor.h"
#include "./tensor.h"
#include "./grad.h"
#include "./grad.h"
#include "./trace.h"
#include "./trace.h"
#include "./common.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include "./numpy_dtypes.h"
#include "./graph_rt.h"
#include "./graph_rt.h"
#include "./helper.h"
#include <pybind11/numpy.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/operators.h>
#include "./helper.h"
#include <unordered_map>
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
mgb
::
imperative
::
python
{
namespace
mgb
::
imperative
::
python
{
...
@@ -413,6 +419,198 @@ struct TensorWeakRef {
...
@@ -413,6 +419,198 @@ struct TensorWeakRef {
}
}
};
};
/* ============== convert inputs ============== */
// map numpy.dtype.kind to priority
inline
uint8_t
category_priority
(
char
c
)
{
switch
(
c
)
{
case
'f'
:
return
3
;
// floating-point
case
'i'
:
return
2
;
// signed integer
case
'u'
:
return
2
;
// unsigned integer
case
'b'
:
return
1
;
// boolean
default:
return
0
;
}
}
// Returns the maximum value of the priority of each type in the list `types`.
uint8_t
max_priority
(
SmallVector
<
PyArray_Descr
*>
types
)
{
if
(
types
.
size
()
==
0
)
{
return
0
;
}
else
{
uint8_t
max_p
=
0
;
for
(
auto
&&
desc
:
types
)
{
max_p
=
std
::
max
(
max_p
,
category_priority
(
desc
->
kind
));
}
return
max_p
;
}
}
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr
*
promote_types
(
SmallVector
<
PyArray_Descr
*>
types
,
uint8_t
cat
)
{
// Return value: New reference
SmallVector
<
PyArray_Descr
*>
used_types
;
for
(
auto
&&
desc
:
types
)
{
auto
&&
v
=
category_priority
(
desc
->
kind
);
if
(
v
==
cat
)
{
used_types
.
emplace_back
(
desc
);
}
}
mgb_assert
(
used_types
.
size
()
>
0
,
"size of used_types is 0"
);
PyArray_Descr
*
res
=
used_types
[
0
];
Py_INCREF
(
res
);
for
(
size_t
i
=
1
;
i
<
used_types
.
size
();
++
i
)
{
PyArray_Descr
*
tmp
=
PyArray_PromoteTypes
(
used_types
[
i
],
res
);
Py_DECREF
(
res
);
res
=
tmp
;
}
return
res
;
}
PyArray_Descr
*
scalar2dtype
(
PyObject
*
arg
)
{
// Return value: New reference
if
(
PyBool_Check
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_BOOL
);
return
descr
;
}
if
(
PyLong_CheckExact
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_INT32
);
return
descr
;
}
if
(
PyFloat_CheckExact
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_FLOAT32
);
return
descr
;
}
return
nullptr
;
}
PyArray_Descr
*
_dtype_promotion
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
// Return value: New reference
SmallVector
<
PyArray_Descr
*>
tensors
;
SmallVector
<
PyArray_Descr
*>
scalars
;
bool
is_tuple
=
false
;
PyObject
*
tuple
;
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
}
else
{
tuple
=
args
[
0
];
Py_INCREF
(
tuple
);
}
nargs
=
PyTuple_Size
(
tuple
);
is_tuple
=
true
;
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
if
(
handle
==
Py_None
)
continue
;
TensorWrapper
*
tw
=
TensorWrapper
::
cast_safe
(
handle
);
if
(
tw
)
{
mgb
::
DType
type
=
tw
->
m_tensor
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
}
else
{
if
(
PyArray_Check
(
handle
)
||
PyArray_CheckScalar
(
handle
))
{
auto
&&
descr
=
PyArray_DescrFromObject
(
handle
,
nullptr
);
tensors
.
emplace_back
(
descr
);
continue
;
}
PyArray_Descr
*
descr
=
scalar2dtype
(
handle
);
if
(
descr
)
{
scalars
.
emplace_back
(
descr
);
continue
;
}
}
}
auto
max_pri_scalars
=
max_priority
(
scalars
);
auto
max_pri_tensors
=
max_priority
(
tensors
);
if
(
max_pri_scalars
<=
0
&&
max_pri_tensors
<=
0
)
{
throw
py
::
value_error
(
"invalid input, no dtype avaliable"
);
}
PyArray_Descr
*
res
;
if
(
max_pri_scalars
>
max_pri_tensors
)
{
res
=
promote_types
(
scalars
,
max_pri_scalars
);
}
else
{
res
=
promote_types
(
tensors
,
max_pri_tensors
);
}
for
(
auto
*
p
:
tensors
)
{
Py_DECREF
(
p
);
}
for
(
auto
*
p
:
scalars
)
{
Py_DECREF
(
p
);
}
Py_DECREF
(
tuple
);
return
res
;
}
CompNode
_get_device
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
bool
is_tuple
=
false
;
PyObject
*
tuple
;
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
}
else
{
tuple
=
args
[
0
];
Py_INCREF
(
tuple
);
}
nargs
=
PyTuple_Size
(
tuple
);
is_tuple
=
true
;
}
bool
valid
=
false
;
CompNode
cn
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
TensorWrapper
*
tw
=
TensorWrapper
::
cast_safe
(
handle
);
if
(
tw
)
{
if
(
!
valid
)
{
cn
=
tw
->
m_tensor
->
comp_node
();
valid
=
true
;
}
else
{
CompNode
cn1
=
tw
->
m_tensor
->
comp_node
();
if
(
cn1
!=
cn
)
{
throw
py
::
value_error
(
ssprintf
(
"ambiguous device: %s vs %s"
,
cn
.
to_string
().
c_str
(),
cn1
.
to_string
().
c_str
()));
}
}
}
}
if
(
!
valid
)
{
mgb_assert
(
0
,
"expact at least 1 device"
);
}
Py_DECREF
(
tuple
);
return
cn
;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject
*
dtype_promotion
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
!
nargs
)
{
PyErr_SetString
(
PyExc_TypeError
,
"empty input is not allowed"
);
return
nullptr
;
}
try
{
PyArray_Descr
*
res
=
_dtype_promotion
(
args
,
nargs
);
return
py
::
cast
(
npy
::
dtype_np2mgb_descr
(
res
)).
release
().
ptr
();
}
catch
(
std
::
exception
&
e
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
e
.
what
());
return
nullptr
;
}
}
PyObject
*
get_device
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
!
nargs
)
{
PyErr_SetString
(
PyExc_TypeError
,
"empty input is not allowed"
);
return
nullptr
;
}
try
{
CompNode
cn
=
_get_device
(
args
,
nargs
);
return
py
::
cast
(
cn
).
release
().
ptr
();
}
catch
(
std
::
exception
&
e
)
{
PyErr_SetString
(
PyExc_RuntimeError
,
e
.
what
());
return
nullptr
;
}
}
void
init_tensor
(
py
::
module
m
)
{
void
init_tensor
(
py
::
module
m
)
{
interpreter_for_py
=
interpreter
::
Interpreter
::
inst
().
create_channel
();
interpreter_for_py
=
interpreter
::
Interpreter
::
inst
().
create_channel
();
...
@@ -444,10 +642,19 @@ void init_tensor(py::module m) {
...
@@ -444,10 +642,19 @@ void init_tensor(py::module m) {
.
def
(
py
::
init
<
const
TensorWrapper
&>
())
.
def
(
py
::
init
<
const
TensorWrapper
&>
())
.
def
(
"__call__"
,
&
TensorWeakRef
::
operator
());
.
def
(
"__call__"
,
&
TensorWeakRef
::
operator
());
static
PyMethodDef
apply_def
{
"apply"
,
(
PyCFunction
)
py_apply
,
METH_FASTCALL
,
nullptr
};
static
PyMethodDef
method_defs
[]
=
{
auto
*
apply_func
=
PyCFunction_NewEx
(
&
apply_def
,
nullptr
,
nullptr
);
{
"apply"
,
(
PyCFunction
)
py_apply
,
METH_FASTCALL
,
nullptr
},
if
(
!
apply_func
)
throw
py
::
error_already_set
();
{
"dtype_promotion"
,
(
PyCFunction
)
dtype_promotion
,
METH_FASTCALL
,
nullptr
},
py
::
setattr
(
m
,
"apply"
,
apply_func
);
{
"get_device"
,
(
PyCFunction
)
get_device
,
METH_FASTCALL
,
nullptr
},
{
nullptr
,
nullptr
,
0
,
nullptr
}
};
for
(
auto
&&
def
:
method_defs
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
auto
*
func
=
PyCFunction_NewEx
(
&
def
,
nullptr
,
nullptr
);
if
(
!
func
)
throw
py
::
error_already_set
();
py
::
setattr
(
m
,
def
.
ml_name
,
func
);
}
}
m
.
def
(
"_set_swap_flag"
,
m
.
def
(
"_set_swap_flag"
,
[](
bool
flag
)
{
interpreter_for_py
->
set_swap_flag
(
flag
);
});
[](
bool
flag
)
{
interpreter_for_py
->
set_swap_flag
(
flag
);
});
...
...
imperative/python/test/unit/core/test_dtype_quant.py
浏览文件 @
87f4b46e
...
@@ -113,7 +113,7 @@ def test_quint8_typecvt():
...
@@ -113,7 +113,7 @@ def test_quint8_typecvt():
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
*
5
-
1
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
*
5
-
1
def
typecvt
(
x
,
dt
=
None
):
def
typecvt
(
x
,
dt
=
None
):
(
y
,)
=
apply
(
ops
.
TypeCvt
(
dtype
=
dt
),
x
)
(
y
,)
=
G
.
apply_normal_op
(
ops
.
TypeCvt
(
dtype
=
dt
),
x
)
return
y
return
y
# convert to quint8
# convert to quint8
...
@@ -194,7 +194,7 @@ def test_quint4_typecvt():
...
@@ -194,7 +194,7 @@ def test_quint4_typecvt():
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
*
5
-
1
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
*
5
-
1
def
typecvt
(
x
,
dt
=
None
):
def
typecvt
(
x
,
dt
=
None
):
(
y
,)
=
apply
(
ops
.
TypeCvt
(
dtype
=
dt
),
x
)
(
y
,)
=
G
.
apply_normal_op
(
ops
.
TypeCvt
(
dtype
=
dt
),
x
)
return
y
return
y
# convert to quint4
# convert to quint4
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录