Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e64c755a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e64c755a
编写于
4月 24, 2020
作者:
K
kpy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change tensor equal bug
上级
00859ae1
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
19 addition
and
18 deletion
+19
-18
mindspore/ccsrc/ir/meta_tensor.cc
mindspore/ccsrc/ir/meta_tensor.cc
+0
-9
mindspore/ccsrc/ir/meta_tensor.h
mindspore/ccsrc/ir/meta_tensor.h
+0
-3
mindspore/common/tensor.py
mindspore/common/tensor.py
+11
-0
mindspore/ops/functional.py
mindspore/ops/functional.py
+2
-0
tests/vm_impl/math_ops_vm_impl.py
tests/vm_impl/math_ops_vm_impl.py
+6
-6
未找到文件。
mindspore/ccsrc/ir/meta_tensor.cc
浏览文件 @
e64c755a
...
...
@@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const {
return
(
MetaTensor
::
operator
==
(
tensor
)
&&
data_
==
tensor
.
data_
);
}
bool
Tensor
::
ValueEqualPy
(
const
py
::
object
&
other
)
const
{
if
(
!
py
::
isinstance
<
Tensor
>
(
other
))
{
MS_LOG
(
WARNING
)
<<
"compare other not a tensor"
;
return
false
;
}
return
ValueEqual
(
py
::
cast
<
Tensor
>
(
other
));
}
bool
Tensor
::
ValueEqual
(
const
Tensor
&
other
)
const
{
auto
equal
=
[
&
other
,
this
]()
->
bool
{
auto
np
=
py
::
module
::
import
(
"numpy"
);
...
...
@@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
)mydelimiter"
)
.
def
(
"__str__"
,
&
Tensor
::
ToString
)
.
def
(
"__repr__"
,
&
Tensor
::
ToStringRepr
)
.
def
(
"__eq__"
,
&
Tensor
::
ValueEqualPy
)
.
def
(
py
::
pickle
(
[](
const
Tensor
&
t
)
{
// __getstate__
/* Return a tuple that fully encodes the state of the object */
...
...
mindspore/ccsrc/ir/meta_tensor.h
浏览文件 @
e64c755a
...
...
@@ -329,9 +329,6 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool
ValueEqual
(
const
Tensor
&
other
)
const
;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool
ValueEqualPy
(
const
py
::
object
&
other
)
const
;
bool
operator
==
(
const
Value
&
other
)
const
override
{
if
(
other
.
isa
<
Tensor
>
())
{
auto
other_
=
static_cast
<
const
Tensor
&>
(
other
);
...
...
mindspore/common/tensor.py
浏览文件 @
e64c755a
...
...
@@ -74,6 +74,17 @@ class Tensor(Tensor_):
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
return
out
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
return
False
x
=
self
.
asnumpy
()
y
=
other
.
asnumpy
()
out
=
np
.
equal
(
x
,
y
)
return
Tensor
(
np
.
array
(
out
))
def
__hash__
(
self
):
return
hash
(
id
(
self
))
def
__mul__
(
self
,
other
):
check_type
(
'tensor input_data'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
self
,
other
)
...
...
mindspore/ops/functional.py
浏览文件 @
e64c755a
...
...
@@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__div__'
,
tensor_div
)
#ms cannot support Tensor(True) compare
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tests/vm_impl/math_ops_vm_impl.py
浏览文件 @
e64c755a
...
...
@@ -172,7 +172,7 @@ def vm_impl_equal(self):
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
out
=
vm
.
equal
(
x
,
y
)
return
Tensor
(
out
)
return
Tensor
(
np
.
array
(
out
)
)
return
vm_impl
...
...
@@ -183,7 +183,7 @@ def vm_impl_not_equal(self):
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
out
=
vm
.
not_equal
(
x
,
y
)
return
Tensor
(
out
)
return
Tensor
(
np
.
array
(
out
)
)
return
vm_impl
...
...
@@ -194,7 +194,7 @@ def vm_impl_greater(self):
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
out
=
vm
.
greater
(
x
,
y
)
return
Tensor
(
out
)
return
Tensor
(
np
.
array
(
out
)
)
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
Maximum
)
...
...
@@ -219,17 +219,17 @@ def vm_impl_minimum(self):
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
Less
)
def
vm_impl_
greater
(
self
):
def
vm_impl_
less
(
self
):
"""Generate vm_impl function for Less"""
def
vm_impl
(
x
,
y
):
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
out
=
vm
.
less
(
x
,
y
)
return
Tensor
(
out
)
return
Tensor
(
np
.
array
(
out
)
)
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
ScalarCast
)
def
vm_impl_
greater
(
self
):
def
vm_impl_
scalar_cast
(
self
):
"""Generate vm_impl function for ScalarCast"""
def
vm_impl
(
x
,
t
):
np_type
=
dtype_to_nptype
(
t
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录