Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9785178b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9785178b
编写于
6月 02, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tensor compare & len & constexpr operation
上级
5c4731b7
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
51 addition
and
6 deletion
+51
-6
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+5
-0
mindspore/common/tensor.py
mindspore/common/tensor.py
+28
-6
mindspore/ops/functional.py
mindspore/ops/functional.py
+8
-0
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-0
tests/ut/python/pynative_mode/test_parse_method.py
tests/ut/python/pynative_mode/test_parse_method.py
+9
-0
未找到文件。
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
9785178b
...
@@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) {
...
@@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) {
value_ret
[
0
]
=
output
[
"value"
];
value_ret
[
0
]
=
output
[
"value"
];
return
value_ret
;
return
value_ret
;
}
}
if
(
py
::
hasattr
(
op_exec_info
->
py_primitive
->
GetPyObj
(),
"const_value"
))
{
py
::
tuple
value_ret
(
1
);
value_ret
[
0
]
=
""
;
return
value_ret
;
}
}
}
MS_LOG
(
INFO
)
<<
"RunOp start, op name is: "
<<
op_exec_info
->
op_name
;
MS_LOG
(
INFO
)
<<
"RunOp start, op name is: "
<<
op_exec_info
->
op_name
;
mindspore
::
parse
::
python_adapter
::
set_python_env_flag
(
true
);
mindspore
::
parse
::
python_adapter
::
set_python_env_flag
(
true
);
...
...
mindspore/common/tensor.py
浏览文件 @
9785178b
...
@@ -71,19 +71,18 @@ class Tensor(Tensor_):
...
@@ -71,19 +71,18 @@ class Tensor(Tensor_):
return
str
(
self
.
__str__
())
return
str
(
self
.
__str__
())
def
__add__
(
self
,
other
):
def
__add__
(
self
,
other
):
check_type
(
'tensor input_data'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
return
out
return
out
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
if
not
isinstance
(
other
,
Tensor
):
return
False
return
False
return
Tensor
(
np
.
array
(
self
.
asnumpy
()
==
other
.
asnumpy
())
)
return
tensor_operator_registry
.
get
(
'__eq__'
)(
self
,
other
)
def
__ne__
(
self
,
other
):
def
__ne__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
if
not
isinstance
(
other
,
Tensor
):
return
True
return
True
return
Tensor
(
np
.
array
(
self
.
asnumpy
()
!=
other
.
asnumpy
())
)
return
tensor_operator_registry
.
get
(
'__ne__'
)(
self
,
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
id
(
self
))
return
hash
(
id
(
self
))
...
@@ -93,7 +92,8 @@ class Tensor(Tensor_):
...
@@ -93,7 +92,8 @@ class Tensor(Tensor_):
return
out
return
out
def
__neg__
(
self
):
def
__neg__
(
self
):
return
Tensor
(
-
self
.
asnumpy
())
out
=
tensor_operator_registry
.
get
(
'__neg__'
)(
self
)
return
out
def
__iadd__
(
self
,
other
):
def
__iadd__
(
self
,
other
):
out
=
self
.
__add__
(
other
)
out
=
self
.
__add__
(
other
)
...
@@ -120,7 +120,7 @@ class Tensor(Tensor_):
...
@@ -120,7 +120,7 @@ class Tensor(Tensor_):
return
out
return
out
def
__sub__
(
self
,
other
):
def
__sub__
(
self
,
other
):
out
=
self
.
__add__
(
-
other
)
out
=
tensor_operator_registry
.
get
(
'__sub__'
)(
self
,
other
)
return
out
return
out
def
__isub__
(
self
,
other
):
def
__isub__
(
self
,
other
):
...
@@ -128,9 +128,31 @@ class Tensor(Tensor_):
...
@@ -128,9 +128,31 @@ class Tensor(Tensor_):
return
out
return
out
def
__rsub__
(
self
,
other
):
def
__rsub__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
other
,
Tensor
(
-
self
.
asnumpy
()))
out
=
tensor_operator_registry
.
get
(
'__sub__'
)(
other
,
self
)
return
out
def
__lt__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__lt__'
)(
self
,
other
)
return
out
def
__le__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__le__'
)(
self
,
other
)
return
out
return
out
def
__gt__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__gt__'
)(
self
,
other
)
return
out
def
__ge__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__ge__'
)(
self
,
other
)
return
out
def
__len__
(
self
):
out
=
tensor_operator_registry
.
get
(
'__shape__'
)(
self
)
if
not
out
:
return
1
return
out
[
0
]
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
dtype
()
==
mstype
.
type_none
:
if
self
.
dtype
()
==
mstype
.
type_none
:
return
"Unknown Tensor type!"
return
"Unknown Tensor type!"
...
...
mindspore/ops/functional.py
浏览文件 @
9785178b
...
@@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul")
...
@@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul")
stop_gradient
=
Primitive
(
"stop_gradient"
)
stop_gradient
=
Primitive
(
"stop_gradient"
)
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__sub__'
,
tensor_sub
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__div__'
,
tensor_div
)
tensor_operator_registry
.
register
(
'__div__'
,
tensor_div
)
#ms cannot support Tensor(True) compare
#ms cannot support Tensor(True) compare
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tensor_operator_registry
.
register
(
'__ne__'
,
not_equal
)
tensor_operator_registry
.
register
(
'__neg__'
,
neg_tensor
)
tensor_operator_registry
.
register
(
'__lt__'
,
tensor_lt
)
tensor_operator_registry
.
register
(
'__le__'
,
tensor_le
)
tensor_operator_registry
.
register
(
'__gt__'
,
tensor_gt
)
tensor_operator_registry
.
register
(
'__ge__'
,
tensor_ge
)
tensor_operator_registry
.
register
(
'__shape__'
,
shape
)
mindspore/ops/primitive.py
浏览文件 @
9785178b
...
@@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None):
...
@@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def
__init__
(
self
):
def
__init__
(
self
):
op_name
=
name
if
name
else
fn
.
__name__
op_name
=
name
if
name
else
fn
.
__name__
PrimitiveWithInfer
.
__init__
(
self
,
op_name
)
PrimitiveWithInfer
.
__init__
(
self
,
op_name
)
self
.
const_value
=
True
def
infer_value
(
self
,
*
args
):
def
infer_value
(
self
,
*
args
):
return
fn
(
*
args
)
return
fn
(
*
args
)
...
...
tests/ut/python/pynative_mode/test_parse_method.py
浏览文件 @
9785178b
...
@@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len
...
@@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len
from
mindspore.common.api
import
ms_function
from
mindspore.common.api
import
ms_function
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops.composite
import
core
from
mindspore.ops.composite
import
core
from
mindspore.ops.primitive
import
constexpr
from
..ut_filter
import
non_graph_engine
from
..ut_filter
import
non_graph_engine
...
@@ -417,3 +418,11 @@ def test_range():
...
@@ -417,3 +418,11 @@ def test_range():
""" test_range """
""" test_range """
res
=
range_spec
(
10
,
10
)
res
=
range_spec
(
10
,
10
)
return
res
return
res
def
test_expr
():
""" test const expr """
a
=
(
1
,
2
)
@
constexpr
def
tuple_len
(
x
):
assert
len
(
x
)
==
2
tuple_len
(
a
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录