Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fa62f6c0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fa62f6c0
编写于
3月 10, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge/utils): move convert_input into C++
GitOrigin-RevId: 0d1cd362511d2d423faaeffd9d80710747cf05f2
上级
d98be080
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
374 addition
and
271 deletion
+374
-271
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+3
-9
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+7
-43
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+0
-1
imperative/python/src/numpy_dtypes.cpp
imperative/python/src/numpy_dtypes.cpp
+6
-0
imperative/python/src/numpy_dtypes.h
imperative/python/src/numpy_dtypes.h
+5
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+6
-217
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+337
-1
imperative/python/src/tensor_utils.h
imperative/python/src/tensor_utils.h
+10
-0
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
fa62f6c0
...
@@ -19,6 +19,7 @@ from .._imperative_rt.core2 import (
...
@@ -19,6 +19,7 @@ from .._imperative_rt.core2 import (
SymbolVar
,
SymbolVar
,
Tensor
,
Tensor
,
apply
,
apply
,
astype_cpp
,
broadcast_cpp
,
broadcast_cpp
,
dtype_promotion
,
dtype_promotion
,
)
)
...
@@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp
...
@@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp
from
..ops
import
builtin
from
..ops
import
builtin
from
.
import
amp
from
.
import
amp
from
.indexing
import
getitem
,
setitem
from
.indexing
import
getitem
,
setitem
from
.utils
import
(
from
.utils
import
_normalize_axis
,
astensor1d
,
cast_tensors
,
make_shape_tuple
,
subgraph
_normalize_axis
,
astensor1d
,
astype
,
cast_tensors
,
make_shape_tuple
,
subgraph
,
)
_ElwMod
=
builtin
.
Elemwise
.
Mode
_ElwMod
=
builtin
.
Elemwise
.
Mode
...
@@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC):
...
@@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC):
r
"""Returns a :class:`Tensor` with the same data and number of elements
r
"""Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`.
with the specified :attr:`~.Tensor.dtype`.
"""
"""
return
astype
(
self
,
dtype
)
return
astype
_cpp
(
self
,
dtype
)
def
reshape
(
self
,
*
args
):
def
reshape
(
self
,
*
args
):
r
"""See :func:`~.reshape`."""
r
"""See :func:`~.reshape`."""
...
...
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
fa62f6c0
...
@@ -20,6 +20,9 @@ from .._imperative_rt.core2 import (
...
@@ -20,6 +20,9 @@ from .._imperative_rt.core2 import (
_get_convert_inputs
,
_get_convert_inputs
,
_set_convert_inputs
,
_set_convert_inputs
,
apply
,
apply
,
astype_cpp
,
convert_inputs_cpp
,
convert_single_value_cpp
,
dtype_promotion
,
dtype_promotion
,
get_device
,
get_device
,
make_shape_tuple
,
make_shape_tuple
,
...
@@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None):
...
@@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None):
return
result
return
result
def
astype
(
x
,
dtype
):
dtype
=
np
.
dtype
(
dtype
)
if
not
is_dtype_equal
(
x
.
dtype
,
dtype
):
(
x
,)
=
apply
(
builtin
.
TypeCvt
(
dtype
=
dtype
),
x
)
return
x
def
convert_single_value
(
v
,
*
,
dtype
=
None
,
device
=
None
):
def
convert_single_value
(
v
,
*
,
dtype
=
None
,
device
=
None
):
if
isinstance
(
v
,
(
Tensor
,
SymbolVar
)):
return
convert_single_value_cpp
(
v
,
dtype
,
device
)
if
not
is_quantize
(
v
.
dtype
):
v
=
astype
(
v
,
dtype
)
else
:
v
=
Const
(
v
,
dtype
,
device
,
None
)
return
v
def
convert_inputs
(
*
args
,
device
=
None
):
def
convert_inputs
(
*
args
,
device
=
None
):
if
not
_get_convert_inputs
():
if
not
_get_convert_inputs
():
return
args
return
args
return
convert_inputs_cpp
(
*
args
,
device
)
dtype
=
dtype_promotion
(
args
)
if
device
is
None
:
device
=
get_device
(
args
)
device
=
as_device
(
device
)
graph
=
None
sym_type
=
None
for
a
in
args
:
if
isinstance
(
a
,
SymbolVar
):
if
graph
is
None
:
graph
=
a
.
var
.
graph
sym_type
=
type
(
a
)
else
:
assert
graph
==
a
.
var
.
graph
args
=
list
(
args
)
if
graph
is
not
None
:
for
i
in
range
(
len
(
args
)):
if
not
isinstance
(
args
[
i
],
SymbolVar
):
rst
=
make_const
(
graph
,
np
.
array
(
args
[
i
]),
device
.
to_c
(),
dtype
)
args
[
i
]
=
sym_type
(
rst
)
def
convert
(
value
):
if
value
is
None
:
return
value
return
convert_single_value
(
value
,
dtype
=
dtype
,
device
=
device
.
to_c
())
return
tuple
(
map
(
convert
,
args
))
def
cast_tensors
(
*
args
,
promote
=
False
):
def
cast_tensors
(
*
args
,
promote
=
False
):
...
@@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
...
@@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
pass
pass
except
ValueError
:
except
ValueError
:
if
dtype
is
not
None
and
dtype
!=
x
.
dtype
:
if
dtype
is
not
None
and
dtype
!=
x
.
dtype
:
x
=
astype
(
x
,
dtype
)
x
=
astype
_cpp
(
x
,
dtype
)
if
device
is
not
None
:
if
device
is
not
None
:
cn
=
as_device
(
device
).
to_c
()
cn
=
as_device
(
device
).
to_c
()
(
x
,)
=
apply
(
builtin
.
Copy
(
comp_node
=
cn
),
x
)
(
x
,)
=
apply
(
builtin
.
Copy
(
comp_node
=
cn
),
x
)
...
@@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
...
@@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if
any
(
isinstance
(
i
,
(
Tensor
,
SymbolVar
))
for
i
in
x
):
if
any
(
isinstance
(
i
,
(
Tensor
,
SymbolVar
))
for
i
in
x
):
x
=
concatenate
(
x
,
device
=
device
)
if
len
(
x
)
>
1
else
x
[
0
]
x
=
concatenate
(
x
,
device
=
device
)
if
len
(
x
)
>
1
else
x
[
0
]
if
dtype
is
not
None
:
if
dtype
is
not
None
:
x
=
astype
(
x
,
dtype
)
x
=
astype
_cpp
(
x
,
dtype
)
return
x
return
x
x
=
Const
(
x
,
dtype
,
device
,
reference
)
x
=
Const
(
x
,
dtype
,
device
,
reference
)
return
x
return
x
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
fa62f6c0
...
@@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph
...
@@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.utils
import
(
from
..core.tensor.utils
import
(
astensor1d
,
astensor1d
,
astype
,
cast_tensors
,
cast_tensors
,
convert_single_value
,
convert_single_value
,
make_shape_tuple
,
make_shape_tuple
,
...
...
imperative/python/src/numpy_dtypes.cpp
浏览文件 @
fa62f6c0
...
@@ -170,6 +170,12 @@ struct _wrap {
...
@@ -170,6 +170,12 @@ struct _wrap {
}
// anonymous namespace
}
// anonymous namespace
namespace
imperative
::
python
{
bool
dtype_equal
(
PyArray_Descr
*
dt1
,
PyArray_Descr
*
dt2
)
{
return
_is_dtype_equal
(
dt1
,
dt2
);
}
}
// namespace imperative::python
#ifdef METH_FASTCALL
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUN) \
#define MGE_PY_INTERFACE(NAME, FUN) \
{ #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr }
{ #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr }
...
...
imperative/python/src/numpy_dtypes.h
浏览文件 @
fa62f6c0
...
@@ -26,6 +26,11 @@
...
@@ -26,6 +26,11 @@
cb(BFloat16, npy_num_bfloat16())
cb(BFloat16, npy_num_bfloat16())
namespace
mgb
{
namespace
mgb
{
namespace
imperative
::
python
{
bool
dtype_equal
(
PyArray_Descr
*
dt1
,
PyArray_Descr
*
dt2
);
}
// namespace imperative::python
//! numpy type num for intb1/2/4 type
//! numpy type num for intb1/2/4 type
#define DEFINE_NPY_INTBX(n) int npy_num_intb##n();
#define DEFINE_NPY_INTBX(n) int npy_num_intb##n();
FOREACH_MGB_LOW_BIT
(
DEFINE_NPY_INTBX
)
FOREACH_MGB_LOW_BIT
(
DEFINE_NPY_INTBX
)
...
...
imperative/python/src/tensor.cpp
浏览文件 @
fa62f6c0
...
@@ -400,223 +400,6 @@ struct TensorWeakRef {
...
@@ -400,223 +400,6 @@ struct TensorWeakRef {
int
_use_cnt
()
{
return
wptr
.
use_count
();
}
int
_use_cnt
()
{
return
wptr
.
use_count
();
}
};
};
/* ============== convert inputs ============== */
// map numpy.dtype.kind to priority
inline
uint8_t
category_priority
(
char
c
)
{
switch
(
c
)
{
case
'f'
:
return
3
;
// floating-point
case
'i'
:
return
2
;
// signed integer
case
'u'
:
return
2
;
// unsigned integer
case
'b'
:
return
1
;
// boolean
default:
return
0
;
}
}
// Returns the maximum value of the priority of each type in the list `types`.
uint8_t
max_priority
(
SmallVector
<
PyArray_Descr
*>
types
)
{
if
(
types
.
size
()
==
0
)
{
return
0
;
}
else
{
uint8_t
max_p
=
0
;
for
(
auto
&&
desc
:
types
)
{
max_p
=
std
::
max
(
max_p
,
category_priority
(
desc
->
kind
));
}
return
max_p
;
}
}
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr
*
promote_types
(
SmallVector
<
PyArray_Descr
*>
types
,
uint8_t
cat
)
{
// Return value: New reference
SmallVector
<
PyArray_Descr
*>
used_types
;
for
(
auto
&&
desc
:
types
)
{
auto
&&
v
=
category_priority
(
desc
->
kind
);
if
(
v
==
cat
)
{
used_types
.
emplace_back
(
desc
);
}
}
mgb_assert
(
used_types
.
size
()
>
0
,
"size of used_types is 0"
);
PyArray_Descr
*
res
=
used_types
[
0
];
Py_INCREF
(
res
);
for
(
size_t
i
=
1
;
i
<
used_types
.
size
();
++
i
)
{
PyArray_Descr
*
tmp
=
PyArray_PromoteTypes
(
used_types
[
i
],
res
);
Py_DECREF
(
res
);
res
=
tmp
;
}
return
res
;
}
PyArray_Descr
*
scalar2dtype
(
PyObject
*
arg
)
{
// Return value: New reference
if
(
PyBool_Check
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_BOOL
);
return
descr
;
}
if
(
PyLong_CheckExact
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_INT32
);
return
descr
;
}
if
(
PyFloat_CheckExact
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_FLOAT32
);
return
descr
;
}
return
nullptr
;
}
PyArray_Descr
*
_dtype_promotion
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
// Return value: New reference
SmallVector
<
PyArray_Descr
*>
tensors
;
SmallVector
<
PyArray_Descr
*>
scalars
;
bool
is_tuple
=
false
;
PyObject
*
tuple
=
nullptr
;
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
}
else
{
tuple
=
args
[
0
];
Py_INCREF
(
tuple
);
}
nargs
=
PyTuple_Size
(
tuple
);
is_tuple
=
true
;
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
if
(
handle
==
Py_None
)
continue
;
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
if
(
tw
)
{
mgb
::
DType
type
=
tw
->
m_tensor
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
}
else
{
if
(
PyArray_Check
(
handle
)
||
PyArray_CheckScalar
(
handle
))
{
auto
&&
descr
=
PyArray_DescrFromObject
(
handle
,
nullptr
);
tensors
.
emplace_back
(
descr
);
continue
;
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
handle
)))
{
auto
var
=
py
::
handle
(
handle
).
cast
<
PySymbolVar
*>
();
mgb
::
DType
type
=
var
->
m_node
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
continue
;
}
PyArray_Descr
*
descr
=
scalar2dtype
(
handle
);
if
(
descr
)
{
scalars
.
emplace_back
(
descr
);
continue
;
}
}
}
auto
max_pri_scalars
=
max_priority
(
scalars
);
auto
max_pri_tensors
=
max_priority
(
tensors
);
if
(
max_pri_scalars
<=
0
&&
max_pri_tensors
<=
0
)
{
throw
py
::
value_error
(
"invalid input, no dtype avaliable"
);
}
PyArray_Descr
*
res
;
if
(
max_pri_scalars
>
max_pri_tensors
)
{
res
=
promote_types
(
scalars
,
max_pri_scalars
);
}
else
{
res
=
promote_types
(
tensors
,
max_pri_tensors
);
}
for
(
auto
*
p
:
tensors
)
{
Py_DECREF
(
p
);
}
for
(
auto
*
p
:
scalars
)
{
Py_DECREF
(
p
);
}
Py_XDECREF
(
tuple
);
return
res
;
}
CompNode
_get_device
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
bool
is_tuple
=
false
;
PyObject
*
tuple
=
nullptr
;
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
}
else
{
tuple
=
args
[
0
];
Py_INCREF
(
tuple
);
}
nargs
=
PyTuple_Size
(
tuple
);
is_tuple
=
true
;
}
bool
valid
=
false
;
CompNode
cn
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
bool
is_symvar
=
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
handle
));
if
(
tw
||
is_symvar
)
{
if
(
!
valid
)
{
cn
=
tw
?
tw
->
m_tensor
->
comp_node
()
:
py
::
handle
(
handle
).
cast
<
PySymbolVar
*>
()
->
m_node
->
comp_node
();
valid
=
true
;
}
else
{
CompNode
cn1
=
tw
?
tw
->
m_tensor
->
comp_node
()
:
py
::
handle
(
handle
)
.
cast
<
PySymbolVar
*>
()
->
m_node
->
comp_node
();
if
(
cn1
!=
cn
)
{
throw
py
::
value_error
(
ssprintf
(
"ambiguous device: %s (from %s) vs %s (from %s)"
,
cn
.
to_string
().
c_str
(),
cn
.
to_string_logical
().
c_str
(),
cn1
.
to_string
().
c_str
(),
cn1
.
to_string_logical
().
c_str
()));
}
}
}
}
if
(
!
valid
)
{
return
CompNode
::
load
(
get_default_device
());
}
Py_XDECREF
(
tuple
);
return
cn
;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject
*
dtype_promotion
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
!
nargs
)
{
PyErr_SetString
(
PyExc_TypeError
,
"empty input is not allowed"
);
return
nullptr
;
}
try
{
PyArray_Descr
*
res
=
_dtype_promotion
(
args
,
nargs
);
return
py
::
cast
(
npy
::
dtype_np2mgb_descr
(
res
)).
release
().
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
get_device
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
!
nargs
)
{
PyErr_SetString
(
PyExc_TypeError
,
"empty input is not allowed"
);
return
nullptr
;
}
try
{
CompNode
cn
=
_get_device
(
args
,
nargs
);
return
py
::
cast
(
cn
).
release
().
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
#ifdef METH_FASTCALL
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
...
@@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp);
...
@@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35
(
broadcast_cpp
);
WRAP_FUNC_PY35
(
broadcast_cpp
);
WRAP_FUNC_PY35
(
reshape_cpp
);
WRAP_FUNC_PY35
(
reshape_cpp
);
WRAP_FUNC_PY35
(
Const
);
WRAP_FUNC_PY35
(
Const
);
WRAP_FUNC_PY35
(
astype_cpp
);
WRAP_FUNC_PY35
(
convert_single_value_cpp
);
WRAP_FUNC_PY35
(
convert_inputs_cpp
);
#undef WRAP_FUNC_PY35
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
...
@@ -779,6 +565,9 @@ void init_tensor(py::module m) {
...
@@ -779,6 +565,9 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE
(
broadcast_cpp
,
broadcast_cpp
),
MGE_PY_INTERFACE
(
broadcast_cpp
,
broadcast_cpp
),
MGE_PY_INTERFACE
(
reshape_cpp
,
reshape_cpp
),
MGE_PY_INTERFACE
(
reshape_cpp
,
reshape_cpp
),
MGE_PY_INTERFACE
(
Const
,
Const
),
MGE_PY_INTERFACE
(
Const
,
Const
),
MGE_PY_INTERFACE
(
astype_cpp
,
astype_cpp
),
MGE_PY_INTERFACE
(
convert_single_value_cpp
,
convert_single_value_cpp
),
MGE_PY_INTERFACE
(
convert_inputs_cpp
,
convert_inputs_cpp
),
{
nullptr
,
nullptr
,
0
,
nullptr
}};
{
nullptr
,
nullptr
,
0
,
nullptr
}};
for
(
auto
&&
def
:
method_defs
)
{
for
(
auto
&&
def
:
method_defs
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
fa62f6c0
...
@@ -52,6 +52,223 @@ namespace views = ranges::views;
...
@@ -52,6 +52,223 @@ namespace views = ranges::views;
namespace
mgb
::
imperative
::
python
{
namespace
mgb
::
imperative
::
python
{
/* ============== convert inputs ============== */
// map numpy.dtype.kind to priority
inline
uint8_t
category_priority
(
char
c
)
{
switch
(
c
)
{
case
'f'
:
return
3
;
// floating-point
case
'i'
:
return
2
;
// signed integer
case
'u'
:
return
2
;
// unsigned integer
case
'b'
:
return
1
;
// boolean
default:
return
0
;
}
}
// Returns the maximum value of the priority of each type in the list `types`.
uint8_t
max_priority
(
SmallVector
<
PyArray_Descr
*>
types
)
{
if
(
types
.
size
()
==
0
)
{
return
0
;
}
else
{
uint8_t
max_p
=
0
;
for
(
auto
&&
desc
:
types
)
{
max_p
=
std
::
max
(
max_p
,
category_priority
(
desc
->
kind
));
}
return
max_p
;
}
}
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr
*
promote_types
(
SmallVector
<
PyArray_Descr
*>
types
,
uint8_t
cat
)
{
// Return value: New reference
SmallVector
<
PyArray_Descr
*>
used_types
;
for
(
auto
&&
desc
:
types
)
{
auto
&&
v
=
category_priority
(
desc
->
kind
);
if
(
v
==
cat
)
{
used_types
.
emplace_back
(
desc
);
}
}
mgb_assert
(
used_types
.
size
()
>
0
,
"size of used_types is 0"
);
PyArray_Descr
*
res
=
used_types
[
0
];
Py_INCREF
(
res
);
for
(
size_t
i
=
1
;
i
<
used_types
.
size
();
++
i
)
{
PyArray_Descr
*
tmp
=
PyArray_PromoteTypes
(
used_types
[
i
],
res
);
Py_DECREF
(
res
);
res
=
tmp
;
}
return
res
;
}
PyArray_Descr
*
scalar2dtype
(
PyObject
*
arg
)
{
// Return value: New reference
if
(
PyBool_Check
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_BOOL
);
return
descr
;
}
if
(
PyLong_CheckExact
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_INT32
);
return
descr
;
}
if
(
PyFloat_CheckExact
(
arg
))
{
auto
&&
descr
=
PyArray_DescrFromType
(
NPY_FLOAT32
);
return
descr
;
}
return
nullptr
;
}
PyArray_Descr
*
_dtype_promotion
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
// Return value: New reference
SmallVector
<
PyArray_Descr
*>
tensors
;
SmallVector
<
PyArray_Descr
*>
scalars
;
bool
is_tuple
=
false
;
PyObject
*
tuple
=
nullptr
;
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
}
else
{
tuple
=
args
[
0
];
Py_INCREF
(
tuple
);
}
nargs
=
PyTuple_Size
(
tuple
);
is_tuple
=
true
;
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
if
(
handle
==
Py_None
)
continue
;
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
if
(
tw
)
{
mgb
::
DType
type
=
tw
->
m_tensor
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
}
else
{
if
(
PyArray_Check
(
handle
)
||
PyArray_CheckScalar
(
handle
))
{
auto
&&
descr
=
PyArray_DescrFromObject
(
handle
,
nullptr
);
tensors
.
emplace_back
(
descr
);
continue
;
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
handle
)))
{
auto
var
=
py
::
handle
(
handle
).
cast
<
PySymbolVar
*>
();
mgb
::
DType
type
=
var
->
m_node
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
continue
;
}
PyArray_Descr
*
descr
=
scalar2dtype
(
handle
);
if
(
descr
)
{
scalars
.
emplace_back
(
descr
);
continue
;
}
}
}
auto
max_pri_scalars
=
max_priority
(
scalars
);
auto
max_pri_tensors
=
max_priority
(
tensors
);
if
(
max_pri_scalars
<=
0
&&
max_pri_tensors
<=
0
)
{
throw
py
::
value_error
(
"invalid input, no dtype avaliable"
);
}
PyArray_Descr
*
res
;
if
(
max_pri_scalars
>
max_pri_tensors
)
{
res
=
promote_types
(
scalars
,
max_pri_scalars
);
}
else
{
res
=
promote_types
(
tensors
,
max_pri_tensors
);
}
for
(
auto
*
p
:
tensors
)
{
Py_DECREF
(
p
);
}
for
(
auto
*
p
:
scalars
)
{
Py_DECREF
(
p
);
}
Py_XDECREF
(
tuple
);
return
res
;
}
CompNode
_get_device
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
bool
is_tuple
=
false
;
PyObject
*
tuple
=
nullptr
;
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
}
else
{
tuple
=
args
[
0
];
Py_INCREF
(
tuple
);
}
nargs
=
PyTuple_Size
(
tuple
);
is_tuple
=
true
;
}
bool
valid
=
false
;
CompNode
cn
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
bool
is_symvar
=
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
handle
));
if
(
tw
||
is_symvar
)
{
if
(
!
valid
)
{
cn
=
tw
?
tw
->
m_tensor
->
comp_node
()
:
py
::
handle
(
handle
).
cast
<
PySymbolVar
*>
()
->
m_node
->
comp_node
();
valid
=
true
;
}
else
{
CompNode
cn1
=
tw
?
tw
->
m_tensor
->
comp_node
()
:
py
::
handle
(
handle
)
.
cast
<
PySymbolVar
*>
()
->
m_node
->
comp_node
();
if
(
cn1
!=
cn
)
{
throw
py
::
value_error
(
ssprintf
(
"ambiguous device: %s (from %s) vs %s (from %s)"
,
cn
.
to_string
().
c_str
(),
cn
.
to_string_logical
().
c_str
(),
cn1
.
to_string
().
c_str
(),
cn1
.
to_string_logical
().
c_str
()));
}
}
}
}
if
(
!
valid
)
{
return
CompNode
::
load
(
get_default_device
());
}
Py_XDECREF
(
tuple
);
return
cn
;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject
*
dtype_promotion
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
!
nargs
)
{
PyErr_SetString
(
PyExc_TypeError
,
"empty input is not allowed"
);
return
nullptr
;
}
try
{
PyArray_Descr
*
res
=
_dtype_promotion
(
args
,
nargs
);
return
py
::
cast
(
npy
::
dtype_np2mgb_descr
(
res
)).
release
().
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
get_device
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
!
nargs
)
{
PyErr_SetString
(
PyExc_TypeError
,
"empty input is not allowed"
);
return
nullptr
;
}
try
{
CompNode
cn
=
_get_device
(
args
,
nargs
);
return
py
::
cast
(
cn
).
release
().
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
bool
is_scalar
(
PyObject
*
tensor
)
{
bool
is_scalar
(
PyObject
*
tensor
)
{
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
tensor
)))
{
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
tensor
)))
{
auto
var
=
py
::
handle
(
tensor
).
cast
<
PySymbolVar
*>
();
auto
var
=
py
::
handle
(
tensor
).
cast
<
PySymbolVar
*>
();
...
@@ -147,7 +364,6 @@ py::object _Const(
...
@@ -147,7 +364,6 @@ py::object _Const(
"dmap_callback"
);
"dmap_callback"
);
if
(
dmap
.
ptr
()
!=
Py_None
)
{
if
(
dmap
.
ptr
()
!=
Py_None
)
{
device_obj
=
dmap
(
device
);
device_obj
=
dmap
(
device
);
py
::
print
(
device_obj
);
}
else
{
}
else
{
device_obj
=
py
::
cast
(
CompNode
::
load
(
device
.
cast
<
std
::
string
>
()));
device_obj
=
py
::
cast
(
CompNode
::
load
(
device
.
cast
<
std
::
string
>
()));
}
}
...
@@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
...
@@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
return
ret
[
0
];
return
ret
[
0
];
}
}
mgb
::
DType
_get_dtype
(
py
::
handle
tensor
)
{
if
(
auto
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
()))
{
return
tw
->
m_tensor
->
dtype
();
}
else
{
auto
var
=
tensor
.
cast
<
PySymbolVar
*>
();
return
var
->
m_node
->
dtype
();
}
}
py
::
object
_astype_cpp
(
py
::
handle
tensor
,
py
::
handle
dtype_hdl
)
{
PyArray_Descr
*
descr
;
if
(
!
PyArray_DescrConverter
(
dtype_hdl
.
ptr
(),
&
descr
))
{
throw
py
::
value_error
(
ssprintf
(
"can not convert to numpy.dtype from %s"
,
dtype_hdl
.
ptr
()
->
ob_type
->
tp_name
));
}
PyArray_Descr
*
cur
=
npy
::
dtype_mgb2np_descr
(
_get_dtype
(
tensor
)).
get
();
if
(
!
dtype_equal
(
cur
,
descr
))
{
std
::
shared_ptr
<
OpDef
>
op
=
TypeCvt
::
make
(
npy
::
dtype_np2mgb_descr
(
descr
));
py
::
object
Op
=
py
::
cast
(
op
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
tensor
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
return
ret
[
0
];
}
else
{
return
py
::
reinterpret_borrow
<
py
::
object
>
(
tensor
);
}
}
py
::
object
_convert_single_value_cpp
(
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
)
{
if
(
is_tensor_or_symbolvar
(
value
))
{
if
(
_get_dtype
(
value
).
category
()
!=
DTypeCategory
::
QUANTIZED
)
{
return
_astype_cpp
(
value
,
dtype
);
}
}
else
{
return
_Const
(
value
,
dtype
,
device
,
py
::
none
());
}
return
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
}
py
::
object
_convert_inputs_cpp
(
PyObject
*
const
*
args
,
size_t
nargs
,
py
::
object
dtype
,
py
::
object
device
)
{
ComputingGraph
*
graph
=
nullptr
;
py
::
handle
typeobj
;
py
::
list
lis
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
py
::
handle
h
=
py
::
handle
(
args
[
i
]);
lis
.
append
(
h
);
if
(
py
::
isinstance
<
PySymbolVar
>
(
h
))
{
auto
var
=
h
.
cast
<
PySymbolVar
*>
();
auto
g
=
var
->
m_node
->
owner_graph
();
if
(
!
graph
)
{
graph
=
g
;
typeobj
=
h
.
get_type
();
}
else
{
mgb_assert
(
graph
==
g
);
}
}
}
if
(
graph
)
{
CompNode
cn
=
device
.
cast
<
CompNode
>
();
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
OperatorNodeConfig
config
(
cn
);
auto
hv
=
npy
::
np2tensor
(
lis
[
i
].
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
.
cast
<
mgb
::
DType
>
());
if
(
py
::
isinstance
<
PySymbolVar
>
(
lis
[
i
]))
{
lis
[
i
]
=
typeobj
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
());
}
}
}
auto
convert
=
[
&
](
py
::
object
value
)
{
if
(
value
.
ptr
()
==
Py_None
)
{
return
value
;
}
return
_convert_single_value_cpp
(
value
,
dtype
,
device
);
};
for
(
size_t
i
=
0
;
i
<
lis
.
size
();
++
i
)
{
lis
[
i
]
=
convert
(
lis
[
i
]);
}
return
py
::
reinterpret_steal
<
py
::
tuple
>
(
PyList_AsTuple
(
lis
.
ptr
()));
}
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
try
{
return
_make_shape_tuple
(
py
::
handle
(
args
[
0
])).
release
().
ptr
();
return
_make_shape_tuple
(
py
::
handle
(
args
[
0
])).
release
().
ptr
();
...
@@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
...
@@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
PyObject
*
astype_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_astype_cpp
(
py
::
handle
(
args
[
0
]),
py
::
handle
(
args
[
1
])).
release
().
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
convert_single_value_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_convert_single_value_cpp
(
py
::
handle
(
args
[
0
]),
py
::
handle
(
args
[
1
]),
py
::
handle
(
args
[
2
]))
.
release
()
.
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
convert_inputs_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
py
::
object
dtype
=
py
::
reinterpret_steal
<
py
::
object
>
(
dtype_promotion
(
self
,
args
,
nargs
-
1
));
py
::
object
device
;
if
(
args
[
nargs
-
1
]
==
Py_None
)
{
device
=
py
::
reinterpret_steal
<
py
::
object
>
(
get_device
(
self
,
args
,
nargs
-
1
));
}
else
{
device
=
py
::
reinterpret_borrow
<
py
::
object
>
(
args
[
nargs
-
1
]);
}
return
_convert_inputs_cpp
(
args
,
nargs
-
1
,
dtype
,
device
).
release
().
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
imperative/python/src/tensor_utils.h
浏览文件 @
fa62f6c0
...
@@ -2,6 +2,10 @@
...
@@ -2,6 +2,10 @@
namespace
mgb
::
imperative
::
python
{
namespace
mgb
::
imperative
::
python
{
PyObject
*
dtype_promotion
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
get_device
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
getitem_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
getitem_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
...
@@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);
...
@@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject
*
Const
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
Const
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
astype_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
convert_single_value_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
convert_inputs_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录