Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
24a10225
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
24a10225
编写于
8月 24, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change base class of ref to tensor in cpp
上级
01aa8338
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
813 addition
and
721 deletion
+813
-721
mindspore/_checkparam.py
mindspore/_checkparam.py
+15
-173
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
+0
-5
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
+42
-50
mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc
.../ccsrc/frontend/operator/composite/multitype_funcgraph.cc
+1
-7
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
...spore/ccsrc/frontend/operator/ops_front_infer_function.cc
+12
-0
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
+2
-2
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+3
-8
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+8
-7
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
+0
-6
mindspore/ccsrc/pipeline/jit/parse/function_block.cc
mindspore/ccsrc/pipeline/jit/parse/function_block.cc
+10
-12
mindspore/ccsrc/pipeline/jit/parse/function_block.h
mindspore/ccsrc/pipeline/jit/parse/function_block.h
+0
-3
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+2
-3
mindspore/ccsrc/pipeline/jit/parse/parse.h
mindspore/ccsrc/pipeline/jit/parse/parse.h
+1
-1
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
+4
-16
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+14
-11
mindspore/ccsrc/pybind_api/ir/dtype_py.cc
mindspore/ccsrc/pybind_api/ir/dtype_py.cc
+1
-1
mindspore/ccsrc/pybind_api/ir/param_info_py.cc
mindspore/ccsrc/pybind_api/ir/param_info_py.cc
+2
-2
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
+1
-0
mindspore/common/parameter.py
mindspore/common/parameter.py
+12
-12
mindspore/core/abstract/abstract_value.cc
mindspore/core/abstract/abstract_value.cc
+12
-36
mindspore/core/abstract/abstract_value.h
mindspore/core/abstract/abstract_value.h
+19
-18
mindspore/core/abstract/prim_others.cc
mindspore/core/abstract/prim_others.cc
+2
-11
mindspore/core/ir/anf.cc
mindspore/core/ir/anf.cc
+11
-0
mindspore/core/ir/anf.h
mindspore/core/ir/anf.h
+2
-1
mindspore/core/ir/dtype.cc
mindspore/core/ir/dtype.cc
+4
-169
mindspore/core/ir/dtype.h
mindspore/core/ir/dtype.h
+5
-93
mindspore/core/ir/dtype/number.h
mindspore/core/ir/dtype/number.h
+2
-0
mindspore/core/ir/dtype/ref.cc
mindspore/core/ir/dtype/ref.cc
+4
-4
mindspore/core/ir/dtype/ref.h
mindspore/core/ir/dtype/ref.h
+6
-21
mindspore/core/ir/dtype/tensor_type.cc
mindspore/core/ir/dtype/tensor_type.cc
+194
-0
mindspore/core/ir/dtype/tensor_type.h
mindspore/core/ir/dtype/tensor_type.h
+132
-0
mindspore/core/ir/func_graph.h
mindspore/core/ir/func_graph.h
+0
-3
mindspore/core/ir/meta_tensor.h
mindspore/core/ir/meta_tensor.h
+13
-0
mindspore/core/ir/meta_tensor_extends.cc
mindspore/core/ir/meta_tensor_extends.cc
+10
-1
mindspore/core/ir/named.h
mindspore/core/ir/named.h
+15
-0
mindspore/core/ir/param_info.h
mindspore/core/ir/param_info.h
+6
-3
mindspore/core/ir/tensor.cc
mindspore/core/ir/tensor.cc
+10
-1
mindspore/core/ir/value.cc
mindspore/core/ir/value.cc
+0
-10
mindspore/core/ir/value.h
mindspore/core/ir/value.h
+6
-13
mindspore/lite/test/CMakeLists.txt
mindspore/lite/test/CMakeLists.txt
+2
-0
mindspore/lite/tools/converter/CMakeLists.txt
mindspore/lite/tools/converter/CMakeLists.txt
+2
-0
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+1
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+6
-11
tests/st/control/test_ascend_control_sink.py
tests/st/control/test_ascend_control_sink.py
+10
-5
tests/ut/python/pipeline/parse/test_parse.py
tests/ut/python/pipeline/parse/test_parse.py
+58
-0
tests/ut/python/pipeline/parse/test_while_param.py
tests/ut/python/pipeline/parse/test_while_param.py
+144
-0
tests/vm_impl/array_ops_vm_impl.py
tests/vm_impl/array_ops_vm_impl.py
+7
-1
未找到文件。
mindspore/_checkparam.py
浏览文件 @
24a10225
...
...
@@ -185,14 +185,23 @@ class Validator:
raise
TypeError
(
f
"
{
msg_prefix
}
`
{
arg_name
}
` must be float."
)
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
prim_name
):
def
check_subclass
(
arg_name
,
type_
,
template_type
s
,
prim_name
):
"""Checks whether some type is subclass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
if
not
isinstance
(
template_types
,
Iterable
):
template_types
=
(
template_types
,)
hit
=
False
for
template_type
in
template_types
:
if
isinstance
(
template_type
,
mstype
.
Type
):
if
mstype
.
issubclass_
(
type_
,
template_type
):
hit
=
True
break
elif
type_
is
template_type
:
hit
=
True
break
if
not
hit
:
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
the type of `
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
s
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
,
prim_name
):
...
...
@@ -206,13 +215,7 @@ class Validator:
def
_check_tensor_type
(
arg
):
arg_key
,
arg_val
=
arg
elem_type
=
arg_val
if
not
elem_type
in
valid_values
:
type_names
=
[]
for
t
in
valid_values
:
type_names
.
append
(
str
(
t
))
types_info
=
'['
+
', '
.
join
(
type_names
)
+
']'
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg_key
}
` should be in
{
types_info
}
,'
f
' but got
{
elem_type
}
.'
)
Validator
.
check_subclass
(
arg_key
,
elem_type
,
valid_values
,
prim_name
)
return
(
arg_key
,
elem_type
)
def
_check_types_same
(
arg1
,
arg2
):
...
...
@@ -335,12 +338,6 @@ class Validator:
class
ParamValidator
:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@
staticmethod
def
equal
(
arg_name
,
arg_value
,
cond_str
,
cond
):
"""Judging valid value."""
if
not
cond
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be
{
cond_str
}
, but got
{
arg_value
}
.'
)
@
staticmethod
def
check
(
arg_name
,
arg_value
,
value_name
,
value
,
rel
=
Rel
.
EQ
):
"""This method is only used for check int values, since when compare float values,
...
...
@@ -360,27 +357,6 @@ class ParamValidator:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_shape_length
(
arg_name
,
arg_value
,
value
,
rel
):
"""Shape length judgment."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
value
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The length of `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
'
)
return
arg_value
@
staticmethod
def
check_int_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_mismatch
=
not
isinstance
(
arg_value
,
int
)
if
type_mismatch
or
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_isinstance
(
arg_name
,
arg_value
,
classes
):
"""Check arg isinstance of classes"""
...
...
@@ -388,33 +364,6 @@ class ParamValidator:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isinstance of
{
classes
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_number_range
(
arg_name
,
arg_value
,
lower_limit
,
upper_limit
,
rel
):
"""Is it necessary to consider error when comparing float values."""
rel_fn
=
Rel
.
get_fns
(
rel
)
if
not
rel_fn
(
arg_value
,
lower_limit
,
upper_limit
):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
lower_limit
,
upper_limit
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be in range
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
with_type_of
=
True
):
"""Check whether some type is subclass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
type_str
=
(
type
(
type_
).
__name__
if
isinstance
(
type_
,
(
tuple
,
list
))
else
""
)
+
str
(
type_
)
raise
TypeError
(
f
'The
{
"type of"
if
with_type_of
else
""
}
`
{
arg_name
}
` should be subclass'
f
' of
{
","
.
join
((
str
(
x
)
for
x
in
template_type
))
}
, but got
{
type_str
}
.'
)
@
staticmethod
def
check_args_tensor
(
args
):
"""Check whether args are all tensor."""
if
not
isinstance
(
args
,
dict
):
raise
TypeError
(
"The args should be a dict."
)
for
arg
,
value
in
args
.
items
():
ParamValidator
.
check_subclass
(
arg
,
value
,
mstype
.
tensor
)
@
staticmethod
def
check_bool
(
arg_name
,
arg_value
):
"""Check arg isinstance of bool"""
...
...
@@ -442,113 +391,6 @@ class ParamValidator:
return
arg_value
raise_error_msg
()
@
staticmethod
def
check_typename
(
arg_name
,
arg_type
,
valid_types
):
"""Does it contain the _name_ attribute."""
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
if
isinstance
(
arg_type
,
type
(
mstype
.
tensor
)):
arg_type
=
arg_type
.
element_type
()
if
arg_type
in
valid_types
:
return
arg_type
type_names
=
[
get_typename
(
t
)
for
t
in
valid_types
]
if
len
(
valid_types
)
==
1
:
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be
{
type_names
[
0
]
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
raise
ValueError
(
f
'The type of `
{
arg_name
}
` should be one of
{
type_names
}
,'
f
' but got
{
get_typename
(
arg_type
)
}
.'
)
@
staticmethod
def
check_string
(
arg_name
,
arg_value
,
valid_values
):
"""String type judgment."""
if
isinstance
(
arg_value
,
str
)
and
arg_value
in
valid_values
:
return
arg_value
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be str and must be one of
{
valid_values
}
,'
f
' but got
{
arg_value
}
.'
)
@
staticmethod
def
check_type_same
(
args
,
valid_values
):
"""Determine whether the types are the same."""
name
=
list
(
args
.
keys
())[
0
]
value
=
list
(
args
.
values
())[
0
]
if
isinstance
(
value
,
type
(
mstype
.
tensor
)):
value
=
value
.
element_type
()
for
arg_name
,
arg_value
in
args
.
items
():
if
isinstance
(
arg_value
,
type
(
mstype
.
tensor
)):
arg_value
=
arg_value
.
element_type
()
if
arg_value
not
in
valid_values
:
raise
TypeError
(
f
'The `
{
arg_name
}
` should be in
{
valid_values
}
,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
.'
)
if
arg_value
!=
value
:
raise
TypeError
(
f
'`
{
arg_name
}
` should be same as `
{
name
}
`,'
f
' but `
{
arg_name
}
` is
{
arg_value
}
, `
{
name
}
` is
{
value
}
.'
)
@
staticmethod
def
check_two_types_same
(
arg1_name
,
arg1_type
,
arg2_name
,
arg2_type
):
"""Determine whether the types of two variables are the same."""
if
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'The type of `
{
arg1_name
}
` and `
{
arg2_name
}
` should be same.'
)
@
staticmethod
def
check_value_on_integer
(
arg_name
,
arg_value
,
value
,
rel
):
"""Judging integer type."""
rel_fn
=
Rel
.
get_fns
(
rel
)
type_match
=
isinstance
(
arg_value
,
int
)
if
type_match
and
(
not
rel_fn
(
arg_value
,
value
)):
rel_str
=
Rel
.
get_strs
(
rel
).
format
(
value
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be an int and must
{
rel_str
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
def
check_param_equal
(
param1_name
,
param1_value
,
param2_name
,
param2_value
):
"""Judging the equality of parameters."""
if
param1_value
!=
param2_value
:
raise
ValueError
(
f
"`
{
param1_name
}
` must equal `
{
param2_name
}
`,"
f
" but got `
{
param1_name
}
` =
{
param1_value
}
,"
f
" `
{
param2_name
}
` =
{
param2_value
}
."
)
@
staticmethod
def
check_const_input
(
arg_name
,
arg_value
):
"""Check valid value."""
if
arg_value
is
None
:
raise
ValueError
(
f
'The `
{
arg_name
}
` must be a const input, but got
{
arg_value
}
.'
)
@
staticmethod
def
check_float_positive
(
arg_name
,
arg_value
):
"""Float type judgment."""
if
isinstance
(
arg_value
,
float
):
if
arg_value
>
0
:
return
arg_value
raise
ValueError
(
f
"The `
{
arg_name
}
` must be positive, but got
{
arg_value
}
."
)
raise
TypeError
(
f
"`
{
arg_name
}
` must be float!"
)
@
staticmethod
def
check_pad_value_by_mode
(
op_name
,
pad_mode
,
padding
):
"""Validate value of padding according to pad_mode"""
if
pad_mode
!=
'pad'
and
padding
!=
0
:
raise
ValueError
(
f
"For op '
{
op_name
}
', padding must be zero when pad_mode is '
{
pad_mode
}
'."
)
return
padding
@
staticmethod
def
check_empty_shape_input
(
arg_name
,
arg_value
):
"""Check zeros value."""
if
0
in
arg_value
:
raise
ValueError
(
f
"Input `
{
arg_name
}
` cannot be empty."
)
@
staticmethod
def
check_scalar_shape_input
(
arg_name
,
arg_value
):
"""Check scalar shape input."""
if
arg_value
!=
[]:
raise
ValueError
(
f
"Input `
{
arg_name
}
` shape should be (). got
{
arg_value
}
"
)
def
check_int
(
input_param
):
"""Int type judgment."""
...
...
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
浏览文件 @
24a10225
...
...
@@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
return
get_single_type
((
*
tuple_ptr
)[
output_idx
]);
};
TypePtr
type_ptr
=
node
->
Type
();
if
(
type_ptr
->
isa
<
RefType
>
())
{
auto
ref_type_ptr
=
type_ptr
->
cast
<
RefTypePtr
>
();
MS_EXCEPTION_IF_NULL
(
ref_type_ptr
);
return
get_tuple_type
(
ref_type_ptr
->
subtype
(),
output_idx
);
}
return
get_tuple_type
(
type_ptr
,
output_idx
);
}
...
...
mindspore/ccsrc/frontend/operator/composite/do_signature.cc
浏览文件 @
24a10225
...
...
@@ -20,6 +20,7 @@
#include "abstract/abstract_value.h"
#include "ir/anf.h"
#include "ir/dtype.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h"
...
...
@@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return
empty
;
}
void
ProcessDefault
(
const
std
::
string
&
func_name
,
const
AbstractBasePtrList
&
args_spec_list
,
const
std
::
vector
<
Signature
>
&
signature
,
bool
has_var
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
)
{
void
ProcessDefault
(
const
std
::
string
&
func_name
,
size_t
actual_param_number
,
const
std
::
vector
<
Signature
>
&
signature
,
bool
has_var
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
)
{
std
::
size_t
sig_size
=
signature
.
size
();
auto
positional_size
=
sig_size
;
if
(
has_var
)
{
positional_size
=
sig_size
-
1
;
}
if
(
a
rgs_spec_list
.
size
()
<
positional_size
)
{
for
(
size_t
i
=
a
rgs_spec_list
.
size
()
;
i
<
sig_size
;
++
i
)
{
if
(
a
ctual_param_number
<
positional_size
)
{
for
(
size_t
i
=
a
ctual_param_number
;
i
<
sig_size
;
++
i
)
{
auto
default_value
=
signature
[
i
].
default_value
;
if
(
default_value
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Function "
<<
func_name
<<
"'s input length is not equal to Signature length."
;
...
...
@@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
*
max_type_number
=
type_number
;
}
bool
GetTensorOrScalarTypeInfo
(
AbstractBasePtr
arg_value
,
bool
is_write
,
TypeId
*
arg_type_id
,
bool
GetTensorOrScalarTypeInfo
(
TypePtr
arg_type_origin
,
bool
is_write
,
TypeId
*
arg_type_id
,
TypeId
*
arg_type
=
nullptr
)
{
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
auto
ref
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
();
arg_value
=
ref
->
ref
();
if
(
!
is_write
&&
ref
->
need_cast
())
{
auto
tensor_type
=
ref
->
target_type
();
*
arg_type_id
=
tensor_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
*
arg_type
=
kObjectTypeTensorType
;
}
return
true
;
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
tensor
=
arg_value
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_type
=
tensor
->
element
()
->
BuildType
();
if
(
arg_type_origin
->
isa
<
TensorType
>
())
{
auto
tensor
=
arg_type_origin
->
cast
<
TensorTypePtr
>
();
auto
tensor_type
=
tensor
->
element
();
MS_EXCEPTION_IF_NULL
(
tensor_type
);
*
arg_type_id
=
tensor_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
...
...
@@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
}
return
true
;
}
if
(
arg_value
->
isa
<
abstract
::
AbstractScalar
>
())
{
auto
scalar
=
arg_value
->
cast
<
abstract
::
AbstractScalarPtr
>
();
auto
scalar_type
=
scalar
->
BuildType
();
if
(
arg_type_origin
->
isa
<
Number
>
())
{
auto
scalar_type
=
arg_type_origin
->
cast
<
NumberPtr
>
();
MS_EXCEPTION_IF_NULL
(
scalar_type
);
*
arg_type_id
=
scalar_type
->
type_id
();
if
(
arg_type
!=
nullptr
)
{
...
...
@@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
return
false
;
}
TypeId
GetMaxTypeId
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
std
::
vector
<
size_t
>
indices
,
TypeId
GetMaxTypeId
(
const
std
::
vector
<
TypePtr
>
&
input_types
,
std
::
vector
<
size_t
>
indices
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
TypeId
max_type_id
=
kTypeUnknown
;
size_t
max_type_number
=
0
;
...
...
@@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId
arg_type_id
=
kTypeUnknown
;
TypeId
arg_type
=
kTypeUnknown
;
auto
is_write
=
(
write_indices
.
find
(
index
)
!=
write_indices
.
end
());
if
(
!
GetTensorOrScalarTypeInfo
(
args_spec_list
[
index
],
is_write
,
&
arg_type_id
,
&
arg_type
))
{
if
(
!
GetTensorOrScalarTypeInfo
(
input_types
[
index
],
is_write
,
&
arg_type_id
,
&
arg_type
))
{
continue
;
}
if
(
arg_type
!=
kObjectTypeTensorType
)
{
...
...
@@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
using
MaxTypeMap
=
std
::
map
<
SignatureEnumDType
,
TypeId
>
;
MaxTypeMap
GetMaxDtype
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
MaxTypeMap
GetMaxDtype
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
,
const
std
::
vector
<
TypePtr
>
&
input_types
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indices
;
...
...
@@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
}
bool
has_tensor
=
false
;
for
(
const
auto
&
index
:
indices
)
{
AbstractBasePtr
arg_value
=
args_spec_list
[
index
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
arg_value
=
input_types
[
index
];
if
(
arg_value
->
isa
<
TensorType
>
())
{
has_tensor
=
true
;
break
;
}
...
...
@@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
kTypeUnknown
));
continue
;
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
args_spec_list
,
indices
,
write_indices
)));
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
input_types
,
indices
,
write_indices
)));
}
return
dst_type
;
}
...
...
@@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
}
void
DoAutoCast
(
const
std
::
string
&
func_name
,
const
std
::
vector
<
Signature
>
&
signature
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
FuncGraphPtr
&
graph
,
const
std
::
vector
<
TypePtr
>
&
input_types
,
const
FuncGraphPtr
&
graph
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_indices
)
{
std
::
vector
<
SignatureEnumDType
>
dtypes
;
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
...
...
@@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return
;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
args_spec_list
,
write_indices
);
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
input_types
,
write_indices
);
// Identify which arg requires auto cast
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_types
.
size
();
++
i
)
{
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
kTypeUnknown
)
{
continue
;
...
...
@@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
auto
is_write
=
(
rw_it
!=
write_indices
.
end
());
TypeId
arg_type_id
=
kTypeUnknown
;
AbstractBasePtr
arg_value
=
args_spec_list
[
i
];
auto
arg_value
=
input_types
[
i
];
(
void
)
GetTensorOrScalarTypeInfo
(
arg_value
,
is_write
,
&
arg_type_id
);
auto
it_map
=
type_name_map
.
find
(
arg_type_id
);
if
(
it_map
==
type_name_map
.
end
())
{
...
...
@@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
}
continue
;
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
(
)
&&
arg_type_id
==
it
->
second
)
{
if
(
(
arg_value
->
isa
<
TensorType
>
()
)
&&
arg_type_id
==
it
->
second
)
{
continue
;
}
MS_LOG
(
DEBUG
)
<<
"do cast for inputs "
<<
i
<<
" "
<<
(
*
op_inputs
)[
i
+
1
]
->
ToString
()
<<
" "
<<
arg_type_id
...
...
@@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
std
::
vector
<
AnfNodePtr
>
op_inputs
;
std
::
set
<
size_t
>
write_indices
;
std
::
vector
<
TypePtr
>
input_types
;
op_inputs
.
push_back
(
NewValueNode
(
function
));
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
...
...
@@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
sig
=
signature
[
sig_size
-
1
].
rw
;
}
TypePtr
type
=
args_spec_list
[
i
]
->
GetTypeTrack
();
if
(
type
&&
type
->
type_id
()
==
kObjectTypeRef
)
{
auto
ref_abs
=
args_spec_list
[
i
]
->
cast
<
abstract
::
AbstractRefPtr
>
(
);
TypePtr
type
=
args_spec_list
[
i
]
->
BuildType
();
if
(
type
&&
type
->
isa
<
RefType
>
()
)
{
auto
cast_type
=
parse
::
GetMixedPrecisionTargetType
(
func_graph
);
if
(
sig
==
SignatureEnumRW
::
kRWRead
)
{
param
=
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
},
func_graph
);
if
(
ref_abs
&&
ref_abs
->
need_cast
())
{
auto
cast
=
prim
::
GetPythonOps
(
"cast"
,
"mindspore.ops.functional"
);
param
=
NewCNode
({
NewValueNode
(
cast
),
param
,
NewValueNode
(
ref_abs
->
target_type
())},
func_graph
);
auto
source_tensor_type
=
type
->
cast
<
TensorTypePtr
>
();
if
(
source_tensor_type
!=
nullptr
)
{
auto
source_element
=
source_tensor_type
->
element
();
if
(
cast_type
!=
nullptr
&&
IsSubType
(
source_element
,
kFloat
)
&&
*
source_element
!=
*
cast_type
)
{
auto
cast
=
prim
::
GetPythonOps
(
"cast"
,
"mindspore.ops.functional"
);
param
=
NewCNode
({
NewValueNode
(
cast
),
param
,
NewValueNode
(
cast_type
)},
func_graph
);
type
=
cast_type
->
type_id
()
==
kNumberTypeFloat16
?
kTensorTypeFP16
:
kTensorTypeFP32
;
}
}
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
)
{
param
=
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
},
func_graph
);
write_indices
.
insert
(
i
);
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_EXCEPTION
(
TypeError
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" should be a Parameter."
;
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
!
((
type
->
type_id
()
==
kObjectTypeRef
)
||
(
type
->
type_id
()
==
kObjectTypeRefKey
)))
{
MS_EXCEPTION
(
TypeError
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" should be a Parameter, but "
<<
type
->
ToString
();
}
MS_LOG
(
DEBUG
)
<<
"Function "
<<
func_name
<<
"'s input "
<<
i
<<
" "
<<
param
->
DebugString
(
2
)
<<
" type "
<<
args_spec_list
[
i
]
->
ToString
();
input_types
.
push_back
(
type
);
op_inputs
.
push_back
(
param
);
}
// process default
ProcessDefault
(
func_name
,
args_spec_list
,
signature
,
has_var
,
&
op_inputs
);
DoAutoCast
(
func_name
,
signature
,
args_spec_list
,
func_graph
,
&
op_inputs
,
write_indices
);
ProcessDefault
(
func_name
,
args_spec_list
.
size
()
,
signature
,
has_var
,
&
op_inputs
);
DoAutoCast
(
func_name
,
signature
,
input_types
,
func_graph
,
&
op_inputs
,
write_indices
);
return
func_graph
->
NewCNode
(
op_inputs
);
}
}
// namespace
...
...
mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc
浏览文件 @
24a10225
...
...
@@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &
}
Register
(
types_name
,
py_fn
);
}
static
TypePtr
UnwrapRef
(
const
TypePtr
&
type
)
{
if
(
type
->
isa
<
RefType
>
())
{
return
type
->
cast
<
RefTypePtr
>
()
->
subtype
();
}
return
type
;
}
// Return Exact match if exists, else return non ambiguous sub class match
// Return py::none() if matching is ambiguous
...
...
@@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
}
auto
match
=
true
;
for
(
size_t
i
=
0
;
i
<
sign
.
size
();
++
i
)
{
if
(
!
IsIdentidityOrSubclass
(
UnwrapRef
(
types
[
i
])
,
sign
[
i
]))
{
if
(
!
IsIdentidityOrSubclass
(
types
[
i
]
,
sign
[
i
]))
{
match
=
false
;
break
;
}
...
...
mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc
浏览文件 @
24a10225
...
...
@@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return
std
::
make_shared
<
AbstractClass
>
(
cls
->
tag
(),
abs_attributes
,
cls
->
methods
());
}
AbstractBasePtr
InferImplAssign
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
2
);
MS_LOG
(
DEBUG
)
<<
"InferImplAssign "
<<
args_spec_list
[
0
];
return
args_spec_list
[
0
];
}
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
TypeOf
,
prim
::
kPrimTypeOf
,
InferImplTypeof
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
HasType
,
prim
::
kPrimHasType
,
InferImplHasType
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
MakeRecord
,
prim
::
kPrimMakeRecord
,
InferImplMakeRecord
);
...
...
@@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
J
,
prim
::
kPrimJ
,
InferImplJ
);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL
(
BroadcastGradientArgs
,
prim
::
kPrimBroadcastGradientArgs
,
InferImplBroadcastGradientArgs
);
REGISTER_PRIMITIVE_EVAL_IMPL
(
Assign
,
prim
::
kPrimAssign
,
InferImplAssign
);
}
// namespace abstract
}
// namespace mindspore
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
浏览文件 @
24a10225
...
...
@@ -20,6 +20,7 @@
#include "ir/anf.h"
#include "ir/param_info.h"
#include "ir/meta_tensor.h"
#include "pipeline/jit/parse/python_adapter.h"
namespace
mindspore
{
...
...
@@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if
(
!
para_ptr
->
has_default
())
{
return
false
;
}
auto
obj
=
py
::
cast
(
para_ptr
->
default_param
());
auto
param_value
=
py
::
cast
<
ParamValuePtr
>
(
obj
.
attr
(
"_value"
));
auto
param_value
=
para_ptr
->
param_info
();
if
(
param_value
==
nullptr
)
{
return
false
;
}
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
24a10225
...
...
@@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
if
(
!
cloned_parameter
->
has_default
())
{
return
false
;
}
auto
obj
=
py
::
cast
(
cloned_parameter
->
default_param
());
auto
param_value
=
py
::
cast
<
ParamValuePtr
>
(
obj
.
attr
(
"_value"
));
auto
param_value
=
cloned_parameter
->
param_info
();
if
(
param_value
==
nullptr
)
{
return
false
;
}
...
...
@@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if
(
!
ParameterIsCloned
(
cloned_parameter_node
))
{
continue
;
}
auto
obj
=
py
::
cast
(
cloned_parameter
->
default_param
());
auto
param_value
=
py
::
cast
<
ParamValuePtr
>
(
obj
.
attr
(
"_value"
));
auto
param_value
=
cloned_parameter
->
param_info
();
if
(
param_value
==
nullptr
)
{
continue
;
}
...
...
@@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
continue
;
}
const
auto
&
param_value_cloned
=
be_cloned_parameter
->
default_param
();
auto
obj_in
=
py
::
cast
(
param_value_cloned
);
auto
param_value_in
=
py
::
cast
<
ParamValuePtr
>
(
obj_in
.
attr
(
"_value"
));
auto
param_value_in
=
be_cloned_parameter
->
param_info
();
if
(
param_value_in
==
nullptr
)
{
continue
;
}
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
24a10225
...
...
@@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for
(
const
auto
&
param
:
func_graph
->
parameters
())
{
auto
param_node
=
std
::
static_pointer_cast
<
Parameter
>
(
param
);
if
(
param_node
->
has_default
())
{
ValuePtr
value
=
param_node
->
default_param
();
constexpr
bool
broaden
=
true
;
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
value
,
broaden
);
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
ptr
);
args_spec
.
push_back
(
ptr
);
parallel
::
ParallelParameterContextCkptInTraining
(
func_graph
,
param_node
,
ptr
);
auto
value
=
param_node
->
default_param
();
auto
abs_value
=
value
->
ToAbstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
ref_key
=
std
::
make_shared
<
RefKey
>
(
param_node
->
name
());
auto
abs_ref_key
=
ref_key
->
ToAbstract
();
auto
abs_ref
=
std
::
make_shared
<
abstract
::
AbstractRef
>
(
abs_ref_key
,
abs_value
);
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
abs_ref
);
args_spec
.
push_back
(
abs_ref
);
parallel
::
ParallelParameterContextCkptInTraining
(
func_graph
,
param_node
,
abs_ref
);
}
}
// Analyze
...
...
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
浏览文件 @
24a10225
...
...
@@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
converted
=
env
;
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_CLASS_MEMBER_NAMESPACE
))
{
converted
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
obj
);
}
else
if
(
py
::
hasattr
(
obj
,
"__parameter__"
))
{
auto
to_convert
=
py
::
cast
<
py
::
object
>
(
python_adapter
::
GetPyObjAttr
(
obj
,
"default_input"
));
ret
=
ConvertData
(
to_convert
,
&
converted
);
}
else
{
ret
=
ConvertOtherObj
(
obj
,
&
converted
);
}
...
...
@@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name)
ValuePtr
PyDataToValue
(
const
py
::
object
&
obj
)
{
py
::
object
to_convert
=
obj
;
if
(
py
::
hasattr
(
obj
,
"__parameter__"
))
{
to_convert
=
py
::
cast
<
py
::
object
>
(
python_adapter
::
GetPyObjAttr
(
obj
,
"default_input"
));
}
ValuePtr
value
=
nullptr
;
(
void
)
ConvertData
(
to_convert
,
&
value
);
return
value
;
...
...
mindspore/ccsrc/pipeline/jit/parse/function_block.cc
浏览文件 @
24a10225
...
...
@@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
}
void
FunctionBlock
::
SetStateAssgin
(
const
AnfNodePtr
&
target
,
const
std
::
string
&
readid
)
{
state_assign_
[
target
]
=
readid
;
const
std
::
string
primitive_name
(
"assign"
);
const
std
::
string
module_name
(
"mindspore.ops.functional"
);
ValueNodePtr
assign_op
=
NewValueNode
(
prim
::
GetPythonOps
(
primitive_name
,
module_name
,
true
));
auto
source
=
ReadVariable
(
readid
);
auto
assign
=
func_graph
()
->
NewCNode
({
assign_op
,
target
,
source
});
WriteVariable
(
readid
,
assign
);
MS_LOG
(
INFO
)
<<
"SetState read "
<<
target
->
DebugString
()
<<
", "
<<
readid
;
AddAutoDepend
(
assign
);
}
void
FunctionBlock
::
AddAutoDepend
(
const
AnfNodePtr
&
target
)
{
auto_depends_
.
push_back
(
target
);
}
...
...
@@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr
make_tuple_op
=
NewValueNode
(
prim
::
kPrimMakeTuple
);
ValueNodePtr
depend_op
=
NewValueNode
(
prim
::
kPrimDepend
);
ValueNodePtr
stop_gradient_op
=
NewValueNode
(
prim
::
kPrimStopGradient
);
const
std
::
string
primitive_name
(
"assign"
);
const
std
::
string
module_name
(
"mindspore.ops.functional"
);
ValueNodePtr
assign_op
=
NewValueNode
(
prim
::
GetPythonOps
(
primitive_name
,
module_name
,
true
));
if
(
state_assign_
.
size
()
==
0
&&
auto_depends_
.
size
()
==
0
)
{
if
(
auto_depends_
.
size
()
==
0
)
{
return
;
}
AnfNodePtr
state
=
nullptr
;
std
::
vector
<
AnfNodePtr
>
vec_states
;
vec_states
.
emplace_back
(
make_tuple_op
);
for
(
auto
&
item
:
state_assign_
)
{
auto
source
=
ReadVariable
(
item
.
second
);
auto
assign
=
func_graph
()
->
NewCNode
({
assign_op
,
item
.
first
,
source
});
MS_LOG
(
INFO
)
<<
"SetState read "
<<
item
.
first
->
ToString
()
<<
", "
<<
item
.
second
;
vec_states
.
emplace_back
(
assign
);
}
for
(
auto
&
item
:
auto_depends_
)
{
MS_LOG
(
DEBUG
)
<<
"auto_depends "
<<
item
->
ToString
();
vec_states
.
emplace_back
(
item
);
...
...
@@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
AnfNodePtr
stopped
=
func_graph
()
->
NewCNode
({
stop_gradient_op
,
state
});
AnfNodePtr
ret
=
func_graph
()
->
NewCNode
({
depend_op
,
old_ret
,
stopped
});
func_graph
()
->
set_output
(
ret
,
true
);
state_assign_
.
clear
();
}
}
// namespace parse
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/parse/function_block.h
浏览文件 @
24a10225
...
...
@@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// keeps all removable phis which will be removed in one pass.
std
::
unordered_map
<
ParameterPtr
,
AnfNodePtr
>
removable_phis_
;
// set state nodes need to insert before function return nodes.
OrderedMap
<
AnfNodePtr
,
std
::
string
>
state_assign_
;
// hold declared global variables in function
std
::
set
<
std
::
string
>
global_vars_
;
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
24a10225
...
...
@@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return
func_graph
;
}
ValuePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
)
{
TypePtr
dst_type
;
TypePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
)
{
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP32
))
{
return
kFloat32
;
}
else
if
(
func_graph
->
has_flag
(
GRAPH_FLAG_MIX_PRECISION_FP16
))
{
return
kFloat16
;
}
else
{
return
kNone
;
return
nullptr
;
}
}
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.h
浏览文件 @
24a10225
...
...
@@ -359,7 +359,7 @@ class ParseAst {
bool
UpdateFuncGraphFlags
(
py
::
object
obj
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
GetMixedPrecisionCastHelp
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
ValuePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
param
);
TypePtr
GetMixedPrecisionTargetType
(
const
FuncGraphPtr
&
func_graph
);
}
// namespace parse
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
浏览文件 @
24a10225
...
...
@@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
auto
value
=
py
::
cast
<
tensor
::
MetaTensorPtr
>
(
obj
);
node
->
set_default_param
(
value
);
// set_abstract for parameter
constexpr
bool
broaden
=
true
;
node
->
set_abstract
(
abs
tract
::
FromValue
(
value
,
broaden
)
);
auto
abs
=
value
->
ToAbstract
()
;
node
->
set_abstract
(
abs
);
para_node
=
node
;
}
auto
iter
=
func_graph
->
make_ref_params
().
find
(
para_node
);
if
(
iter
==
func_graph
->
make_ref_params
().
end
())
{
ValuePtr
target_type
=
GetMixedPrecisionTargetType
(
func_graph
,
para_node
);
AnfNodePtr
make_ref
=
NewValueNode
(
prim
::
kPrimMakeRef
);
AnfNodePtr
ref_key
=
NewValueNode
(
std
::
make_shared
<
RefKey
>
(
param_name
));
AnfNodePtr
target_type_node
=
NewValueNode
(
target_type
);
AnfNodePtr
ref_node
=
func_graph
->
NewCNode
({
make_ref
,
ref_key
,
para_node
,
target_type_node
});
func_graph
->
make_ref_params
()[
para_node
]
=
ref_node
;
func_graph
->
add_parameter_obj_node
(
ref_node
);
return
ref_node
;
}
else
{
return
iter
->
second
;
}
return
para_node
;
}
bool
ResolveObjectToNode
(
const
FuncGraphPtr
&
func_graph
,
const
py
::
object
&
obj
,
AnfNodePtr
*
const
node
)
{
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
24a10225
...
...
@@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
size_t
size
=
op_exec_info
->
op_inputs
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
auto
obj
=
op_exec_info
->
op_inputs
[
i
];
bool
op_mask
=
py
::
hasattr
(
obj
,
"__parameter__"
);
bool
op_mask
=
false
;
if
(
py
::
isinstance
<
tensor
::
MetaTensor
>
(
obj
))
{
auto
meta_tensor
=
obj
.
cast
<
tensor
::
MetaTensorPtr
>
();
if
(
meta_tensor
)
{
op_mask
=
meta_tensor
->
is_parameter
();
}
}
(
*
op_masks
).
push_back
(
op_mask
);
MS_LOG
(
DEBUG
)
<<
"gen "
<<
op_exec_info
->
op_name
<<
" arg "
<<
i
<<
": op mask "
<<
op_mask
<<
" grad_flag_ "
<<
grad_flag_
;
...
...
@@ -988,8 +995,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
if
(
graph_info_map_
[
df_builder_
].
param_map
.
count
(
obj_id
)
==
0
)
{
auto
free_param
=
df_builder_
->
add_parameter
();
free_param
->
set_name
(
param_name
);
free_param
->
set_default_param
(
py
::
cast
<
tensor
::
TensorPtr
>
(
obj
));
free_param
->
debug_info
()
->
set_name
(
param_name
);
auto
value
=
py
::
cast
<
tensor
::
TensorPtr
>
(
obj
);
free_param
->
set_default_param
(
value
);
MS_LOG
(
DEBUG
)
<<
"Top graph set free parameter "
<<
obj_id
;
graph_info_map_
[
df_builder_
].
param_map
[
obj_id
]
=
free_param
;
return
free_param
;
...
...
@@ -1157,17 +1165,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
auto
param_name
=
py
::
cast
<
std
::
string
>
(
name_attr
);
auto
free_param
=
df_builder_
->
add_parameter
();
free_param
->
set_name
(
param_name
);
free_param
->
set_default_param
(
py
::
cast
<
tensor
::
TensorPtr
>
(
param
));
auto
value
=
py
::
cast
<
tensor
::
TensorPtr
>
(
param
);
free_param
->
set_default_param
(
value
);
free_param
->
debug_info
()
->
set_name
(
param_name
);
para_node
=
free_param
;
}
ValuePtr
target_type
=
parse
::
GetMixedPrecisionTargetType
(
df_builder_
,
para_node
);
AnfNodePtr
make_ref
=
NewValueNode
(
prim
::
kPrimMakeRef
);
auto
refkey
=
std
::
make_shared
<
RefKey
>
(
para_node
->
cast
<
ParameterPtr
>
()
->
name
());
AnfNodePtr
ref_key_node
=
NewValueNode
(
refkey
);
AnfNodePtr
target_type_node
=
NewValueNode
(
target_type
);
AnfNodePtr
ref_node
=
df_builder_
->
NewCNode
({
make_ref
,
ref_key_node
,
para_node
,
target_type_node
});
w_args
.
push_back
(
ref_node
);
w_args
.
push_back
(
para_node
);
}
}
else
{
MS_LOG
(
DEBUG
)
<<
"training not paramter_tuple"
;
...
...
@@ -1195,7 +1198,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
auto
param_node
=
std
::
static_pointer_cast
<
Parameter
>
(
param
);
if
(
param_node
->
has_default
())
{
ValuePtr
value
=
param_node
->
default_param
();
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
value
,
true
);
auto
ptr
=
value
->
ToAbstract
(
);
if
(
ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Args convert error"
;
}
...
...
mindspore/ccsrc/pybind_api/ir/dtype_py.cc
浏览文件 @
24a10225
...
...
@@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE(
(
void
)
py
::
class_
<
TypeType
,
Type
,
std
::
shared_ptr
<
TypeType
>>
(
m_sub
,
"TypeType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
String
,
Type
,
std
::
shared_ptr
<
String
>>
(
m_sub
,
"String"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefKeyType
,
Type
,
std
::
shared_ptr
<
RefKeyType
>>
(
m_sub
,
"RefKeyType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefType
,
Type
,
std
::
shared_ptr
<
RefType
>>
(
m_sub
,
"RefType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefType
,
T
ensorType
,
T
ype
,
std
::
shared_ptr
<
RefType
>>
(
m_sub
,
"RefType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeAnything
,
Type
,
std
::
shared_ptr
<
TypeAnything
>>
(
m_sub
,
"TypeAnything"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
Slice
,
Type
,
std
::
shared_ptr
<
Slice
>>
(
m_sub
,
"Slice"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeEllipsis
,
Type
,
std
::
shared_ptr
<
TypeEllipsis
>>
(
m_sub
,
"TypeEllipsis"
).
def
(
py
::
init
());
...
...
mindspore/ccsrc/pybind_api/ir/param_info_py.cc
浏览文件 @
24a10225
...
...
@@ -21,7 +21,7 @@ namespace mindspore {
namespace
py
=
pybind11
;
REGISTER_PYBIND_DEFINE
(
ParamInfo
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
ParamInfo
,
Param
Value
Ptr
>
(
*
m
,
"ParamInfo"
)
(
void
)
py
::
class_
<
ParamInfo
,
Param
Info
Ptr
>
(
*
m
,
"ParamInfo"
)
.
def
(
py
::
init
())
.
def
(
"clone"
,
&
ParamInfo
::
Clone
)
.
def_property
(
"name"
,
&
ParamInfo
::
name
,
&
ParamInfo
::
set_name
)
...
...
@@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
if
(
t
.
size
()
!=
6
)
{
std
::
runtime_error
(
"Invalid state for ParamInfo!"
);
}
Param
Value
Ptr
p
=
std
::
make_shared
<
ParamInfo
>
();
Param
Info
Ptr
p
=
std
::
make_shared
<
ParamInfo
>
();
p
->
set_name
(
t
[
1
].
cast
<
std
::
string
>
());
p
->
set_requires_grad
(
t
[
2
].
cast
<
bool
>
());
p
->
set_layerwise_parallel
(
t
[
3
].
cast
<
bool
>
());
...
...
mindspore/ccsrc/pybind_api/ir/tensor_py.cc
浏览文件 @
24a10225
...
...
@@ -213,6 +213,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.
def
(
py
::
init
<
TypePtr
,
const
std
::
vector
<
int
>>
(),
py
::
arg
(
"dtype"
),
py
::
arg
(
"shape"
))
.
def_property_readonly
(
"dtype"
,
&
MetaTensor
::
Dtype
,
"Get the MetaTensor's dtype."
)
.
def_property_readonly
(
"shape"
,
&
MetaTensor
::
shape
,
"Get the MetaTensor's shape."
)
.
def_property
(
"_param_info"
,
&
MetaTensor
::
param_info
,
&
MetaTensor
::
set_param_info
)
.
def
(
py
::
pickle
(
[](
const
MetaTensor
&
t
)
{
// __getstate__
/* Return a tuple that fully encodes the state of the object */
...
...
mindspore/common/parameter.py
浏览文件 @
24a10225
...
...
@@ -42,7 +42,7 @@ class Parameter(MetaTensor):
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
compil
e
for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
compil
ing
for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
Note:
Each parameter of Cell is represented by Parameter class.
...
...
@@ -108,7 +108,7 @@ class Parameter(MetaTensor):
Parameter
,
(
data
,
self
.
name
,
self
.
requires_grad
,
self
.
layerwise_parallel
))
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
):
self
.
_
value
=
ParamInfo
()
self
.
_
param_info
=
ParamInfo
()
self
.
name
=
name
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
...
...
@@ -156,13 +156,13 @@ class Parameter(MetaTensor):
value_str
=
MetaTensor
.
__str__
(
self
)
if
isinstance
(
self
,
Tensor
):
value_str
=
Tensor
.
__str__
(
self
)
return
f
'Parameter (name=
{
self
.
_
value
.
name
}
, value=
{
value_str
}
)'
return
f
'Parameter (name=
{
self
.
_
param_info
.
name
}
, value=
{
value_str
}
)'
def
__repr__
(
self
):
value_str
=
MetaTensor
.
__repr__
(
self
)
if
isinstance
(
self
,
Tensor
):
value_str
=
Tensor
.
__repr__
(
self
)
return
f
'Parameter (name=
{
self
.
_
value
.
name
}
, value=
{
value_str
}
)'
return
f
'Parameter (name=
{
self
.
_
param_info
.
name
}
, value=
{
value_str
}
)'
def
__parameter__
(
self
):
"""For parse check."""
...
...
@@ -181,7 +181,7 @@ class Parameter(MetaTensor):
@
property
def
name
(
self
):
"""Get the name of the parameter."""
return
self
.
_
value
.
name
return
self
.
_
param_info
.
name
@
name
.
setter
def
name
(
self
,
name_
):
...
...
@@ -203,7 +203,7 @@ class Parameter(MetaTensor):
format
(
name_
,
PARAMETER_NAME_PREFIX_MAX_LEN
))
else
:
raise
ValueError
(
"The type of the name should be `str` or `None`."
)
self
.
_
value
.
name
=
name_
self
.
_
param_info
.
name
=
name_
@
property
def
cast_type
(
self
):
...
...
@@ -254,8 +254,8 @@ class Parameter(MetaTensor):
_check_str_by_regular
(
prefix
)
x
=
copy
(
self
)
# pylint: disable=protected-access
x
.
_
value
=
self
.
_value
.
clone
()
x
.
_
value
.
name
=
prefix
+
'.'
+
self
.
_value
.
name
x
.
_
param_info
=
self
.
_param_info
.
clone
()
x
.
_
param_info
.
name
=
prefix
+
'.'
+
self
.
_param_info
.
name
x
.
is_init
=
False
if
init
!=
'same'
:
shape
=
self
.
shape
...
...
@@ -265,24 +265,24 @@ class Parameter(MetaTensor):
@
property
def
layerwise_parallel
(
self
):
return
self
.
_
value
.
layerwise_parallel
return
self
.
_
param_info
.
layerwise_parallel
@
layerwise_parallel
.
setter
def
layerwise_parallel
(
self
,
value
=
True
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`layerwise_parallel` parameter must be bool type"
)
self
.
_
value
.
layerwise_parallel
=
value
self
.
_
param_info
.
layerwise_parallel
=
value
@
property
def
requires_grad
(
self
):
"""Return whether the parameter requires gradient."""
return
self
.
_
value
.
requires_grad
return
self
.
_
param_info
.
requires_grad
@
requires_grad
.
setter
def
requires_grad
(
self
,
value
=
True
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`requires_grad` parameter must be bool type"
)
self
.
_
value
.
requires_grad
=
value
self
.
_
param_info
.
requires_grad
=
value
@
property
def
data
(
self
):
...
...
mindspore/core/abstract/abstract_value.cc
浏览文件 @
24a10225
...
...
@@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
auto
other_tensor
=
dyn_cast
<
AbstractTensor
>
(
other
);
if
(
other_tensor
==
nullptr
)
{
auto
ref_tensor
=
dyn_cast
<
AbstractRef
>
(
other
);
if
(
ref_tensor
!=
nullptr
)
{
return
this
->
Join
(
ref_tensor
->
ref
());
}
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
}
if
(
*
this
==
*
other
)
{
...
...
@@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
return
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
}
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
bool
AbstractTensor
::
equal_to
(
const
AbstractTensor
&
other
)
const
{
if
(
&
other
==
this
)
{
return
true
;
}
...
...
@@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const {
return
(
*
element_
==
*
other
.
element_
)
&&
(
*
shape
()
==
*
other
.
shape
())
&&
is_value_equal
;
}
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
return
equal_to
(
other
);
}
bool
AbstractTensor
::
operator
==
(
const
AbstractBase
&
other
)
const
{
if
(
&
other
==
this
)
{
return
true
;
}
if
(
other
.
isa
<
AbstractTensor
>
())
{
if
(
other
.
tid
()
==
tid
())
{
auto
other_tensor
=
static_cast
<
const
AbstractTensor
*>
(
&
other
);
return
*
this
==
*
other_tensor
;
}
else
{
...
...
@@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const {
return
buffer
.
str
();
}
AbstractRef
::
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
bool
need_cast
,
TypePtr
cast_target
)
:
ref_key_
(
ref_key
),
ref_
(
ref_value
),
need_cast_
(
false
),
target_type_
(
nullptr
),
ref_key_value_
(
nullptr
)
{
AbstractRef
::
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractTensorPtr
&
ref_value
)
:
AbstractTensor
(
*
ref_value
),
ref_key_
(
ref_key
),
ref_key_value_
(
nullptr
)
{
set_type
(
std
::
make_shared
<
RefType
>
());
auto
origin_type
=
ref_value
->
BuildType
();
if
(
need_cast
&&
cast_target
&&
origin_type
&&
origin_type
->
isa
<
TensorType
>
())
{
auto
tensor_dtype
=
origin_type
->
cast
<
TensorTypePtr
>
()
->
element
();
if
(
tensor_dtype
&&
IsSubType
(
tensor_dtype
,
kFloat
))
{
if
(
cast_target
!=
tensor_dtype
)
{
need_cast_
=
true
;
target_type_
=
cast_target
;
}
}
}
if
(
ref_key
&&
ref_key
->
isa
<
AbstractRefKey
>
())
{
ref_key_value_
=
ref_key
->
cast
<
AbstractRefKeyPtr
>
()
->
ref_key_value
();
}
}
BaseShapePtr
AbstractRef
::
BuildShape
()
const
{
return
ref_
->
BuildShape
();
}
TypePtr
AbstractRef
::
BuildType
()
const
{
TypePtr
subtype
=
ref_
->
BuildType
();
TypePtr
subtype_origin
=
subtype
;
if
(
need_cast_
)
{
subtype_origin
=
std
::
make_shared
<
TensorType
>
(
target_type_
);
}
return
std
::
make_shared
<
RefType
>
(
subtype
,
subtype_origin
);
auto
subtype
=
AbstractTensor
::
BuildType
()
->
cast
<
TensorTypePtr
>
();
return
std
::
make_shared
<
RefType
>
(
subtype
);
}
bool
AbstractRef
::
operator
==
(
const
AbstractRef
&
other
)
const
{
return
(
*
ref_
==
*
other
.
ref_
)
&&
(
need_cast_
==
other
.
need_cast_
)
&&
(
*
ref_key_
==
*
other
.
ref_key_
)
&&
(
!
need_cast_
||
(
*
target_type_
==
*
other
.
target_type_
));
return
AbstractTensor
::
equal_to
(
other
)
&&
(
*
ref_key_
==
*
other
.
ref_key_
);
}
bool
AbstractRef
::
operator
==
(
const
AbstractBase
&
other
)
const
{
...
...
@@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
AbstractBasePtr
AbstractRef
::
Join
(
const
AbstractBasePtr
&
other
)
{
auto
other_ref
=
other
->
cast
<
AbstractRefPtr
>
();
if
(
other_ref
==
nullptr
)
{
auto
new_ref
=
ref_
->
Join
(
other
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
,
new_ref
);
return
AbstractTensor
::
Join
(
other
)
->
cast
<
AbstractTensorPtr
>
();
}
if
((
*
this
==
*
other
)
&&
(
*
ref_key_
==
*
other_ref
->
ref_key_
))
{
return
shared_from_base
<
AbstractBase
>
();
}
auto
ref_key
=
ref_key_
->
Join
(
other_ref
->
ref_key_
);
auto
ref
=
ref_
->
Join
(
other_ref
->
ref
()
);
auto
ref
=
AbstractTensor
::
Join
(
other_ref
->
ref
())
->
cast
<
AbstractTensorPtr
>
(
);
return
std
::
make_shared
<
AbstractRef
>
(
ref_key
,
ref
);
}
std
::
string
AbstractRef
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"("
<<
"key: "
<<
ref_key_
->
ToString
()
<<
" ref_value: "
<<
ref_
->
ToString
();
if
(
need_cast_
)
{
buffer
<<
" cast to: "
<<
target_type_
->
ToString
();
}
<<
"key: "
<<
ref_key_
->
ToString
()
<<
" ref_value: "
<<
AbstractTensor
::
ToString
();
auto
value
=
GetValueTrack
();
if
(
value
)
{
buffer
<<
", value: "
<<
value
->
ToString
();
...
...
mindspore/core/abstract/abstract_value.h
浏览文件 @
24a10225
...
...
@@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined {
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
BroadenWithShape
()
const
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
final
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
);
bool
operator
==
(
const
AbstractTensor
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
std
::
string
ToString
()
const
override
;
std
::
size_t
hash
()
const
override
{
auto
value
=
GetValueTrack
();
...
...
@@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined {
}
return
hash_sum
;
}
protected:
bool
equal_to
(
const
AbstractTensor
&
other
)
const
;
};
using
AbstractTensorPtr
=
std
::
shared_ptr
<
AbstractTensor
>
;
using
AbstractTensorPtrList
=
std
::
vector
<
AbstractTensorPtr
>
;
...
...
@@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase {
};
using
AbstractRefKeyPtr
=
std
::
shared_ptr
<
AbstractRefKey
>
;
class
AbstractRef
:
public
Abstract
Base
{
class
AbstractRef
:
public
Abstract
Tensor
{
public:
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractBasePtr
&
ref_value
,
bool
need_cast
=
false
,
TypePtr
cast_target
=
nullptr
);
AbstractRef
(
const
AbstractBasePtr
&
ref_key
,
const
AbstractTensorPtr
&
ref_value
);
~
AbstractRef
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractRef
,
Abstract
Base
)
MS_DECLARE_PARENT
(
AbstractRef
,
Abstract
Tensor
)
TypePtr
BuildType
()
const
override
;
BaseShapePtr
BuildShape
()
const
override
;
bool
operator
==
(
const
AbstractRef
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Clone
(),
ref_
->
Clone
(),
need_cast_
,
target_type_
);
auto
abs_tensor
=
AbstractTensor
::
Clone
()
->
cast
<
AbstractTensorPtr
>
();
if
(
abs_tensor
==
nullptr
)
{
return
nullptr
;
}
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Clone
(),
abs_tensor
);
}
std
::
string
ToString
()
const
override
;
inline
Abstract
BasePtr
ref
()
const
{
return
ref_
;
}
inline
Abstract
TensorPtr
ref
()
{
return
shared_from_base
<
AbstractTensor
>
()
;
}
inline
AbstractBasePtr
ref_key
()
const
{
return
ref_key_
;
}
inline
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
inline
TypePtr
target_type
()
const
{
return
target_type_
;
}
inline
bool
need_cast
()
const
{
return
need_cast_
;
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
// always broaden for ref
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(
config
),
ref_
->
Broaden
(),
need_cast_
,
target_type_
);
auto
abs_tensor
=
AbstractTensor
::
Broaden
()
->
cast
<
AbstractTensorPtr
>
();
if
(
abs_tensor
==
nullptr
)
{
return
nullptr
;
}
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(
config
),
abs_tensor
);
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
std
::
size_t
hash
()
const
override
{
return
ref_
->
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
// ref_key_->hash() ^
return
AbstractTensor
::
hash
()
^
(
std
::
hash
<
uint32_t
>
{}(
this
->
tid
())
<<
1
);
// ref_key_->hash() ^
}
private:
AbstractBasePtr
ref_key_
;
AbstractBasePtr
ref_
;
// For mix presicion, only float type need to cast to float16 of float32
bool
need_cast_
;
TypePtr
target_type_
;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr
ref_key_value_
;
};
...
...
mindspore/core/abstract/prim_others.cc
浏览文件 @
24a10225
...
...
@@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG
(
EXCEPTION
)
<<
"make_ref evaluator requires 3 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
TypePtr
type
=
args_spec_list
[
0
]
->
GetTypeTrack
();
ValuePtr
tensor_target_v
=
args_spec_list
[
2
]
->
BuildValue
();
if
(
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of make_ref should be a RefKey but a "
<<
type
->
ToString
();
}
auto
need_cast
=
!
tensor_target_v
->
isa
<
None
>
();
if
(
need_cast
&&
!
tensor_target_v
->
isa
<
Type
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Third input of make_ref should be a Type but a "
<<
tensor_target_v
->
ToString
();
}
TypePtr
cast_target
=
tensor_target_v
->
cast
<
TypePtr
>
();
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
need_cast
,
cast_target
);
auto
tensor
=
args_spec_list
[
1
]
->
cast
<
abstract
::
AbstractTensorPtr
>
();
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
tensor
);
}
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
...
...
mindspore/core/ir/anf.cc
浏览文件 @
24a10225
...
...
@@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const {
return
buffer
.
str
();
}
ParamInfoPtr
Parameter
::
param_info
()
const
{
if
(
!
has_default
())
{
return
nullptr
;
}
auto
tensor
=
default_param
()
->
cast
<
tensor
::
MetaTensorPtr
>
();
if
(
tensor
==
nullptr
||
!
tensor
->
is_parameter
())
{
return
nullptr
;
}
return
tensor
->
param_info
();
}
std
::
string
ValueNode
::
ToString
()
const
{
MS_EXCEPTION_IF_NULL
(
value_
);
if
(
value_
->
isa
<
FuncGraph
>
())
{
...
...
mindspore/core/ir/anf.h
浏览文件 @
24a10225
...
...
@@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>;
class
AnfIrVisitor
;
class
ParamInfo
;
using
Param
Value
Ptr
=
std
::
shared_ptr
<
ParamInfo
>
;
using
Param
Info
Ptr
=
std
::
shared_ptr
<
ParamInfo
>
;
// AnfNode is the basic class of the IR definition derived from Base.
// Only two types of nodes are derived: CNode and ANode.
...
...
@@ -288,6 +288,7 @@ class Parameter : public ANode {
has_default_
=
true
;
}
ValuePtr
default_param
()
const
{
return
default_param_
;
}
ParamInfoPtr
param_info
()
const
;
bool
operator
==
(
const
AnfNode
&
other
)
const
override
{
if
(
!
other
.
isa
<
Parameter
>
())
{
...
...
mindspore/core/ir/dtype.cc
浏览文件 @
24a10225
...
...
@@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const {
std
::
string
Slice
::
DumpText
()
const
{
return
ToString
();
}
TypePtr
UndeterminedType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
return
std
::
make_shared
<
UndeterminedType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
UndeterminedType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
UndeterminedType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
UndeterminedType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
TensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
TensorType
>
();
}
return
std
::
make_shared
<
TensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
TensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"tensor"
;
}
return
"tensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
TensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
TensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor("
+
element_type_
->
DumpText
()
+
")"
;
}
bool
TensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
TensorType
&>
(
other
).
element_type_
;
// When element_type_ = nullptr, which means any type of Array.
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
RowTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
RowTensorType
>
();
}
return
std
::
make_shared
<
RowTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
RowTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
RowTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
RowTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
RowTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
RowTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
SparseTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
SparseTensorType
>
();
}
return
std
::
make_shared
<
SparseTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
SparseTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
SparseTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
SparseTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
Function
::
Function
()
:
Object
(
kObjectTypeFunction
)
{
args_
=
std
::
vector
<
TypePtr
>
();
retval_
=
nullptr
;
...
...
@@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
os
<<
problem
->
ToString
();
return
os
;
}
const
TypePtr
kTensorTypeFP16
=
std
::
make_shared
<
TensorType
>
(
std
::
make_shared
<
Float
>
(
16
));
const
TypePtr
kTensorTypeFP32
=
std
::
make_shared
<
TensorType
>
(
std
::
make_shared
<
Float
>
(
32
));
}
// namespace mindspore
mindspore/core/ir/dtype.h
浏览文件 @
24a10225
...
...
@@ -32,10 +32,11 @@
#include "ir/named.h"
#include "ir/dtype/type.h"
#include "ir/dtype/ref.h"
#include "ir/dtype/number.h"
#include "ir/dtype/container.h"
#include "ir/dtype/empty.h"
#include "ir/dtype/tensor_type.h"
#include "ir/dtype/ref.h"
/* namespace to support intermediate representation definition */
namespace
mindspore
{
...
...
@@ -108,98 +109,6 @@ class Slice : public Object {
};
using
SlicePtr
=
std
::
shared_ptr
<
Slice
>
;
class
UndeterminedType
:
public
Object
{
public:
UndeterminedType
()
:
Object
(
kObjectTypeUndeterminedType
)
{}
explicit
UndeterminedType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeUndeterminedType
,
kMetaTypeObject
,
false
),
element_type_
(
ele
)
{}
~
UndeterminedType
()
override
=
default
;
MS_DECLARE_PARENT
(
UndeterminedType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeUndeterminedType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
protected:
TypePtr
element_type_
;
};
using
MetaTensorTypePtr
=
std
::
shared_ptr
<
UndeterminedType
>
;
class
TensorType
:
public
Object
{
public:
TensorType
()
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
TensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
TensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
TensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
TensorTypePtr
=
std
::
shared_ptr
<
TensorType
>
;
class
RowTensorType
:
public
Object
{
public:
RowTensorType
()
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
RowTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
RowTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
RowTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeRowTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
RowTensorTypePtr
=
std
::
shared_ptr
<
RowTensorType
>
;
class
SparseTensorType
:
public
Object
{
public:
SparseTensorType
()
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
SparseTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
SparseTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
SparseTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeSparseTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
SparseTensorTypePtr
=
std
::
shared_ptr
<
SparseTensorType
>
;
class
Function
:
public
Object
{
public:
Function
();
...
...
@@ -353,6 +262,9 @@ extern const TypePtr kDict;
extern
const
TypePtr
kSlice
;
extern
const
TypePtr
kKeyword
;
extern
const
TypePtr
kTensorType
;
extern
const
TypePtr
kTensorTypeFP16
;
extern
const
TypePtr
kTensorTypeFP32
;
}
// namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_H_
mindspore/core/ir/dtype/number.h
浏览文件 @
24a10225
...
...
@@ -68,6 +68,8 @@ class Number : public Object {
const
int
nbits_
;
};
using
NumberPtr
=
std
::
shared_ptr
<
Number
>
;
// Bool
class
Bool
:
public
Number
{
public:
...
...
mindspore/core/ir/dtype/ref.cc
浏览文件 @
24a10225
...
...
@@ -19,15 +19,15 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "ir/dtype/tensor_type.h"
namespace
mindspore
{
TypePtr
RefType
::
DeepCopy
()
const
{
if
(
IsGeneric
())
{
return
std
::
make_shared
<
RefType
>
();
}
else
{
auto
subtype
=
subtype_
->
DeepCopy
();
auto
subtype_origin
=
subtype_origin_
->
DeepCopy
();
return
std
::
make_shared
<
RefType
>
(
subtype
,
subtype_origin
);
auto
subtype
=
TensorType
::
DeepCopy
()
->
cast
<
TensorTypePtr
>
();
return
std
::
make_shared
<
RefType
>
(
subtype
);
}
}
...
...
@@ -39,7 +39,7 @@ std::string RefType::DumpText() const {
buffer
<<
"Ref"
;
}
else
{
buffer
<<
"Ref["
;
buffer
<<
subtype_
->
DumpText
()
<<
"]"
;
buffer
<<
TensorType
::
DumpText
()
<<
"]"
;
}
return
buffer
.
str
();
}
...
...
mindspore/core/ir/dtype/ref.h
浏览文件 @
24a10225
...
...
@@ -17,21 +17,13 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
#define MINDSPORE_CORE_IR_DTYPE_REF_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
#include "ir/dtype/tensor_type.h"
namespace
mindspore
{
// TypeRefKey type
...
...
@@ -48,23 +40,16 @@ class RefKeyType : public Object {
};
// TypeRef type
class
RefType
:
public
Object
{
class
RefType
:
public
TensorType
{
public:
RefType
()
:
Object
(
kObjectTypeRef
)
{}
RefType
(
const
TypePtr
&
subtype
,
const
TypePtr
&
subtype_origin
)
:
Object
(
kObjectTypeRef
,
false
),
subtype_
(
subtype
),
subtype_origin_
(
subtype_origin
)
{}
RefType
()
:
TensorType
()
{}
explicit
RefType
(
const
TensorTypePtr
&
subtype
)
:
TensorType
(
subtype
->
element
())
{}
~
RefType
()
override
{}
MS_DECLARE_PARENT
(
RefType
,
Object
)
MS_DECLARE_PARENT
(
RefType
,
TensorType
)
TypePtr
subtype
()
const
{
return
subtype_
;
}
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeRef
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
DumpText
()
const
override
;
private:
TypePtr
subtype_
;
TypePtr
subtype_origin_
;
};
using
RefTypePtr
=
std
::
shared_ptr
<
RefType
>
;
...
...
mindspore/core/ir/dtype/tensor_type.cc
0 → 100644
浏览文件 @
24a10225
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype/tensor_type.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
namespace
mindspore
{
TypePtr
UndeterminedType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
return
std
::
make_shared
<
UndeterminedType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
UndeterminedType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
UndeterminedType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
UndeterminedType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
TensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
TensorType
>
();
}
return
std
::
make_shared
<
TensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
TensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"tensor"
;
}
return
"tensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
TensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
TensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Tensor"
;
}
return
"Tensor("
+
element_type_
->
DumpText
()
+
")"
;
}
bool
TensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
TensorType
&>
(
other
).
element_type_
;
// When element_type_ = nullptr, which means any type of Array.
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
RowTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
RowTensorType
>
();
}
return
std
::
make_shared
<
RowTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
RowTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
RowTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
RowTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"RowTensor"
;
}
return
"RowTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
RowTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
RowTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
SparseTensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
SparseTensorType
>
();
}
return
std
::
make_shared
<
SparseTensorType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
SparseTensorType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
SparseTensorType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"SparseTensor"
;
}
return
"SparseTensor["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
SparseTensorType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
SparseTensorType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
}
// namespace mindspore
mindspore/core/ir/dtype/tensor_type.h
0 → 100644
浏览文件 @
24a10225
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
namespace
mindspore
{
class
UndeterminedType
:
public
Object
{
public:
UndeterminedType
()
:
Object
(
kObjectTypeUndeterminedType
)
{}
explicit
UndeterminedType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeUndeterminedType
,
kMetaTypeObject
,
false
),
element_type_
(
ele
)
{}
~
UndeterminedType
()
override
=
default
;
MS_DECLARE_PARENT
(
UndeterminedType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeUndeterminedType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
protected:
TypePtr
element_type_
;
};
using
MetaTensorTypePtr
=
std
::
shared_ptr
<
UndeterminedType
>
;
class
TensorType
:
public
Object
{
public:
TensorType
()
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
TensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
TensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
TensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
TensorTypePtr
=
std
::
shared_ptr
<
TensorType
>
;
class
RowTensorType
:
public
Object
{
public:
RowTensorType
()
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
RowTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeRowTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
RowTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
RowTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeRowTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
RowTensorTypePtr
=
std
::
shared_ptr
<
RowTensorType
>
;
class
SparseTensorType
:
public
Object
{
public:
SparseTensorType
()
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
SparseTensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeSparseTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
SparseTensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
SparseTensorType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeSparseTensorType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
SparseTensorTypePtr
=
std
::
shared_ptr
<
SparseTensorType
>
;
}
// namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
mindspore/core/ir/func_graph.h
浏览文件 @
24a10225
...
...
@@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase {
const
std
::
vector
<
AnfNodePtr
>
&
paramter_obj_nodes
()
const
{
return
paramter_obj_nodes_
;
}
void
add_parameter_obj_node
(
const
AnfNodePtr
&
p
);
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
&
make_ref_params
()
{
return
make_ref_params_
;
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
std
::
vector
<
BaseShapePtr
>
joined_shapes_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphTransform
>
transforms_
;
// parameter default value
std
::
map
<
std
::
string
,
AnfNodePtr
>
parameter_default_value_
;
std
::
unordered_map
<
AnfNodePtr
,
AnfNodePtr
>
make_ref_params_
;
size_t
seen_
;
std
::
list
<
CNodePtr
>
GetOrderedCnodes
();
...
...
mindspore/core/ir/meta_tensor.h
浏览文件 @
24a10225
...
...
@@ -23,6 +23,7 @@
#include <string>
#include "base/base.h"
#include "ir/param_info.h"
#include "ir/dtype.h"
#include "utils/convert_utils_base.h"
#include "utils/hashing.h"
...
...
@@ -163,6 +164,15 @@ class MetaTensor : public Value {
return
false
;
}
}
// Get tensor's param_info info.
ParamInfoPtr
param_info
()
const
{
return
param_info_
;
}
bool
is_parameter
()
const
{
return
is_parameter_
;
}
// Set tensor's param_info info.
void
set_param_info
(
const
ParamInfoPtr
&
param_info
)
{
is_parameter_
=
true
;
param_info_
=
param_info
;
}
protected:
// brief Data type of the tensor.
...
...
@@ -184,6 +194,9 @@ class MetaTensor : public Value {
//
// Includes the format and data type of a tensor on device.
DeviceInfo
device_info_
;
bool
is_parameter_
{
false
};
ParamInfoPtr
param_info_
{
nullptr
};
};
using
MetaTensorPtr
=
std
::
shared_ptr
<
MetaTensor
>
;
...
...
mindspore/core/ir/meta_tensor_extends.cc
浏览文件 @
24a10225
...
...
@@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
}
auto
tensor_shape
=
tens
->
shape
();
auto
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
dtype
,
tensor_shape
);
abs_tensor
->
set_value
(
shared_from_base
<
MetaTensor
>
());
// if is parameter always no value.
if
(
is_parameter
())
{
auto
param_name
=
param_info
()
->
name
();
auto
ref_key
=
std
::
make_shared
<
RefKey
>
(
param_name
);
auto
abs_ref_key
=
ref_key
->
ToAbstract
();
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractRef
>
(
abs_ref_key
,
abs_tensor
);
}
else
{
abs_tensor
->
set_value
(
shared_from_base
<
MetaTensor
>
());
}
return
abs_tensor
;
}
...
...
mindspore/core/ir/named.h
浏览文件 @
24a10225
...
...
@@ -62,6 +62,21 @@ class Named : public Value {
};
using
NamedPtr
=
std
::
shared_ptr
<
Named
>
;
struct
NamedHasher
{
std
::
size_t
operator
()(
NamedPtr
const
&
name
)
const
{
std
::
size_t
hash
=
name
->
Hash
();
return
hash
;
}
};
struct
NamedEqual
{
bool
operator
()(
NamedPtr
const
&
t1
,
NamedPtr
const
&
t2
)
const
{
MS_EXCEPTION_IF_NULL
(
t1
);
MS_EXCEPTION_IF_NULL
(
t2
);
return
*
t1
==
*
t2
;
}
};
class
None
:
public
Named
{
public:
None
()
:
Named
(
"None"
)
{}
...
...
mindspore/core/ir/param_info.h
浏览文件 @
24a10225
...
...
@@ -21,10 +21,13 @@
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/
tensor
.h"
#include "ir/
dtype
.h"
namespace
mindspore
{
class
ParamInfo
;
using
ParamInfoPtr
=
std
::
shared_ptr
<
ParamInfo
>
;
class
ParamInfo
{
public:
ParamInfo
()
{}
...
...
@@ -55,7 +58,7 @@ class ParamInfo {
int32_t
cloned_index
()
const
{
return
cloned_index_
;
}
// Make a cloned parameter and update clone info.
Param
Value
Ptr
Clone
()
{
Param
Info
Ptr
Clone
()
{
static
std
::
atomic
<
int32_t
>
parameter_cloned_index
{
1
};
int32_t
index
=
parameter_cloned_index
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
auto
clone
=
std
::
make_shared
<
ParamInfo
>
(
*
this
);
...
...
mindspore/core/ir/tensor.cc
浏览文件 @
24a10225
...
...
@@ -461,6 +461,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
}
return
*
this
;
}
abstract
::
AbstractBasePtr
Tensor
::
ToAbstract
()
{
auto
tens
=
shared_from_base
<
Tensor
>
();
auto
dtype
=
tens
->
Dtype
();
...
...
@@ -469,7 +470,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
}
auto
tensor_shape
=
tens
->
shape
();
auto
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
dtype
,
tensor_shape
);
abs_tensor
->
set_value
(
shared_from_base
<
Tensor
>
());
// if is parameter always no value.
if
(
is_parameter
())
{
auto
param_name
=
param_info
()
->
name
();
auto
ref_key
=
std
::
make_shared
<
RefKey
>
(
param_name
);
auto
abs_ref_key
=
ref_key
->
ToAbstract
();
abs_tensor
=
std
::
make_shared
<
abstract
::
AbstractRef
>
(
abs_ref_key
,
abs_tensor
);
}
else
{
abs_tensor
->
set_value
(
shared_from_base
<
Tensor
>
());
}
return
abs_tensor
;
}
...
...
mindspore/core/ir/value.cc
浏览文件 @
24a10225
...
...
@@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const {
}
bool
StringImm
::
operator
==
(
const
StringImm
&
other
)
const
{
return
str_
==
other
.
str_
;
}
bool
RefKey
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
RefKey
>
())
{
auto
other_
=
static_cast
<
const
RefKey
&>
(
other
);
return
*
this
==
other_
;
}
else
{
return
false
;
}
}
bool
RefKey
::
operator
==
(
const
RefKey
&
other
)
const
{
return
tag_
==
other
.
tag_
;
}
bool
AnyValue
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
AnyValue
>
())
{
return
true
;
...
...
mindspore/core/ir/value.h
浏览文件 @
24a10225
...
...
@@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>;
IMM_TRAITS
(
StringImmPtr
,
std
::
string
)
IMM_TRAITS
(
StringImmPtr
,
const
char
*
)
class
RefKey
:
public
Value
{
class
RefKey
:
public
Named
{
public:
explicit
RefKey
(
const
std
::
string
&
tag
)
:
Value
(
kRefKeyType
),
tag_
(
tag
),
hash_
(
std
::
hash
<
std
::
string
>
{}(
tag
)
)
{}
explicit
RefKey
(
const
std
::
string
&
tag
)
:
Named
(
tag
)
{}
~
RefKey
()
override
=
default
;
MS_DECLARE_PARENT
(
RefKey
,
Value
)
std
::
size_t
hash
()
const
override
{
return
hash_
;
}
const
std
::
string
&
tag
()
const
{
return
tag_
;
}
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
RefKey
&
other
)
const
;
MS_DECLARE_PARENT
(
RefKey
,
Named
)
const
std
::
string
&
tag
()
const
{
return
name
();
}
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
std
::
string
ToString
()
const
override
{
return
"RefKey["
+
tag_
+
"]"
;
}
std
::
string
ToString
()
const
override
{
return
"RefKey["
+
name
()
+
"]"
;
}
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"RefKey[
\"
"
<<
tag_
<<
"
\"
]"
;
oss
<<
"RefKey[
\"
"
<<
name
()
<<
"
\"
]"
;
return
oss
.
str
();
}
private:
std
::
string
tag_
;
std
::
size_t
hash_
=
0
;
};
using
RefKeyPtr
=
std
::
shared_ptr
<
RefKey
>
;
...
...
mindspore/lite/test/CMakeLists.txt
浏览文件 @
24a10225
...
...
@@ -43,6 +43,8 @@ if(BUILD_CONVERTER)
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/scope.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/value_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/ref.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/tensor_type.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/container.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/empty.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../core/ir/dtype/number.cc
...
...
mindspore/lite/tools/converter/CMakeLists.txt
浏览文件 @
24a10225
...
...
@@ -29,6 +29,8 @@ set(ANF_SRC
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/scope.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/value.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/value_extends.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/ref.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/tensor_type.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/container.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/empty.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../../core/ir/dtype/number.cc
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
24a10225
...
...
@@ -23,7 +23,7 @@ from ...common import dtype as mstype
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
class
Assign
(
Primitive
WithInfer
):
class
Assign
(
Primitive
):
"""
Assign `Parameter` with a value.
...
...
mindspore/ops/primitive.py
浏览文件 @
24a10225
...
...
@@ -18,7 +18,6 @@
import
inspect
import
copy
from
mindspore.common.api
import
_wrap_func
from
mindspore.common
import
Parameter
from
mindspore.common._register_for_tensor
import
tensor_operator_registry
from
mindspore
import
context
from
.._c_expression
import
Primitive_
,
real_run_op
,
prim_type
...
...
@@ -410,16 +409,12 @@ def _run_op(obj, op_name, args):
if
op_name
==
"Cast"
or
obj
.
update_parameter
:
cast_args
=
args
else
:
cast_args
=
list
()
for
arg
in
args
:
if
isinstance
(
arg
,
Parameter
):
if
arg
.
cast_type
:
cast_args
.
append
(
cast
(
arg
,
arg
.
cast_type
))
else
:
cast_args
.
append
(
arg
)
else
:
cast_args
.
append
(
arg
)
output
=
real_run_op
(
obj
,
op_name
,
tuple
(
cast_args
))
cast_args
=
args
for
idx
,
arg
in
enumerate
(
args
):
cast_type
=
getattr
(
arg
,
"cast_type"
,
None
)
if
cast_type
:
cast_args
[
idx
]
=
cast
(
arg
,
cast_type
)
output
=
real_run_op
(
obj
,
op_name
,
cast_args
)
if
not
output
:
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
if
len
(
output
)
==
1
:
...
...
tests/st/control/test_ascend_control_sink.py
浏览文件 @
24a10225
...
...
@@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell):
self
.
var
=
Parameter
(
initializer
(
1
,
(
1
),
mstype
.
float32
),
name
=
"var"
)
def
construct
(
self
,
x
,
y
,
z
,
c2
,
c4
):
out
=
self
.
assign
(
self
.
var
,
c4
)
out
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
x
<
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
y
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
y
<
c2
and
x
<
c2
:
if
2
*
y
<
c2
:
y
=
y
+
2
else
:
y
=
y
+
1
out
=
out
+
y
z
=
self
.
assign
(
self
.
var
,
c4
)
z
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
z
=
z
+
1
out
=
out
+
z
x
=
x
+
1
out
=
out
+
x
while
x
<
2
*
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
y
=
c4
self
.
assign
(
self
.
var
,
c4
)
x
=
x
+
1
while
y
<
c2
:
z
=
self
.
assign
(
self
.
var
,
c4
)
z
=
c4
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
z
=
z
+
1
if
x
<
c2
:
...
...
tests/ut/python/pipeline/parse/test_parse.py
浏览文件 @
24a10225
...
...
@@ -27,6 +27,7 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.common.api
import
ms_function
,
_executor
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
...
...
@@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape():
net
=
BpropWithWrongOutputShapeCell
()
net
.
set_grad
()
grad_all
(
net
)(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
class
AssignWhenInsertGrad
(
nn
.
Cell
):
""" NetWithNDarray definition """
def
__init__
(
self
):
super
(
AssignWhenInsertGrad
,
self
).
__init__
()
self
.
gather
=
P
.
GatherV2
()
self
.
damping
=
Tensor
(
np
.
array
([
0.03
,
0.03
]).
astype
(
np
.
float32
))
self
.
cov_step
=
ms
.
Parameter
(
0
,
name
=
"cov_step"
,
requires_grad
=
False
)
self
.
freq
=
Tensor
(
278
,
ms
.
int32
)
self
.
getG
=
P
.
InsertGradientOf
(
self
.
save_gradient
)
def
save_gradient
(
self
,
dout
):
self
.
cov_step
=
self
.
cov_step
+
self
.
freq
return
dout
def
construct
(
self
,
x
):
self
.
gather
(
self
.
damping
,
self
.
cov_step
,
0
)
out
=
P
.
ReLU
()(
x
)
out
=
self
.
getG
(
out
)
return
out
grad_all
=
C
.
GradOperation
(
'get_all'
,
get_all
=
True
)
class
GradNet
(
nn
.
Cell
):
def
__init__
(
self
,
net
):
super
(
GradNet
,
self
).
__init__
()
self
.
net
=
net
def
construct
(
self
,
*
inputs
):
out
=
self
.
net
(
*
inputs
)
return
out
,
grad_all
(
self
.
net
)(
*
inputs
)
def
test_assign_in_insert_grad
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
AssignWhenInsertGrad
().
to_float
(
ms
.
float16
)
input_data
=
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]]).
astype
(
'float32'
)
net_back
=
GradNet
(
net
)
net_back
(
ms
.
Tensor
(
input_data
))
class
Assign
(
nn
.
Cell
):
""" NetWithNDarray definition """
def
__init__
(
self
):
super
(
Assign
,
self
).
__init__
()
self
.
cov_step
=
ms
.
Parameter
(
0.0
,
name
=
"cov_step"
,
requires_grad
=
False
)
def
construct
(
self
,
x
):
self
.
cov_step
=
self
.
cov_step
+
x
return
self
.
cov_step
def
test_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
Assign
()
input_data
=
ms
.
Tensor
(
np
.
array
(
1
).
astype
(
np
.
int32
))
net_back
=
GradNet
(
net
)
net_back
(
input_data
)
tests/ut/python/pipeline/parse/test_while_param.py
0 → 100644
浏览文件 @
24a10225
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_cont_break """
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
Tensor
,
context
,
nn
,
ms_function
from
mindspore.nn
import
Cell
from
mindspore.ops
import
operations
as
P
class
WhileSubGraphParam
(
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
update
=
ms
.
Parameter
(
Tensor
(
1
,
ms
.
float32
),
"update"
)
def
construct
(
self
,
x
,
y
,
z
):
out1
=
z
while
x
<
y
:
self
.
update
=
self
.
update
+
1
out1
=
out1
+
1
x
=
x
+
1
return
out1
,
self
.
update
def
test_while_loop_phi
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
0
,
ms
.
float32
)
y
=
Tensor
(
10
,
ms
.
float32
)
z
=
Tensor
(
100
,
ms
.
float32
)
net
=
WhileSubGraphParam
()
net
(
x
,
y
,
z
)
class
WhileSubGraphParam2
(
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
update
=
ms
.
Parameter
(
Tensor
(
1
,
ms
.
float32
),
"update"
)
def
construct
(
self
,
x
,
y
,
z
):
out1
=
z
i
=
self
.
update
while
x
<
y
:
i
=
i
+
1
out1
=
out1
+
1
x
=
x
+
1
return
out1
,
self
.
update
def
test_while_loop_phi_2
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
0
,
ms
.
float32
)
y
=
Tensor
(
10
,
ms
.
float32
)
z
=
Tensor
(
100
,
ms
.
float32
)
net
=
WhileSubGraphParam2
()
net
(
x
,
y
,
z
)
class
WhileSubGraphParam3
(
Cell
):
def
__init__
(
self
,
initial_input_x
):
super
().
__init__
()
self
.
initial_input_x
=
initial_input_x
self
.
X
=
ms
.
Parameter
(
initial_input_x
,
name
=
"parameter_x"
)
self
.
Y
=
ms
.
Parameter
(
self
.
initial_input_x
,
name
=
"parameter_y"
)
def
construct
(
self
):
a
=
0
while
a
<
3
:
self
.
X
=
self
.
X
+
self
.
Y
a
+=
1
return
self
.
X
def
test_while_loop_phi_3
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
0
,
ms
.
float32
)
net
=
WhileSubGraphParam3
(
x
)
net
()
class
ControlMixedWhileIf
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
assign
=
P
.
Assign
()
self
.
var
=
ms
.
Parameter
(
ms
.
Tensor
([
1
],
ms
.
float32
),
name
=
"var"
)
@
ms_function
def
construct
(
self
,
x
,
y
,
z
,
c2
,
c4
):
out
=
self
.
assign
(
self
.
var
,
c4
)
while
x
<
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
while
y
<
c2
and
x
<
c2
:
if
2
*
y
<
c2
:
y
=
y
+
2
else
:
y
=
y
+
1
out
=
out
+
y
z
=
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
z
=
z
+
1
out
=
out
+
z
x
=
x
+
1
out
=
out
+
x
while
x
<
2
*
c2
:
y
=
self
.
assign
(
self
.
var
,
c4
)
x
=
x
+
1
while
y
<
c2
:
z
=
self
.
assign
(
self
.
var
,
c4
)
while
z
<
c2
:
z
=
z
+
1
if
x
<
c2
:
y
=
y
-
1
else
:
y
=
y
+
1
out
=
out
+
z
out
=
out
+
y
out
=
out
+
x
return
out
def
test_mixed_while_if
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
x
=
np
.
array
(
2
).
astype
(
np
.
int32
)
y
=
np
.
array
(
14
).
astype
(
np
.
int32
)
z
=
np
.
array
(
1
).
astype
(
np
.
int32
)
c2
=
Tensor
([
14
],
ms
.
int32
)
c4
=
Tensor
([
0
],
ms
.
int32
)
net
=
ControlMixedWhileIf
()
output
=
net
(
Tensor
(
x
),
Tensor
(
y
),
Tensor
(
z
),
c2
,
c4
)
expect
=
np
.
array
(
3318
).
astype
(
np
.
int32
)
assert
np
.
allclose
(
expect
,
output
.
asnumpy
(),
0.0001
,
0.0001
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
tests/vm_impl/array_ops_vm_impl.py
浏览文件 @
24a10225
...
...
@@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from
.vm_interface
import
vm
# pylint: disable=unused-argument
@
vm_impl_getters
.
register
(
P
.
Assign
)
def
vm_impl_assign
(
self
):
"""Generate vm_impl function for Assign"""
def
vm_impl
(
x
,
value
):
x
.
assign_value
(
value
)
return
x
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
ExpandDims
)
def
vm_impl_expand_dims
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录