Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
95212b55
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
95212b55
编写于
8月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3271 make reftype a subtype of MetaTensor and try to mark ref in node input
Merge pull request !3271 from vlne-v1/ref_demo
上级
7b10f21f
24a10225
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
813 addition
and
721 deletion
+813
-721
mindspore/_checkparam.py
mindspore/_checkparam.py
+15
-173
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
+0
-5
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
+42
-50
mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc
.../ccsrc/frontend/operator/composite/multitype_funcgraph.cc
+1
-7
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
...spore/ccsrc/frontend/operator/ops_front_infer_function.cc
+12
-0
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
+2
-2
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+3
-8
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+8
-7
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
+0
-6
mindspore/ccsrc/pipeline/jit/parse/function_block.cc
mindspore/ccsrc/pipeline/jit/parse/function_block.cc
+10
-12
mindspore/ccsrc/pipeline/jit/parse/function_block.h
mindspore/ccsrc/pipeline/jit/parse/function_block.h
+0
-3
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+2
-3
mindspore/ccsrc/pipeline/jit/parse/parse.h
mindspore/ccsrc/pipeline/jit/parse/parse.h
+1
-1
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
+4
-16
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+14
-11
mindspore/ccsrc/pybind_api/ir/dtype_py.cc
mindspore/ccsrc/pybind_api/ir/dtype_py.cc
+1
-1
mindspore/ccsrc/pybind_api/ir/param_info_py.cc
mindspore/ccsrc/pybind_api/ir/param_info_py.cc
+2
-2
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
+1
-0
mindspore/common/parameter.py
mindspore/common/parameter.py
+12
-12
mindspore/core/abstract/abstract_value.cc
mindspore/core/abstract/abstract_value.cc
+12
-36
mindspore/core/abstract/abstract_value.h
mindspore/core/abstract/abstract_value.h
+19
-18
mindspore/core/abstract/prim_others.cc
mindspore/core/abstract/prim_others.cc
+2
-11
mindspore/core/ir/anf.cc
mindspore/core/ir/anf.cc
+11
-0
mindspore/core/ir/anf.h
mindspore/core/ir/anf.h
+2
-1
mindspore/core/ir/dtype.cc
mindspore/core/ir/dtype.cc
+4
-169
mindspore/core/ir/dtype.h
mindspore/core/ir/dtype.h
+5
-93
mindspore/core/ir/dtype/number.h
mindspore/core/ir/dtype/number.h
+2
-0
mindspore/core/ir/dtype/ref.cc
mindspore/core/ir/dtype/ref.cc
+4
-4
mindspore/core/ir/dtype/ref.h
mindspore/core/ir/dtype/ref.h
+6
-21
mindspore/core/ir/dtype/tensor_type.cc
mindspore/core/ir/dtype/tensor_type.cc
+194
-0
mindspore/core/ir/dtype/tensor_type.h
mindspore/core/ir/dtype/tensor_type.h
+132
-0
mindspore/core/ir/func_graph.h
mindspore/core/ir/func_graph.h
+0
-3
mindspore/core/ir/meta_tensor.h
mindspore/core/ir/meta_tensor.h
+13
-0
mindspore/core/ir/meta_tensor_extends.cc
mindspore/core/ir/meta_tensor_extends.cc
+10
-1
mindspore/core/ir/named.h
mindspore/core/ir/named.h
+15
-0
mindspore/core/ir/param_info.h
mindspore/core/ir/param_info.h
+6
-3
mindspore/core/ir/tensor.cc
mindspore/core/ir/tensor.cc
+10
-1
mindspore/core/ir/value.cc
mindspore/core/ir/value.cc
+0
-10
mindspore/core/ir/value.h
mindspore/core/ir/value.h
+6
-13
mindspore/lite/test/CMakeLists.txt
mindspore/lite/test/CMakeLists.txt
+2
-0
mindspore/lite/tools/converter/CMakeLists.txt
mindspore/lite/tools/converter/CMakeLists.txt
+2
-0
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+1
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+6
-11
tests/st/control/test_ascend_control_sink.py
tests/st/control/test_ascend_control_sink.py
+10
-5
tests/ut/python/pipeline/parse/test_parse.py
tests/ut/python/pipeline/parse/test_parse.py
+58
-0
tests/ut/python/pipeline/parse/test_while_param.py
tests/ut/python/pipeline/parse/test_while_param.py
+144
-0
tests/vm_impl/array_ops_vm_impl.py
tests/vm_impl/array_ops_vm_impl.py
+7
-1
未找到文件。
mindspore/_checkparam.py
浏览文件 @
95212b55
...
@@ -185,14 +185,23 @@ class Validator:
...
@@ -185,14 +185,23 @@ class Validator:
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
@
staticmethod
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
prim_name
):
def
check_subclass
(
arg_name
,
type_
,
template_type
s
,
prim_name
):
"""Checks whether some type is subclass of another type"""
"""Checks whether some type is subclass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
if
not
isinstance
(
template_types
,
Iterable
):
template_type
=
(
template_type
,)
template_types
=
(
template_types
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
hit
=
False
for
template_type
in
template_types
:
if
isinstance
(
template_type
,
mstype
.
Type
):
if
mstype
.
issubclass_
(
type_
,
template_type
):
hit
=
True
break
elif
type_
is
template_type
:
hit
=
True
break
if
not
hit
:
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be subclass'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
s
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
,
prim_name
):
def
check_const_input
(
arg_name
,
arg_value
,
prim_name
):
...
@@ -206,13 +215,7 @@ class Validator:
...
@@ -206,13 +215,7 @@ class Validator:
def
_check_tensor_type
(
arg
):
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
arg_key
,
arg_val
=
arg
elem_type
=
arg_val
elem_type
=
arg_val
if
not
elem_type
in
valid_values
:
Validator
.
check_subclass
(
arg_key
,
elem_type
,
valid_values
,
prim_name
)
type_names
=
[]
for
t
in
valid_values
:
type_names
.
append
(
str
(
t
))
types_info
=
'['
+
', '
.
join
(
type_names
)
+
']'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
types_info
}
,'
f
' but got
{
elem_type
}
.'
)
return
(
arg_key
,
elem_type
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
def
_check_types_same
(
arg1
,
arg2
):
...
@@ -335,12 +338,6 @@ class Validator:
...
@@ -335,12 +338,6 @@ class Validator:
class
ParamValidator
:
class
ParamValidator
:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@
staticmethod
def
equal
(
arg_name
,
arg_value
,
cond_str
,
cond
):
"""Judging valid value."""
if
not
cond
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be
{
cond_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
@
staticmethod
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
):
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
):
"""This method is only used for check int values, since when compare float values,
"""This method is only used for check int values, since when compare float values,
...
@@ -360,27 +357,6 @@ class ParamValidator:
...
@@ -360,27 +357,6 @@ class ParamValidator:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
return
arg_value
@
staticmethod
def
check_shape_length
(
arg_name
,
arg_value
,
value
,
rel
):
"""Shape length judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The length of `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
'
)
return
arg_value
@
staticmethod
def
check_int_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
@
staticmethod
def
check_isinstance
(
arg_name
,
arg_value
,
classes
):
def
check_isinstance
(
arg_name
,
arg_value
,
classes
):
"""Check arg isinstance of classes"""
"""Check arg isinstance of classes"""
...
@@ -388,33 +364,6 @@ class ParamValidator:
...
@@ -388,33 +364,6 @@ class ParamValidator:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of
{
classes
}
, but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of
{
classes
}
, but got
{
arg_value
}
.'
)
return
arg_value
return
arg_value
@
staticmethod
def
check_number_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""Is it necessary to consider error when comparing float values."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
with_type_of
=
True
):
"""Check whether some type is subclass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'The
{
"type of"
if
with_type_of
else
""
}
`
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_args_tensor
(
args
):
"""Check whether args are all tensor."""
if
not
isinstance
(
args
,
dict
):
raise
TypeError
(
"The args should be a dict."
)
for
arg
,
value
in
args
.
items
():
ParamValidator
.
check_subclass
(
arg
,
value
,
mstype
.
tensor
)
@
staticmethod
@
staticmethod
def
check_bool
(
arg_name
,
arg_value
):
def
check_bool
(
arg_name
,
arg_value
):
"""Check arg isinstance of bool"""
"""Check arg isinstance of bool"""
...
@@ -442,113 +391,6 @@ class ParamValidator:
...
@@ -442,113 +391,6 @@ class ParamValidator:
return
arg_value
return
arg_value
raise_error_msg
()
raise_error_msg
()
@
staticmethod
def
check_typename
(
arg_name
,
arg_type
,
valid_types
):
"""Does it contain the _name_ attribute."""
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
if
isinstance
(
arg_type
,
type
(
mstype
.
tensor
)):
arg_type
=
arg_type
.
element_type
()
if
arg_type
in
valid_types
:
return
arg_type
type_names
=
[
get_typename
(
t
)
for
t
in
valid_types
]
if
len
(
valid_types
)
==
1
:
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be
{
type_names
[
0
]
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be one of
{
type_names
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
@
staticmethod
def
check_string
(
arg_name
,
arg_value
,
valid_values
):
"""String type judgment."""
if
isinstance
(
arg_value
,
str
)
and
arg_value
in
valid_values
:
return
arg_value
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be one of
{
valid_values
}
,'
f
' but got
{
arg_value
}
.'
)
@
staticmethod
def
check_type_same
(
args
,
valid_values
):
"""Determine whether the types are the same."""
name
=
list
(
args
.
keys
())[
0
]
value
=
list
(
args
.
values
())[
0
]
if
isinstance
(
value
,
type
(
mstype
.
tensor
)):
value
=
value
.
element_type
()
for
arg_name
,
arg_value
in
args
.
items
():
if
isinstance
(
arg_value
,
type
(
mstype
.
tensor
)):
arg_value
=
arg_value
.
element_type
()
if
arg_value
not
in
valid_values
:
raise
TypeError
(
f
'The `
{
arg_name
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
.'
)
if
arg_value
!=
value
:
raise
TypeError
(
f
'`
{
arg_name
}
` should be same as `
{
name
}
`,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
, `
{
name
}
` is
{
value
}
.'
)
@
staticmethod
def
check_two_types_same
(
arg1_name
,
arg1_type
,
arg2_name
,
arg2_type
):
"""Determine whether the types of two variables are the same."""
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'The type of `
{
arg1_name
}
` and `
{
arg2_name
}
` should be same.'
)
@
staticmethod
def
check_value_on_integer
(
arg_name
,
arg_value
,
value
,
rel
):
"""Judging integer type."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_match
=
isinstance
(
arg_value
,
int
)
if
type_match
and
(
not
rel_fn
(
arg_value
,
value
)):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_param_equal
(
param1_name
,
param1_value
,
param2_name
,
param2_value
):
"""Judging the equality of parameters."""
if
param1_value
!=
param2_value
:
raise
ValueError
(
f
"`
{
param1_name
}
` must equal `
{
param2_name
}
`,"
f
" but got `
{
param1_name
}
` =
{
param1_value
}
,"
f
" `
{
param2_name
}
` =
{
param2_value
}
."
)
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
):
"""Check valid value."""
if
arg_value
is
None
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be a const input, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_float_positive
(
arg_name
,
arg_value
):
"""Float type judgment."""
if
isinstance
(
arg_value
,
float
):
if
arg_value
>
0
:
return
arg_value
raise
ValueError
(
f
"The `
{
arg_name
}
` must be positive, but got
{
arg_value
}
."
)
raise
TypeError
(
f
"`
{
arg_name
}
` must be float!"
)
@
staticmethod
def
check_pad_value_by_mode
(
op_name
,
pad_mode
,
padding
):
"""Validate value of padding according to pad_mode"""
if
pad_mode
!=
'pad'
and
padding
!=
0
:
raise
ValueError
(
f
"For op '
{
op_name
}
', padding must be zero when pad_mode is '
{
pad_mode
}
'."
)
return
padding
@
staticmethod
def
check_empty_shape_input
(
arg_name
,
arg_value
):
"""Check zeros value."""
if
0
in
arg_value
:
raise
ValueError
(
f
"Input `
{
arg_name
}
` cannot be empty."
)
@
staticmethod
def
check_scalar_shape_input
(
arg_name
,
arg_value
):
"""Check scalar shape input."""
if
arg_value
!=
[]:
raise
ValueError
(
f
"Input `
{
arg_name
}
` shape should be (). got
{
arg_value
}
"
)
def
check_int
(
input_param
):
def
check_int
(
input_param
):
"""Int type judgment."""
"""Int type judgment."""
...
...
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
浏览文件 @
95212b55
...
@@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
...
@@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
return
get_single_type
((
*
tuple_ptr
)[
output_idx
]);
return
get_single_type
((
*
tuple_ptr
)[
output_idx
]);
};
};
TypePtr
type_ptr
=
node
->
Type
();
TypePtr
type_ptr
=
node
->
Type
();
if
(
type_ptr
->
isa
<
RefType
>
())
{
auto
ref_type_ptr
=
type_ptr
->
cast
<
RefTypePtr
>
();
MS_EXCEPTION_IF_NULL
(
ref_type_ptr
);
return
get_tuple_type
(
ref_type_ptr
->
subtype
(),
output_idx
);
}
return
get_tuple_type
(
type_ptr
,
output_idx
);
return
get_tuple_type
(
type_ptr
,
output_idx
);
}
}
...
...
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
浏览文件 @
95212b55
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "abstract/abstract_value.h"
#include "abstract/abstract_value.h"
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/dtype.h"
#include "abstract/dshape.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h"
#include "frontend/operator/cc_implementations.h"
...
@@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
...
@@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return
empty
;
return
empty
;
}
}
void
ProcessDefault
(
const
std
::
string
&
func_name
,
const
AbstractBasePtrList
&
args_spec_list
,
void
ProcessDefault
(
const
std
::
string
&
func_name
,
size_t
actual_param_number
,
const
std
::
vector
<
Signature
>
&
signature
,
const
std
::
vector
<
Signature
>
&
signature
,
bool
has_var
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
)
{
bool
has_var
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
)
{
std
::
size_t
sig_size
=
signature
.
size
();
std
::
size_t
sig_size
=
signature
.
size
();
auto
positional_size
=
sig_size
;
auto
positional_size
=
sig_size
;
if
(
has_var
)
{
if
(
has_var
)
{
positional_size
=
sig_size
-
1
;
positional_size
=
sig_size
-
1
;
}
}
if
(
a
rgs_spec_list
.
size
()
<
positional_size
)
{
if
(
a
ctual_param_number
<
positional_size
)
{
for
(
size_t
i
=
a
rgs_spec_list
.
size
()
;
i
<
sig_size
;
++
i
)
{
for
(
size_t
i
=
a
ctual_param_number
;
i
<
sig_size
;
++
i
)
{
auto
default_value
=
signature
[
i
].
default_value
;
auto
default_value
=
signature
[
i
].
default_value
;
if
(
default_value
==
nullptr
)
{
if
(
default_value
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Function "
<<
func_name
<<
"'s input length is not equal to Signature length."
;
MS_LOG
(
EXCEPTION
)
<<
"Function "
<<
func_name
<<
"'s input length is not equal to Signature length."
;
...
@@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
...
@@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
*
max_type_number
=
type_number
;
*
max_type_number
=
type_number
;
}
}
bool
GetTensorOrScalarTypeInfo
(
AbstractBasePtr
arg_value
,
bool
is_write
,
TypeId
*
arg_type_id
,
bool
GetTensorOrScalarTypeInfo
(
TypePtr
arg_type_origin
,
bool
is_write
,
TypeId
*
arg_type_id
,
TypeId
*
arg_type
=
nullptr
)
{
TypeId
*
arg_type
=
nullptr
)
{
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
if
(
arg_type_origin
->
isa
<
TensorType
>
())
{
auto
ref
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
();
auto
tensor
=
arg_type_origin
->
cast
<
TensorTypePtr
>
();
arg_value
=
ref
->
ref
();
auto
tensor_type
=
tensor
->
element
();
if
(
!
is_write
&&
ref
->
need_cast
())
{
auto
tensor_type
=
ref
->
target_type
();
*
arg_type_id
=
tensor_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
*
arg_type
=
kObjectTypeTensorType
;
}
return
true
;
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
tensor
=
arg_value
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_type
=
tensor
->
element
()
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
tensor_type
);
MS_EXCEPTION_IF_NULL
(
tensor_type
);
*
arg_type_id
=
tensor_type
->
type_id
();
*
arg_type_id
=
tensor_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
if
(
arg_type
!=
nullptr
)
{
...
@@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
...
@@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
}
}
return
true
;
return
true
;
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractScalar
>
())
{
if
(
arg_type_origin
->
isa
<
Number
>
())
{
auto
scalar
=
arg_value
->
cast
<
abstract
::
AbstractScalarPtr
>
();
auto
scalar_type
=
arg_type_origin
->
cast
<
NumberPtr
>
();
auto
scalar_type
=
scalar
->
BuildType
();
MS_EXCEPTION_IF_NULL
(
scalar_type
);
MS_EXCEPTION_IF_NULL
(
scalar_type
);
*
arg_type_id
=
scalar_type
->
type_id
();
*
arg_type_id
=
scalar_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
if
(
arg_type
!=
nullptr
)
{
...
@@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
...
@@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
return
false
;
return
false
;
}
}
TypeId
GetMaxTypeId
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
std
::
vector
<
size_t
>
indices
,
TypeId
GetMaxTypeId
(
const
std
::
vector
<
TypePtr
>
&
input_types
,
std
::
vector
<
size_t
>
indices
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
const
std
::
set
<
size_t
>
&
write_indices
)
{
TypeId
max_type_id
=
kTypeUnknown
;
TypeId
max_type_id
=
kTypeUnknown
;
size_t
max_type_number
=
0
;
size_t
max_type_number
=
0
;
...
@@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
...
@@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId
arg_type_id
=
kTypeUnknown
;
TypeId
arg_type_id
=
kTypeUnknown
;
TypeId
arg_type
=
kTypeUnknown
;
TypeId
arg_type
=
kTypeUnknown
;
auto
is_write
=
(
write_indices
.
find
(
index
)
!=
write_indices
.
end
());
auto
is_write
=
(
write_indices
.
find
(
index
)
!=
write_indices
.
end
());
if
(
!
GetTensorOrScalarTypeInfo
(
args_spec_list
[
index
],
is_write
,
&
arg_type_id
,
&
arg_type
))
{
if
(
!
GetTensorOrScalarTypeInfo
(
input_types
[
index
],
is_write
,
&
arg_type_id
,
&
arg_type
))
{
continue
;
continue
;
}
}
if
(
arg_type
!=
kObjectTypeTensorType
)
{
if
(
arg_type
!=
kObjectTypeTensorType
)
{
...
@@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
...
@@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
// Get the largest type of index in the same SignatureEnumDType of arguments.
using
MaxTypeMap
=
std
::
map
<
SignatureEnumDType
,
TypeId
>
;
using
MaxTypeMap
=
std
::
map
<
SignatureEnumDType
,
TypeId
>
;
MaxTypeMap
GetMaxDtype
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
,
MaxTypeMap
GetMaxDtype
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
,
const
std
::
vector
<
TypePtr
>
&
input_types
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
const
std
::
set
<
size_t
>
&
write_indices
)
{
// record index for signature.dtypes of the same type
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indices
;
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indices
;
...
@@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
...
@@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
}
}
bool
has_tensor
=
false
;
bool
has_tensor
=
false
;
for
(
const
auto
&
index
:
indices
)
{
for
(
const
auto
&
index
:
indices
)
{
AbstractBasePtr
arg_value
=
args_spec_list
[
index
];
auto
arg_value
=
input_types
[
index
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
if
(
arg_value
->
isa
<
TensorType
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
has_tensor
=
true
;
has_tensor
=
true
;
break
;
break
;
}
}
...
@@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
...
@@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
kTypeUnknown
));
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
kTypeUnknown
));
continue
;
continue
;
}
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
args_spec_list
,
indices
,
write_indices
)));
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
input_types
,
indices
,
write_indices
)));
}
}
return
dst_type
;
return
dst_type
;
}
}
...
@@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
...
@@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
}
}
void
DoAutoCast
(
const
std
::
string
&
func_name
,
const
std
::
vector
<
Signature
>
&
signature
,
void
DoAutoCast
(
const
std
::
string
&
func_name
,
const
std
::
vector
<
Signature
>
&
signature
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
FuncGraphPtr
&
graph
,
const
std
::
vector
<
TypePtr
>
&
input_types
,
const
FuncGraphPtr
&
graph
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
std
::
vector
<
SignatureEnumDType
>
dtypes
;
std
::
vector
<
SignatureEnumDType
>
dtypes
;
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
...
@@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
...
@@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return
;
return
;
}
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
args_spec_list
,
write_indices
);
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
input_types
,
write_indices
);
// Identify which arg requires auto cast
// Identify which arg requires auto cast
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_types
.
size
();
++
i
)
{
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
kTypeUnknown
)
{
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
kTypeUnknown
)
{
continue
;
continue
;
...
@@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
...
@@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
auto
is_write
=
(
rw_it
!=
write_indices
.
end
());
auto
is_write
=
(
rw_it
!=
write_indices
.
end
());
TypeId
arg_type_id
=
kTypeUnknown
;
TypeId
arg_type_id
=
kTypeUnknown
;
AbstractBasePtr
arg_value
=
args_spec_list
[
i
];
auto
arg_value
=
input_types
[
i
];
(
void
)
GetTensorOrScalarTypeInfo
(
arg_value
,
is_write
,
&
arg_type_id
);
(
void
)
GetTensorOrScalarTypeInfo
(
arg_value
,
is_write
,
&
arg_type_id
);
auto
it_map
=
type_name_map
.
find
(
arg_type_id
);
auto
it_map
=
type_name_map
.
find
(
arg_type_id
);
if
(
it_map
==
type_name_map
.
end
())
{
if
(
it_map
==
type_name_map
.
end
())
{
...
@@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
...
@@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
}
}
continue
;
continue
;
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
(
)
&&
arg_type_id
==
it
->
second
)
{
if
(
(
arg_value
->
isa
<
TensorType
>
()
)
&&
arg_type_id
==
it
->
second
)
{
continue
;
continue
;
}
}
MS_LOG
(
DEBUG
)
<<
"do cast for inputs "
<<
i
<<
" "
<<
(
*
op_inputs
)[
i
+
1
]
->
ToString
()
<<
" "
<<
arg_type_id
MS_LOG
(
DEBUG
)
<<
"do cast for inputs "
<<
i
<<
" "
<<
(
*
op_inputs
)[
i
+
1
]
->
ToString
()
<<
" "
<<
arg_type_id
...
@@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
...
@@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
}
std
::
vector
<
AnfNodePtr
>
op_inputs
;
std
::
vector
<
AnfNodePtr
>
op_inputs
;
std
::
set
<
size_t
>
write_indices
;
std
::
set
<
size_t
>
write_indices
;
std
::
vector
<
TypePtr
>
input_types
;
op_inputs
.
push_back
(
NewValueNode
(
function
));
op_inputs
.
push_back
(
NewValueNode
(
function
));
// Assume, the write input of op is always the first input. We check if any write op,
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
// and add cast op on other inputs to keep the same type with assigned parameter.
...
@@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
...
@@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
sig
=
signature
[
sig_size
-
1
].
rw
;
sig
=
signature
[
sig_size
-
1
].
rw
;
}
}
TypePtr
type
=
args_spec_list
[
i
]
->
GetTypeTrack
();
TypePtr
type
=
args_spec_list
[
i
]
->
BuildType
();
if
(
type
&&
type
->
type_id
()
==
kObjectTypeRef
)
{
if
(
type
&&
type
->
isa
<
RefType
>
()
)
{
auto
ref_abs
=
args_spec_list
[
i
]
->
cast
<
abstract
::
AbstractRefPtr
>
(
);
auto
cast_type
=
parse
::
GetMixedPrecisionTargetType
(
func_graph
);
if
(
sig
==
SignatureEnumRW
::
kRWRead
)
{
if
(
sig
==
SignatureEnumRW
::
kRWRead
)
{
param
=
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
},
func_graph
);
auto
source_tensor_type
=
type
->
cast
<
TensorTypePtr
>
();
if
(
ref_abs
&&
ref_abs
->
need_cast
())
{
if
(
source_tensor_type
!=
nullptr
)
{
auto
cast
=
prim
::
GetPythonOps
(
"cast"
,
"mindspore.ops.functional"
);
auto
source_element
=
source_tensor_type
->
element
();
param
=
NewCNode
({
NewValueNode
(
cast
),
param
,
NewValueNode
(
ref_abs
->
target_type
())},
func_graph
);
if
(
cast_type
!=
nullptr
&&
IsSubType
(
source_element
,
kFloat
)
&&
*
source_element
!=
*
cast_type
)
{
auto
cast
=
prim
::
GetPythonOps
(
"cast"
,
"mindspore.ops.functional"
);
param
=
NewCNode
({
NewValueNode
(
cast
),
param
,
NewValueNode
(
cast_type
)},
func_graph
);
type
=
cast_type
->
type_id
()
==
kNumberTypeFloat16
?
kTensorTypeFP16
:
kTensorTypeFP32
;
}
}
}
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
)
{
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
)
{
param
=
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
},
func_graph
);
write_indices
.
insert
(
i
);
write_indices
.
insert
(
i
);
}
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
// If sig is SignatureEnumRW::kRWRef, not do anything.
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
MS_EXCEPTION
(
TypeError
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" should be a Parameter."
;
!
((
type
->
type_id
()
==
kObjectTypeRef
)
||
(
type
->
type_id
()
==
kObjectTypeRefKey
)))
{
MS_EXCEPTION
(
TypeError
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" should be a Parameter, but "
<<
type
->
ToString
();
}
}
MS_LOG
(
DEBUG
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" "
<<
param
->
DebugString
(
2
)
<<
" type "
MS_LOG
(
DEBUG
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" "
<<
param
->
DebugString
(
2
)
<<
" type "
<<
args_spec_list
[
i
]
->
ToString
();
<<
args_spec_list
[
i
]
->
ToString
();
input_types
.
push_back
(
type
);
op_inputs
.
push_back
(
param
);
op_inputs
.
push_back
(
param
);
}
}
// process default
// process default
ProcessDefault
(
func_name
,
args_spec_list
,
signature
,
has_var
,
&
op_inputs
);
ProcessDefault
(
func_name
,
args_spec_list
.
size
()
,
signature
,
has_var
,
&
op_inputs
);
DoAutoCast
(
func_name
,
signature
,
args_spec_list
,
func_graph
,
&
op_inputs
,
write_indices
);
DoAutoCast
(
func_name
,
signature
,
input_types
,
func_graph
,
&
op_inputs
,
write_indices
);
return
func_graph
->
NewCNode
(
op_inputs
);
return
func_graph
->
NewCNode
(
op_inputs
);
}
}
}
// namespace
}
// namespace
...
...
mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc
浏览文件 @
95212b55
...
@@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &
...
@@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &
}
}
Register
(
types_name
,
py_fn
);
Register
(
types_name
,
py_fn
);
}
}
static
TypePtr
UnwrapRef
(
const
TypePtr
&
type
)
{
if
(
type
->
isa
<
RefType
>
())
{
return
type
->
cast
<
RefTypePtr
>
()
->
subtype
();
}
return
type
;
}
// Return Exact match if exists, else return non ambiguous sub class match
// Return Exact match if exists, else return non ambiguous sub class match
// Return py::none() if matching is ambiguous
// Return py::none() if matching is ambiguous
...
@@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
...
@@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
}
}
auto
match
=
true
;
auto
match
=
true
;
for
(
size_t
i
=
0
;
i
<
sign
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sign
.
size
();
++
i
)
{
if
(
!
IsIdentidityOrSubclass
(
UnwrapRef
(
types
[
i
])
,
sign
[
i
]))
{
if
(
!
IsIdentidityOrSubclass
(
types
[
i
]
,
sign
[
i
]))
{
match
=
false
;
match
=
false
;
break
;
break
;
}
}
...
...
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
浏览文件 @
95212b55
...
@@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
...
@@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return
std
::
make_shared
<
AbstractClass
>
(
cls
->
tag
(),
abs_attributes
,
cls
->
methods
());
return
std
::
make_shared
<
AbstractClass
>
(
cls
->
tag
(),
abs_attributes
,
cls
->
methods
());
}
}
AbstractBasePtr
InferImplAssign
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
2
);
MS_LOG
(
DEBUG
)
<<
"InferImplAssign "
<<
args_spec_list
[
0
];
return
args_spec_list
[
0
];
}
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
TypeOf
,
prim
::
kPrimTypeOf
,
InferImplTypeof
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
TypeOf
,
prim
::
kPrimTypeOf
,
InferImplTypeof
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
HasType
,
prim
::
kPrimHasType
,
InferImplHasType
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
HasType
,
prim
::
kPrimHasType
,
InferImplHasType
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
MakeRecord
,
prim
::
kPrimMakeRecord
,
InferImplMakeRecord
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
MakeRecord
,
prim
::
kPrimMakeRecord
,
InferImplMakeRecord
);
...
@@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl
...
@@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
J
,
prim
::
kPrimJ
,
InferImplJ
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
J
,
prim
::
kPrimJ
,
InferImplJ
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
BroadcastGradientArgs
,
prim
::
kPrimBroadcastGradientArgs
,
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
BroadcastGradientArgs
,
prim
::
kPrimBroadcastGradientArgs
,
InferImplBroadcastGradientArgs
);
InferImplBroadcastGradientArgs
);
REGISTER_PRIMITIVE_EVAL_IMPL
(
Assign
,
prim
::
kPrimAssign
,
InferImplAssign
);
}
// namespace abstract
}
// namespace abstract
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
浏览文件 @
95212b55
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "ir/anf.h"
#include "ir/anf.h"
#include "ir/param_info.h"
#include "ir/param_info.h"
#include "ir/meta_tensor.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/python_adapter.h"
namespace
mindspore
{
namespace
mindspore
{
...
@@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
...
@@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if
(
!
para_ptr
->
has_default
())
{
if
(
!
para_ptr
->
has_default
())
{
return
false
;
return
false
;
}
}
auto
obj
=
py
::
cast
(
para_ptr
->
default_param
());
auto
param_value
=
para_ptr
->
param_info
();
auto
param_value
=
py
::
cast
<
ParamValuePtr
>
(
obj
.
attr
(
"_value"
));
if
(
param_value
==
nullptr
)
{
if
(
param_value
==
nullptr
)
{
return
false
;
return
false
;
}
}
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
95212b55
...
@@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
...
@@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
if
(
!
cloned_parameter
->
has_default
())
{
if
(
!
cloned_parameter
->
has_default
())
{
return
false
;
return
false
;
}
}
auto
obj
=
py
::
cast
(
cloned_parameter
->
default_param
());
auto
param_value
=
cloned_parameter
->
param_info
();
auto
param_value
=
py
::
cast
<
ParamValuePtr
>
(
obj
.
attr
(
"_value"
));
if
(
param_value
==
nullptr
)
{
if
(
param_value
==
nullptr
)
{
return
false
;
return
false
;
}
}
...
@@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
...
@@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if
(
!
ParameterIsCloned
(
cloned_parameter_node
))
{
if
(
!
ParameterIsCloned
(
cloned_parameter_node
))
{
continue
;
continue
;
}
}
auto
obj
=
py
::
cast
(
cloned_parameter
->
default_param
());
auto
param_value
=
cloned_parameter
->
param_info
();
auto
param_value
=
py
::
cast
<
ParamValuePtr
>
(
obj
.
attr
(
"_value"
));
if
(
param_value
==
nullptr
)
{
if
(
param_value
==
nullptr
)
{
continue
;
continue
;
}
}
...
@@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
...
@@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
continue
;
continue
;
}
}
const
auto
&
param_value_cloned
=
be_cloned_parameter
->
default_param
();
auto
param_value_in
=
be_cloned_parameter
->
param_info
();
auto
obj_in
=
py
::
cast
(
param_value_cloned
);
auto
param_value_in
=
py
::
cast
<
ParamValuePtr
>
(
obj_in
.
attr
(
"_value"
));
if
(
param_value_in
==
nullptr
)
{
if
(
param_value_in
==
nullptr
)
{
continue
;
continue
;
}
}
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
95212b55
...
@@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
...
@@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for
(
const
auto
&
param
:
func_graph
->
parameters
())
{
for
(
const
auto
&
param
:
func_graph
->
parameters
())
{
auto
param_node
=
std
::
static_pointer_cast
<
Parameter
>
(
param
);
auto
param_node
=
std
::
static_pointer_cast
<
Parameter
>
(
param
);
if
(
param_node
->
has_default
())
{
if
(
param_node
->
has_default
())
{
ValuePtr
value
=
param_node
->
default_param
();
auto
value
=
param_node
->
default_param
();
constexpr
bool
broaden
=
true
;
auto
abs_value
=
value
->
ToAbstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
value
,
broaden
);
auto
ref_key
=
std
::
make_shared
<
RefKey
>
(
param_node
->
name
());
auto
abs_ref_key
=
ref_key
->
ToAbstract
();
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
ptr
);
auto
abs_ref
=
std
::
make_shared
<
abstract
::
AbstractRef
>
(
abs_ref_key
,
abs_value
);
args_spec
.
push_back
(
ptr
);
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
abs_ref
);
parallel
::
ParallelParameterContextCkptInTraining
(
func_graph
,
param_node
,
ptr
);
args_spec
.
push_back
(
abs_ref
);
parallel
::
ParallelParameterContextCkptInTraining
(
func_graph
,
param_node
,
abs_ref
);
}
}
}
}
// Analyze
// Analyze
...
...
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
浏览文件 @
95212b55
...
@@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
...
@@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
converted
=
env
;
converted
=
env
;
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_CLASS_MEMBER_NAMESPACE
))
{
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_CLASS_MEMBER_NAMESPACE
))
{
converted
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
obj
);
converted
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
obj
);
}
else
if
(
py
::
hasattr
(
obj
,
"__parameter__"
))
{
auto
to_convert
=
py
::
cast
<
py
::
object
>
(
python_adapter
::
GetPyObjAttr
(
obj
,
"default_input"
));
ret
=
ConvertData
(
to_convert
,
&
converted
);
}
else
{
}
else
{
ret
=
ConvertOtherObj
(
obj
,
&
converted
);
ret
=
ConvertOtherObj
(
obj
,
&
converted
);
}
}
...
@@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name)
...
@@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name)
ValuePtr
PyDataToValue
(
const
py
::
object
&
obj
)
{
ValuePtr
PyDataToValue
(
const
py
::
object
&
obj
)
{
py
::
object
to_convert
=
obj
;
py
::
object
to_convert
=
obj
;
if
(
py
::
hasattr
(
obj
,
"__parameter__"
))
{
to_convert
=
py
::
cast
<
py
::
object
>
(
python_adapter
::
GetPyObjAttr
(
obj
,
"default_input"
));
}
ValuePtr
value
=
nullptr
;
ValuePtr
value
=
nullptr
;
(
void
)
ConvertData
(
to_convert
,
&
value
);
(
void
)
ConvertData
(
to_convert
,
&
value
);
return
value
;
return
value
;
...
...
mindspore/ccsrc/pipeline/jit/parse/function_block.cc
浏览文件 @
95212b55
...
@@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
...
@@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
}
}
void
FunctionBlock
::
SetStateAssgin
(
const
AnfNodePtr
&
target
,
const
std
::
string
&
readid
)
{
void
FunctionBlock
::
SetStateAssgin
(
const
AnfNodePtr
&
target
,
const
std
::
string
&
readid
)
{
state_assign_
[
target
]
=
readid
;
const
std
::
string
primitive_name
(
"assign"
);
const
std
::
string
module_name
(
"mindspore.ops.functional"
);
ValueNodePtr
assign_op
=
NewValueNode
(
prim
::
GetPythonOps
(
primitive_name
,
module_name
,
true
));
auto
source
=
ReadVariable
(
readid
);
auto
assign
=
func_graph
()
->
NewCNode
({
assign_op
,
target
,
source
});
WriteVariable
(
readid
,
assign
);
MS_LOG
(
INFO
)
<<
"SetState read "
<<
target
->
DebugString
()
<<
", "
<<
readid
;
AddAutoDepend
(
assign
);
}
}
void
FunctionBlock
::
AddAutoDepend
(
const
AnfNodePtr
&
target
)
{
auto_depends_
.
push_back
(
target
);
}
void
FunctionBlock
::
AddAutoDepend
(
const
AnfNodePtr
&
target
)
{
auto_depends_
.
push_back
(
target
);
}
...
@@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
...
@@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr
make_tuple_op
=
NewValueNode
(
prim
::
kPrimMakeTuple
);
ValueNodePtr
make_tuple_op
=
NewValueNode
(
prim
::
kPrimMakeTuple
);
ValueNodePtr
depend_op
=
NewValueNode
(
prim
::
kPrimDepend
);
ValueNodePtr
depend_op
=
NewValueNode
(
prim
::
kPrimDepend
);
ValueNodePtr
stop_gradient_op
=
NewValueNode
(
prim
::
kPrimStopGradient
);
ValueNodePtr
stop_gradient_op
=
NewValueNode
(
prim
::
kPrimStopGradient
);
const
std
::
string
primitive_name
(
"assign"
);
const
std
::
string
module_name
(
"mindspore.ops.functional"
);
if
(
auto_depends_
.
size
()
==
0
)
{
ValueNodePtr
assign_op
=
NewValueNode
(
prim
::
GetPythonOps
(
primitive_name
,
module_name
,
true
));
if
(
state_assign_
.
size
()
==
0
&&
auto_depends_
.
size
()
==
0
)
{
return
;
return
;
}
}
AnfNodePtr
state
=
nullptr
;
AnfNodePtr
state
=
nullptr
;
std
::
vector
<
AnfNodePtr
>
vec_states
;
std
::
vector
<
AnfNodePtr
>
vec_states
;
vec_states
.
emplace_back
(
make_tuple_op
);
vec_states
.
emplace_back
(
make_tuple_op
);
for
(
auto
&
item
:
state_assign_
)
{
auto
source
=
ReadVariable
(
item
.
second
);
auto
assign
=
func_graph
()
->
NewCNode
({
assign_op
,
item
.
first
,
source
});
MS_LOG
(
INFO
)
<<
"SetState read "
<<
item
.
first
->
ToString
()
<<
", "
<<
item
.
second
;
vec_states
.
emplace_back
(
assign
);
}
for
(
auto
&
item
:
auto_depends_
)
{
for
(
auto
&
item
:
auto_depends_
)
{
MS_LOG
(
DEBUG
)
<<
"auto_depends "
<<
item
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"auto_depends "
<<
item
->
ToString
();
vec_states
.
emplace_back
(
item
);
vec_states
.
emplace_back
(
item
);
...
@@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
...
@@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
AnfNodePtr
stopped
=
func_graph
()
->
NewCNode
({
stop_gradient_op
,
state
});
AnfNodePtr
stopped
=
func_graph
()
->
NewCNode
({
stop_gradient_op
,
state
});
AnfNodePtr
ret
=
func_graph
()
->
NewCNode
({
depend_op
,
old_ret
,
stopped
});
AnfNodePtr
ret
=
func_graph
()
->
NewCNode
({
depend_op
,
old_ret
,
stopped
});
func_graph
()
->
set_output
(
ret
,
true
);
func_graph
()
->
set_output
(
ret
,
true
);
state_assign_
.
clear
();
}
}
}
// namespace parse
}
// namespace parse
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/parse/function_block.h
浏览文件 @
95212b55
...
@@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
...
@@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// keeps all removable phis which will be removed in one pass.
// keeps all removable phis which will be removed in one pass.
std
::
unordered_map
<
ParameterPtr
,
AnfNodePtr
>
removable_phis_
;
std
::
unordered_map
<
ParameterPtr
,
AnfNodePtr
>
removable_phis_
;
// set state nodes need to insert before function return nodes.
OrderedMap
<
AnfNodePtr
,
std
::
string
>
state_assign_
;
// hold declared global variables in function
// hold declared global variables in function
std
::
set
<
std
::
string
>
global_vars_
;
std
::
set
<
std
::
string
>
global_vars_
;
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
95212b55
...
@@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
...
@@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return
func_graph
;
return
func_graph
;
}
}
ValuePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
)
{
TypePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
)
{
TypePtr
dst_type
;
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP32
))
{
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP32
))
{
return
kFloat32
;
return
kFloat32
;
}
else
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP16
))
{
}
else
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP16
))
{
return
kFloat16
;
return
kFloat16
;
}
else
{
}
else
{
return
kNone
;
return
nullptr
;
}
}
}
}
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.h
浏览文件 @
95212b55
...
@@ -364,7 +364,7 @@ class ParseAst {
...
@@ -364,7 +364,7 @@ class ParseAst {
bool
UpdateFuncGraphFlags
(
py
::
object
obj
,
const
FuncGraphPtr
&
func_graph
);
bool
UpdateFuncGraphFlags
(
py
::
object
obj
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
ValuePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
TypePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
);
}
// namespace parse
}
// namespace parse
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
浏览文件 @
95212b55
...
@@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
...
@@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
auto
value
=
py
::
cast
<
tensor
::
MetaTensorPtr
>
(
obj
);
auto
value
=
py
::
cast
<
tensor
::
MetaTensorPtr
>
(
obj
);
node
->
set_default_param
(
value
);
node
->
set_default_param
(
value
);
// set_abstract for parameter
// set_abstract for parameter
constexpr
bool
broaden
=
true
;
auto
abs
=
value
->
ToAbstract
()
;
node
->
set_abstract
(
abs
tract
::
FromValue
(
value
,
broaden
)
);
node
->
set_abstract
(
abs
);
para_node
=
node
;
para_node
=
node
;
}
}
auto
iter
=
func_graph
->
make_ref_params
().
find
(
para_node
);
if
(
iter
==
func_graph
->
make_ref_params
().
end
())
{
return
para_node
;
ValuePtr
target_type
=
GetMixedPrecisionTargetType
(
func_graph
,
para_node
);
AnfNodePtr
make_ref
=
NewValueNode
(
prim
::
kPrimMakeRef
);
AnfNodePtr
ref_key
=
NewValueNode
(
std
::
make_shared
<
RefKey
>
(
param_name
));
AnfNodePtr
target_type_node
=
NewValueNode
(
target_type
);
AnfNodePtr
ref_node
=
func_graph
->
NewCNode
({
make_ref
,
ref_key
,
para_node
,
target_type_node
});
func_graph
->
make_ref_params
()[
para_node
]
=
ref_node
;
func_graph
->
add_parameter_obj_node
(
ref_node
);
return
ref_node
;
}
else
{
return
iter
->
second
;
}
}
}
bool
ResolveObjectToNode
(
const
FuncGraphPtr
&
func_graph
,
const
py
::
object
&
obj
,
AnfNodePtr
*
const
node
)
{
bool
ResolveObjectToNode
(
const
FuncGraphPtr
&
func_graph
,
const
py
::
object
&
obj
,
AnfNodePtr
*
const
node
)
{
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
95212b55
...
@@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
...
@@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
size_t
size
=
op_exec_info
->
op_inputs
.
size
();
size_t
size
=
op_exec_info
->
op_inputs
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
auto
obj
=
op_exec_info
->
op_inputs
[
i
];
auto
obj
=
op_exec_info
->
op_inputs
[
i
];
bool
op_mask
=
py
::
hasattr
(
obj
,
"__parameter__"
);
bool
op_mask
=
false
;
if
(
py
::
isinstance
<
tensor
::
MetaTensor
>
(
obj
))
{
auto
meta_tensor
=
obj
.
cast
<
tensor
::
MetaTensorPtr
>
();
if
(
meta_tensor
)
{
op_mask
=
meta_tensor
->
is_parameter
();
}
}
(
*
op_masks
).
push_back
(
op_mask
);
(
*
op_masks
).
push_back
(
op_mask
);
MS_LOG
(
DEBUG
)
<<
"gen "
<<
op_exec_info
->
op_name
<<
" arg "
<<
i
<<
": op mask "
<<
op_mask
<<
" grad_flag_ "
MS_LOG
(
DEBUG
)
<<
"gen "
<<
op_exec_info
->
op_name
<<
" arg "
<<
i
<<
": op mask "
<<
op_mask
<<
" grad_flag_ "
<<
grad_flag_
;
<<
grad_flag_
;
...
@@ -990,8 +997,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
...
@@ -990,8 +997,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
if
(
graph_info_map_
[
df_builder_
].
param_map
.
count
(
obj_id
)
==
0
)
{
if
(
graph_info_map_
[
df_builder_
].
param_map
.
count
(
obj_id
)
==
0
)
{
auto
free_param
=
df_builder_
->
add_parameter
();
auto
free_param
=
df_builder_
->
add_parameter
();
free_param
->
set_name
(
param_name
);
free_param
->
set_name
(
param_name
);
free_param
->
set_default_param
(
py
::
cast
<
tensor
::
TensorPtr
>
(
obj
));
free_param
->
debug_info
()
->
set_name
(
param_name
);
free_param
->
debug_info
()
->
set_name
(
param_name
);
auto
value
=
py
::
cast
<
tensor
::
TensorPtr
>
(
obj
);
free_param
->
set_default_param
(
value
);
MS_LOG
(
DEBUG
)
<<
"Top graph set free parameter "
<<
obj_id
;
MS_LOG
(
DEBUG
)
<<
"Top graph set free parameter "
<<
obj_id
;
graph_info_map_
[
df_builder_
].
param_map
[
obj_id
]
=
free_param
;
graph_info_map_
[
df_builder_
].
param_map
[
obj_id
]
=
free_param
;
return
free_param
;
return
free_param
;
...
@@ -1159,17 +1167,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
...
@@ -1159,17 +1167,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
auto
param_name
=
py
::
cast
<
std
::
string
>
(
name_attr
);
auto
param_name
=
py
::
cast
<
std
::
string
>
(
name_attr
);
auto
free_param
=
df_builder_
->
add_parameter
();
auto
free_param
=
df_builder_
->
add_parameter
();
free_param
->
set_name
(
param_name
);
free_param
->
set_name
(
param_name
);
free_param
->
set_default_param
(
py
::
cast
<
tensor
::
TensorPtr
>
(
param
));
auto
value
=
py
::
cast
<
tensor
::
TensorPtr
>
(
param
);
free_param
->
set_default_param
(
value
);
free_param
->
debug_info
()
->
set_name
(
param_name
);
free_param
->
debug_info
()
->
set_name
(
param_name
);
para_node
=
free_param
;
para_node
=
free_param
;
}
}
ValuePtr
target_type
=
parse
::
GetMixedPrecisionTargetType
(
df_builder_
,
para_node
);
w_args
.
push_back
(
para_node
);
AnfNodePtr
make_ref
=
NewValueNode
(
prim
::
kPrimMakeRef
);
auto
refkey
=
std
::
make_shared
<
RefKey
>
(
para_node
->
cast
<
ParameterPtr
>
()
->
name
());
AnfNodePtr
ref_key_node
=
NewValueNode
(
refkey
);
AnfNodePtr
target_type_node
=
NewValueNode
(
target_type
);
AnfNodePtr
ref_node
=
df_builder_
->
NewCNode
({
make_ref
,
ref_key_node
,
para_node
,
target_type_node
});
w_args
.
push_back
(
ref_node
);
}
}
}
else
{
}
else
{
MS_LOG
(
DEBUG
)
<<
"training not paramter_tuple"
;
MS_LOG
(
DEBUG
)
<<
"training not paramter_tuple"
;
...
@@ -1197,7 +1200,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
...
@@ -1197,7 +1200,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
auto
param_node
=
std
::
static_pointer_cast
<
Parameter
>
(
param
);
auto
param_node
=
std
::
static_pointer_cast
<
Parameter
>
(
param
);
if
(
param_node
->
has_default
())
{
if
(
param_node
->
has_default
())
{
ValuePtr
value
=
param_node
->
default_param
();
ValuePtr
value
=
param_node
->
default_param
();
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
value
,
true
);
auto
ptr
=
value
->
ToAbstract
(
);
if
(
ptr
==
nullptr
)
{
if
(
ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Args convert error"
;
MS_LOG
(
EXCEPTION
)
<<
"Args convert error"
;
}
}
...
...
mindspore/ccsrc/pybind_api/ir/dtype_py.cc
浏览文件 @
95212b55
...
@@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE(
...
@@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE(
(
void
)
py
::
class_
<
TypeType
,
Type
,
std
::
shared_ptr
<
TypeType
>>
(
m_sub
,
"TypeType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeType
,
Type
,
std
::
shared_ptr
<
TypeType
>>
(
m_sub
,
"TypeType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
String
,
Type
,
std
::
shared_ptr
<
String
>>
(
m_sub
,
"String"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
String
,
Type
,
std
::
shared_ptr
<
String
>>
(
m_sub
,
"String"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefKeyType
,
Type
,
std
::
shared_ptr
<
RefKeyType
>>
(
m_sub
,
"RefKeyType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefKeyType
,
Type
,
std
::
shared_ptr
<
RefKeyType
>>
(
m_sub
,
"RefKeyType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefType
,
Type
,
std
::
shared_ptr
<
RefType
>>
(
m_sub
,
"RefType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefType
,
T
ensorType
,
T
ype
,
std
::
shared_ptr
<
RefType
>>
(
m_sub
,
"RefType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeAnything
,
Type
,
std
::
shared_ptr
<
TypeAnything
>>
(
m_sub
,
"TypeAnything"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeAnything
,
Type
,
std
::
shared_ptr
<
TypeAnything
>>
(
m_sub
,
"TypeAnything"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
Slice
,
Type
,
std
::
shared_ptr
<
Slice
>>
(
m_sub
,
"Slice"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
Slice
,
Type
,
std
::
shared_ptr
<
Slice
>>
(
m_sub
,
"Slice"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeEllipsis
,
Type
,
std
::
shared_ptr
<
TypeEllipsis
>>
(
m_sub
,
"TypeEllipsis"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeEllipsis
,
Type
,
std
::
shared_ptr
<
TypeEllipsis
>>
(
m_sub
,
"TypeEllipsis"
).
def
(
py
::
init
());
...
...
mindspore/ccsrc/pybind_api/ir/param_info_py.cc
浏览文件 @
95212b55
...
@@ -21,7 +21,7 @@ namespace mindspore {
...
@@ -21,7 +21,7 @@ namespace mindspore {
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
REGISTER_PYBIND_DEFINE
(
ParamInfo
,
([](
const
py
::
module
*
m
)
{
REGISTER_PYBIND_DEFINE
(
ParamInfo
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
ParamInfo
,
Param
Value
Ptr
>
(
*
m
,
"ParamInfo"
)
(
void
)
py
::
class_
<
ParamInfo
,
Param
Info
Ptr
>
(
*
m
,
"ParamInfo"
)
.
def
(
py
::
init
())
.
def
(
py
::
init
())
.
def
(
"clone"
,
&
ParamInfo
::
Clone
)
.
def
(
"clone"
,
&
ParamInfo
::
Clone
)
.
def_property
(
"name"
,
&
ParamInfo
::
name
,
&
ParamInfo
::
set_name
)
.
def_property
(
"name"
,
&
ParamInfo
::
name
,
&
ParamInfo
::
set_name
)
...
@@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
...
@@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
if
(
t
.
size
()
!=
6
)
{
if
(
t
.
size
()
!=
6
)
{
std
::
runtime_error
(
"Invalid state for ParamInfo!"
);
std
::
runtime_error
(
"Invalid state for ParamInfo!"
);
}
}
Param
Value
Ptr
p
=
std
::
make_shared
<
ParamInfo
>
();
Param
Info
Ptr
p
=
std
::
make_shared
<
ParamInfo
>
();
p
->
set_name
(
t
[
1
].
cast
<
std
::
string
>
());
p
->
set_name
(
t
[
1
].
cast
<
std
::
string
>
());
p
->
set_requires_grad
(
t
[
2
].
cast
<
bool
>
());
p
->
set_requires_grad
(
t
[
2
].
cast
<
bool
>
());
p
->
set_layerwise_parallel
(
t
[
3
].
cast
<
bool
>
());
p
->
set_layerwise_parallel
(
t
[
3
].
cast
<
bool
>
());
...
...
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
浏览文件 @
95212b55
...
@@ -291,6 +291,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
...
@@ -291,6 +291,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.
def
(
py
::
init
<
TypePtr
,
const
std
::
vector
<
int
>>
(),
py
::
arg
(
"dtype"
),
py
::
arg
(
"shape"
))
.
def
(
py
::
init
<
TypePtr
,
const
std
::
vector
<
int
>>
(),
py
::
arg
(
"dtype"
),
py
::
arg
(
"shape"
))
.
def_property_readonly
(
"dtype"
,
&
MetaTensor
::
Dtype
,
"Get the MetaTensor's dtype."
)
.
def_property_readonly
(
"dtype"
,
&
MetaTensor
::
Dtype
,
"Get the MetaTensor's dtype."
)
.
def_property_readonly
(
"shape"
,
&
MetaTensor
::
shape
,
"Get the MetaTensor's shape."
)
.
def_property_readonly
(
"shape"
,
&
MetaTensor
::
shape
,
"Get the MetaTensor's shape."
)
.
def_property
(
"_param_info"
,
&
MetaTensor
::
param_info
,
&
MetaTensor
::
set_param_info
)
.
def
(
py
::
pickle
(
.
def
(
py
::
pickle
(
[](
const
MetaTensor
&
t
)
{
// __getstate__
[](
const
MetaTensor
&
t
)
{
// __getstate__
/* Return a tuple that fully encodes the state of the object */
/* Return a tuple that fully encodes the state of the object */
...
...
mindspore/common/parameter.py
浏览文件 @
95212b55
...
@@ -42,7 +42,7 @@ class Parameter(MetaTensor):
...
@@ -42,7 +42,7 @@ class Parameter(MetaTensor):
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
compil
e
for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
compil
ing
for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
Note:
Note:
Each parameter of Cell is represented by Parameter class.
Each parameter of Cell is represented by Parameter class.
...
@@ -108,7 +108,7 @@ class Parameter(MetaTensor):
...
@@ -108,7 +108,7 @@ class Parameter(MetaTensor):
Parameter
,
(
data
,
self
.
name
,
self
.
requires_grad
,
self
.
layerwise_parallel
))
Parameter
,
(
data
,
self
.
name
,
self
.
requires_grad
,
self
.
layerwise_parallel
))
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
):
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
):
self
.
_
value
=
ParamInfo
()
self
.
_
param_info
=
ParamInfo
()
self
.
name
=
name
self
.
name
=
name
self
.
requires_grad
=
requires_grad
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
self
.
layerwise_parallel
=
layerwise_parallel
...
@@ -156,13 +156,13 @@ class Parameter(MetaTensor):
...
@@ -156,13 +156,13 @@ class Parameter(MetaTensor):
value_str
=
MetaTensor
.
__str__
(
self
)
value_str
=
MetaTensor
.
__str__
(
self
)
if
isinstance
(
self
,
Tensor
):
if
isinstance
(
self
,
Tensor
):
value_str
=
Tensor
.
__str__
(
self
)
value_str
=
Tensor
.
__str__
(
self
)
return
f
'Parameter (name=
{
self
.
_
value
.
name
}
, value=
{
value_str
}
)'
return
f
'Parameter (name=
{
self
.
_
param_info
.
name
}
, value=
{
value_str
}
)'
def
__repr__
(
self
):
def
__repr__
(
self
):
value_str
=
MetaTensor
.
__repr__
(
self
)
value_str
=
MetaTensor
.
__repr__
(
self
)
if
isinstance
(
self
,
Tensor
):
if
isinstance
(
self
,
Tensor
):
value_str
=
Tensor
.
__repr__
(
self
)
value_str
=
Tensor
.
__repr__
(
self
)
return
f
'Parameter (name=
{
self
.
_
value
.
name
}
, value=
{
value_str
}
)'
return
f
'Parameter (name=
{
self
.
_
param_info
.
name
}
, value=
{
value_str
}
)'
def
__parameter__
(
self
):
def
__parameter__
(
self
):
"""For parse check."""
"""For parse check."""
...
@@ -181,7 +181,7 @@ class Parameter(MetaTensor):
...
@@ -181,7 +181,7 @@ class Parameter(MetaTensor):
@
property
@
property
def
name
(
self
):
def
name
(
self
):
"""Get the name of the parameter."""
"""Get the name of the parameter."""
return
self
.
_
value
.
name
return
self
.
_
param_info
.
name
@
name
.
setter
@
name
.
setter
def
name
(
self
,
name_
):
def
name
(
self
,
name_
):
...
@@ -203,7 +203,7 @@ class Parameter(MetaTensor):
...
@@ -203,7 +203,7 @@ class Parameter(MetaTensor):
format
(
name_
,
PARAMETER_NAME_PREFIX_MAX_LEN
))
format
(
name_
,
PARAMETER_NAME_PREFIX_MAX_LEN
))
else
:
else
:
raise
ValueError
(
"The type of the name should be `str` or `None`."
)
raise
ValueError
(
"The type of the name should be `str` or `None`."
)
self
.
_
value
.
name
=
name_
self
.
_
param_info
.
name
=
name_
@
property
@
property
def
cast_type
(
self
):
def
cast_type
(
self
):
...
@@ -254,8 +254,8 @@ class Parameter(MetaTensor):
...
@@ -254,8 +254,8 @@ class Parameter(MetaTensor):
_check_str_by_regular
(
prefix
)
_check_str_by_regular
(
prefix
)
x
=
copy
(
self
)
x
=
copy
(
self
)
# pylint: disable=protected-access
# pylint: disable=protected-access
x
.
_
value
=
self
.
_value
.
clone
()
x
.
_
param_info
=
self
.
_param_info
.
clone
()
x
.
_
value
.
name
=
prefix
+
'.'
+
self
.
_value
.
name
x
.
_
param_info
.
name
=
prefix
+
'.'
+
self
.
_param_info
.
name
x
.
is_init
=
False
x
.
is_init
=
False
if
init
!=
'same'
:
if
init
!=
'same'
:
shape
=
self
.
shape
shape
=
self
.
shape
...
@@ -265,24 +265,24 @@ class Parameter(MetaTensor):
...
@@ -265,24 +265,24 @@ class Parameter(MetaTensor):
@
property
@
property
def
layerwise_parallel
(
self
):
def
layerwise_parallel
(
self
):
return
self
.
_
value
.
layerwise_parallel
return
self
.
_
param_info
.
layerwise_parallel
@
layerwise_parallel
.
setter
@
layerwise_parallel
.
setter
def
layerwise_parallel
(
self
,
value
=
True
):
def
layerwise_parallel
(
self
,
value
=
True
):
if
not
isinstance
(
value
,
bool
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`layerwise_parallel` parameter must be bool type"
)
raise
TypeError
(
"`layerwise_parallel` parameter must be bool type"
)
self
.
_
value
.
layerwise_parallel
=
value
self
.
_
param_info
.
layerwise_parallel
=
value
@
property
@
property
def
requires_grad
(
self
):
def
requires_grad
(
self
):
"""Return whether the parameter requires gradient."""
"""Return whether the parameter requires gradient."""
return
self
.
_
value
.
requires_grad
return
self
.
_
param_info
.
requires_grad
@
requires_grad
.
setter
@
requires_grad
.
setter
def
requires_grad
(
self
,
value
=
True
):
def
requires_grad
(
self
,
value
=
True
):
if
not
isinstance
(
value
,
bool
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`requires_grad` parameter must be bool type"
)
raise
TypeError
(
"`requires_grad` parameter must be bool type"
)
self
.
_
value
.
requires_grad
=
value
self
.
_
param_info
.
requires_grad
=
value
@
property
@
property
def
data
(
self
):
def
data
(
self
):
...
...
mindspore/core/abstract/abstract_value.cc
浏览文件 @
95212b55
...
@@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
...
@@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
}
auto
other_tensor
=
dyn_cast
<
AbstractTensor
>
(
other
);
auto
other_tensor
=
dyn_cast
<
AbstractTensor
>
(
other
);
if
(
other_tensor
==
nullptr
)
{
if
(
other_tensor
==
nullptr
)
{
auto
ref_tensor
=
dyn_cast
<
AbstractRef
>
(
other
);
if
(
ref_tensor
!=
nullptr
)
{
return
this
->
Join
(
ref_tensor
->
ref
());
}
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
}
}
if
(
*
this
==
*
other
)
{
if
(
*
this
==
*
other
)
{
...
@@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
...
@@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
return
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
return
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
}
}
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
bool
AbstractTensor
::
equal_to
(
const
AbstractTensor
&
other
)
const
{
if
(
&
other
==
this
)
{
if
(
&
other
==
this
)
{
return
true
;
return
true
;
}
}
...
@@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const {
...
@@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const {
return
(
*
element_
==
*
other
.
element_
)
&&
(
*
shape
()
==
*
other
.
shape
())
&&
is_value_equal
;
return
(
*
element_
==
*
other
.
element_
)
&&
(
*
shape
()
==
*
other
.
shape
())
&&
is_value_equal
;
}
}
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
return
equal_to
(
other
);
}
bool
AbstractTensor
::
operator
==
(
const
AbstractBase
&
other
)
const
{
bool
AbstractTensor
::
operator
==
(
const
AbstractBase
&
other
)
const
{
if
(
&
other
==
this
)
{
if
(
&
other
==
this
)
{
return
true
;
return
true
;
}
}
if
(
other
.
isa
<
AbstractTensor
>
())
{
if
(
other
.
tid
()
==
tid
())
{
auto
other_tensor
=
static_cast
<
const
AbstractTensor
*>
(
&
other
);
auto
other_tensor
=
static_cast
<
const
AbstractTensor
*>
(
&
other
);
return
*
this
==
*
other_tensor
;
return
*
this
==
*
other_tensor
;
}
else
{
}
else
{
...
@@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const {
...
@@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const {
return
buffer
.
str
();
return
buffer
.
str
();
}
}
AbstractRef
::
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
bool
need_cast
,
AbstractRef
::
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractTensorPtr
&
ref_value
)
TypePtr
cast_target
)
:
AbstractTensor
(
*
ref_value
),
ref_key_
(
ref_key
),
ref_key_value_
(
nullptr
)
{
:
ref_key_
(
ref_key
),
ref_
(
ref_value
),
need_cast_
(
false
),
target_type_
(
nullptr
),
ref_key_value_
(
nullptr
)
{
set_type
(
std
::
make_shared
<
RefType
>
());
set_type
(
std
::
make_shared
<
RefType
>
());
auto
origin_type
=
ref_value
->
BuildType
();
if
(
need_cast
&&
cast_target
&&
origin_type
&&
origin_type
->
isa
<
TensorType
>
())
{
auto
tensor_dtype
=
origin_type
->
cast
<
TensorTypePtr
>
()
->
element
();
if
(
tensor_dtype
&&
IsSubType
(
tensor_dtype
,
kFloat
))
{
if
(
cast_target
!=
tensor_dtype
)
{
need_cast_
=
true
;
target_type_
=
cast_target
;
}
}
}
if
(
ref_key
&&
ref_key
->
isa
<
AbstractRefKey
>
())
{
if
(
ref_key
&&
ref_key
->
isa
<
AbstractRefKey
>
())
{
ref_key_value_
=
ref_key
->
cast
<
AbstractRefKeyPtr
>
()
->
ref_key_value
();
ref_key_value_
=
ref_key
->
cast
<
AbstractRefKeyPtr
>
()
->
ref_key_value
();
}
}
}
}
BaseShapePtr
AbstractRef
::
BuildShape
()
const
{
return
ref_
->
BuildShape
();
}
TypePtr
AbstractRef
::
BuildType
()
const
{
TypePtr
AbstractRef
::
BuildType
()
const
{
TypePtr
subtype
=
ref_
->
BuildType
();
auto
subtype
=
AbstractTensor
::
BuildType
()
->
cast
<
TensorTypePtr
>
();
TypePtr
subtype_origin
=
subtype
;
return
std
::
make_shared
<
RefType
>
(
subtype
);
if
(
need_cast_
)
{
subtype_origin
=
std
::
make_shared
<
TensorType
>
(
target_type_
);
}
return
std
::
make_shared
<
RefType
>
(
subtype
,
subtype_origin
);
}
}
bool
AbstractRef
::
operator
==
(
const
AbstractRef
&
other
)
const
{
bool
AbstractRef
::
operator
==
(
const
AbstractRef
&
other
)
const
{
return
(
*
ref_
==
*
other
.
ref_
)
&&
(
need_cast_
==
other
.
need_cast_
)
&&
(
*
ref_key_
==
*
other
.
ref_key_
)
&&
return
AbstractTensor
::
equal_to
(
other
)
&&
(
*
ref_key_
==
*
other
.
ref_key_
);
(
!
need_cast_
||
(
*
target_type_
==
*
other
.
target_type_
));
}
}
bool
AbstractRef
::
operator
==
(
const
AbstractBase
&
other
)
const
{
bool
AbstractRef
::
operator
==
(
const
AbstractBase
&
other
)
const
{
...
@@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
...
@@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
AbstractBasePtr
AbstractRef
::
Join
(
const
AbstractBasePtr
&
other
)
{
AbstractBasePtr
AbstractRef
::
Join
(
const
AbstractBasePtr
&
other
)
{
auto
other_ref
=
other
->
cast
<
AbstractRefPtr
>
();
auto
other_ref
=
other
->
cast
<
AbstractRefPtr
>
();
if
(
other_ref
==
nullptr
)
{
if
(
other_ref
==
nullptr
)
{
auto
new_ref
=
ref_
->
Join
(
other
);
return
AbstractTensor
::
Join
(
other
)
->
cast
<
AbstractTensorPtr
>
();
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
,
new_ref
);
}
}
if
((
*
this
==
*
other
)
&&
(
*
ref_key_
==
*
other_ref
->
ref_key_
))
{
if
((
*
this
==
*
other
)
&&
(
*
ref_key_
==
*
other_ref
->
ref_key_
))
{
return
shared_from_base
<
AbstractBase
>
();
return
shared_from_base
<
AbstractBase
>
();
}
}
auto
ref_key
=
ref_key_
->
Join
(
other_ref
->
ref_key_
);
auto
ref_key
=
ref_key_
->
Join
(
other_ref
->
ref_key_
);
auto
ref
=
ref_
->
Join
(
other_ref
->
ref
()
);
auto
ref
=
AbstractTensor
::
Join
(
other_ref
->
ref
())
->
cast
<
AbstractTensorPtr
>
(
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key
,
ref
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key
,
ref
);
}
}
std
::
string
AbstractRef
::
ToString
()
const
{
std
::
string
AbstractRef
::
ToString
()
const
{
std
::
ostringstream
buffer
;
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"("
buffer
<<
type_name
()
<<
"("
<<
"key: "
<<
ref_key_
->
ToString
()
<<
" ref_value: "
<<
ref_
->
ToString
();
<<
"key: "
<<
ref_key_
->
ToString
()
<<
" ref_value: "
<<
AbstractTensor
::
ToString
();
if
(
need_cast_
)
{
buffer
<<
" cast to: "
<<
target_type_
->
ToString
();
}
auto
value
=
GetValueTrack
();
auto
value
=
GetValueTrack
();
if
(
value
)
{
if
(
value
)
{
buffer
<<
", value: "
<<
value
->
ToString
();
buffer
<<
", value: "
<<
value
->
ToString
();
...
...
mindspore/core/abstract/abstract_value.h
浏览文件 @
95212b55
...
@@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined {
...
@@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined {
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
BroadenWithShape
()
const
;
AbstractBasePtr
BroadenWithShape
()
const
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
final
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
);
bool
operator
==
(
const
AbstractTensor
&
other
)
const
;
bool
operator
==
(
const
AbstractTensor
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
size_t
hash
()
const
override
{
std
::
size_t
hash
()
const
override
{
auto
value
=
GetValueTrack
();
auto
value
=
GetValueTrack
();
...
@@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined {
...
@@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined {
}
}
return
hash_sum
;
return
hash_sum
;
}
}
protected:
bool
equal_to
(
const
AbstractTensor
&
other
)
const
;
};
};
using
AbstractTensorPtr
=
std
::
shared_ptr
<
AbstractTensor
>
;
using
AbstractTensorPtr
=
std
::
shared_ptr
<
AbstractTensor
>
;
using
AbstractTensorPtrList
=
std
::
vector
<
AbstractTensorPtr
>
;
using
AbstractTensorPtrList
=
std
::
vector
<
AbstractTensorPtr
>
;
...
@@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase {
...
@@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase {
};
};
using
AbstractRefKeyPtr
=
std
::
shared_ptr
<
AbstractRefKey
>
;
using
AbstractRefKeyPtr
=
std
::
shared_ptr
<
AbstractRefKey
>
;
class
AbstractRef
:
public
Abstract
Base
{
class
AbstractRef
:
public
Abstract
Tensor
{
public:
public:
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
bool
need_cast
=
false
,
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractTensorPtr
&
ref_value
);
TypePtr
cast_target
=
nullptr
);
~
AbstractRef
()
override
=
default
;
~
AbstractRef
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractRef
,
Abstract
Base
)
MS_DECLARE_PARENT
(
AbstractRef
,
Abstract
Tensor
)
TypePtr
BuildType
()
const
override
;
TypePtr
BuildType
()
const
override
;
BaseShapePtr
BuildShape
()
const
override
;
bool
operator
==
(
const
AbstractRef
&
other
)
const
;
bool
operator
==
(
const
AbstractRef
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
{
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Clone
(),
ref_
->
Clone
(),
need_cast_
,
target_type_
);
auto
abs_tensor
=
AbstractTensor
::
Clone
()
->
cast
<
AbstractTensorPtr
>
();
if
(
abs_tensor
==
nullptr
)
{
return
nullptr
;
}
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Clone
(),
abs_tensor
);
}
}
std
::
string
ToString
()
const
override
;
std
::
string
ToString
()
const
override
;
inline
Abstract
BasePtr
ref
()
const
{
return
ref_
;
}
inline
Abstract
TensorPtr
ref
()
{
return
shared_from_base
<
AbstractTensor
>
()
;
}
inline
AbstractBasePtr
ref_key
()
const
{
return
ref_key_
;
}
inline
AbstractBasePtr
ref_key
()
const
{
return
ref_key_
;
}
inline
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
inline
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
inline
TypePtr
target_type
()
const
{
return
target_type_
;
}
inline
bool
need_cast
()
const
{
return
need_cast_
;
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
// always broaden for ref
// always broaden for ref
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(
config
),
ref_
->
Broaden
(),
need_cast_
,
target_type_
);
auto
abs_tensor
=
AbstractTensor
::
Broaden
()
->
cast
<
AbstractTensorPtr
>
();
if
(
abs_tensor
==
nullptr
)
{
return
nullptr
;
}
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(
config
),
abs_tensor
);
}
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
std
::
size_t
hash
()
const
override
{
std
::
size_t
hash
()
const
override
{
return
ref_
->
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
// ref_key_->hash() ^
return
AbstractTensor
::
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
// ref_key_->hash() ^
}
}
private:
private:
AbstractBasePtr
ref_key_
;
AbstractBasePtr
ref_key_
;
AbstractBasePtr
ref_
;
// For mix presicion, only float type need to cast to float16 of float32
bool
need_cast_
;
TypePtr
target_type_
;
// cache for ref_key after build value, when value is null, return nullptr.
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr
ref_key_value_
;
RefKeyPtr
ref_key_value_
;
};
};
...
...
mindspore/core/abstract/prim_others.cc
浏览文件 @
95212b55
...
@@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
...
@@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG
(
EXCEPTION
)
<<
"make_ref evaluator requires 3 parameters, while the input size is "
<<
args_spec_list
.
size
()
MS_LOG
(
EXCEPTION
)
<<
"make_ref evaluator requires 3 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
<<
"."
;
}
}
TypePtr
type
=
args_spec_list
[
0
]
->
GetTypeTrack
();
auto
tensor
=
args_spec_list
[
1
]
->
cast
<
abstract
::
AbstractTensorPtr
>
();
ValuePtr
tensor_target_v
=
args_spec_list
[
2
]
->
BuildValue
();
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
tensor
);
if
(
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of make_ref should be a RefKey but a "
<<
type
->
ToString
();
}
auto
need_cast
=
!
tensor_target_v
->
isa
<
None
>
();
if
(
need_cast
&&
!
tensor_target_v
->
isa
<
Type
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Third input of make_ref should be a Type but a "
<<
tensor_target_v
->
ToString
();
}
TypePtr
cast_target
=
tensor_target_v
->
cast
<
TypePtr
>
();
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
need_cast
,
cast_target
);
}
}
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
...
...
mindspore/core/ir/anf.cc
浏览文件 @
95212b55
...
@@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const {
...
@@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const {
return
buffer
.
str
();
return
buffer
.
str
();
}
}
ParamInfoPtr
Parameter
::
param_info
()
const
{
if
(
!
has_default
())
{
return
nullptr
;
}
auto
tensor
=
default_param
()
->
cast
<
tensor
::
MetaTensorPtr
>
();
if
(
tensor
==
nullptr
||
!
tensor
->
is_parameter
())
{
return
nullptr
;
}
return
tensor
->
param_info
();
}
std
::
string
ValueNode
::
ToString
()
const
{
std
::
string
ValueNode
::
ToString
()
const
{
MS_EXCEPTION_IF_NULL
(
value_
);
MS_EXCEPTION_IF_NULL
(
value_
);
if
(
value_
->
isa
<
FuncGraph
>
())
{
if
(
value_
->
isa
<
FuncGraph
>
())
{
...
...
mindspore/core/ir/anf.h
浏览文件 @
95212b55
...
@@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>;
...
@@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>;
class
AnfIrVisitor
;
class
AnfIrVisitor
;
class
ParamInfo
;
class
ParamInfo
;
using
Param
Value
Ptr
=
std
::
shared_ptr
<
ParamInfo
>
;
using
Param
Info
Ptr
=
std
::
shared_ptr
<
ParamInfo
>
;
// AnfNode is the basic class of the IR definition derived from Base.
// AnfNode is the basic class of the IR definition derived from Base.
// Only two types of nodes are derived: CNode and ANode.
// Only two types of nodes are derived: CNode and ANode.
...
@@ -288,6 +288,7 @@ class Parameter : public ANode {
...
@@ -288,6 +288,7 @@ class Parameter : public ANode {
has_default_
=
true
;
has_default_
=
true
;
}
}
ValuePtr
default_param
()
const
{
return
default_param_
;
}
ValuePtr
default_param
()
const
{
return
default_param_
;
}
ParamInfoPtr
param_info
()
const
;
bool
operator
==
(
const
AnfNode
&
other
)
const
override
{
bool
operator
==
(
const
AnfNode
&
other
)
const
override
{
if
(
!
other
.
isa
<
Parameter
>
())
{
if
(
!
other
.
isa
<
Parameter
>
())
{
...
...
mindspore/core/ir/dtype.cc
浏览文件 @
95212b55
...
@@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const {
...
@@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const {
std
::
string
Slice
::
DumpText
()
const
{
return
ToString
();
}
std
::
string
Slice
::
DumpText
()
const
{
return
ToString
();
}
TypePtr
UndeterminedType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
return
std
::
make_shared
<
UndeterminedType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
UndeterminedType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
UndeterminedType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
UndeterminedType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
TensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
TensorType
>
();
}
return
std
::
make_shared
<
TensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
TensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"tensor"
;
}
return
"tensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
TensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
TensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor("
+
element_type_
->
DumpText
()
+
")"
;
}
bool
TensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
TensorType
&>
(
other
).
element_type_
;
// When element_type_ = nullptr, which means any type of Array.
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
RowTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
RowTensorType
>
();
}
return
std
::
make_shared
<
RowTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
RowTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
RowTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
RowTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
RowTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
RowTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
SparseTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
SparseTensorType
>
();
}
return
std
::
make_shared
<
SparseTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
SparseTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
SparseTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
SparseTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
Function
::
Function
()
:
Object
(
kObjectTypeFunction
)
{
Function
::
Function
()
:
Object
(
kObjectTypeFunction
)
{
args_
=
std
::
vector
<
TypePtr
>
();
args_
=
std
::
vector
<
TypePtr
>
();
retval_
=
nullptr
;
retval_
=
nullptr
;
...
@@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
...
@@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
os
<<
problem
->
ToString
();
os
<<
problem
->
ToString
();
return
os
;
return
os
;
}
}
const
TypePtr
kTensorTypeFP16
=
std
::
make_shared
<
TensorType
>
(
std
::
make_shared
<
Float
>
(
16
));
const
TypePtr
kTensorTypeFP32
=
std
::
make_shared
<
TensorType
>
(
std
::
make_shared
<
Float
>
(
32
));
}
// namespace mindspore
}
// namespace mindspore
mindspore/core/ir/dtype.h
浏览文件 @
95212b55
...
@@ -32,10 +32,11 @@
...
@@ -32,10 +32,11 @@
#include "ir/named.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
#include "ir/dtype/type.h"
#include "ir/dtype/ref.h"
#include "ir/dtype/number.h"
#include "ir/dtype/number.h"
#include "ir/dtype/container.h"
#include "ir/dtype/container.h"
#include "ir/dtype/empty.h"
#include "ir/dtype/empty.h"
#include "ir/dtype/tensor_type.h"
#include "ir/dtype/ref.h"
/* namespace to support intermediate representation definition */
/* namespace to support intermediate representation definition */
namespace
mindspore
{
namespace
mindspore
{
...
@@ -108,98 +109,6 @@ class Slice : public Object {
...
@@ -108,98 +109,6 @@ class Slice : public Object {
};
};
using
SlicePtr
=
std
::
shared_ptr
<
Slice
>
;
using
SlicePtr
=
std
::
shared_ptr
<
Slice
>
;
class
UndeterminedType
:
public
Object
{
public:
UndeterminedType
()
:
Object
(
kObjectTypeUndeterminedType
)
{}
explicit
UndeterminedType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeUndeterminedType
,
kMetaTypeObject
,
false
),
element_type_
(
ele
)
{}
~
UndeterminedType
()
override
=
default
;
MS_DECLARE_PARENT
(
UndeterminedType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeUndeterminedType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
protected:
TypePtr
element_type_
;
};
using
MetaTensorTypePtr
=
std
::
shared_ptr
<
UndeterminedType
>
;
class
TensorType
:
public
Object
{
public:
TensorType
()
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
TensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
TensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
TensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
TensorTypePtr
=
std
::
shared_ptr
<
TensorType
>
;
class
RowTensorType
:
public
Object
{
public:
RowTensorType
()
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
RowTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
RowTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
RowTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeRowTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
RowTensorTypePtr
=
std
::
shared_ptr
<
RowTensorType
>
;
class
SparseTensorType
:
public
Object
{
public:
SparseTensorType
()
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
SparseTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
SparseTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
SparseTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeSparseTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
SparseTensorTypePtr
=
std
::
shared_ptr
<
SparseTensorType
>
;
class
Function
:
public
Object
{
class
Function
:
public
Object
{
public:
public:
Function
();
Function
();
...
@@ -353,6 +262,9 @@ extern const TypePtr kDict;
...
@@ -353,6 +262,9 @@ extern const TypePtr kDict;
extern
const
TypePtr
kSlice
;
extern
const
TypePtr
kSlice
;
extern
const
TypePtr
kKeyword
;
extern
const
TypePtr
kKeyword
;
extern
const
TypePtr
kTensorType
;
extern
const
TypePtr
kTensorType
;
extern
const
TypePtr
kTensorTypeFP16
;
extern
const
TypePtr
kTensorTypeFP32
;
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_H_
#endif // MINDSPORE_CORE_IR_DTYPE_H_
mindspore/core/ir/dtype/number.h
浏览文件 @
95212b55
...
@@ -68,6 +68,8 @@ class Number : public Object {
...
@@ -68,6 +68,8 @@ class Number : public Object {
const
int
nbits_
;
const
int
nbits_
;
};
};
using
NumberPtr
=
std
::
shared_ptr
<
Number
>
;
// Bool
// Bool
class
Bool
:
public
Number
{
class
Bool
:
public
Number
{
public:
public:
...
...
mindspore/core/ir/dtype/ref.cc
浏览文件 @
95212b55
...
@@ -19,15 +19,15 @@
...
@@ -19,15 +19,15 @@
#include <cstdlib>
#include <cstdlib>
#include <algorithm>
#include <algorithm>
#include "utils/log_adapter.h"
#include "utils/log_adapter.h"
#include "ir/dtype/tensor_type.h"
namespace
mindspore
{
namespace
mindspore
{
TypePtr
RefType
::
DeepCopy
()
const
{
TypePtr
RefType
::
DeepCopy
()
const
{
if
(
IsGeneric
())
{
if
(
IsGeneric
())
{
return
std
::
make_shared
<
RefType
>
();
return
std
::
make_shared
<
RefType
>
();
}
else
{
}
else
{
auto
subtype
=
subtype_
->
DeepCopy
();
auto
subtype
=
TensorType
::
DeepCopy
()
->
cast
<
TensorTypePtr
>
();
auto
subtype_origin
=
subtype_origin_
->
DeepCopy
();
return
std
::
make_shared
<
RefType
>
(
subtype
);
return
std
::
make_shared
<
RefType
>
(
subtype
,
subtype_origin
);
}
}
}
}
...
@@ -39,7 +39,7 @@ std::string RefType::DumpText() const {
...
@@ -39,7 +39,7 @@ std::string RefType::DumpText() const {
buffer
<<
"Ref"
;
buffer
<<
"Ref"
;
}
else
{
}
else
{
buffer
<<
"Ref["
;
buffer
<<
"Ref["
;
buffer
<<
subtype_
->
DumpText
()
<<
"]"
;
buffer
<<
TensorType
::
DumpText
()
<<
"]"
;
}
}
return
buffer
.
str
();
return
buffer
.
str
();
}
}
...
...
mindspore/core/ir/dtype/ref.h
浏览文件 @
95212b55
...
@@ -17,21 +17,13 @@
...
@@ -17,21 +17,13 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
#define MINDSPORE_CORE_IR_DTYPE_REF_H_
#define MINDSPORE_CORE_IR_DTYPE_REF_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "base/base.h"
#include "ir/named.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
#include "ir/dtype/type.h"
#include "ir/dtype/tensor_type.h"
namespace
mindspore
{
namespace
mindspore
{
// TypeRefKey type
// TypeRefKey type
...
@@ -48,23 +40,16 @@ class RefKeyType : public Object {
...
@@ -48,23 +40,16 @@ class RefKeyType : public Object {
};
};
// TypeRef type
// TypeRef type
class
RefType
:
public
Object
{
class
RefType
:
public
TensorType
{
public:
public:
RefType
()
:
Object
(
kObjectTypeRef
)
{}
RefType
()
:
TensorType
()
{}
RefType
(
const
TypePtr
&
subtype
,
const
TypePtr
&
subtype_origin
)
explicit
RefType
(
const
TensorTypePtr
&
subtype
)
:
TensorType
(
subtype
->
element
())
{}
:
Object
(
kObjectTypeRef
,
false
),
subtype_
(
subtype
),
subtype_origin_
(
subtype_origin
)
{}
~
RefType
()
override
{}
~
RefType
()
override
{}
MS_DECLARE_PARENT
(
RefType
,
Object
)
MS_DECLARE_PARENT
(
RefType
,
TensorType
)
TypePtr
subtype
()
const
{
return
subtype_
;
}
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeRef
;
}
TypePtr
DeepCopy
()
const
override
;
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
DumpText
()
const
override
;
std
::
string
DumpText
()
const
override
;
private:
TypePtr
subtype_
;
TypePtr
subtype_origin_
;
};
};
using
RefTypePtr
=
std
::
shared_ptr
<
RefType
>
;
using
RefTypePtr
=
std
::
shared_ptr
<
RefType
>
;
...
...
mindspore/core/ir/dtype/tensor_type.cc
0 → 100644
浏览文件 @
95212b55
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype/tensor_type.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
namespace
mindspore
{
TypePtr
UndeterminedType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
return
std
::
make_shared
<
UndeterminedType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
UndeterminedType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
UndeterminedType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
UndeterminedType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
TensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
TensorType
>
();
}
return
std
::
make_shared
<
TensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
TensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"tensor"
;
}
return
"tensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
TensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
TensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor("
+
element_type_
->
DumpText
()
+
")"
;
}
bool
TensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
TensorType
&>
(
other
).
element_type_
;
// When element_type_ = nullptr, which means any type of Array.
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
RowTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
RowTensorType
>
();
}
return
std
::
make_shared
<
RowTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
RowTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
RowTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
RowTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
RowTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
RowTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
SparseTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
SparseTensorType
>
();
}
return
std
::
make_shared
<
SparseTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
SparseTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
SparseTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
SparseTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
}
// namespace mindspore
mindspore/core/ir/dtype/tensor_type.h
0 → 100644
浏览文件 @
95212b55
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
namespace
mindspore
{
class
UndeterminedType
:
public
Object
{
public:
UndeterminedType
()
:
Object
(
kObjectTypeUndeterminedType
)
{}
explicit
UndeterminedType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeUndeterminedType
,
kMetaTypeObject
,
false
),
element_type_
(
ele
)
{}
~
UndeterminedType
()
override
=
default
;
MS_DECLARE_PARENT
(
UndeterminedType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeUndeterminedType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
protected:
TypePtr
element_type_
;
};
using
MetaTensorTypePtr
=
std
::
shared_ptr
<
UndeterminedType
>
;
class
TensorType
:
public
Object
{
public:
TensorType
()
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
TensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
TensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
TensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
TensorTypePtr
=
std
::
shared_ptr
<
TensorType
>
;
class
RowTensorType
:
public
Object
{
public:
RowTensorType
()
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
RowTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
RowTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
RowTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeRowTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
RowTensorTypePtr
=
std
::
shared_ptr
<
RowTensorType
>
;
class
SparseTensorType
:
public
Object
{
public:
SparseTensorType
()
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
SparseTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
SparseTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
SparseTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeSparseTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
SparseTensorTypePtr
=
std
::
shared_ptr
<
SparseTensorType
>
;
}
// namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
mindspore/core/ir/func_graph.h
浏览文件 @
95212b55
...
@@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase {
...
@@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase {
const
std
::
vector
<
AnfNodePtr
>
&
paramter_obj_nodes
()
const
{
return
paramter_obj_nodes_
;
}
const
std
::
vector
<
AnfNodePtr
>
&
paramter_obj_nodes
()
const
{
return
paramter_obj_nodes_
;
}
void
add_parameter_obj_node
(
const
AnfNodePtr
&
p
);
void
add_parameter_obj_node
(
const
AnfNodePtr
&
p
);
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
&
make_ref_params
()
{
return
make_ref_params_
;
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
vector
<
BaseShapePtr
>
joined_shapes_
;
std
::
vector
<
BaseShapePtr
>
joined_shapes_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphTransform
>
transforms_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphTransform
>
transforms_
;
// parameter default value
// parameter default value
std
::
map
<
std
::
string
,
AnfNodePtr
>
parameter_default_value_
;
std
::
map
<
std
::
string
,
AnfNodePtr
>
parameter_default_value_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
make_ref_params_
;
size_t
seen_
;
size_t
seen_
;
std
::
list
<
CNodePtr
>
GetOrderedCnodes
();
std
::
list
<
CNodePtr
>
GetOrderedCnodes
();
...
...
mindspore/core/ir/meta_tensor.h
浏览文件 @
95212b55
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include <string>
#include <string>
#include "base/base.h"
#include "base/base.h"
#include "ir/param_info.h"
#include "ir/dtype.h"
#include "ir/dtype.h"
#include "utils/convert_utils_base.h"
#include "utils/convert_utils_base.h"
#include "utils/hashing.h"
#include "utils/hashing.h"
...
@@ -163,6 +164,15 @@ class MetaTensor : public Value {
...
@@ -163,6 +164,15 @@ class MetaTensor : public Value {
return
false
;
return
false
;
}
}
}
}
// Get tensor's param_info info.
ParamInfoPtr
param_info
()
const
{
return
param_info_
;
}
bool
is_parameter
()
const
{
return
is_parameter_
;
}
// Set tensor's param_info info.
void
set_param_info
(
const
ParamInfoPtr
&
param_info
)
{
is_parameter_
=
true
;
param_info_
=
param_info
;
}
protected:
protected:
// brief Data type of the tensor.
// brief Data type of the tensor.
...
@@ -184,6 +194,9 @@ class MetaTensor : public Value {
...
@@ -184,6 +194,9 @@ class MetaTensor : public Value {
//
//
// Includes the format and data type of a tensor on device.
// Includes the format and data type of a tensor on device.
DeviceInfo
device_info_
;
DeviceInfo
device_info_
;
bool
is_parameter_
{
false
};
ParamInfoPtr
param_info_
{
nullptr
};
};
};
using
MetaTensorPtr
=
std
::
shared_ptr
<
MetaTensor
>
;
using
MetaTensorPtr
=
std
::
shared_ptr
<
MetaTensor
>
;
...
...
mindspore/core/ir/meta_tensor_extends.cc
浏览文件 @
95212b55
...
@@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
...
@@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
}
}
auto
tensor_shape
=
tens
->
shape
();
auto
tensor_shape
=
tens
->
shape
();
auto
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
dtype
,
tensor_shape
);
auto
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
dtype
,
tensor_shape
);
abs_tensor
->
set_value
(
shared_from_base
<
MetaTensor
>
());
// if is parameter always no value.
if
(
is_parameter
())
{
auto
param_name
=
param_info
()
->
name
();
auto
ref_key
=
std
::
make_shared
<
RefKey
>
(
param_name
);
auto
abs_ref_key
=
ref_key
->
ToAbstract
();
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractRef
>
(
abs_ref_key
,
abs_tensor
);
}
else
{
abs_tensor
->
set_value
(
shared_from_base
<
MetaTensor
>
());
}
return
abs_tensor
;
return
abs_tensor
;
}
}
...
...
mindspore/core/ir/named.h
浏览文件 @
95212b55
...
@@ -62,6 +62,21 @@ class Named : public Value {
...
@@ -62,6 +62,21 @@ class Named : public Value {
};
};
using
NamedPtr
=
std
::
shared_ptr
<
Named
>
;
using
NamedPtr
=
std
::
shared_ptr
<
Named
>
;
struct
NamedHasher
{
std
::
size_t
operator
()(
NamedPtr
const
&
name
)
const
{
std
::
size_t
hash
=
name
->
Hash
();
return
hash
;
}
};
struct
NamedEqual
{
bool
operator
()(
NamedPtr
const
&
t1
,
NamedPtr
const
&
t2
)
const
{
MS_EXCEPTION_IF_NULL
(
t1
);
MS_EXCEPTION_IF_NULL
(
t2
);
return
*
t1
==
*
t2
;
}
};
class
None
:
public
Named
{
class
None
:
public
Named
{
public:
public:
None
()
:
Named
(
"None"
)
{}
None
()
:
Named
(
"None"
)
{}
...
...
mindspore/core/ir/param_info.h
浏览文件 @
95212b55
...
@@ -21,10 +21,13 @@
...
@@ -21,10 +21,13 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "ir/anf.h"
#include "ir/
tensor
.h"
#include "ir/
dtype
.h"
namespace
mindspore
{
namespace
mindspore
{
class
ParamInfo
;
using
ParamInfoPtr
=
std
::
shared_ptr
<
ParamInfo
>
;
class
ParamInfo
{
class
ParamInfo
{
public:
public:
ParamInfo
()
{}
ParamInfo
()
{}
...
@@ -55,7 +58,7 @@ class ParamInfo {
...
@@ -55,7 +58,7 @@ class ParamInfo {
int32_t
cloned_index
()
const
{
return
cloned_index_
;
}
int32_t
cloned_index
()
const
{
return
cloned_index_
;
}
// Make a cloned parameter and update clone info.
// Make a cloned parameter and update clone info.
Param
Value
Ptr
Clone
()
{
Param
Info
Ptr
Clone
()
{
static
std
::
atomic
<
int32_t
>
parameter_cloned_index
{
1
};
static
std
::
atomic
<
int32_t
>
parameter_cloned_index
{
1
};
int32_t
index
=
parameter_cloned_index
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
int32_t
index
=
parameter_cloned_index
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
auto
clone
=
std
::
make_shared
<
ParamInfo
>
(
*
this
);
auto
clone
=
std
::
make_shared
<
ParamInfo
>
(
*
this
);
...
...
mindspore/core/ir/tensor.cc
浏览文件 @
95212b55
...
@@ -467,6 +467,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
...
@@ -467,6 +467,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
}
}
return
*
this
;
return
*
this
;
}
}
abstract
::
AbstractBasePtr
Tensor
::
ToAbstract
()
{
abstract
::
AbstractBasePtr
Tensor
::
ToAbstract
()
{
auto
tens
=
shared_from_base
<
Tensor
>
();
auto
tens
=
shared_from_base
<
Tensor
>
();
auto
dtype
=
tens
->
Dtype
();
auto
dtype
=
tens
->
Dtype
();
...
@@ -475,7 +476,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
...
@@ -475,7 +476,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
}
}
auto
tensor_shape
=
tens
->
shape
();
auto
tensor_shape
=
tens
->
shape
();
auto
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
dtype
,
tensor_shape
);
auto
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
dtype
,
tensor_shape
);
abs_tensor
->
set_value
(
shared_from_base
<
Tensor
>
());
// if is parameter always no value.
if
(
is_parameter
())
{
auto
param_name
=
param_info
()
->
name
();
auto
ref_key
=
std
::
make_shared
<
RefKey
>
(
param_name
);
auto
abs_ref_key
=
ref_key
->
ToAbstract
();
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractRef
>
(
abs_ref_key
,
abs_tensor
);
}
else
{
abs_tensor
->
set_value
(
shared_from_base
<
Tensor
>
());
}
return
abs_tensor
;
return
abs_tensor
;
}
}
...
...
mindspore/core/ir/value.cc
浏览文件 @
95212b55
...
@@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const {
...
@@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const {
}
}
bool
StringImm
::
operator
==
(
const
StringImm
&
other
)
const
{
return
str_
==
other
.
str_
;
}
bool
StringImm
::
operator
==
(
const
StringImm
&
other
)
const
{
return
str_
==
other
.
str_
;
}
bool
RefKey
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
RefKey
>
())
{
auto
other_
=
static_cast
<
const
RefKey
&>
(
other
);
return
*
this
==
other_
;
}
else
{
return
false
;
}
}
bool
RefKey
::
operator
==
(
const
RefKey
&
other
)
const
{
return
tag_
==
other
.
tag_
;
}
bool
AnyValue
::
operator
==
(
const
Value
&
other
)
const
{
bool
AnyValue
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
AnyValue
>
())
{
if
(
other
.
isa
<
AnyValue
>
())
{
return
true
;
return
true
;
...
...
mindspore/core/ir/value.h
浏览文件 @
95212b55
...
@@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>;
...
@@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>;
IMM_TRAITS
(
StringImmPtr
,
std
::
string
)
IMM_TRAITS
(
StringImmPtr
,
std
::
string
)
IMM_TRAITS
(
StringImmPtr
,
const
char
*
)
IMM_TRAITS
(
StringImmPtr
,
const
char
*
)
class
RefKey
:
public
Value
{
class
RefKey
:
public
Named
{
public:
public:
explicit
RefKey
(
const
std
::
string
&
tag
)
:
Value
(
kRefKeyType
),
tag_
(
tag
),
hash_
(
std
::
hash
<
std
::
string
>
{}(
tag
)
)
{}
explicit
RefKey
(
const
std
::
string
&
tag
)
:
Named
(
tag
)
{}
~
RefKey
()
override
=
default
;
~
RefKey
()
override
=
default
;
MS_DECLARE_PARENT
(
RefKey
,
Value
)
MS_DECLARE_PARENT
(
RefKey
,
Named
)
std
::
size_t
hash
()
const
override
{
return
hash_
;
}
const
std
::
string
&
tag
()
const
{
return
name
();
}
const
std
::
string
&
tag
()
const
{
return
tag_
;
}
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
RefKey
&
other
)
const
;
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
std
::
string
ToString
()
const
override
{
return
"RefKey["
+
tag_
+
"]"
;
}
std
::
string
ToString
()
const
override
{
return
"RefKey["
+
name
()
+
"]"
;
}
std
::
string
DumpText
()
const
override
{
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
oss
<<
"RefKey[
\"
"
<<
tag_
<<
"
\"
]"
;
oss
<<
"RefKey[
\"
"
<<
name
()
<<
"
\"
]"
;
return
oss
.
str
();
return
oss
.
str
();
}
}
private:
std
::
string
tag_
;
std
::
size_t
hash_
=
0
;
};
};
using
RefKeyPtr
=
std
::
shared_ptr
<
RefKey
>
;
using
RefKeyPtr
=
std
::
shared_ptr
<
RefKey
>
;
...
...
mindspore/lite/test/CMakeLists.txt
浏览文件 @
95212b55
...
@@ -43,6 +43,8 @@ if(BUILD_CONVERTER)
...
@@ -43,6 +43,8 @@ if(BUILD_CONVERTER)
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/scope.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/scope.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/value_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/value_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/ref.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/tensor_type.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/container.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/container.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/empty.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/empty.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/number.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/number.cc
...
...
mindspore/lite/tools/converter/CMakeLists.txt
浏览文件 @
95212b55
...
@@ -29,6 +29,8 @@ set(ANF_SRC
...
@@ -29,6 +29,8 @@ set(ANF_SRC
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/scope.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/scope.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/value_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/value_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/ref.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/tensor_type.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/container.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/container.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/empty.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/empty.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/number.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/number.cc
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
95212b55
...
@@ -23,7 +23,7 @@ from ...common import dtype as mstype
...
@@ -23,7 +23,7 @@ from ...common import dtype as mstype
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
class
Assign
(
Primitive
WithInfer
):
class
Assign
(
Primitive
):
"""
"""
Assign `Parameter` with a value.
Assign `Parameter` with a value.
...
...
mindspore/ops/primitive.py
浏览文件 @
95212b55
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
import
inspect
import
inspect
import
copy
import
copy
from
mindspore.common.api
import
_wrap_func
from
mindspore.common.api
import
_wrap_func
from
mindspore.common
import
Parameter
from
mindspore.common._register_for_tensor
import
tensor_operator_registry
from
mindspore.common._register_for_tensor
import
tensor_operator_registry
from
mindspore
import
context
from
mindspore
import
context
from
.._c_expression
import
Primitive_
,
real_run_op
,
prim_type
from
.._c_expression
import
Primitive_
,
real_run_op
,
prim_type
...
@@ -410,16 +409,12 @@ def _run_op(obj, op_name, args):
...
@@ -410,16 +409,12 @@ def _run_op(obj, op_name, args):
if
op_name
==
"Cast"
or
obj
.
update_parameter
:
if
op_name
==
"Cast"
or
obj
.
update_parameter
:
cast_args
=
args
cast_args
=
args
else
:
else
:
cast_args
=
list
()
cast_args
=
args
for
arg
in
args
:
for
idx
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
Parameter
):
cast_type
=
getattr
(
arg
,
"cast_type"
,
None
)
if
arg
.
cast_type
:
if
cast_type
:
cast_args
.
append
(
cast
(
arg
,
arg
.
cast_type
))
cast_args
[
idx
]
=
cast
(
arg
,
cast_type
)
else
:
output
=
real_run_op
(
obj
,
op_name
,
cast_args
)
cast_args
.
append
(
arg
)
else
:
cast_args
.
append
(
arg
)
output
=
real_run_op
(
obj
,
op_name
,
tuple
(
cast_args
))
if
not
output
:
if
not
output
:
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
if
len
(
output
)
==
1
:
if
len
(
output
)
==
1
:
...
...
tests/st/control/test_ascend_control_sink.py
浏览文件 @
95212b55
...
@@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell):
...
@@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell):
self
.
var
=
Parameter
(
initializer
(
1
,
(
1
),
mstype
.
float32
),
name
=
"var"
)
self
.
var
=
Parameter
(
initializer
(
1
,
(
1
),
mstype
.
float32
),
name
=
"var"
)
def
construct
(
self
,
x
,
y
,
z
,
c2
,
c4
):
def
construct
(
self
,
x
,
y
,
z
,
c2
,
c4
):
out
=
self
.
assign
(
self
.
var
,
c4
)
out
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
x
<
c2
:
while
x
<
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
y
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
y
<
c2
and
x
<
c2
:
while
y
<
c2
and
x
<
c2
:
if
2
*
y
<
c2
:
if
2
*
y
<
c2
:
y
=
y
+
2
y
=
y
+
2
else
:
else
:
y
=
y
+
1
y
=
y
+
1
out
=
out
+
y
out
=
out
+
y
z
=
self
.
assign
(
self
.
var
,
c4
)
z
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
while
z
<
c2
:
z
=
z
+
1
z
=
z
+
1
out
=
out
+
z
out
=
out
+
z
x
=
x
+
1
x
=
x
+
1
out
=
out
+
x
out
=
out
+
x
while
x
<
2
*
c2
:
while
x
<
2
*
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
y
=
c4
self
.
assign
(
self
.
var
,
c4
)
x
=
x
+
1
x
=
x
+
1
while
y
<
c2
:
while
y
<
c2
:
z
=
self
.
assign
(
self
.
var
,
c4
)
z
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
while
z
<
c2
:
z
=
z
+
1
z
=
z
+
1
if
x
<
c2
:
if
x
<
c2
:
...
...
tests/ut/python/pipeline/parse/test_parse.py
浏览文件 @
95212b55
...
@@ -27,6 +27,7 @@ import mindspore.nn as nn
...
@@ -27,6 +27,7 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
ms_function
,
_executor
from
mindspore.common.api
import
ms_function
,
_executor
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
...
@@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape():
...
@@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape():
net
=
BpropWithWrongOutputShapeCell
()
net
=
BpropWithWrongOutputShapeCell
()
net
.
set_grad
()
net
.
set_grad
()
grad_all
(
net
)(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
grad_all
(
net
)(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
class
AssignWhenInsertGrad
(
nn
.
Cell
):
""" NetWithNDarray definition """
def
__init__
(
self
):
super
(
AssignWhenInsertGrad
,
self
).
__init__
()
self
.
gather
=
P
.
GatherV2
()
self
.
damping
=
Tensor
(
np
.
array
([
0.03
,
0.03
]).
astype
(
np
.
float32
))
self
.
cov_step
=
ms
.
Parameter
(
0
,
name
=
"cov_step"
,
requires_grad
=
False
)
self
.
freq
=
Tensor
(
278
,
ms
.
int32
)
self
.
getG
=
P
.
InsertGradientOf
(
self
.
save_gradient
)
def
save_gradient
(
self
,
dout
):
self
.
cov_step
=
self
.
cov_step
+
self
.
freq
return
dout
def
construct
(
self
,
x
):
self
.
gather
(
self
.
damping
,
self
.
cov_step
,
0
)
out
=
P
.
ReLU
()(
x
)
out
=
self
.
getG
(
out
)
return
out
grad_all
=
C
.
GradOperation
(
'get_all'
,
get_all
=
True
)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
*
inputs
):
out
=
self
.
net
(
*
inputs
)
return
out
,
grad_all
(
self
.
net
)(
*
inputs
)
def
test_assign_in_insert_grad
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
AssignWhenInsertGrad
().
to_float
(
ms
.
float16
)
input_data
=
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]]).
astype
(
'float32'
)
net_back
=
GradNet
(
net
)
net_back
(
ms
.
Tensor
(
input_data
))
class
Assign
(
nn
.
Cell
):
""" NetWithNDarray definition """
def
__init__
(
self
):
super
(
Assign
,
self
).
__init__
()
self
.
cov_step
=
ms
.
Parameter
(
0.0
,
name
=
"cov_step"
,
requires_grad
=
False
)
def
construct
(
self
,
x
):
self
.
cov_step
=
self
.
cov_step
+
x
return
self
.
cov_step
def
test_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
Assign
()
input_data
=
ms
.
Tensor
(
np
.
array
(
1
).
astype
(
np
.
int32
))
net_back
=
GradNet
(
net
)
net_back
(
input_data
)
tests/ut/python/pipeline/parse/test_while_param.py
0 → 100644
浏览文件 @
95212b55
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_cont_break """
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
Tensor
,
context
,
nn
,
ms_function
from
mindspore.nn
import
Cell
from
mindspore.ops
import
operations
as
P
class
WhileSubGraphParam
(
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
update
=
ms
.
Parameter
(
Tensor
(
1
,
ms
.
float32
),
"update"
)
def
construct
(
self
,
x
,
y
,
z
):
out1
=
z
while
x
<
y
:
self
.
update
=
self
.
update
+
1
out1
=
out1
+
1
x
=
x
+
1
return
out1
,
self
.
update
def
test_while_loop_phi
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
0
,
ms
.
float32
)
y
=
Tensor
(
10
,
ms
.
float32
)
z
=
Tensor
(
100
,
ms
.
float32
)
net
=
WhileSubGraphParam
()
net
(
x
,
y
,
z
)
class
WhileSubGraphParam2
(
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
update
=
ms
.
Parameter
(
Tensor
(
1
,
ms
.
float32
),
"update"
)
def
construct
(
self
,
x
,
y
,
z
):
out1
=
z
i
=
self
.
update
while
x
<
y
:
i
=
i
+
1
out1
=
out1
+
1
x
=
x
+
1
return
out1
,
self
.
update
def
test_while_loop_phi_2
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
0
,
ms
.
float32
)
y
=
Tensor
(
10
,
ms
.
float32
)
z
=
Tensor
(
100
,
ms
.
float32
)
net
=
WhileSubGraphParam2
()
net
(
x
,
y
,
z
)
class
WhileSubGraphParam3
(
Cell
):
def
__init__
(
self
,
initial_input_x
):
super
().
__init__
()
self
.
initial_input_x
=
initial_input_x
self
.
X
=
ms
.
Parameter
(
initial_input_x
,
name
=
"parameter_x"
)
self
.
Y
=
ms
.
Parameter
(
self
.
initial_input_x
,
name
=
"parameter_y"
)
def
construct
(
self
):
a
=
0
while
a
<
3
:
self
.
X
=
self
.
X
+
self
.
Y
a
+=
1
return
self
.
X
def
test_while_loop_phi_3
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
0
,
ms
.
float32
)
net
=
WhileSubGraphParam3
(
x
)
net
()
class
ControlMixedWhileIf
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
assign
=
P
.
Assign
()
self
.
var
=
ms
.
Parameter
(
ms
.
Tensor
([
1
],
ms
.
float32
),
name
=
"var"
)
@
ms_function
def
construct
(
self
,
x
,
y
,
z
,
c2
,
c4
):
out
=
self
.
assign
(
self
.
var
,
c4
)
while
x
<
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
while
y
<
c2
and
x
<
c2
:
if
2
*
y
<
c2
:
y
=
y
+
2
else
:
y
=
y
+
1
out
=
out
+
y
z
=
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
z
=
z
+
1
out
=
out
+
z
x
=
x
+
1
out
=
out
+
x
while
x
<
2
*
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
x
=
x
+
1
while
y
<
c2
:
z
=
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
z
=
z
+
1
if
x
<
c2
:
y
=
y
-
1
else
:
y
=
y
+
1
out
=
out
+
z
out
=
out
+
y
out
=
out
+
x
return
out
def
test_mixed_while_if
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
x
=
np
.
array
(
2
).
astype
(
np
.
int32
)
y
=
np
.
array
(
14
).
astype
(
np
.
int32
)
z
=
np
.
array
(
1
).
astype
(
np
.
int32
)
c2
=
Tensor
([
14
],
ms
.
int32
)
c4
=
Tensor
([
0
],
ms
.
int32
)
net
=
ControlMixedWhileIf
()
output
=
net
(
Tensor
(
x
),
Tensor
(
y
),
Tensor
(
z
),
c2
,
c4
)
expect
=
np
.
array
(
3318
).
astype
(
np
.
int32
)
assert
np
.
allclose
(
expect
,
output
.
asnumpy
(),
0.0001
,
0.0001
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
tests/vm_impl/array_ops_vm_impl.py
浏览文件 @
95212b55
...
@@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
...
@@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from
.vm_interface
import
vm
from
.vm_interface
import
vm
# pylint: disable=unused-argument
# pylint: disable=unused-argument
@
vm_impl_getters
.
register
(
P
.
Assign
)
def
vm_impl_assign
(
self
):
"""Generate vm_impl function for Assign"""
def
vm_impl
(
x
,
value
):
x
.
assign_value
(
value
)
return
x
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
ExpandDims
)
@
vm_impl_getters
.
register
(
P
.
ExpandDims
)
def
vm_impl_expand_dims
(
self
):
def
vm_impl_expand_dims
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录