Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
fe82d821
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fe82d821
编写于
6月 30, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 30, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1904 Add IndexedSlices
Merge pull request !1904 from riemann_penn/add_indexed_slices
上级
3cb53143
d6635bbb
变更
41
展开全部
隐藏空白更改
内联
并排
Showing
41 changed file
with
1208 addition
and
192 deletion
+1208
-192
mindspore/_extends/parse/resources.py
mindspore/_extends/parse/resources.py
+4
-0
mindspore/ccsrc/debug/dump_proto.cc
mindspore/ccsrc/debug/dump_proto.cc
+4
-0
mindspore/ccsrc/ir/dtype.cc
mindspore/ccsrc/ir/dtype.cc
+84
-0
mindspore/ccsrc/ir/dtype.h
mindspore/ccsrc/ir/dtype.h
+51
-2
mindspore/ccsrc/ir/dtype/type.cc
mindspore/ccsrc/ir/dtype/type.cc
+4
-0
mindspore/ccsrc/ir/dtype/type.h
mindspore/ccsrc/ir/dtype/type.h
+7
-2
mindspore/ccsrc/ir/dtype/type_id.h
mindspore/ccsrc/ir/dtype/type_id.h
+2
-0
mindspore/ccsrc/ir/dtype_extends.cc
mindspore/ccsrc/ir/dtype_extends.cc
+58
-0
mindspore/ccsrc/operator/composite/multitype_funcgraph.cc
mindspore/ccsrc/operator/composite/multitype_funcgraph.cc
+42
-8
mindspore/ccsrc/operator/composite/multitype_funcgraph.h
mindspore/ccsrc/operator/composite/multitype_funcgraph.h
+1
-0
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+7
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+7
-0
mindspore/ccsrc/operator/prim_others.cc
mindspore/ccsrc/operator/prim_others.cc
+76
-0
mindspore/ccsrc/optimizer/clean.cc
mindspore/ccsrc/optimizer/clean.cc
+2
-1
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+6
-0
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+3
-0
mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h
mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h
+75
-0
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+3
-0
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+3
-1
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+1
-0
mindspore/ccsrc/pipeline/resource.cc
mindspore/ccsrc/pipeline/resource.cc
+151
-142
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
+87
-10
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
+61
-18
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
+14
-0
mindspore/ccsrc/pipeline/static_analysis/param_validator.h
mindspore/ccsrc/pipeline/static_analysis/param_validator.h
+1
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+49
-0
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+11
-0
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
+4
-0
mindspore/ccsrc/pipeline/validator.cc
mindspore/ccsrc/pipeline/validator.cc
+3
-1
mindspore/ccsrc/utils/context/ms_context.cc
mindspore/ccsrc/utils/context/ms_context.cc
+1
-0
mindspore/ccsrc/utils/context/ms_context.h
mindspore/ccsrc/utils/context/ms_context.h
+4
-0
mindspore/common/__init__.py
mindspore/common/__init__.py
+2
-2
mindspore/common/parameter.py
mindspore/common/parameter.py
+15
-1
mindspore/common/tensor.py
mindspore/common/tensor.py
+6
-1
mindspore/context.py
mindspore/context.py
+11
-1
mindspore/ops/functional.py
mindspore/ops/functional.py
+8
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+13
-0
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
+35
-0
tests/ut/python/ir/test_indexed_slices.py
tests/ut/python/ir/test_indexed_slices.py
+290
-0
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
+1
-1
未找到文件。
mindspore/_extends/parse/resources.py
浏览文件 @
fe82d821
...
...
@@ -17,6 +17,7 @@
"""Resources for ast tree parse."""
import
ast
import
math
from
mindspore
import
IndexedSlices
from
mindspore.ops.composite
import
multitype_ops
from
mindspore.ops
import
functional
as
F
,
composite
as
C
from
.
import
standard_method
as
M
...
...
@@ -135,4 +136,7 @@ convert_object_map = {
math
.
sin
:
NO_IMPLEMENT
,
math
.
cos
:
NO_IMPLEMENT
,
math
.
tan
:
NO_IMPLEMENT
,
# user defined
IndexedSlices
:
F
.
make_indexed_slices
,
}
mindspore/ccsrc/debug/dump_proto.cc
浏览文件 @
fe82d821
...
...
@@ -120,6 +120,10 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
type_proto
->
mutable_tensor_type
()
->
mutable_shape
()
->
add_dim
()
->
set_size
(
elem
);
}
}
}
else
if
(
type
->
isa
<
IndexedSlicesType
>
())
{
// Do Nothing
}
else
if
(
type
->
isa
<
UndeterminedType
>
())
{
// Do Nothing
}
else
if
(
type
->
isa
<
Tuple
>
())
{
TuplePtr
tuple_type
=
dyn_cast
<
Tuple
>
(
type
);
type_proto
->
set_data_type
(
irpb
::
DT_TUPLE
);
...
...
mindspore/ccsrc/ir/dtype.cc
浏览文件 @
fe82d821
...
...
@@ -94,6 +94,48 @@ bool Slice::operator==(const Type &other) const {
std
::
string
Slice
::
DumpText
()
const
{
return
ToString
();
}
TypePtr
UndeterminedType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
return
std
::
make_shared
<
UndeterminedType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
UndeterminedType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
UndeterminedType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"Undetermined"
;
}
return
"Undetermined["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
UndeterminedType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
UndeterminedType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
TensorType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
...
...
@@ -137,6 +179,48 @@ bool TensorType::operator==(const Type &other) const {
return
*
element_type_
==
*
other_elem_type
;
}
TypePtr
IndexedSlicesType
::
DeepCopy
()
const
{
MS_EXCEPTION_IF_NULL
(
element_type_
);
if
(
IsGeneric
())
{
return
std
::
make_shared
<
IndexedSlicesType
>
();
}
return
std
::
make_shared
<
IndexedSlicesType
>
(
element_type_
->
DeepCopy
());
}
std
::
string
IndexedSlicesType
::
ToReprString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"IndexedSlices"
;
}
return
"IndexedSlices["
+
element_type_
->
ToReprString
()
+
"]"
;
}
std
::
string
IndexedSlicesType
::
ToString
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"IndexedSlices"
;
}
return
"IndexedSlices["
+
element_type_
->
ToString
()
+
"]"
;
}
std
::
string
IndexedSlicesType
::
DumpText
()
const
{
if
(
element_type_
==
nullptr
)
{
return
"IndexedSlices"
;
}
return
"IndexedSlices["
+
element_type_
->
DumpText
()
+
"]"
;
}
bool
IndexedSlicesType
::
operator
==
(
const
Type
&
other
)
const
{
if
(
!
IsSameObjectType
(
*
this
,
other
))
{
return
false
;
}
auto
other_elem_type
=
static_cast
<
const
IndexedSlicesType
&>
(
other
).
element_type_
;
if
(
element_type_
==
nullptr
&&
other_elem_type
==
nullptr
)
{
return
true
;
}
else
if
(
element_type_
==
nullptr
||
other_elem_type
==
nullptr
)
{
return
false
;
}
return
*
element_type_
==
*
other_elem_type
;
}
Function
::
Function
()
:
Object
(
kObjectTypeFunction
)
{
args_
=
std
::
vector
<
TypePtr
>
();
retval_
=
nullptr
;
...
...
mindspore/ccsrc/ir/dtype.h
浏览文件 @
fe82d821
...
...
@@ -108,10 +108,34 @@ class Slice : public Object {
};
using
SlicePtr
=
std
::
shared_ptr
<
Slice
>
;
class
UndeterminedType
:
public
Object
{
public:
UndeterminedType
()
:
Object
(
kObjectTypeUndeterminedType
)
{}
explicit
UndeterminedType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeUndeterminedType
,
kMetaTypeObject
,
false
),
element_type_
(
ele
)
{}
~
UndeterminedType
()
override
=
default
;
MS_DECLARE_PARENT
(
UndeterminedType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeUndeterminedType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
protected:
TypePtr
element_type_
;
};
using
MetaTensorTypePtr
=
std
::
shared_ptr
<
UndeterminedType
>
;
class
TensorType
:
public
Object
{
public:
TensorType
()
:
Object
(
kObjectTypeTensorType
)
{}
explicit
TensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeTensorType
,
false
),
element_type_
(
ele
)
{}
TensorType
()
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
)
{}
explicit
TensorType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeTensorType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
TensorType
()
override
=
default
;
MS_DECLARE_PARENT
(
TensorType
,
Object
)
...
...
@@ -130,6 +154,29 @@ class TensorType : public Object {
};
using
TensorTypePtr
=
std
::
shared_ptr
<
TensorType
>
;
class
IndexedSlicesType
:
public
Object
{
public:
IndexedSlicesType
()
:
Object
(
kObjectTypeIndexedSlicesType
,
kObjectTypeUndeterminedType
)
{}
explicit
IndexedSlicesType
(
const
TypePtr
&
ele
)
:
Object
(
kObjectTypeIndexedSlicesType
,
kObjectTypeUndeterminedType
,
false
),
element_type_
(
ele
)
{}
~
IndexedSlicesType
()
override
=
default
;
MS_DECLARE_PARENT
(
IndexedSlicesType
,
Object
)
TypeId
generic_type_id
()
const
override
{
return
kObjectTypeIndexedSlicesType
;
}
const
TypePtr
element
()
const
{
return
element_type_
;
}
void
set_element
(
const
TypePtr
&
element_type
)
{
element_type_
=
element_type
;
}
TypePtr
DeepCopy
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
ToReprString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
Type
&
other
)
const
override
;
private:
TypePtr
element_type_
;
};
using
IndexedSlicesTypePtr
=
std
::
shared_ptr
<
IndexedSlicesType
>
;
class
Function
:
public
Object
{
public:
Function
();
...
...
@@ -255,6 +302,8 @@ TypePtr StringToType(const std::string &type_name);
// Judge whether x is predicate or is a subclass of predicate.
bool
IsIdentidityOrSubclass
(
TypePtr
const
&
x
,
TypePtr
const
&
base_type
);
bool
IsParentOrChildrenType
(
TypePtr
const
&
x
,
TypePtr
const
&
base_type
);
// Whether t1 is identity or a subclass of t2.
bool
IsSubType
(
TypePtr
const
&
t1
,
TypePtr
const
&
t2
=
nullptr
);
...
...
mindspore/ccsrc/ir/dtype/type.cc
浏览文件 @
fe82d821
...
...
@@ -115,6 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) {
return
"kObjectTypeKeyword"
;
case
kObjectTypeTensorType
:
return
"kObjectTypeTensorType"
;
case
kObjectTypeIndexedSlicesType
:
return
"kObjectTypeIndexedSlicesType"
;
case
kObjectTypeUndeterminedType
:
return
"kObjectTypeUndeterminedType"
;
case
kObjectTypeDictionary
:
return
"kObjectTypeDictionary"
;
case
kObjectTypeClass
:
...
...
mindspore/ccsrc/ir/dtype/type.h
浏览文件 @
fe82d821
...
...
@@ -67,6 +67,7 @@ class Type : public Value {
virtual
bool
equal
(
const
TypePtr
other
)
const
{
return
*
this
==
*
other
;
}
virtual
TypeId
object_type
()
const
{
return
kTypeUnknown
;
}
virtual
TypeId
parent_type
()
const
{
return
kTypeUnknown
;
}
virtual
TypeId
number_type
()
const
{
return
kTypeUnknown
;
}
virtual
TypePtr
DeepCopy
()
const
=
0
;
virtual
TypePtr
Clone
()
const
{
return
DeepCopy
();
}
...
...
@@ -97,13 +98,16 @@ using TypePtrList = std::vector<TypePtr>;
//
class
Object
:
public
Type
{
public:
Object
()
:
Type
(
kMetaTypeObject
),
object_type_
(
kMetaTypeObject
)
{}
Object
()
:
Type
(
kMetaTypeObject
),
object_type_
(
kMetaTypeObject
)
,
parent_type_
(
kMetaTypeObject
)
{}
explicit
Object
(
const
TypeId
object_type
,
bool
is_generic
=
true
)
:
Type
(
kMetaTypeObject
,
is_generic
),
object_type_
(
object_type
)
{}
:
Type
(
kMetaTypeObject
,
is_generic
),
object_type_
(
object_type
),
parent_type_
(
kMetaTypeObject
)
{}
explicit
Object
(
const
TypeId
object_type
,
const
TypeId
parent_type
,
bool
is_generic
=
true
)
:
Type
(
kMetaTypeObject
,
is_generic
),
object_type_
(
object_type
),
parent_type_
(
parent_type
)
{}
~
Object
()
override
=
default
;
MS_DECLARE_PARENT
(
Object
,
Type
)
TypeId
object_type
()
const
override
{
return
object_type_
;
}
TypeId
parent_type
()
const
override
{
return
parent_type_
;
}
TypeId
type_id
()
const
override
{
return
object_type_
;
}
TypeId
generic_type_id
()
const
override
{
return
kMetaTypeObject
;
}
bool
equal
(
const
TypePtr
other
)
const
override
;
...
...
@@ -114,6 +118,7 @@ class Object : public Type {
private:
const
TypeId
object_type_
;
const
TypeId
parent_type_
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
TypePtrList
&
types
);
...
...
mindspore/ccsrc/ir/dtype/type_id.h
浏览文件 @
fe82d821
...
...
@@ -50,6 +50,8 @@ enum TypeId : int {
kObjectTypeSlice
,
kObjectTypeKeyword
,
kObjectTypeTensorType
,
kObjectTypeIndexedSlicesType
,
kObjectTypeUndeterminedType
,
kObjectTypeClass
,
kObjectTypeDictionary
,
kObjectTypeFunction
,
...
...
mindspore/ccsrc/ir/dtype_extends.cc
浏览文件 @
fe82d821
...
...
@@ -192,6 +192,40 @@ TypePtr TensorStrToType(const std::string &type_name) {
return
type
;
}
TypePtr
IndexedSlicesStrToType
(
const
std
::
string
&
type_name
)
{
if
(
type_name
==
"IndexedSlices"
)
{
return
std
::
make_shared
<
IndexedSlicesType
>
();
}
auto
start
=
type_name
.
find_first_of
(
'['
)
+
1
;
auto
end
=
type_name
.
find_last_of
(
']'
);
if
(
start
>=
type_name
.
size
())
{
return
nullptr
;
}
auto
element_str
=
type_name
.
substr
(
start
,
end
-
start
);
auto
element_type
=
StringToType
(
element_str
);
if
(
element_type
==
nullptr
)
{
return
nullptr
;
}
return
std
::
make_shared
<
IndexedSlicesType
>
(
element_type
);
}
TypePtr
UndeterminedStrToType
(
const
std
::
string
&
type_name
)
{
if
(
type_name
==
"Undetermined"
)
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
auto
start
=
type_name
.
find_first_of
(
'['
)
+
1
;
auto
end
=
type_name
.
find_last_of
(
']'
);
if
(
start
>=
type_name
.
size
())
{
return
nullptr
;
}
auto
element_str
=
type_name
.
substr
(
start
,
end
-
start
);
auto
element_type
=
StringToType
(
element_str
);
if
(
element_type
==
nullptr
)
{
return
nullptr
;
}
return
std
::
make_shared
<
UndeterminedType
>
(
element_type
);
}
TypePtr
ListStrToType
(
const
std
::
string
&
type_name
)
{
TypePtr
type
=
nullptr
;
if
(
type_name
==
"List"
)
{
...
...
@@ -313,6 +347,10 @@ TypePtr StringToType(const std::string &type_name) {
type
=
StringToNumberType
<
Float
>
(
type_name
,
"Float"
);
}
else
if
(
type_name
.
compare
(
0
,
strlen
(
"Tensor"
),
"Tensor"
)
==
0
)
{
type
=
TensorStrToType
(
type_name
);
}
else
if
(
type_name
.
compare
(
0
,
strlen
(
"Undetermined"
),
"Undetermined"
)
==
0
)
{
type
=
UndeterminedStrToType
(
type_name
);
}
else
if
(
type_name
.
compare
(
0
,
strlen
(
"IndexedSlices"
),
"IndexedSlices"
)
==
0
)
{
type
=
IndexedSlicesStrToType
(
type_name
);
}
else
if
(
type_name
.
compare
(
0
,
strlen
(
"List"
),
"List"
)
==
0
)
{
type
=
ListStrToType
(
type_name
);
}
else
if
(
type_name
.
compare
(
0
,
strlen
(
"Tuple"
),
"Tuple"
)
==
0
)
{
...
...
@@ -340,6 +378,20 @@ TypePtr StringToType(const std::string &type_name) {
return
type
;
}
bool
IsParentOrChildrenType
(
TypePtr
const
&
x
,
TypePtr
const
&
base_type
)
{
if
(
x
==
nullptr
||
base_type
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Type is nullptr."
;
return
false
;
}
if
(
base_type
->
type_id
()
==
kTypeUnknown
||
x
->
type_id
()
==
kTypeUnknown
)
{
return
false
;
}
if
(
base_type
->
type_id
()
==
x
->
parent_type
()
||
x
->
type_id
()
==
base_type
->
parent_type
())
{
return
true
;
}
return
false
;
}
bool
IsIdentidityOrSubclass
(
TypePtr
const
&
x
,
TypePtr
const
&
base_type
)
{
if
(
x
==
nullptr
||
base_type
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Type is nullptr."
;
...
...
@@ -481,6 +533,10 @@ REGISTER_PYBIND_DEFINE(
TensorType
data
(
TypeIdToType
(
TypeId
(
static_cast
<
int
>
(
t
[
0
].
cast
<
py
::
int_
>
()))));
return
data
;
}));
(
void
)
py
::
class_
<
IndexedSlicesType
,
Type
,
std
::
shared_ptr
<
IndexedSlicesType
>>
(
m_sub
,
"IndexedSlicesType"
)
.
def
(
py
::
init
());
(
void
)
py
::
class_
<
UndeterminedType
,
Type
,
std
::
shared_ptr
<
UndeterminedType
>>
(
m_sub
,
"UndeterminedType"
)
.
def
(
py
::
init
());
(
void
)
py
::
class_
<
Function
,
Type
,
std
::
shared_ptr
<
Function
>>
(
m_sub
,
"Function"
)
.
def
(
py
::
init
())
.
def
(
py
::
init
<
std
::
vector
<
TypePtr
>
,
TypePtr
>
(),
py
::
arg
(
"args"
),
py
::
arg
(
"retval"
));
...
...
@@ -501,6 +557,8 @@ const TypePtr kTypeExternal = std::make_shared<External>();
const
TypePtr
kTypeEnv
=
std
::
make_shared
<
EnvType
>
();
const
TypePtr
kTypeType
=
std
::
make_shared
<
TypeType
>
();
const
TypePtr
kTensorType
=
std
::
make_shared
<
TensorType
>
();
const
TypePtr
kIndexedSlicesType
=
std
::
make_shared
<
IndexedSlicesType
>
();
const
TypePtr
kUndeterminedType
=
std
::
make_shared
<
UndeterminedType
>
();
const
TypePtr
kString
=
std
::
make_shared
<
String
>
();
const
TypePtr
kList
=
std
::
make_shared
<
List
>
();
const
TypePtr
kTuple
=
std
::
make_shared
<
Tuple
>
();
...
...
mindspore/ccsrc/operator/composite/multitype_funcgraph.cc
浏览文件 @
fe82d821
...
...
@@ -93,15 +93,17 @@ static TypePtr UnwrapRef(const TypePtr &type) {
}
return
type
;
}
FuncGraphPtr
MultitypeFuncGraph
::
GenerateFromTypes
(
const
TypePtrList
&
types
)
{
bool
find_fn
=
false
;
py
::
function
py_fn
;
// Return Exact match if exists, else return non ambiguous sub class match
// Return py::none() if matching is ambiguous
const
py
::
function
MultitypeFuncGraph
::
SignMatch
(
const
TypePtrList
&
types
)
{
// Exact match
for
(
auto
&
item
:
fn_cache_py_
)
{
TypePtrList
sign
=
item
.
first
;
if
(
sign
.
size
()
!=
types
.
size
())
{
continue
;
}
bool
match
=
true
;
auto
match
=
true
;
for
(
size_t
i
=
0
;
i
<
sign
.
size
();
++
i
)
{
if
(
!
IsIdentidityOrSubclass
(
UnwrapRef
(
types
[
i
]),
sign
[
i
]))
{
match
=
false
;
...
...
@@ -111,13 +113,45 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
if
(
!
match
)
{
continue
;
}
find_fn
=
true
;
py_fn
=
item
.
second
;
break
;
return
item
.
second
;
}
// Try best match
py
::
function
py_fn_subclass
;
size_t
subclass_match_cnt
=
0
;
for
(
auto
&
item
:
fn_cache_py_
)
{
TypePtrList
sign
=
item
.
first
;
if
(
sign
.
size
()
!=
types
.
size
())
{
continue
;
}
auto
match
=
true
;
for
(
size_t
i
=
0
;
i
<
sign
.
size
();
++
i
)
{
if
(
!
IsIdentidityOrSubclass
(
UnwrapRef
(
types
[
i
]),
sign
[
i
])
&&
!
IsParentOrChildrenType
(
UnwrapRef
(
types
[
i
]),
sign
[
i
]))
{
match
=
false
;
break
;
}
}
if
(
!
match
)
{
continue
;
}
py_fn_subclass
=
item
.
second
;
subclass_match_cnt
++
;
}
if
(
subclass_match_cnt
>
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"There are more than one prototypes for overload function match by subclass"
;
}
if
(
subclass_match_cnt
==
1
)
{
MS_LOG
(
DEBUG
)
<<
"Found one subclass match"
;
return
py_fn_subclass
;
}
return
py
::
none
();
}
FuncGraphPtr
MultitypeFuncGraph
::
GenerateFromTypes
(
const
TypePtrList
&
types
)
{
auto
py_fn
=
SignMatch
(
types
);
std
::
ostringstream
buffer
;
buffer
<<
types
;
if
(
find_fn
)
{
if
(
py_fn
!=
py
::
none
()
)
{
FuncGraphPtr
func_graph
=
parse
::
ParsePythonCode
(
py_fn
);
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Fail to parse overload function "
<<
buffer
.
str
();
...
...
mindspore/ccsrc/operator/composite/multitype_funcgraph.h
浏览文件 @
fe82d821
...
...
@@ -54,6 +54,7 @@ class MultitypeFuncGraph : public MetaFuncGraph {
}
private:
const
py
::
function
SignMatch
(
const
TypePtrList
&
types
);
std
::
unordered_map
<
TypePtrList
,
specialize_fn
,
TypeListHasher
,
TypeListEqual
>
fn_cache_
;
std
::
unordered_map
<
TypePtrList
,
py
::
function
,
TypeListHasher
,
TypeListEqual
>
fn_cache_py_
;
};
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
fe82d821
...
...
@@ -277,5 +277,12 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
const
PrimitivePtr
kPrimTensorSummary
=
std
::
make_shared
<
Primitive
>
(
"TensorSummary"
);
const
PrimitivePtr
kPrimHistogramSummary
=
std
::
make_shared
<
Primitive
>
(
"HistogramSummary"
);
const
PrimitivePtr
kPrimDebug
=
std
::
make_shared
<
Primitive
>
(
"Debug"
);
// IndexedSlices
const
PrimitivePtr
kPrimMakeIndexedSlices
=
std
::
make_shared
<
Primitive
>
(
"MakeIndexedSlices"
);
const
PrimitivePtr
kPrimIndexedSlicesGetValues
=
std
::
make_shared
<
Primitive
>
(
"IndexedSlicesGetValues"
);
const
PrimitivePtr
kPrimIndexedSlicesGetIndices
=
std
::
make_shared
<
Primitive
>
(
"IndexedSlicesGetIndices"
);
const
PrimitivePtr
kPrimIndexedSlicesGetDenseShape
=
std
::
make_shared
<
Primitive
>
(
"IndexedSlicesGetDenseShape"
);
const
PrimitivePtr
kPrimIsIndexedSlices
=
std
::
make_shared
<
Primitive
>
(
"IsIndexedSlices"
);
}
// namespace prim
}
// namespace mindspore
mindspore/ccsrc/operator/ops.h
浏览文件 @
fe82d821
...
...
@@ -287,6 +287,13 @@ extern const PrimitivePtr kPrimMirror;
extern
const
PrimitivePtr
kPrimVirtualDiv
;
extern
const
PrimitivePtr
kPrimVirtualDataset
;
// IndexedSlices
extern
const
PrimitivePtr
kPrimMakeIndexedSlices
;
extern
const
PrimitivePtr
kPrimIndexedSlicesGetValues
;
extern
const
PrimitivePtr
kPrimIndexedSlicesGetIndices
;
extern
const
PrimitivePtr
kPrimIndexedSlicesGetDenseShape
;
extern
const
PrimitivePtr
kPrimIsIndexedSlices
;
class
DoSignaturePrimitive
:
public
Primitive
{
public:
explicit
DoSignaturePrimitive
(
const
std
::
string
&
name
,
const
ValuePtr
&
function
)
...
...
mindspore/ccsrc/operator/prim_others.cc
浏览文件 @
fe82d821
...
...
@@ -24,6 +24,7 @@
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
namespace
mindspore
{
namespace
abstract
{
...
...
@@ -173,6 +174,13 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
return
std
::
make_shared
<
AbstractTuple
>
(
sparse_list
);
}
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
if
(
enable_sparse_flag
&&
key
->
has_indexed_slices_grad
()
&&
dflt
->
isa
<
AbstractTensor
>
())
{
auto
dflt_tensor
=
dflt
->
cast
<
AbstractTensorPtr
>
();
return
std
::
make_shared
<
AbstractUndetermined
>
(
dflt_tensor
->
element
()
->
Clone
(),
dflt_tensor
->
shape
()
->
Clone
());
}
if
(
!
key
->
GetValueTrack
()
->
isa
<
SymbolicKeyInstance
>
())
{
return
dflt
;
}
...
...
@@ -236,6 +244,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
}
auto
ret
=
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
ret
->
set_sparse_grad
(
args_spec_list
[
2
]
->
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
args_spec_list
[
2
]
->
has_indexed_slices_grad
());
return
ret
;
}
...
...
@@ -437,5 +446,72 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
}
return
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
kBool
);
}
AbstractBasePtr
InferImplMakeIndexedSlices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two tensors and a tuple.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
3
);
auto
indices
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
auto
values
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
1
);
auto
dense_shape
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
2
);
auto
dense_shape_value
=
dense_shape
->
BuildValue
()
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
dense_shape_value
);
auto
shp
=
dense_shape_value
->
value
();
std
::
vector
<
int
>
dense_shape_vec
;
(
void
)
std
::
transform
(
std
::
begin
(
shp
),
std
::
end
(
shp
),
std
::
back_inserter
(
dense_shape_vec
),
[](
const
ValuePtr
&
e
)
->
int
{
auto
elem
=
GetValue
<
int
>
(
e
);
return
elem
;
});
auto
ret
=
std
::
make_shared
<
AbstractIndexedSlices
>
(
values
->
element
()
->
BuildType
(),
dense_shape_vec
);
ret
->
set_indices
(
indices
);
ret
->
set_values
(
values
);
ret
->
set_dense_shape
(
dense_shape
);
return
ret
;
}
AbstractBasePtr
InferImplIndexedSlicesGetValues
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two tensors and a tuple.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
auto
indexed_slices
=
CheckArg
<
AbstractIndexedSlices
>
(
op_name
,
args_spec_list
,
0
);
MS_EXCEPTION_IF_NULL
(
indexed_slices
->
values
());
return
indexed_slices
->
values
();
}
AbstractBasePtr
InferImplIndexedSlicesGetIndices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two tensors and a tuple.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
auto
indexed_slices
=
CheckArg
<
AbstractIndexedSlices
>
(
op_name
,
args_spec_list
,
0
);
MS_EXCEPTION_IF_NULL
(
indexed_slices
->
indices
());
return
indexed_slices
->
indices
();
}
AbstractBasePtr
InferImplIndexedSlicesGetDenseShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two tensors and a tuple.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
auto
indexed_slices
=
CheckArg
<
AbstractIndexedSlices
>
(
op_name
,
args_spec_list
,
0
);
MS_EXCEPTION_IF_NULL
(
indexed_slices
->
dense_shape
());
return
indexed_slices
->
dense_shape
();
}
AbstractBasePtr
InferImplIsIndexedSlices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
bool
ret
=
false
;
if
(
args_spec_list
[
0
]
->
isa
<
AbstractIndexedSlices
>
())
{
ret
=
true
;
}
MS_LOG
(
DEBUG
)
<<
"IsIndexedSlices result: "
<<
ret
<<
", input: "
<<
args_spec_list
[
0
]
->
ToString
();
return
std
::
make_shared
<
AbstractScalar
>
(
ret
);
}
}
// namespace abstract
}
// namespace mindspore
mindspore/ccsrc/optimizer/clean.cc
浏览文件 @
fe82d821
...
...
@@ -36,6 +36,7 @@ using mindspore::abstract::AbstractJTagged;
using
mindspore
::
abstract
::
AbstractList
;
using
mindspore
::
abstract
::
AbstractScalar
;
using
mindspore
::
abstract
::
AbstractTuple
;
using
mindspore
::
abstract
::
AbstractUndetermined
;
static
AbstractBasePtr
Reabs
(
const
AbstractBasePtr
&
t
)
{
if
(
t
==
nullptr
)
{
...
...
@@ -78,7 +79,7 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL
(
cons
);
auto
dt
=
data
->
abstract
();
if
(
dt
==
nullptr
)
{
if
(
dt
==
nullptr
||
dt
->
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
return
nullptr
;
}
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
fe82d821
...
...
@@ -42,6 +42,7 @@
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"
#include "optimizer/irpass/indexed_slices_eliminate.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -153,6 +154,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Mark interface fusion
mark_interface_fusion_
=
MakeSubstitution
(
std
::
make_shared
<
MarkInterfaceFusion
>
(),
"mark_interface_fusion"
,
prim
::
kPrimSelect
);
// IndexedSlices Eliminate
indexed_slices_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
IndexedSlicesEliminater
>
(),
"indexed_slices_eliminate"
,
{
prim
::
kPrimIndexedSlicesGetIndices
,
prim
::
kPrimIndexedSlicesGetValues
,
prim
::
kPrimIndexedSlicesGetDenseShape
});
}
ResolveIRPassLib
::
ResolveIRPassLib
()
{
...
...
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
fe82d821
...
...
@@ -104,6 +104,9 @@ class OptimizeIRPassLib {
// Fusion
SubstitutionPtr
mark_interface_fusion_
;
// IndexedSlices Eliminate
SubstitutionPtr
indexed_slices_eliminate_
;
};
// the collection of irpass for resolve action
...
...
mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h
0 → 100644
浏览文件 @
fe82d821
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "operator/ops.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}}
// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}}
// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}}
class
IndexedSlicesEliminater
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimIndexedSlicesGetIndices
,
{
IsCNode
})(
node
);
if
(
is_match_
)
{
return
tuple_
->
input
(
1
);
}
AnfVisitor
::
Match
(
prim
::
kPrimIndexedSlicesGetValues
,
{
IsCNode
})(
node
);
if
(
is_match_
)
{
return
tuple_
->
input
(
2
);
}
AnfVisitor
::
Match
(
prim
::
kPrimIndexedSlicesGetDenseShape
,
{
IsCNode
})(
node
);
if
(
is_match_
)
{
return
tuple_
->
input
(
3
);
}
return
nullptr
;
}
void
Visit
(
const
CNodePtr
&
cnode
)
override
{
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimMakeIndexedSlices
))
{
tuple_
=
cnode
;
is_match_
=
true
;
}
}
void
Reset
()
{
tuple_
=
nullptr
;
is_match_
=
false
;
}
private:
bool
is_match_
{
false
};
CNodePtr
tuple_
{
nullptr
};
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
fe82d821
...
...
@@ -232,6 +232,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
auto
sparse_grad
=
py
::
cast
<
std
::
string
>
(
parse
::
python_adapter
::
GetPyObjAttr
(
param_value
->
value
(),
"sparse_grad"
));
ptr
->
set_sparse_grad
(
sparse_grad
);
auto
has_indexed_slices_grad
=
py
::
cast
<
bool
>
(
parse
::
python_adapter
::
GetPyObjAttr
(
param_value
->
value
(),
"has_indexed_slices_grad"
));
ptr
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
);
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
ptr
);
args_spec
.
push_back
(
ptr
);
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
fe82d821
...
...
@@ -154,7 +154,9 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"set_print_file_path"
,
&
mindspore
::
MsContext
::
set_print_file_path
,
"Set path to print."
)
.
def
(
"set_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
set_enable_graph_kernel
,
"Set the GraphKernel switch to on or off."
)
.
def
(
"get_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
enable_graph_kernel
,
"Get the value of GraphKernel switch."
);
.
def
(
"get_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
enable_graph_kernel
,
"Get the value of GraphKernel switch."
)
.
def
(
"get_enable_sparse_flag"
,
&
mindspore
::
MsContext
::
enable_sparse_flag
,
"Get whether to enable sparse."
)
.
def
(
"set_enable_sparse_flag"
,
&
mindspore
::
MsContext
::
set_enable_sparse_flag
,
"Set whether to enable sparse."
);
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
fe82d821
...
...
@@ -156,6 +156,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
replace_refkey_by_param_
,
irpass
.
make_ref_eliminate_
,
irpass
.
get_ref_param_eliminate_
,
irpass
.
indexed_slices_eliminate_
,
});
OptPassGroupMap
map
({
{
"b_1"
,
b_1
},
...
...
mindspore/ccsrc/pipeline/resource.cc
浏览文件 @
fe82d821
此差异已折叠。
点击以展开。
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
浏览文件 @
fe82d821
...
...
@@ -30,6 +30,10 @@ bool AbstractBase::operator==(const AbstractBase &other) const {
if
(
tid
()
!=
other
.
tid
())
{
return
false
;
}
if
(
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
&&
other
.
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
return
true
;
}
if
(
value_
==
nullptr
||
other
.
value_
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"If value_ is nullptr, AbstractBase::operator== should not be called. this: "
<<
this
->
ToString
()
<<
", other: "
<<
other
.
ToString
();
...
...
@@ -65,7 +69,7 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL
(
shape_
);
buffer
<<
type_name
()
<<
"("
<<
"Type: "
<<
type_
->
ToString
()
<<
" Value: "
<<
value
<<
" Shape: "
<<
shape_
->
ToString
()
<<
" sparse_grad: "
<<
sparse_grad_
<<
")"
;
<<
" sparse_grad: "
<<
sparse_grad_
<<
"
has_indexed_slices_grad: "
<<
has_indexed_slices_grad_
<<
"
)"
;
return
buffer
.
str
();
}
...
...
@@ -76,6 +80,7 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
if
(
*
this
==
*
other
)
{
auto
ret
=
shared_from_base
<
AbstractBase
>
();
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
auto
value_self
=
GetValueTrack
();
...
...
@@ -85,10 +90,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
if
(
res_value
==
value_self
)
{
auto
ret
=
shared_from_base
<
AbstractBase
>
();
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
auto
ret
=
std
::
make_shared
<
AbstractScalar
>
(
res_value
,
res_type
);
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
...
...
@@ -409,6 +416,14 @@ std::size_t AbstractSlice::hash() const {
return
hash_combine
({
tid
(),
start_
->
hash
(),
stop_
->
hash
(),
step_
->
hash
()});
}
ShapePtr
AbstractUndetermined
::
shape
()
const
{
auto
shp
=
dyn_cast
<
Shape
>
(
GetShapeTrack
());
if
(
shp
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Tensor should have a shape."
;
}
return
shp
;
}
TypePtr
AbstractTensor
::
BuildType
()
const
{
MS_EXCEPTION_IF_NULL
(
element_
);
TypePtr
element_type
=
element_
->
BuildType
();
...
...
@@ -425,6 +440,13 @@ BaseShapePtr AbstractTensor::BuildShape() const {
}
AbstractBasePtr
AbstractTensor
::
Join
(
const
AbstractBasePtr
&
other
)
{
if
(
other
->
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
auto
other_tensor
=
dyn_cast
<
AbstractUndetermined
>
(
other
);
auto
element
=
element_
->
Join
(
other_tensor
->
element
());
auto
shape
=
ShapeJoin
(
this
->
shape
(),
other_tensor
->
shape
());
auto
ret
=
std
::
make_shared
<
AbstractUndetermined
>
(
element
,
shape
);
return
ret
;
}
auto
other_tensor
=
dyn_cast
<
AbstractTensor
>
(
other
);
if
(
other_tensor
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
...
...
@@ -433,6 +455,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
auto
shape
=
ShapeJoin
(
this
->
shape
(),
other_tensor
->
shape
());
auto
ret
=
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
...
...
@@ -474,6 +497,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
clone
->
set_shape
(
shp
->
Clone
());
clone
->
set_value
(
GetValueTrack
());
clone
->
set_sparse_grad
(
sparse_grad
());
clone
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
clone
;
}
...
...
@@ -484,6 +508,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
broaden
->
set_shape
(
shp
->
Clone
());
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_sparse_grad
(
sparse_grad
());
broaden
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
broaden
;
}
...
...
@@ -495,17 +520,10 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
broaden
->
set_shape
(
shp
);
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_sparse_grad
(
sparse_grad
());
broaden
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
broaden
;
}
ShapePtr
AbstractTensor
::
shape
()
const
{
auto
shp
=
dyn_cast
<
Shape
>
(
GetShapeTrack
());
if
(
shp
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Tensor should have a shape."
;
}
return
shp
;
}
std
::
string
AbstractTensor
::
ToString
()
const
{
std
::
ostringstream
buffer
;
BaseShapePtr
shape_track
=
GetShapeTrack
();
...
...
@@ -516,7 +534,7 @@ std::string AbstractTensor::ToString() const {
buffer
<<
type_name
()
<<
"("
<<
"shape: "
<<
shape_track
->
ToString
()
<<
", element: "
<<
element_
->
ToString
()
<<
", value_ptr: "
<<
value_track
<<
", value: "
<<
value_track
->
ToString
()
<<
" sparse_grad "
<<
sparse_grad
()
<<
")"
;
<<
"
has_indexed_slices_grad "
<<
has_indexed_slices_grad
()
<<
"
)"
;
return
buffer
.
str
();
}
...
...
@@ -1019,5 +1037,64 @@ std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &arg
bool
AbstractBasePtrListEqual
::
operator
()(
const
AbstractBasePtrList
&
lhs
,
const
AbstractBasePtrList
&
rhs
)
const
{
return
AbstractBasePtrListDeepEqual
(
lhs
,
rhs
);
}
// IndexedSlices
TypePtr
AbstractIndexedSlices
::
BuildType
()
const
{
MS_EXCEPTION_IF_NULL
(
element
());
TypePtr
element_type
=
element
()
->
BuildType
();
return
std
::
make_shared
<
IndexedSlicesType
>
(
element_type
);
}
AbstractBasePtr
AbstractIndexedSlices
::
Clone
()
const
{
MS_EXCEPTION_IF_NULL
(
element
());
auto
clone
=
std
::
make_shared
<
AbstractIndexedSlices
>
(
element
()
->
Clone
());
ShapePtr
shp
=
shape
();
clone
->
set_shape
(
shp
->
Clone
());
clone
->
set_value
(
GetValueTrack
());
clone
->
set_indices
(
indices_
->
Clone
()
->
cast
<
AbstractTensorPtr
>
());
clone
->
set_values
(
values_
->
Clone
()
->
cast
<
AbstractTensorPtr
>
());
clone
->
set_dense_shape
(
dense_shape_
->
Clone
()
->
cast
<
AbstractTuplePtr
>
());
return
clone
;
}
AbstractBasePtr
AbstractIndexedSlices
::
Broaden
()
const
{
MS_EXCEPTION_IF_NULL
(
element
());
auto
broaden
=
std
::
make_shared
<
AbstractIndexedSlices
>
(
element
()
->
Broaden
());
auto
shp
=
shape
();
broaden
->
set_shape
(
shp
->
Clone
());
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_indices
(
indices_
->
Clone
()
->
cast
<
AbstractTensorPtr
>
());
broaden
->
set_values
(
values_
->
Clone
()
->
cast
<
AbstractTensorPtr
>
());
broaden
->
set_dense_shape
(
dense_shape_
->
Clone
()
->
cast
<
AbstractTuplePtr
>
());
return
broaden
;
}
AbstractBasePtr
AbstractIndexedSlices
::
BroadenWithShape
()
const
{
MS_EXCEPTION_IF_NULL
(
element
());
auto
broaden
=
std
::
make_shared
<
AbstractIndexedSlices
>
(
element
()
->
Broaden
());
auto
shp
=
shape
()
->
Clone
();
shp
->
Broaden
();
broaden
->
set_shape
(
shp
);
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_indices
(
indices_
->
Clone
()
->
cast
<
AbstractTensorPtr
>
());
broaden
->
set_values
(
values_
->
Clone
()
->
cast
<
AbstractTensorPtr
>
());
broaden
->
set_dense_shape
(
dense_shape_
->
Clone
()
->
cast
<
AbstractTuplePtr
>
());
return
broaden
;
}
std
::
string
AbstractIndexedSlices
::
ToString
()
const
{
std
::
ostringstream
buffer
;
BaseShapePtr
shape_track
=
GetShapeTrack
();
MS_EXCEPTION_IF_NULL
(
shape_track
);
MS_EXCEPTION_IF_NULL
(
element
());
auto
value_track
=
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
value_track
);
buffer
<<
type_name
()
<<
"("
<<
"shape: "
<<
shape_track
->
ToString
()
<<
", element: "
<<
element
()
->
ToString
()
<<
", value_ptr: "
<<
value_track
<<
", value: "
<<
value_track
->
ToString
()
<<
")"
<<
", indices: "
<<
indices_
->
ToString
()
<<
", values"
<<
values_
->
ToString
()
<<
", dense_shape: "
<<
dense_shape_
->
ToString
();
return
buffer
.
str
();
}
}
// namespace abstract
}
// namespace mindspore
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
浏览文件 @
fe82d821
...
...
@@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
explicit
AbstractBase
(
const
ValuePtr
&
value
=
nullptr
,
const
TypePtr
&
type
=
kAnyType
,
const
BaseShapePtr
&
shape
=
kNoShape
)
:
value_
(
value
),
type_
(
type
),
shape_
(
shape
),
sparse_grad_
(
""
)
{}
:
value_
(
value
),
type_
(
type
),
shape_
(
shape
),
sparse_grad_
(
""
)
,
has_indexed_slices_grad_
(
false
)
{}
~
AbstractBase
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractBase
,
Base
)
...
...
@@ -54,12 +54,16 @@ class AbstractBase : public Base {
virtual
bool
operator
==
(
const
AbstractBase
&
other
)
const
;
void
set_value
(
const
ValuePtr
&
value
)
{
value_
=
value
;
}
void
set_sparse_grad
(
const
std
::
string
&
sparse_grad
)
{
sparse_grad_
=
sparse_grad
;
}
void
set_has_indexed_slices_grad
(
const
bool
&
has_indexed_slices_grad
)
{
has_indexed_slices_grad_
=
has_indexed_slices_grad
;
}
void
set_type
(
const
TypePtr
&
type
)
{
type_
=
type
;
}
void
set_shape
(
const
BaseShapePtr
&
shape
)
{
shape_
=
shape
;
}
void
set_value_desc
(
const
std
::
string
&
desc
)
{
value_desc_
=
desc
;
}
const
std
::
string
&
value_desc
()
const
{
return
value_desc_
;
}
ValuePtr
GetValueTrack
()
const
{
return
value_
;
}
const
std
::
string
&
sparse_grad
()
const
{
return
sparse_grad_
;
}
const
bool
&
has_indexed_slices_grad
()
const
{
return
has_indexed_slices_grad_
;
}
TypePtr
GetTypeTrack
()
const
{
return
type_
;
}
BaseShapePtr
GetShapeTrack
()
const
{
return
shape_
;
}
...
...
@@ -88,6 +92,7 @@ class AbstractBase : public Base {
BaseShapePtr
shape_
;
std
::
string
value_desc_
;
// store initial value description for error report
std
::
string
sparse_grad_
;
bool
has_indexed_slices_grad_
;
};
class
AbstractScalar
:
public
AbstractBase
{
...
...
@@ -231,35 +236,49 @@ class AbstractKeywordArg : public AbstractBase {
};
using
AbstractKeywordArgPtr
=
std
::
shared_ptr
<
AbstractKeywordArg
>
;
class
Abstract
Tensor
:
public
AbstractBase
{
class
Abstract
Undetermined
:
public
AbstractBase
{
public:
// shape and type are all unknown
AbstractUndetermined
()
:
AbstractBase
(
kAnyValue
)
{}
// only element_ and value, shape track are valid member, type track are unknown.
explicit
Abstract
Tensor
(
const
AbstractBasePtr
&
element
,
const
BaseShapePtr
&
shape
=
std
::
make_shared
<
Shape
>
())
explicit
Abstract
Undetermined
(
const
AbstractBasePtr
&
element
,
const
BaseShapePtr
&
shape
=
std
::
make_shared
<
Shape
>
())
:
AbstractBase
(
kAnyValue
),
element_
(
element
)
{
if
(
element
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"element is nullptr"
;
}
if
(
element
->
isa
<
Abstract
Tensor
>
())
{
if
(
element
->
isa
<
Abstract
Undetermined
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"element type error"
;
}
set_shape
(
shape
);
}
Abstract
Tensor
(
const
TypePtr
&
element_type
,
const
std
::
vector
<
int
>
&
shape
)
Abstract
Undetermined
(
const
TypePtr
&
element_type
,
const
std
::
vector
<
int
>
&
shape
)
:
AbstractBase
(
kAnyValue
),
element_
(
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
element_type
))
{
if
(
element_type
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"element_type is nullptr"
;
}
set_shape
(
std
::
make_shared
<
Shape
>
(
shape
));
}
explicit
AbstractTensor
(
const
tensor
::
TensorPtr
&
tensor
)
:
AbstractBase
(
tensor
),
element_
(
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
tensor
->
Dtype
()))
{
if
(
tensor
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"tensor is nullptr"
;
}
set_shape
(
std
::
make_shared
<
Shape
>
(
tensor
->
shape
()));
}
~
AbstractUndetermined
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractUndetermined
,
AbstractBase
)
TypePtr
BuildType
()
const
override
{
return
std
::
make_shared
<
UndeterminedType
>
();
}
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractUndetermined
>
();
}
const
AbstractBasePtr
element
()
const
{
return
element_
;
}
ShapePtr
shape
()
const
;
protected:
AbstractBasePtr
element_
;
};
class
AbstractTensor
:
public
AbstractUndetermined
{
public:
// only element_ and value, shape track are valid member, type track are unknown.
explicit
AbstractTensor
(
const
AbstractBasePtr
&
element
,
const
BaseShapePtr
&
shape
=
std
::
make_shared
<
Shape
>
())
:
AbstractUndetermined
(
element
,
shape
)
{}
AbstractTensor
(
const
TypePtr
&
element_type
,
const
std
::
vector
<
int
>
&
shape
)
:
AbstractUndetermined
(
element_type
,
shape
)
{}
explicit
AbstractTensor
(
const
tensor
::
TensorPtr
&
tensor
)
:
AbstractUndetermined
(
tensor
->
Dtype
(),
tensor
->
shape
())
{}
~
AbstractTensor
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractTensor
,
Abstract
Base
)
MS_DECLARE_PARENT
(
AbstractTensor
,
Abstract
Undetermined
)
TypePtr
BuildType
()
const
override
;
BaseShapePtr
BuildShape
()
const
override
;
...
...
@@ -271,9 +290,7 @@ class AbstractTensor : public AbstractBase {
bool
operator
==
(
const
AbstractTensor
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
ShapePtr
shape
()
const
;
std
::
string
ToString
()
const
override
;
const
AbstractBasePtr
element
()
const
{
return
element_
;
}
std
::
size_t
hash
()
const
override
{
auto
value
=
GetValueTrack
();
auto
hash_sum
=
hash_combine
(
tid
(),
element_
->
hash
());
...
...
@@ -285,9 +302,6 @@ class AbstractTensor : public AbstractBase {
}
return
hash_sum
;
}
private:
AbstractBasePtr
element_
;
};
using
AbstractTensorPtr
=
std
::
shared_ptr
<
AbstractTensor
>
;
using
AbstractTensorPtrList
=
std
::
vector
<
AbstractTensorPtr
>
;
...
...
@@ -585,6 +599,35 @@ struct AbstractBasePtrListEqual {
std
::
size_t
AbstractBasePtrListHash
(
const
AbstractBasePtrList
&
args_spec_list
);
bool
AbstractBasePtrListDeepEqual
(
const
AbstractBasePtrList
&
lhs
,
const
AbstractBasePtrList
&
rhs
);
// IndexedSlices
class
AbstractIndexedSlices
:
public
AbstractUndetermined
{
public:
explicit
AbstractIndexedSlices
(
const
AbstractBasePtr
&
element
,
const
BaseShapePtr
&
shape
=
std
::
make_shared
<
Shape
>
())
:
AbstractUndetermined
(
element
,
shape
)
{}
AbstractIndexedSlices
(
const
TypePtr
&
element_type
,
const
std
::
vector
<
int
>
&
shape
)
:
AbstractUndetermined
(
element_type
,
shape
)
{}
~
AbstractIndexedSlices
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractIndexedSlices
,
AbstractUndetermined
)
const
AbstractTensorPtr
indices
()
const
{
return
indices_
;
}
const
AbstractTensorPtr
values
()
const
{
return
values_
;
}
const
AbstractTuplePtr
dense_shape
()
const
{
return
dense_shape_
;
}
void
set_indices
(
const
AbstractTensorPtr
&
indices
)
{
indices_
=
indices
;
}
void
set_values
(
const
AbstractTensorPtr
&
values
)
{
values_
=
values
;
}
void
set_dense_shape
(
const
AbstractTuplePtr
&
dense_shape
)
{
dense_shape_
=
dense_shape
;
}
TypePtr
BuildType
()
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
BroadenWithShape
()
const
;
std
::
string
ToString
()
const
override
;
private:
AbstractTensorPtr
indices_
;
AbstractTensorPtr
values_
;
AbstractTuplePtr
dense_shape_
;
};
}
// namespace abstract
}
// namespace mindspore
#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
浏览文件 @
fe82d821
...
...
@@ -58,6 +58,20 @@ class Evaluator : public Base {
return
args_spec_list
;
}
virtual
EvalResultPtr
AbstractEval
(
const
AbstractBasePtrList
&
args_spec_list
)
{
auto
is_abstract
=
std
::
any_of
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
auto
&
arg
)
{
if
(
arg
->
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
return
true
;
}
return
false
;
});
if
(
is_abstract
)
{
MS_LOG
(
DEBUG
)
<<
"Eval "
<<
identifier_
<<
" return abstract result"
;
return
std
::
make_shared
<
EvalResult
>
(
std
::
make_shared
<
AbstractUndetermined
>
(),
std
::
make_shared
<
AttrValueMap
>
());
}
return
nullptr
;
}
std
::
string
ToString
()
const
override
{
return
identifier_
;
}
virtual
AnfNodePtr
bound_node
()
const
{
return
bound_node_
.
lock
();
}
...
...
mindspore/ccsrc/pipeline/static_analysis/param_validator.h
浏览文件 @
fe82d821
...
...
@@ -66,6 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function)
ABSTRACT_REPORT_NAME_TRAITS
(
Type
)
ABSTRACT_REPORT_NAME_TRAITS
(
KeywordArg
)
ABSTRACT_REPORT_NAME_TRAITS
(
Class
)
ABSTRACT_REPORT_NAME_TRAITS
(
IndexedSlices
)
template
<
typename
T
>
std
::
shared_ptr
<
T
>
CheckArg
(
const
std
::
string
&
op
,
const
AbstractBasePtrList
&
args_spec_list
,
size_t
index
)
{
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
fe82d821
...
...
@@ -36,6 +36,7 @@
#include "pipeline/parse/resolve.h"
#include "ir/tensor.h"
#include "utils/convert_utils.h"
#include "utils/context/ms_context.h"
#include "pipeline/parse/data_converter.h"
#include "pipeline/static_analysis/param_validator.h"
#include "common/utils.h"
...
...
@@ -132,6 +133,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimControlDepend
,
{
InferImplControlDepend
,
true
}},
// Debug
{
prim
::
kPrimDebug
,
{
InferImplDebug
,
true
}},
// IndexedSlices
{
prim
::
kPrimMakeIndexedSlices
,
{
InferImplMakeIndexedSlices
,
true
}},
{
prim
::
kPrimIndexedSlicesGetValues
,
{
InferImplIndexedSlicesGetValues
,
true
}},
{
prim
::
kPrimIndexedSlicesGetIndices
,
{
InferImplIndexedSlicesGetIndices
,
true
}},
{
prim
::
kPrimIndexedSlicesGetDenseShape
,
{
InferImplIndexedSlicesGetDenseShape
,
true
}},
{
prim
::
kPrimIsIndexedSlices
,
{
InferImplIsIndexedSlices
,
true
}},
};
return
prim_eval_implement_map
;
}
...
...
@@ -139,6 +146,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using
mindspore
::
parse
::
PyObjectWrapper
;
EvalResultPtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
if
(
enable_sparse_flag
&&
prim_
!=
prim
::
kPrimMakeTuple
&&
prim_
!=
prim
::
kPrimSwitch
)
{
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"StandardPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
prim_
->
BeginRecordAddAttr
();
AbstractBasePtr
abs_base
=
eval_impl_
(
engine
,
prim_
,
args
);
prim_
->
EndRecordAddAttr
();
...
...
@@ -485,6 +502,16 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
// end anonymous namespace
EvalResultPtr
PythonPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
if
(
enable_sparse_flag
)
{
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"PythonPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
MS_LOG
(
DEBUG
)
<<
"Eval for:"
<<
prim_py_
->
ToString
();
const
auto
&
iter
=
cache_
->
find
(
args
);
...
...
@@ -512,6 +539,16 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
}
EvalResultPtr
UniformPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
if
(
enable_sparse_flag
)
{
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"UniformPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if
(
nargs_
!=
args
.
size
())
{
MS_LOG
(
ERROR
)
<<
"UniformPrimEvaluator expect "
<<
nargs_
<<
" args, but got "
<<
args
.
size
()
<<
" inputs"
;
...
...
@@ -871,6 +908,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
auto
ref_value
=
ref_abs
->
ref
();
MS_EXCEPTION_IF_NULL
(
ref_value
);
ret
->
set_sparse_grad
(
ref_value
->
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
ref_value
->
has_indexed_slices_grad
());
return
std
::
make_shared
<
EvalResult
>
(
ret
,
std
::
make_shared
<
AttrValueMap
>
());
}
...
...
@@ -886,6 +924,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
std
::
shared_ptr
<
SymbolicKeyInstance
>
key
=
std
::
make_shared
<
SymbolicKeyInstance
>
(
node
,
x
);
std
::
shared_ptr
<
AbstractScalar
>
abs_scalar
=
std
::
make_shared
<
AbstractScalar
>
(
key
,
type
);
abs_scalar
->
set_sparse_grad
(
x
->
sparse_grad
());
abs_scalar
->
set_has_indexed_slices_grad
(
x
->
has_indexed_slices_grad
());
return
std
::
make_shared
<
EvalResult
>
(
abs_scalar
,
std
::
make_shared
<
AttrValueMap
>
());
}
};
...
...
@@ -897,6 +936,16 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT
(
GetAttrEvaluator
,
TransitionPrimEvaluator
);
EvalResultPtr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
if
(
enable_sparse_flag
)
{
auto
ret_abstract
=
AbstractEval
(
args_spec_list
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"GetAttrEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
// Inputs: data, item
if
(
args_spec_list
.
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Expected args_spec_list size = 2, but has size:"
<<
args_spec_list
.
size
();
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
fe82d821
...
...
@@ -350,6 +350,17 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr
InferImplDebug
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
void
InitUndeterminedFromEnv
(
const
std
::
string
&
sparse_shape_types
);
AbstractBasePtr
InferImplMakeIndexedSlices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIndexedSlicesGetValues
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIndexedSlicesGetIndices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIndexedSlicesGetDenseShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsIndexedSlices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
}
// namespace abstract
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
浏览文件 @
fe82d821
...
...
@@ -228,6 +228,10 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
MS_LOG
(
EXCEPTION
)
<<
"func_conf.GetEvaluatedValue() return null, func_conf: "
<<
func_conf
->
ToString
()
<<
" NodeInfo: "
<<
trace
::
GetDebugInfo
(
cnode
->
debug_info
());
}
if
(
maybe_func
->
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
MS_LOG
(
DEBUG
)
<<
"EvalCNode eval Undetermined"
;
return
std
::
make_shared
<
EvalResult
>
(
maybe_func
->
Clone
(),
std
::
make_shared
<
AttrValueMap
>
());
}
AbstractFunctionPtr
func
=
dyn_cast
<
AbstractFunction
>
(
maybe_func
);
if
(
func
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"func_conf.GetEvaluatedValue() return not AbstractFunction: "
<<
maybe_func
->
ToString
()
...
...
mindspore/ccsrc/pipeline/validator.cc
浏览文件 @
fe82d821
...
...
@@ -32,6 +32,7 @@ using mindspore::abstract::AbstractBase;
using
mindspore
::
abstract
::
AbstractClass
;
using
mindspore
::
abstract
::
AbstractError
;
using
mindspore
::
abstract
::
AbstractFunction
;
using
mindspore
::
abstract
::
AbstractIndexedSlices
;
using
mindspore
::
abstract
::
AbstractJTagged
;
using
mindspore
::
abstract
::
AbstractList
;
using
mindspore
::
abstract
::
AbstractScalar
;
...
...
@@ -93,7 +94,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
}
if
(
ptrBase
->
isa
<
AbstractType
>
()
||
ptrBase
->
isa
<
AbstractFunction
>
()
||
ptrBase
->
isa
<
AbstractTuple
>
()
||
ptrBase
->
isa
<
AbstractList
>
()
||
ptrBase
->
isa
<
AbstractTensor
>
()
||
ptrBase
->
isa
<
abstract
::
AbstractRefKey
>
())
{
ptrBase
->
isa
<
AbstractList
>
()
||
ptrBase
->
isa
<
AbstractTensor
>
()
||
ptrBase
->
isa
<
AbstractIndexedSlices
>
()
||
ptrBase
->
isa
<
abstract
::
AbstractRefKey
>
())
{
return
;
}
...
...
mindspore/ccsrc/utils/context/ms_context.cc
浏览文件 @
fe82d821
...
...
@@ -89,6 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
max_device_memory_
=
kDefaultMaxDeviceMemory
;
print_file_path_
=
""
;
enable_graph_kernel_
=
false
;
enable_sparse_flag_
=
false
;
}
std
::
shared_ptr
<
MsContext
>
MsContext
::
GetInstance
()
{
...
...
mindspore/ccsrc/utils/context/ms_context.h
浏览文件 @
fe82d821
...
...
@@ -161,6 +161,9 @@ class MsContext {
void
set_enable_graph_kernel
(
bool
enable_graph_kernel
)
{
enable_graph_kernel_
=
enable_graph_kernel
;
}
bool
enable_graph_kernel
()
const
{
return
enable_graph_kernel_
;
}
bool
enable_sparse_flag
()
const
{
return
enable_sparse_flag_
;
}
void
set_enable_sparse_flag
(
bool
enable_sparse_flag
)
{
enable_sparse_flag_
=
enable_sparse_flag
;
}
private:
MsContext
(
const
std
::
string
&
backend_policy
,
const
std
::
string
&
target
);
void
GetGeOptions
(
std
::
map
<
std
::
string
,
std
::
string
>
*
ge_options
)
const
;
...
...
@@ -204,6 +207,7 @@ class MsContext {
float
max_device_memory_
;
std
::
string
print_file_path_
;
bool
enable_graph_kernel_
;
bool
enable_sparse_flag_
;
};
}
// namespace mindspore
...
...
mindspore/common/__init__.py
浏览文件 @
fe82d821
...
...
@@ -17,10 +17,10 @@ from . import dtype
from
.api
import
ms_function
from
.dtype
import
*
from
.parameter
import
Parameter
,
ParameterTuple
from
.tensor
import
MetaTensor
,
Tensor
from
.tensor
import
MetaTensor
,
Tensor
,
IndexedSlices
__all__
=
[
"MetaTensor"
,
"Tensor"
,
# tensor
"MetaTensor"
,
"Tensor"
,
"IndexedSlices"
,
# tensor
'ms_function'
,
# api
'Parameter'
,
'ParameterTuple'
,
# parameter
"dtype"
...
...
mindspore/common/parameter.py
浏览文件 @
fe82d821
...
...
@@ -52,13 +52,16 @@ class Parameter:
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
"""
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
,
sparse_grad
=
""
):
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
,
sparse_grad
=
""
,
has_indexed_slices_grad
=
False
):
self
.
set_parameter_data
(
default_input
)
self
.
name
=
name
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
self
.
sparse_grad
=
sparse_grad
self
.
has_indexed_slices_grad
=
has_indexed_slices_grad
self
.
_is_init
=
False
self
.
_sliced
=
False
self
.
clone_info
=
_CloneInfo
()
...
...
@@ -186,6 +189,17 @@ class Parameter:
raise
TypeError
(
"`sparse_grad` parameter must be str type"
)
self
.
_sparse_grad
=
value
@
property
def
has_indexed_slices_grad
(
self
):
"""Return whether the parameter's gradient is indexed_slices."""
return
self
.
_has_indexed_slices_grad
@
has_indexed_slices_grad
.
setter
def
has_indexed_slices_grad
(
self
,
value
=
False
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`has_indexed_slices_grad` parameter must be bool type"
)
self
.
_has_indexed_slices_grad
=
value
@
property
def
data
(
self
):
return
self
.
default_input
...
...
mindspore/common/tensor.py
浏览文件 @
fe82d821
...
...
@@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
from
.
import
dtype
as
mstype
from
._register_for_tensor
import
tensor_operator_registry
__all__
=
[
'Tensor'
,
'MetaTensor'
]
__all__
=
[
'Tensor'
,
'MetaTensor'
,
'IndexedSlices'
]
np_types
=
(
np
.
int8
,
np
.
int16
,
np
.
int32
,
np
.
int64
,
np
.
uint8
,
np
.
uint16
,
np
.
uint32
,
np
.
uint64
,
np
.
float16
,
np
.
float32
,
np
.
float64
,
np
.
bool_
)
...
...
@@ -214,3 +214,8 @@ class Tensor(Tensor_):
raise
TypeError
(
"init_flag must be bool."
)
self
.
set_init_flag
(
value
)
self
.
_init_flag
=
value
class
IndexedSlices
:
def
__init__
(
self
,
indices
,
values
,
dense_shape
):
raise
NotImplementedError
mindspore/context.py
浏览文件 @
fe82d821
...
...
@@ -355,6 +355,14 @@ class _Context:
def
check_bprop
(
self
,
check_bprop_flag
):
self
.
_context_handle
.
set_check_bprop_flag
(
check_bprop_flag
)
@
property
def
enable_sparse
(
self
):
return
self
.
_context_handle
.
get_enable_sparse_flag
()
@
enable_sparse
.
setter
def
enable_sparse
(
self
,
enable_sparse_flag
):
self
.
_context_handle
.
set_enable_sparse_flag
(
enable_sparse_flag
)
@
property
def
max_device_memory
(
self
):
return
self
.
_context_handle
.
get_max_device_memory
()
...
...
@@ -510,7 +518,8 @@ def reset_auto_parallel_context():
save_graphs_path
=
str
,
save_ms_model
=
bool
,
save_ms_model_path
=
str
,
enable_dump
=
bool
,
save_dump_path
=
str
,
enable_reduce_precision
=
bool
,
variable_memory_max_size
=
str
,
enable_profiling
=
bool
,
profiling_options
=
str
,
enable_auto_mixed_precision
=
bool
,
enable_graph_kernel
=
bool
,
check_bprop
=
bool
,
max_device_memory
=
str
,
print_file_path
=
str
)
enable_graph_kernel
=
bool
,
check_bprop
=
bool
,
max_device_memory
=
str
,
print_file_path
=
str
,
enable_sparse
=
bool
)
def
set_context
(
**
kwargs
):
"""
Sets context for running environment.
...
...
@@ -567,6 +576,7 @@ def set_context(**kwargs):
The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen.
enable_sparse (bool): Whether to enable sparse feature. Default: False.
Raises:
ValueError: If input key is not an attribute in context.
...
...
mindspore/ops/functional.py
浏览文件 @
fe82d821
...
...
@@ -153,6 +153,14 @@ shape_mul = Primitive("shape_mul")
# a primitive to compare between tuple.
stop_gradient
=
Primitive
(
"stop_gradient"
)
make_indexed_slices
=
Primitive
(
'MakeIndexedSlices'
)
indexed_slices_get_values
=
Primitive
(
'IndexedSlicesGetValues'
)
indexed_slices_get_indices
=
Primitive
(
'IndexedSlicesGetIndices'
)
indexed_slices_get_dense_shape
=
Primitive
(
'IndexedSlicesGetDenseShape'
)
is_indexed_slices
=
Primitive
(
'IsIndexedSlices'
)
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__sub__'
,
tensor_sub
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
fe82d821
...
...
@@ -564,7 +564,7 @@ class SparseGatherV2(GatherV2):
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = P.GatherV2()(input_params, input_indices, axis)
>>> out = P.
Sparse
GatherV2()(input_params, input_indices, axis)
"""
...
...
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
fe82d821
...
...
@@ -603,5 +603,18 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
ASSERT_TRUE
(
CheckOpt
(
before2l
,
after2
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before2r
,
after2
,
patterns
));
}
TEST_F
(
TestOptLib
,
test_indexed_slices
)
{
FuncGraphPtr
before_get_indices
=
getPyFun
.
CallAndParseRet
(
"test_indexed_slices"
,
"before_get_indices"
);
FuncGraphPtr
after_get_indices
=
getPyFun
.
CallAndParseRet
(
"test_indexed_slices"
,
"after_get_indices"
);
FuncGraphPtr
before_get_values
=
getPyFun
.
CallAndParseRet
(
"test_indexed_slices"
,
"before_get_values"
);
FuncGraphPtr
after_get_values
=
getPyFun
.
CallAndParseRet
(
"test_indexed_slices"
,
"after_get_values"
);
FuncGraphPtr
before_get_dense_shape
=
getPyFun
.
CallAndParseRet
(
"test_indexed_slices"
,
"before_get_dense_shape"
);
FuncGraphPtr
after_get_dense_shape
=
getPyFun
.
CallAndParseRet
(
"test_indexed_slices"
,
"after_get_dense_shape"
);
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
indexed_slices_eliminate_
});
ASSERT_TRUE
(
CheckOpt
(
before_get_indices
,
after_get_indices
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before_get_values
,
after_get_values
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before_get_dense_shape
,
after_get_dense_shape
,
patterns
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
浏览文件 @
fe82d821
...
...
@@ -1130,3 +1130,38 @@ def test_adjust_allreduce_mul_add(tag):
return
Mul
(
AllReduce
(
AddN
((
Mul
(
z
,
z
),
x
))),
y
)
return
fns
[
tag
]
def
test_indexed_slices
(
tag
):
""" test_add_zero """
fns
=
FnDict
()
make_indexed_slices
=
Primitive
(
'MakeIndexedSlices'
)
indexed_slices_get_values
=
Primitive
(
'IndexedSlicesGetValues'
)
indexed_slices_get_indices
=
Primitive
(
'IndexedSlicesGetIndices'
)
indexed_slices_get_dense_shape
=
Primitive
(
'IndexedSlicesGetDenseShape'
)
@
fns
def
before_get_indices
(
x
,
y
,
z
):
return
indexed_slices_get_indices
(
make_indexed_slices
(
x
,
y
,
z
))
@
fns
def
after_get_indices
(
x
,
y
,
z
):
return
x
@
fns
def
before_get_values
(
x
,
y
,
z
):
return
indexed_slices_get_values
(
make_indexed_slices
(
x
,
y
,
z
))
@
fns
def
after_get_values
(
x
,
y
,
z
):
return
y
@
fns
def
before_get_dense_shape
(
x
,
y
,
z
):
return
indexed_slices_get_dense_shape
(
make_indexed_slices
(
x
,
y
,
z
))
@
fns
def
after_get_dense_shape
(
x
,
y
,
z
):
return
z
return
fns
[
tag
]
tests/ut/python/ir/test_indexed_slices.py
0 → 100644
浏览文件 @
fe82d821
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_indexed_slices.py
@Author:
@Date : 2020-06-08
@Desc : test mindspore indexed_slices's operation
"""
import
numpy
as
np
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.composite.multitype_ops.zeros_like_impl
import
zeros_like
from
mindspore.ops.primitive
import
constexpr
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore
import
Tensor
,
IndexedSlices
,
context
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common
import
dtype
as
mstype
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore.nn
import
Optimizer
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
reduce_sum
=
P
.
ReduceSum
()
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
transpose
=
P
.
Transpose
()
shape_op
=
P
.
Shape
()
reshape
=
P
.
Reshape
()
size_op
=
P
.
Size
()
invert_permutation
=
P
.
InvertPermutation
()
logical_and
=
P
.
LogicalAnd
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_sparse
=
True
)
@
constexpr
def
_generate_shape_index
(
out_shape
,
indices_shape
,
axis
):
out_rank
=
len
(
out_shape
)
ind_rank
=
len
(
indices_shape
)
if
axis
<
0
:
axis
+=
out_rank
-
ind_rank
+
1
perm_part1
=
tuple
(
range
(
axis
,
axis
+
ind_rank
))
index
=
tuple
(
range
(
out_rank
))
perm
=
perm_part1
+
index
[:
axis
]
+
index
[
axis
+
ind_rank
:]
return
perm
@
constexpr
def
_generate_inverse_index
(
x_shape
,
axis
):
x_rank
=
len
(
x_shape
)
index
=
tuple
(
range
(
x_rank
))
if
axis
<
0
:
axis
+=
x_rank
perm
=
index
[
1
:
1
+
axis
]
+
(
0
,)
+
index
[
1
+
axis
:]
return
perm
class
MySparseGatherV2
(
P
.
GatherV2
):
"""
For test
"""
@
bprop_getters
.
register
(
MySparseGatherV2
)
def
get_bprop_sparse_gather_v2
(
self
):
"""Generate bprop for MySparseGatherV2"""
def
bprop
(
x
,
indices
,
axis
,
out
,
dout
):
x_shp
=
shape_op
(
x
)
if
axis
==
0
:
indices_size
=
(
size_op
(
indices
),)
x_tail_shp
=
x_shp
[
1
:]
values_shape
=
indices_size
+
x_tail_shp
values
=
reshape
(
dout
,
values_shape
)
indices
=
reshape
(
indices
,
indices_size
)
return
IndexedSlices
(
indices
,
values
,
x_shp
),
zeros_like
(
indices
),
zeros_like
(
axis
)
if
F
.
rank
(
dout
)
==
0
:
dout
=
P
.
ExpandDims
()(
dout
,
-
1
)
if
F
.
rank
(
indices
)
==
0
:
indices
=
P
.
ExpandDims
()(
indices
,
-
1
)
out_shp
=
shape_op
(
dout
)
ind_shp
=
shape_op
(
indices
)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1
=
_generate_shape_index
(
out_shp
,
ind_shp
,
axis
)
values_transpose
=
transpose
(
dout
,
perm_1
)
params_grad
=
unsorted_segment_sum
(
values_transpose
,
indices
,
shape_op
(
x
)[
axis
])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
perm_2
=
_generate_inverse_index
(
x_shp
,
axis
)
params_grad
=
transpose
(
params_grad
,
perm_2
)
return
params_grad
,
zeros_like
(
indices
),
zeros_like
(
axis
)
return
bprop
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Undetermined"
,
"Bool"
)
def
_update_run_op_for_map
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
if
gradient
.
is_indexed_slices
():
return
gradient
.
values
()
op_mul
=
P
.
Mul
()
op_square
=
P
.
Square
()
op_sqrt
=
P
.
Sqrt
()
op_cast
=
P
.
Cast
()
op_reshape
=
P
.
Reshape
()
op_shape
=
P
.
Shape
()
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
update
=
next_m
/
(
op_sqrt
(
next_v
)
+
eps
)
if
decay_flag
:
update
=
update
+
op_mul
(
weight_decay_tensor
,
param_fp32
)
update_with_lr
=
op_mul
(
lr
,
update
)
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
param
,
next_param
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
m
,
next_m
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
v
,
next_v
))
return
next_v
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"beta2"
,
beta2
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"eps"
,
eps
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"weight_dacay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
class
AdamWeightDecaySparse
(
Optimizer
):
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
AdamWeightDecaySparse
,
self
).
__init__
(
learning_rate
,
params
)
if
self
.
is_group
:
raise
RuntimeError
(
f
"The
{
self
.
cls_name
}
optimizer cannot support group setting."
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
self
.
beta1
=
Tensor
(
np
.
array
([
beta1
]).
astype
(
np
.
float32
))
self
.
beta2
=
Tensor
(
np
.
array
([
beta2
]).
astype
(
np
.
float32
))
self
.
eps
=
Tensor
(
np
.
array
([
eps
]).
astype
(
np
.
float32
))
self
.
weight_decay_tensor
=
Tensor
(
np
.
array
([
weight_decay
]).
astype
(
np
.
float32
))
self
.
params
=
self
.
parameters
self
.
moments1
=
self
.
params
.
clone
(
prefix
=
"adam_m"
,
init
=
'zeros'
)
self
.
moments2
=
self
.
params
.
clone
(
prefix
=
"adam_v"
,
init
=
'zeros'
)
self
.
decay_flag
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
params
)
self
.
map
=
C
.
Map
()
def
construct
(
self
,
gradients
):
lr
=
self
.
get_lr
()
updated_velocity
=
self
.
map
(
F
.
partial
(
adam_opt_for_map
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
return
updated_velocity
def
test_indexed_slices_make_indexed_slices
():
class
MakeIndexedSlices
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MakeIndexedSlices
,
self
).
__init__
()
self
.
dense_shape
=
(
3
,
4
)
def
construct
(
self
,
indices
,
values
):
ret
=
(
IndexedSlices
(
indices
,
values
,
self
.
dense_shape
),)
return
ret
[
0
].
is_indexed_slices
()
indices
=
Tensor
([[
0
,
0
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
MakeIndexedSlices
()(
indices
,
values
)
def
test_indexed_slices_attr
():
class
IndexedSlicesGetAttr
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
IndexedSlicesGetAttr
,
self
).
__init__
()
self
.
dense_shape
=
(
3
,
4
)
def
construct
(
self
,
indices
,
values
):
x
=
IndexedSlices
(
indices
,
values
,
self
.
dense_shape
)
return
x
.
values
(),
x
.
indices
(),
x
.
dense_shape
()
indices
=
Tensor
([[
0
,
0
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
IndexedSlicesGetAttr
()(
indices
,
values
)
def
test_indexed_slices_sparse_gatherv2_grad_all
():
grad_all
=
C
.
GradOperation
(
'get_all'
,
get_all
=
True
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
grad
=
grad_all
(
self
.
network
)(
x
,
y
)
return
grad
,
grad
[
0
].
is_indexed_slices
(),
grad
[
1
].
is_indexed_slices
()
class
SparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseGatherV2
,
self
).
__init__
()
self
.
sparse_gatherv2
=
MySparseGatherV2
()
self
.
axis
=
0
def
construct
(
self
,
params
,
indices
):
return
self
.
sparse_gatherv2
(
params
,
indices
,
self
.
axis
)
params
=
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
int32
))
indices
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
GradWrap
(
SparseGatherV2
())(
params
,
indices
)
def
test_indexed_slices_sparse_gatherv2_grad_with_pram
():
grad_by_list
=
C
.
GradOperation
(
'get_by_list'
,
get_by_list
=
True
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
self
.
weights
=
ParameterTuple
(
filter
(
lambda
x
:
x
.
requires_grad
,
network
.
get_parameters
()))
def
construct
(
self
,
x
):
weights
=
self
.
weights
grad
=
grad_by_list
(
self
.
network
,
weights
)(
x
)
x
=
grad
[
0
]
return
x
.
is_indexed_slices
(),
x
.
values
(),
x
.
indices
(),
x
.
dense_shape
()
class
SparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseGatherV2
,
self
).
__init__
()
self
.
sparse_gatherv2
=
MySparseGatherV2
()
self
.
axis
=
0
self
.
params
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
int32
)),
name
=
"params"
,
has_indexed_slices_grad
=
True
)
def
construct
(
self
,
indices
):
return
self
.
sparse_gatherv2
(
self
.
params
,
indices
,
self
.
axis
)
indices
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
network
=
GradWrap
(
SparseGatherV2
())
network
(
indices
)
def
test_indexed_slices_is_indexed_slices
():
class
MakeIndexedSlices
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MakeIndexedSlices
,
self
).
__init__
()
self
.
dense_shape
=
(
3
,
4
)
def
construct
(
self
,
indices
,
values
):
indexed_slices
=
IndexedSlices
(
indices
,
values
,
self
.
dense_shape
)
ret
=
indexed_slices
.
is_indexed_slices
()
return
ret
indices
=
Tensor
([[
0
,
0
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
MakeIndexedSlices
()(
indices
,
values
)
def
test_indexed_slices_env_get
():
class
Loss
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Loss
,
self
).
__init__
()
def
construct
(
self
,
base
,
target
):
return
base
class
NetWithSparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
,
has_indexed_slices_grad
=
True
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
gatherv2
=
MySparseGatherV2
()
self
.
axis
=
0
def
construct
(
self
,
indices
):
return
self
.
gatherv2
(
self
.
w1
,
indices
,
self
.
axis
)
*
self
.
w2
inputs
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
label
=
Tensor
(
np
.
zeros
([
2
,
1
,
2
]).
astype
(
np
.
float32
))
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
loss
=
Loss
()
optimizer
=
AdamWeightDecaySparse
(
net
.
trainable_params
())
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
train_network
(
inputs
,
label
)
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
浏览文件 @
fe82d821
...
...
@@ -155,7 +155,7 @@ def test_AdamWeightDecaySparse():
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
,
sparse_grad
=
"sparse_key_w1"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
,
sparse_grad
=
"sparse_key_w2"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
gatherv2
=
P
.
SparseGatherV2
()
self
.
axis
=
0
def
construct
(
self
,
indices
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录