Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2f4a75e7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2f4a75e7
编写于
9月 28, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/utils): redesign dtype promotion
GitOrigin-RevId: 4f2fe1b6ce1430e96cb8bac34ffbbc46548007f5
上级
2e9ba679
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
68 addition
and
33 deletion
+68
-33
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+68
-33
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
2f4a75e7
...
...
@@ -16,39 +16,74 @@ from ..ops.special import Const
from
..tensor.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
def
dtype_promotion
(
raw_inputs
):
def
add_dtype
(
i
):
if
type
(
i
)
==
int
:
return
np
.
array
(
i
,
dtype
=
np
.
int32
)
if
type
(
i
)
==
float
:
return
np
.
array
(
i
,
dtype
=
np
.
float32
)
if
type
(
i
)
==
bool
:
return
np
.
array
(
i
,
dtype
=
np
.
bool_
)
return
None
scalar_inputs
=
[
add_dtype
(
i
)
for
i
in
raw_inputs
if
not
hasattr
(
i
,
"dtype"
)
and
add_dtype
(
i
)
]
inputs
=
[
i
for
i
in
raw_inputs
if
hasattr
(
i
,
"dtype"
)]
assert
len
(
scalar_inputs
+
inputs
)
>
0
dtype
=
None
if
len
(
inputs
)
>
0
:
dtype
=
np
.
result_type
(
*
inputs
)
dtype_all
=
np
.
result_type
(
*
(
inputs
+
scalar_inputs
))
assert
(
dtype
!=
np
.
float64
and
dtype
!=
np
.
int64
),
"unsupport dtype {} by dtype_promotion, please use explict type convert"
.
format
(
dtype
)
if
dtype_all
==
np
.
bool_
:
for
i
in
raw_inputs
:
if
not
hasattr
(
i
,
"dtype"
)
or
i
.
dtype
!=
np
.
bool_
:
raise
TypeError
(
"bool dtype can not be operated with an element without bool dtype"
)
if
dtype_all
==
np
.
float64
:
dtype_all
=
np
.
float32
return
dtype_all
def
dtype_promotion
(
inputs
):
"""
Returns the dtype that would result from performing an arithmetic
operation on the provided input tensors and scalars.
"""
# map numpy.dtype.kind to priority
category_priority
=
{
"f"
:
3
,
# floating-point
"i"
:
2
,
# signed integer
"u"
:
2
,
# unsigned integer
"b"
:
1
,
# boolean
}
def
scalar2dtype
(
x
):
"""
For scalar `x`, returns its corresponding type. A floating point scalar
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'.
A boolean scalar has dtype 'bool'.
"""
if
isinstance
(
x
,
bool
):
return
np
.
bool_
if
isinstance
(
x
,
int
):
return
np
.
int32
if
isinstance
(
x
,
float
):
return
np
.
float32
def
promote_types
(
types
,
cat
):
"""
Returns the data type with sufficient size to hold all types of
category `cat` in the list `types`.
"""
used_types
=
[
i
for
i
in
types
if
category_priority
.
get
(
np
.
dtype
(
i
).
kind
,
0
)
==
cat
]
assert
len
(
used_types
)
>
0
res
=
used_types
[
0
]
for
i
in
used_types
:
res
=
np
.
promote_types
(
res
,
i
)
return
res
def
max_priority
(
types
):
"""
Returns the maximum value of the priority of each type in the list
`types`.
"""
if
not
types
:
return
0
else
:
return
max
([
category_priority
.
get
(
np
.
dtype
(
i
).
kind
,
0
)
for
i
in
types
])
scalars
=
[]
tensors
=
[]
for
data
in
inputs
:
if
hasattr
(
data
,
"dtype"
):
tensors
.
append
(
data
.
dtype
)
elif
isinstance
(
data
,
(
float
,
int
,
bool
)):
scalars
.
append
(
scalar2dtype
(
data
))
max_pri_scalars
=
max_priority
(
scalars
)
max_pri_tensors
=
max_priority
(
tensors
)
assert
max_pri_scalars
>
0
or
max_pri_tensors
>
0
if
max_pri_scalars
>
max_pri_tensors
:
return
promote_types
(
scalars
,
max_pri_scalars
)
else
:
return
promote_types
(
tensors
,
max_pri_tensors
)
def
get_device
(
inputs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录