Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
078738ad
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
078738ad
编写于
6月 23, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tensor mod & floordiv operation
上级
c22c865c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
30 addition
and
21 deletion
+30
-21
mindspore/common/api.py
mindspore/common/api.py
+3
-1
mindspore/common/tensor.py
mindspore/common/tensor.py
+15
-6
mindspore/ops/functional.py
mindspore/ops/functional.py
+2
-0
tests/ut/python/ir/test_tensor.py
tests/ut/python/ir/test_tensor.py
+9
-1
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
+1
-13
未找到文件。
mindspore/common/api.py
浏览文件 @
078738ad
...
...
@@ -158,7 +158,9 @@ class _MindSporeFunction:
# replace key with obj info and object ext info when fn is a method
if
self
.
obj
is
not
None
:
self
.
obj
.
__parse_method__
=
method_name
generate_name
=
self
.
obj
.
__module__
+
"."
+
str
(
self
.
obj
.
create_time
)
generate_name
=
self
.
obj
.
__module__
+
"."
if
self
.
obj
.
__class__
.
__name__
!=
"ClipByNorm"
:
generate_name
=
generate_name
+
str
(
self
.
obj
.
create_time
)
if
self
.
identify_obj
is
not
None
:
generate_name
=
generate_name
+
str
(
id
(
self
.
identify_obj
))
...
...
mindspore/common/tensor.py
浏览文件 @
078738ad
...
...
@@ -102,16 +102,14 @@ class Tensor(Tensor_):
return
out
def
__iadd__
(
self
,
other
):
out
=
self
.
__add__
(
other
)
return
out
return
self
.
__add__
(
other
)
def
__radd__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
return
out
def
__imul__
(
self
,
other
):
out
=
self
.
__mul__
(
other
)
return
out
return
self
.
__mul__
(
other
)
def
__rmul__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
self
,
other
)
...
...
@@ -130,8 +128,7 @@ class Tensor(Tensor_):
return
out
def
__isub__
(
self
,
other
):
out
=
self
.
__sub__
(
other
)
return
out
return
self
.
__sub__
(
other
)
def
__rsub__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__sub__'
)(
other
,
self
)
...
...
@@ -168,6 +165,18 @@ class Tensor(Tensor_):
return
1
return
out
[
0
]
def
__mod__
(
self
,
other
):
return
tensor_operator_registry
.
get
(
'__mod__'
)(
self
,
other
)
def
__imod__
(
self
,
other
):
return
self
.
__mod__
(
other
)
def
__floordiv__
(
self
,
other
):
return
tensor_operator_registry
.
get
(
'__floordiv__'
)(
self
,
other
)
def
__ifloordiv__
(
self
,
other
):
return
self
.
__floordiv__
(
other
)
def
__str__
(
self
):
if
self
.
dtype
==
mstype
.
type_none
:
return
"Unknown Tensor type!"
...
...
mindspore/ops/functional.py
浏览文件 @
078738ad
...
...
@@ -157,6 +157,8 @@ tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry
.
register
(
'__sub__'
,
tensor_sub
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__truediv__'
,
tensor_div
)
tensor_operator_registry
.
register
(
'__mod__'
,
tensor_mod
)
tensor_operator_registry
.
register
(
'__floordiv__'
,
tensor_floordiv
)
#ms cannot support Tensor(True) compare
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tensor_operator_registry
.
register
(
'__ne__'
,
not_equal
)
...
...
tests/ut/python/ir/test_tensor.py
浏览文件 @
078738ad
...
...
@@ -24,13 +24,15 @@ import pytest
import
mindspore
as
ms
import
mindspore.common.api
as
me
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
,
context
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
..ut_filter
import
non_graph_engine
ndarr
=
np
.
ones
((
2
,
3
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
test_tensor_flatten
():
with
pytest
.
raises
(
AttributeError
):
...
...
@@ -452,5 +454,11 @@ def test_tensor_operation():
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
res
=
8
/
x
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
res
=
x
%
3
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
)))
res
=
x
//
3
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
)))
x
%=
3
assert
np
.
all
(
x
.
asnumpy
()
==
np
.
ones
((
3
,
3
)))
with
pytest
.
raises
(
ValueError
):
res
=
x
*
(
2
,
3
)
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py
浏览文件 @
078738ad
...
...
@@ -18,8 +18,7 @@ import pytest
import
mindspore.nn
as
nn
from
mindspore.common.tensor
import
Tensor
from
mindspore.nn
import
WithGradCell
,
WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn
import
WithGradCell
from
mindspore.ops
import
operations
as
P
...
...
@@ -63,17 +62,6 @@ def test_lenet_pynative_train_net():
loss_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
)
grad_fn
=
nn
.
SoftmaxCrossEntropyWithLogits
()
grad_net
=
WithGradCell
(
net
,
grad_fn
,
sens
=
dout
)
gradients
=
grad_net
(
data
,
label
)
# update parameters
opt
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
opt
(
gradients
)
# verification
if
i
==
verification_step
:
loss_net
=
WithLossCell
(
net
,
loss_fn
)
loss_output
=
loss_net
(
data
,
label
)
print
(
"The loss of %s-th iteration is %s"
%
(
i
,
loss_output
.
asnumpy
()))
def
test_lenet_pynative_train_model
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录