Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
938152af
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
938152af
编写于
1月 13, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/functional): convert input type to float32 for more elemwise op
GitOrigin-RevId: cf3bf8cb805a3229700dd2939393a3994bc59f35
上级
19466046
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
72 addition
and
23 deletion
+72
-23
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+25
-12
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+29
-11
imperative/python/test/unit/functional/test_elemwise.py
imperative/python/test/unit/functional/test_elemwise.py
+18
-0
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
938152af
...
...
@@ -27,9 +27,31 @@ from .utils import setscalar
_ElwMod
=
Elemwise
.
Mode
def
_elwise
(
*
args
,
mode
):
def
_elwise
_apply
(
args
,
mode
):
op
=
builtin
.
Elemwise
(
mode
)
if
mode
in
(
_ElwMod
.
TRUE_DIV
,
_ElwMod
.
POW
):
_isscalar
=
True
for
i
in
args
:
if
isscalar
(
i
)
==
False
:
_isscalar
=
False
break
(
result
,)
=
apply
(
op
,
*
args
)
if
_isscalar
:
setscalar
(
result
)
return
result
def
_elwise
(
*
args
,
mode
):
if
mode
in
(
_ElwMod
.
TRUE_DIV
,
_ElwMod
.
POW
,
_ElwMod
.
CEIL
,
_ElwMod
.
FLOOR
,
_ElwMod
.
ROUND
,
):
if
mode
in
(
_ElwMod
.
CEIL
,
_ElwMod
.
FLOOR
,
_ElwMod
.
ROUND
)
and
np
.
issubdtype
(
args
[
0
].
dtype
,
np
.
integer
):
return
args
[
0
]
args
=
tuple
(
map
(
lambda
x
:
x
.
astype
(
"float32"
)
...
...
@@ -39,16 +61,7 @@ def _elwise(*args, mode):
)
)
args
=
utils
.
convert_inputs
(
*
args
)
(
result
,)
=
apply
(
op
,
*
args
)
_isscalar
=
True
for
i
in
args
:
if
isscalar
(
i
)
==
False
:
_isscalar
=
False
break
if
_isscalar
:
setscalar
(
result
)
return
result
return
_elwise_apply
(
args
,
mode
)
def
_matmul
(
inp1
,
inp2
):
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
938152af
...
...
@@ -9,10 +9,13 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import
functools
import
numpy
as
np
from
..core._imperative_rt.core2
import
apply
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Elemwise
from
..core.tensor
import
megbrain_graph
,
utils
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..device
import
get_default_device
from
..jit.tracing
import
is_tracing
...
...
@@ -74,7 +77,6 @@ __all__ = [
def
_elwise
(
*
args
,
mode
):
op
=
builtin
.
Elemwise
(
mode
)
tensor_args
=
list
(
filter
(
lambda
x
:
isinstance
(
x
,
(
Tensor
,
megbrain_graph
.
VarNode
)),
args
)
)
...
...
@@ -84,17 +86,33 @@ def _elwise(*args, mode):
args
=
utils
.
convert_inputs
(
first_arg
,
*
args
[
1
:])
else
:
args
=
utils
.
convert_inputs
(
*
args
)
if
mode
in
(
"true_div"
,
"exp"
,
"pow"
,
"log"
,
"expm1"
,
"log1p"
):
if
mode
in
(
Elemwise
.
Mode
.
TRUE_DIV
,
Elemwise
.
Mode
.
EXP
,
Elemwise
.
Mode
.
POW
,
Elemwise
.
Mode
.
LOG
,
Elemwise
.
Mode
.
EXPM1
,
Elemwise
.
Mode
.
LOG1P
,
Elemwise
.
Mode
.
TANH
,
Elemwise
.
Mode
.
ACOS
,
Elemwise
.
Mode
.
ASIN
,
Elemwise
.
Mode
.
ATAN2
,
Elemwise
.
Mode
.
CEIL
,
Elemwise
.
Mode
.
COS
,
Elemwise
.
Mode
.
FLOOR
,
Elemwise
.
Mode
.
H_SWISH
,
Elemwise
.
Mode
.
ROUND
,
Elemwise
.
Mode
.
SIGMOID
,
Elemwise
.
Mode
.
SIN
,
):
if
mode
in
(
Elemwise
.
Mode
.
CEIL
,
Elemwise
.
Mode
.
FLOOR
,
Elemwise
.
Mode
.
ROUND
,
)
and
np
.
issubdtype
(
args
[
0
].
dtype
,
np
.
integer
):
return
args
[
0
]
args
=
tuple
(
map
(
lambda
x
:
x
.
astype
(
"float32"
),
args
))
_isscalar
=
True
for
i
in
args
:
if
isscalar
(
i
)
==
False
:
_isscalar
=
False
break
(
result
,)
=
apply
(
op
,
*
args
)
if
_isscalar
:
setscalar
(
result
)
return
result
return
_elwise_apply
(
args
,
mode
)
def
_elemwise_multi_type
(
*
args
,
mode
,
**
kwargs
):
...
...
imperative/python/test/unit/functional/test_elemwise.py
浏览文件 @
938152af
...
...
@@ -9,6 +9,7 @@
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.functional.elemwise
as
elemwise
from
megengine
import
tensor
from
megengine.core.tensor
import
dtype
from
megengine.functional.elemwise
import
_elwise
...
...
@@ -166,3 +167,20 @@ def test_qadd():
result_mge
=
result_mge
.
astype
(
"float32"
).
numpy
()
result_expect
=
x
.
astype
(
"float32"
).
numpy
()
+
y
.
astype
(
"float32"
).
numpy
()
np
.
testing
.
assert_almost_equal
(
result_mge
,
result_expect
,
decimal
=
6
)
def
test_int32_input
():
x
=
tensor
(
np
.
array
([
1
,
2
,
3
,
4
,
5
]),
dtype
=
"int32"
)
for
op_name
in
elemwise
.
__all__
:
op
=
getattr
(
elemwise
,
op_name
)
nargs
=
op
.
__code__
.
co_argcount
if
op_name
==
"clip"
:
inp
=
(
x
,
0
,
1
)
elif
op_name
.
endswith
(
"_shift"
):
inp
=
(
x
,
1
)
elif
op_name
.
startswith
(
"logical_"
):
continue
else
:
inp
=
(
x
,)
*
nargs
y
=
op
(
*
inp
)
y
.
numpy
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录