Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8dec7490
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8dec7490
编写于
8月 05, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 05, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4000 improve interface '__bool__' for tensor
Merge pull request !4000 from zhangbuxue/improve_bool_for_tensor
上级
a722a0da
ace34525
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
20 addition
and
26 deletion
+20
-26
mindspore/_extends/parse/standard_method.py
mindspore/_extends/parse/standard_method.py
+4
-10
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
+10
-10
mindspore/common/tensor.py
mindspore/common/tensor.py
+6
-2
tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py
.../ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py
+0
-4
未找到文件。
mindspore/_extends/parse/standard_method.py
浏览文件 @
8dec7490
...
...
@@ -17,7 +17,6 @@
"""standard_method"""
from
dataclasses
import
dataclass
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common._register_for_tensor
import
tensor_operator_registry
from
...ops
import
functional
as
F
from
...ops
import
operations
as
P
from
...ops.primitive
import
constexpr
...
...
@@ -206,13 +205,11 @@ def const_tensor_to_bool(x):
if
x
is
None
:
raise
ValueError
(
"Only constant tensor bool can be converted to bool"
)
x
=
x
.
asnumpy
()
if
x
.
shape
not
in
((),
(
1
,)):
raise
ValueError
(
"The truth value of an array with several elements is ambiguous."
)
if
x
.
shape
==
():
value
=
bool
(
x
)
else
:
value
=
bool
(
x
[
0
])
r
eturn
value
return
bool
(
x
)
if
x
.
shape
==
(
1
,)
:
return
bool
(
x
[
0
])
r
aise
ValueError
(
"The truth value of an array with several elements is ambiguous."
)
def
tensor_bool
(
x
):
...
...
@@ -349,6 +346,3 @@ def list_append(self_, item):
def
to_array
(
x
):
"""Implementation of `to_array`."""
return
x
.
__ms_to_array__
()
tensor_operator_registry
.
register
(
'__bool__'
,
tensor_bool
)
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
浏览文件 @
8dec7490
...
...
@@ -73,7 +73,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
namespace
{
bool
ConvertTuple
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
,
bool
use_signature
)
{
MS_LOG
(
DEBUG
)
<<
"Converting python tuple"
;
py
::
tuple
tuple
=
obj
.
cast
<
py
::
tuple
>
();
auto
tuple
=
obj
.
cast
<
py
::
tuple
>
();
std
::
vector
<
ValuePtr
>
value_list
;
for
(
size_t
it
=
0
;
it
<
tuple
.
size
();
++
it
)
{
ValuePtr
out
=
nullptr
;
...
...
@@ -91,7 +91,7 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur
bool
ConvertList
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
,
bool
use_signature
)
{
MS_LOG
(
DEBUG
)
<<
"Converting python list"
;
py
::
list
list
=
obj
.
cast
<
py
::
list
>
();
auto
list
=
obj
.
cast
<
py
::
list
>
();
std
::
vector
<
ValuePtr
>
value_list
;
for
(
size_t
it
=
0
;
it
<
list
.
size
();
++
it
)
{
ValuePtr
out
=
nullptr
;
...
...
@@ -124,7 +124,7 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa
bool
ConvertDict
(
const
py
::
object
&
obj
,
ValuePtr
*
data
,
bool
use_signature
)
{
MS_LOG
(
DEBUG
)
<<
"Converting python dict"
;
py
::
dict
dict_values
=
obj
.
cast
<
py
::
dict
>
();
auto
dict_values
=
obj
.
cast
<
py
::
dict
>
();
std
::
vector
<
std
::
pair
<
std
::
string
,
ValuePtr
>>
key_values
;
for
(
auto
item
:
dict_values
)
{
if
(
!
py
::
isinstance
<
py
::
str
>
(
item
.
first
))
{
...
...
@@ -208,7 +208,7 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
bool
ConvertSlice
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
)
{
MS_LOG
(
DEBUG
)
<<
"Converting slice object"
;
py
::
slice
slice_obj
=
obj
.
cast
<
py
::
slice
>
();
auto
slice_obj
=
obj
.
cast
<
py
::
slice
>
();
auto
convert_func
=
[
obj
](
std
::
string
attr
)
->
ValuePtr
{
auto
py_attr
=
py
::
getattr
(
obj
,
attr
.
c_str
());
if
(
py
::
isinstance
<
py
::
none
>
(
py_attr
))
{
...
...
@@ -335,7 +335,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
}
else
if
(
py
::
isinstance
<
MetaTensor
>
(
obj
))
{
converted
=
obj
.
cast
<
MetaTensorPtr
>
();
}
else
if
(
py
::
isinstance
<
EnvInstance
>
(
obj
))
{
std
::
shared_ptr
<
EnvInstance
>
env
=
obj
.
cast
<
std
::
shared_ptr
<
EnvInstance
>>
();
auto
env
=
obj
.
cast
<
std
::
shared_ptr
<
EnvInstance
>>
();
converted
=
env
;
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_CLASS_MEMBER_NAMESPACE
))
{
converted
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
obj
);
...
...
@@ -374,7 +374,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
data_converter
::
MakeProperNameToFuncGraph
(
func_graph
,
obj_id
);
data_converter
::
CacheObjectValue
(
obj_id
,
func_graph
);
if
(
obj_key
!=
""
)
{
if
(
!
obj_key
.
empty
()
)
{
MS_LOG
(
DEBUG
)
<<
"Add graph:"
<<
obj_key
<<
", func_graph:"
<<
func_graph
->
ToString
();
data_converter
::
SetObjGraphValue
(
obj_key
,
func_graph
);
}
...
...
@@ -440,7 +440,7 @@ bool IsCellInstance(const py::object &obj) {
py
::
object
CreatePythonObject
(
const
py
::
object
&
type
,
const
py
::
tuple
&
params
)
{
py
::
module
mod
=
python_adapter
::
GetPyModule
(
PYTHON_MOD_PARSE_MODULE
);
py
::
object
obj
;
if
(
params
.
size
()
==
0
)
{
if
(
params
.
empty
()
)
{
obj
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_CREATE_OBJ_INSTANCE
,
type
);
}
else
{
obj
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_CREATE_OBJ_INSTANCE
,
type
,
params
);
...
...
@@ -499,7 +499,7 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
ClassAttrVector
attributes
;
py
::
dict
names
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_GET_DATACLASS_ATTRS
,
cls_obj
);
for
(
auto
&
item
:
names
)
{
TypePtr
type_value
=
item
.
second
.
cast
<
TypePtr
>
();
auto
type_value
=
item
.
second
.
cast
<
TypePtr
>
();
MS_EXCEPTION_IF_NULL
(
type_value
);
MS_LOG
(
DEBUG
)
<<
"(Name: "
<<
py
::
cast
<
std
::
string
>
(
item
.
first
)
<<
", type: "
<<
type_value
->
ToString
()
<<
")"
;
attributes
.
push_back
(
std
::
make_pair
(
py
::
cast
<
std
::
string
>
(
item
.
first
),
type_value
));
...
...
@@ -508,8 +508,8 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
methods_map
;
py
::
dict
methods
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_GET_DATACLASS_METHODS
,
cls_obj
);
for
(
auto
&
item
:
methods
)
{
std
::
string
fun_name
=
item
.
first
.
cast
<
std
::
string
>
();
py
::
object
obj
=
py
::
cast
<
py
::
object
>
(
item
.
second
);
auto
fun_name
=
item
.
first
.
cast
<
std
::
string
>
();
auto
obj
=
py
::
cast
<
py
::
object
>
(
item
.
second
);
std
::
shared_ptr
<
PyObjectWrapper
>
method_obj
=
std
::
make_shared
<
PyObjectWrapper
>
(
obj
,
fun_name
);
methods_map
[
fun_name
]
=
method_obj
;
}
...
...
mindspore/common/tensor.py
浏览文件 @
8dec7490
...
...
@@ -108,8 +108,12 @@ class Tensor(Tensor_):
return
out
def
__bool__
(
self
):
out
=
tensor_operator_registry
.
get
(
'__bool__'
)(
self
)
return
out
data
=
self
.
asnumpy
()
if
data
.
shape
==
():
return
bool
(
data
)
if
data
.
shape
==
(
1
,):
return
bool
(
data
[
0
])
raise
ValueError
(
"The truth value of an array with several elements is ambiguous."
)
def
__pos__
(
self
):
return
self
...
...
tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py
浏览文件 @
8dec7490
...
...
@@ -35,7 +35,6 @@ def test_dtype_and_shape_as_attr():
dtype
=
x
.
dtype
return
shape
,
dtype
net
=
Net
()
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
ret
=
net
(
x
)
...
...
@@ -55,7 +54,6 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
y
=
self
.
fill
(
dtype
,
shape
,
self
.
value
)
return
y
net
=
Net
(
2.2
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
float32
))
ret
=
net
(
x
)
...
...
@@ -71,7 +69,6 @@ def test_type_not_have_the_attr():
shape
=
x
.
shapes
return
shape
net
=
Net
()
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
with
pytest
.
raises
(
RuntimeError
)
as
ex
:
...
...
@@ -88,7 +85,6 @@ def test_type_not_have_the_method():
shape
=
x
.
dtypes
()
return
shape
net
=
Net
()
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
with
pytest
.
raises
(
RuntimeError
)
as
ex
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录