Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
acaa66a7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
acaa66a7
编写于
5月 21, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sparse grad for gatherv2
上级
54991615
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
841 addition
and
21 deletion
+841
-21
mindspore/ccsrc/debug/anf_ir_utils.cc
mindspore/ccsrc/debug/anf_ir_utils.cc
+15
-4
mindspore/ccsrc/operator/composite/map.cc
mindspore/ccsrc/operator/composite/map.cc
+289
-0
mindspore/ccsrc/operator/composite/map.h
mindspore/ccsrc/operator/composite/map.h
+98
-0
mindspore/ccsrc/operator/prim_others.cc
mindspore/ccsrc/operator/prim_others.cc
+93
-5
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+3
-0
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
+20
-6
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
+4
-1
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+6
-1
mindspore/common/parameter.py
mindspore/common/parameter.py
+14
-1
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+32
-0
mindspore/ops/composite/__init__.py
mindspore/ops/composite/__init__.py
+1
-1
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+64
-1
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+23
-0
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+4
-0
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
+173
-0
未找到文件。
mindspore/ccsrc/debug/anf_ir_utils.cc
浏览文件 @
acaa66a7
...
...
@@ -30,6 +30,7 @@
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/resolve.h"
#include "operator/composite/composite.h"
#include "operator/composite/map.h"
#include "utils/ordered_map.h"
#include "utils/ordered_set.h"
#include "utils/utils.h"
...
...
@@ -190,6 +191,8 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* ├── MultitypeGraph
* ├── HyperMap
* │ └── HyperMapPy
* ├── Map
* │ └── MapPy
* ├── Tail
* ├── MakeTupleGradient
* ├── GradOperation
...
...
@@ -208,17 +211,25 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
oss
<<
GetMultitypeFuncGraphText
(
mt_func_graph
);
}
else
if
(
meta_func_graph
->
isa
<
prim
::
HyperMapPy
>
())
{
// this statement must before 'meta_graph->isa<prim::HyperMap>()'
prim
::
HyperMapPyPtr
hyper_map
=
meta_func_graph
->
cast
<
prim
::
HyperMapPyPtr
>
();
MS_EXCEPTION_IF_NULL
(
hyper_map
);
auto
hyper_map
=
meta_func_graph
->
cast
<
prim
::
HyperMapPyPtr
>
();
if
(
hyper_map
->
GetFnLeaf
()
!=
nullptr
)
{
oss
<<
"{fn_leaf="
<<
GetMetaFuncGraphText
(
hyper_map
->
GetFnLeaf
())
<<
"}"
;
}
}
else
if
(
meta_func_graph
->
isa
<
prim
::
HyperMap
>
())
{
prim
::
HyperMapPtr
hyper_map
=
meta_func_graph
->
cast
<
prim
::
HyperMapPtr
>
();
MS_EXCEPTION_IF_NULL
(
hyper_map
);
auto
hyper_map
=
meta_func_graph
->
cast
<
prim
::
HyperMapPtr
>
();
if
(
hyper_map
->
GetFnLeaf
()
!=
nullptr
)
{
oss
<<
"{fn_leaf="
<<
GetMetaFuncGraphText
(
hyper_map
->
GetFnLeaf
())
<<
"}"
;
}
}
else
if
(
meta_func_graph
->
isa
<
prim
::
MapPy
>
())
{
// this statement must before 'meta_graph->isa<prim::Map>()'
auto
map
=
meta_func_graph
->
cast
<
prim
::
MapPyPtr
>
();
if
(
map
->
GetFnLeaf
()
!=
nullptr
)
{
oss
<<
"{fn_leaf="
<<
GetMetaFuncGraphText
(
map
->
GetFnLeaf
())
<<
"}"
;
}
}
else
if
(
meta_func_graph
->
isa
<
prim
::
Map
>
())
{
auto
map
=
meta_func_graph
->
cast
<
prim
::
MapPtr
>
();
if
(
map
->
GetFnLeaf
()
!=
nullptr
)
{
oss
<<
"{fn_leaf="
<<
GetMetaFuncGraphText
(
map
->
GetFnLeaf
())
<<
"}"
;
}
}
else
if
(
meta_func_graph
->
isa
<
prim
::
GradOperation
>
())
{
prim
::
GradOperationPtr
grad_op
=
meta_func_graph
->
cast
<
prim
::
GradOperationPtr
>
();
oss
<<
"{get_all="
<<
grad_op
->
get_all_
<<
", get_by_list="
<<
grad_op
->
get_by_list_
...
...
mindspore/ccsrc/operator/composite/map.cc
0 → 100644
浏览文件 @
acaa66a7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "operator/composite/map.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "pipeline/static_analysis/dshape.h"
#include "pybind_api/api_register.h"
#include "debug/trace.h"
#include "operator/ops.h"
#include "./common.h"
namespace
mindspore
{
// namespace to support composite operators definition
namespace
prim
{
using
FuncGraphAbstractClosure
=
mindspore
::
abstract
::
FuncGraphAbstractClosure
;
AnfNodePtr
Map
::
FullMakeLeaf
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
AnfNodePtrList
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"Map FullMakeLeaf non recursive.
\n
"
;
MS_EXCEPTION_IF_NULL
(
func_graph
);
std
::
vector
<
AnfNodePtr
>
inputs
;
if
(
fn_arg
!=
nullptr
)
{
inputs
.
emplace_back
(
fn_arg
);
}
else
{
inputs
.
emplace_back
(
NewValueNode
(
fn_leaf_
));
}
inputs
.
insert
(
inputs
.
end
(),
args
.
begin
(),
args
.
end
());
return
func_graph
->
NewCNode
(
inputs
);
}
FuncGraphPtr
Map
::
GenerateLeafFunc
(
const
size_t
&
args_size
)
{
// Generate func for leaf nodes
FuncGraphPtr
ptrGraph
=
std
::
make_shared
<
FuncGraph
>
();
ptrGraph
->
set_flags
(
FUNC_GRAPH_FLAG_CORE
,
true
);
ptrGraph
->
set_flags
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
,
true
);
ptrGraph
->
debug_info
()
->
set_name
(
"map"
);
AnfNodePtr
ptrFnArg
=
nullptr
;
if
(
fn_leaf_
==
nullptr
)
{
ptrFnArg
=
ptrGraph
->
add_parameter
();
}
AnfNodePtrList
args
;
for
(
size_t
i
=
0
;
i
<
args_size
;
++
i
)
{
args
.
emplace_back
(
ptrGraph
->
add_parameter
());
}
ptrGraph
->
set_output
(
FullMakeLeaf
(
ptrGraph
,
ptrFnArg
,
args
));
return
ptrGraph
;
}
AnfNodePtr
Map
::
FullMakeList
(
const
std
::
shared_ptr
<
List
>
&
type
,
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
type
);
std
::
size_t
size
=
type
->
elements
().
size
();
bool
is_not_same
=
std
::
any_of
(
arg_pairs
.
begin
(),
arg_pairs
.
end
(),
[
size
](
const
std
::
pair
<
AnfNodePtr
,
TypePtr
>
&
item
)
{
auto
lhs
=
std
::
dynamic_pointer_cast
<
List
>
(
item
.
second
);
MS_EXCEPTION_IF_NULL
(
lhs
);
return
lhs
->
elements
().
size
()
!=
size
;
});
if
(
is_not_same
)
{
MS_LOG
(
EXCEPTION
)
<<
"List in Map should have same length"
;
}
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeList
));
for
(
int
i
=
0
;
i
<
SizeToInt
(
size
);
++
i
)
{
MS_LOG
(
DEBUG
)
<<
"GenerateLeafFunc for the "
<<
i
<<
"th arg of the target"
;
auto
ptrGraph
=
GenerateLeafFunc
(
arg_pairs
.
size
());
auto
fn
=
NewValueNode
(
ptrGraph
);
std
::
vector
<
AnfNodePtr
>
inputs2
;
inputs2
.
push_back
(
fn
);
if
(
fn_arg
!=
nullptr
)
{
inputs2
.
push_back
(
fn_arg
);
}
(
void
)
std
::
transform
(
arg_pairs
.
begin
(),
arg_pairs
.
end
(),
std
::
back_inserter
(
inputs2
),
[
&
func_graph
,
i
](
const
std
::
pair
<
AnfNodePtr
,
Any
>
&
item
)
{
return
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimListGetItem
),
item
.
first
,
NewValueNode
(
i
)});
});
inputs
.
push_back
(
func_graph
->
NewCNode
(
inputs2
));
}
return
func_graph
->
NewCNode
(
inputs
);
}
AnfNodePtr
Map
::
FullMakeTuple
(
const
std
::
shared_ptr
<
Tuple
>
&
type
,
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
type
);
std
::
size_t
size
=
type
->
elements
().
size
();
bool
is_not_same
=
std
::
any_of
(
arg_pairs
.
begin
(),
arg_pairs
.
end
(),
[
size
](
const
std
::
pair
<
AnfNodePtr
,
TypePtr
>
&
item
)
{
auto
lhs
=
std
::
dynamic_pointer_cast
<
Tuple
>
(
item
.
second
);
MS_EXCEPTION_IF_NULL
(
lhs
);
return
lhs
->
elements
().
size
()
!=
size
;
});
if
(
is_not_same
)
{
MS_LOG
(
EXCEPTION
)
<<
"tuple in Map should have same length"
;
}
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
int
i
=
0
;
i
<
SizeToInt
(
size
);
++
i
)
{
MS_LOG
(
DEBUG
)
<<
"GenerateLeafFunc for the "
<<
i
<<
"th arg of the tuple inputs"
;
auto
ptrGraph
=
GenerateLeafFunc
(
arg_pairs
.
size
());
auto
fn
=
NewValueNode
(
ptrGraph
);
std
::
vector
<
AnfNodePtr
>
inputs2
;
inputs2
.
push_back
(
fn
);
if
(
fn_arg
!=
nullptr
)
{
inputs2
.
push_back
(
fn_arg
);
}
(
void
)
std
::
transform
(
arg_pairs
.
begin
(),
arg_pairs
.
end
(),
std
::
back_inserter
(
inputs2
),
[
&
func_graph
,
&
i
](
std
::
pair
<
AnfNodePtr
,
Any
>
item
)
{
return
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
item
.
first
,
NewValueNode
(
i
)});
});
inputs
.
push_back
(
func_graph
->
NewCNode
(
inputs2
));
}
return
func_graph
->
NewCNode
(
inputs
);
}
AnfNodePtr
Map
::
FullMakeClass
(
const
std
::
shared_ptr
<
Class
>
&
type
,
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
)
{
MS_EXCEPTION_IF_NULL
(
type
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeRecord
));
inputs
.
push_back
(
NewValueNode
(
type
));
std
::
size_t
attrSize
=
type
->
GetAttributes
().
size
();
for
(
std
::
size_t
i
=
0
;
i
<
attrSize
;
++
i
)
{
MS_LOG
(
DEBUG
)
<<
"GenerateLeafFunc for the "
<<
i
<<
"th element of the inputs"
;
auto
ptrGraph
=
GenerateLeafFunc
(
arg_pairs
.
size
());
auto
fn
=
NewValueNode
(
ptrGraph
);
std
::
vector
<
AnfNodePtr
>
inputs2
;
inputs2
.
push_back
(
fn
);
if
(
fn_arg
!=
nullptr
)
{
inputs2
.
push_back
(
fn_arg
);
}
int
j
=
0
;
for
(
auto
item
:
arg_pairs
)
{
inputs2
.
push_back
(
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetAttr
),
item
.
first
,
NewValueNode
(
j
)}));
j
++
;
}
inputs
.
push_back
(
func_graph
->
NewCNode
(
inputs2
));
}
return
func_graph
->
NewCNode
(
inputs
);
}
AnfNodePtr
Map
::
Make
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
)
{
bool
found
=
false
;
TypeId
id
=
kObjectTypeEnd
;
std
::
pair
<
AnfNodePtr
,
TypePtr
>
pair
;
for
(
auto
&
item
:
arg_pairs
)
{
pair
=
item
;
MS_LOG
(
DEBUG
)
<<
"Map "
<<
pair
.
second
->
ToString
();
id
=
item
.
second
->
type_id
();
if
(
nonleaf_
.
count
(
id
))
{
found
=
true
;
break
;
}
}
if
(
found
)
{
// In a nonleaf situation, all arguments must have the same generic.
bool
is_not_same
=
std
::
any_of
(
arg_pairs
.
begin
(),
arg_pairs
.
end
(),
[
pair
](
const
std
::
pair
<
AnfNodePtr
,
TypePtr
>
&
item
)
{
if
(
item
.
first
!=
pair
.
first
)
{
return
item
.
second
->
type_id
()
!=
pair
.
second
->
type_id
();
}
return
false
;
});
if
(
is_not_same
)
{
std
::
ostringstream
oss
;
oss
<<
"There are "
<<
arg_pairs
.
size
()
<<
" inputs of `"
<<
name_
<<
"`, corresponding type info:
\n
"
<<
trace
::
GetDebugInfo
(
func_graph
->
debug_info
())
<<
"
\n
"
;
int
idx
=
0
;
for
(
auto
&
item
:
arg_pairs
)
{
oss
<<
++
idx
<<
": "
<<
item
.
second
->
ToString
()
<<
"
\n
"
;
}
MS_LOG
(
EXCEPTION
)
<<
"Map cannot match up all input types of arguments.
\n
"
<<
oss
.
str
()
<<
pair
.
second
->
ToString
()
<<
"
\n
"
;
}
}
switch
(
id
)
{
case
kObjectTypeList
:
{
auto
type
=
std
::
static_pointer_cast
<
List
>
(
pair
.
second
);
return
FullMakeList
(
type
,
func_graph
,
fn_arg
,
arg_pairs
);
}
case
kObjectTypeTuple
:
{
auto
type
=
std
::
static_pointer_cast
<
Tuple
>
(
pair
.
second
);
return
FullMakeTuple
(
type
,
func_graph
,
fn_arg
,
arg_pairs
);
}
case
kObjectTypeClass
:
{
auto
type
=
std
::
static_pointer_cast
<
Class
>
(
pair
.
second
);
return
FullMakeClass
(
type
,
func_graph
,
fn_arg
,
arg_pairs
);
}
default:
MS_LOG
(
EXCEPTION
)
<<
"Map can only be applied to list, tuple and class "
<<
", but got "
<<
pair
.
second
->
ToString
();
}
}
FuncGraphPtr
Map
::
GenerateFromTypes
(
const
TypePtrList
&
args_spec_list
)
{
FuncGraphPtr
ptrGraph
=
std
::
make_shared
<
FuncGraph
>
();
ptrGraph
->
set_flags
(
FUNC_GRAPH_FLAG_CORE
,
true
);
ptrGraph
->
set_flags
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
,
true
);
ptrGraph
->
debug_info
()
->
set_name
(
"map"
);
AnfNodePtr
ptrFnArg
=
nullptr
;
std
::
size_t
i
=
0
;
if
(
fn_leaf_
==
nullptr
)
{
ptrFnArg
=
ptrGraph
->
add_parameter
();
i
=
1
;
}
ArgsPairList
arg_pairs
;
std
::
size_t
size
=
args_spec_list
.
size
();
for
(;
i
<
size
;
++
i
)
{
MS_LOG
(
DEBUG
)
<<
"GenerateFromTypes for elements from "
<<
args_spec_list
[
i
]
->
ToString
();
arg_pairs
.
push_back
(
std
::
make_pair
(
ptrGraph
->
add_parameter
(),
args_spec_list
[
i
]));
}
ptrGraph
->
set_output
(
Make
(
ptrGraph
,
ptrFnArg
,
arg_pairs
));
return
ptrGraph
;
}
abstract
::
AbstractBasePtrList
Map
::
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
if
(
fn_leaf_
==
nullptr
)
{
MS_EXCEPTION_IF_NULL
(
args_spec_list
[
0
]);
// Assert that map's function param does not contain free variables
if
(
args_spec_list
[
0
]
->
isa
<
FuncGraphAbstractClosure
>
())
{
auto
graph_func
=
dyn_cast
<
FuncGraphAbstractClosure
>
(
args_spec_list
[
0
]);
auto
func_graph
=
graph_func
->
func_graph
();
if
(
func_graph
->
parent
()
!=
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Map don't support Closure with free variable yet."
;
}
}
}
AbstractBasePtrList
broadened
;
(
void
)
std
::
transform
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
broadened
),
[](
const
AbstractBasePtr
&
arg
)
->
AbstractBasePtr
{
MS_EXCEPTION_IF_NULL
(
arg
);
return
arg
->
Broaden
();
});
return
broadened
;
}
REGISTER_PYBIND_DEFINE
(
Map_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
MapPy
,
MetaFuncGraph
,
std
::
shared_ptr
<
MapPy
>>
(
*
m
,
"Map_"
)
.
def
(
py
::
init
<
std
::
shared_ptr
<
MultitypeFuncGraph
>>
(),
py
::
arg
(
"leaf"
))
.
def
(
py
::
init
<>
());
}));
}
// namespace prim
}
// namespace mindspore
mindspore/ccsrc/operator/composite/map.h
0 → 100644
浏览文件 @
acaa66a7
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
#include <memory>
#include <set>
#include <utility>
#include <vector>
#include "ir/dtype.h"
#include "ir/meta_func_graph.h"
#include "operator/composite/multitype_funcgraph.h"
namespace
mindspore
{
// namespace to support composite operators definition
namespace
prim
{
using
ArgsPairList
=
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
TypePtr
>>
;
class
Map
:
public
MetaFuncGraph
{
public:
explicit
Map
(
const
std
::
shared_ptr
<
MultitypeFuncGraph
>
&
fn_leaf
=
nullptr
)
:
MetaFuncGraph
(
"map"
),
fn_leaf_
(
fn_leaf
),
broadcast_
(
false
),
nonleaf_
({
kObjectTypeList
,
kObjectTypeTuple
,
kObjectTypeClass
})
{
Init
();
}
Map
(
const
Map
&
h
)
:
MetaFuncGraph
(
"map"
),
fn_leaf_
(
h
.
fn_leaf_
),
broadcast_
(
h
.
broadcast_
),
nonleaf_
(
h
.
nonleaf_
)
{
Init
();
}
Map
&
operator
=
(
const
Map
&
h
)
{
if
(
this
!=
&
h
)
{
fn_leaf_
=
h
.
fn_leaf_
;
broadcast_
=
h
.
broadcast_
;
nonleaf_
=
h
.
nonleaf_
;
if
(
fn_leaf_
)
{
name_
=
"map["
+
fn_leaf_
->
name
()
+
"]"
;
}
}
return
*
this
;
}
~
Map
()
override
=
default
;
MS_DECLARE_PARENT
(
Map
,
MetaFuncGraph
)
abstract
::
AbstractBasePtrList
NormalizeArgs
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
)
const
override
;
FuncGraphPtr
GenerateFromTypes
(
const
TypePtrList
&
args_spec_list
)
override
;
MetaFuncGraphPtr
GetFnLeaf
()
{
return
fn_leaf_
;
}
private:
FuncGraphPtr
GenerateLeafFunc
(
const
size_t
&
args_size
);
AnfNodePtr
FullMakeLeaf
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
AnfNodePtrList
&
args
);
AnfNodePtr
FullMakeList
(
const
std
::
shared_ptr
<
List
>
&
type
,
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
);
AnfNodePtr
FullMakeTuple
(
const
std
::
shared_ptr
<
Tuple
>
&
type
,
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
);
AnfNodePtr
FullMakeClass
(
const
std
::
shared_ptr
<
Class
>
&
type
,
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
);
AnfNodePtr
Make
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
);
void
Init
()
{
if
(
fn_leaf_
!=
nullptr
)
{
name_
=
"map["
+
fn_leaf_
->
name
()
+
"]"
;
}
signatures_
=
// def map(func:read, *args:ref):
std
::
vector
<
Signature
>
({{
"func"
,
SignatureEnumRW
::
kRWRead
,
SignatureEnumKind
::
kKindDefault
},
{
"args"
,
SignatureEnumRW
::
kRWRef
,
SignatureEnumKind
::
kKindVarPositional
}});
}
MultitypeFuncGraphPtr
fn_leaf_
;
bool
broadcast_
;
std
::
set
<
TypeId
>
nonleaf_
;
};
using
MapPtr
=
std
::
shared_ptr
<
Map
>
;
class
MapPy
:
public
Map
{
public:
explicit
MapPy
(
const
std
::
shared_ptr
<
MultitypeFuncGraph
>
&
fn_leaf
=
nullptr
)
:
Map
(
fn_leaf
)
{}
~
MapPy
()
override
=
default
;
MS_DECLARE_PARENT
(
MapPy
,
Map
)
};
using
MapPyPtr
=
std
::
shared_ptr
<
MapPy
>
;
}
// namespace prim
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
mindspore/ccsrc/operator/prim_others.cc
浏览文件 @
acaa66a7
...
...
@@ -14,9 +14,14 @@
* limitations under the License.
*/
#include <string>
#include <sstream>
#include "ir/dtype.h"
#include "common/utils.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/prim.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
...
...
@@ -50,6 +55,65 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return
AbstractFunction
::
MakeAbstractFunction
(
jv
);
}
class
UndeterminedShapeType
{
public:
explicit
UndeterminedShapeType
(
const
std
::
string
&
env_str
)
{
// param_name indices_shape indices_type values_shape values_type dense_shape
// export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2"
std
::
vector
<
string
>
fields
;
string
tmp
;
std
::
stringstream
input
(
env_str
);
while
(
std
::
getline
(
input
,
tmp
,
':'
))
{
fields
.
push_back
(
tmp
);
}
if
(
fields
.
size
()
!=
fields_num
)
{
MS_LOG
(
EXCEPTION
)
<<
"Expect "
<<
fields_num
<<
" fields, but got "
<<
fields
.
size
();
}
param_name_
=
fields
[
0
];
indices_shape_
=
GetShape
(
fields
[
1
]);
indices_type_
=
StringToType
(
fields
[
2
]);
values_shape_
=
GetShape
(
fields
[
3
]);
values_type_
=
StringToType
(
fields
[
4
]);
auto
dense_shape_vec
=
GetShape
(
fields
[
5
]);
AbstractBasePtrList
dense_shape_list
;
(
void
)
std
::
transform
(
dense_shape_vec
.
begin
(),
dense_shape_vec
.
end
(),
std
::
back_inserter
(
dense_shape_list
),
[](
const
auto
&
elem
)
{
return
FromValue
(
elem
,
false
);
});
dense_shape_
=
dense_shape_list
;
}
const
std
::
string
&
param_name
()
{
return
param_name_
;
}
const
std
::
vector
<
int
>
&
indices_shape
()
{
return
indices_shape_
;
}
const
TypePtr
&
indices_type
()
{
return
indices_type_
;
}
const
std
::
vector
<
int
>
&
values_shape
()
{
return
values_shape_
;
}
const
TypePtr
&
values_type
()
{
return
values_type_
;
}
const
AbstractBasePtrList
&
dense_shape
()
{
return
dense_shape_
;
}
private:
std
::
string
param_name_
;
std
::
vector
<
int
>
indices_shape_
;
TypePtr
indices_type_
;
std
::
vector
<
int
>
values_shape_
;
TypePtr
values_type_
;
AbstractBasePtrList
dense_shape_
;
static
const
size_t
fields_num
;
std
::
vector
<
int
>
GetShape
(
const
std
::
string
&
shape_str
);
};
std
::
vector
<
int
>
UndeterminedShapeType
::
GetShape
(
const
std
::
string
&
shape_str
)
{
std
::
vector
<
int
>
ret
;
std
::
istringstream
iss
(
shape_str
);
int
elem
;
while
(
iss
.
good
())
{
iss
>>
elem
;
ret
.
emplace_back
(
elem
);
}
return
ret
;
}
const
size_t
UndeterminedShapeType
::
fields_num
=
6
;
AbstractBasePtr
InferImplEnvGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
...
...
@@ -62,6 +126,31 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
if
(
type
->
type_id
()
!=
kObjectTypeSymbolicKeyType
)
{
MS_LOG
(
EXCEPTION
)
<<
"EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: "
<<
key
->
ToString
();
}
if
(
key
->
sparse_grad
())
{
// Will be fixed once undetermined type ready
auto
sparse_shape_types
=
common
::
GetEnv
(
"UNDETERMINED_SPARSE_SHAPE_TYPES"
);
if
(
sparse_shape_types
.
empty
())
{
sparse_shape_types
=
"w1:2:Int32:2 1 2:Float32:3 1 2"
;
}
MS_LOG
(
DEBUG
)
<<
"EnvGetItem is sparse_grad "
<<
key
->
ToString
()
<<
", Undetermined shape is "
<<
sparse_shape_types
;
auto
shape_types
=
UndeterminedShapeType
(
sparse_shape_types
);
AbstractBasePtrList
sparse_list
;
// indices
auto
indices_ele
=
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
shape_types
.
indices_type
());
auto
indices
=
std
::
make_shared
<
AbstractTensor
>
(
indices_ele
,
std
::
make_shared
<
Shape
>
(
shape_types
.
indices_shape
()));
sparse_list
.
emplace_back
(
indices
);
// values
auto
dout_ele
=
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
shape_types
.
values_type
());
auto
dout
=
std
::
make_shared
<
AbstractTensor
>
(
dout_ele
,
std
::
make_shared
<
Shape
>
(
shape_types
.
values_shape
()));
sparse_list
.
emplace_back
(
dout
);
// dense_shape
sparse_list
.
emplace_back
(
std
::
make_shared
<
AbstractTuple
>
(
shape_types
.
dense_shape
()));
return
std
::
make_shared
<
AbstractTuple
>
(
sparse_list
);
}
if
(
!
key
->
GetValueTrack
()
->
isa
<
SymbolicKeyInstance
>
())
{
return
dflt
;
}
...
...
@@ -80,8 +169,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
3
);
auto
key
=
args_spec_list
[
1
];
auto
value
=
args_spec_list
[
2
];
ValuePtr
key_value_ptr
=
key
->
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
key_value_ptr
);
auto
key_value_track
=
key_value_ptr
->
cast
<
SymbolicKeyInstancePtr
>
();
...
...
@@ -91,7 +178,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
}
auto
expected
=
key_value_track
->
abstract
();
MS_EXCEPTION_IF_NULL
(
expected
);
(
void
)
expected
->
Join
(
value
);
return
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
std
::
make_shared
<
EnvType
>
());
}
...
...
@@ -126,7 +212,9 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
if
(
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of make_ref should be a RefKey but a "
<<
type
->
ToString
();
}
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
auto
ret
=
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
ret
->
set_sparse_grad
(
args_spec_list
[
2
]
->
sparse_grad
());
return
ret
;
}
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
acaa66a7
...
...
@@ -38,6 +38,7 @@
#include "pipeline/remove_value_node_dup.h"
#include "optimizer/optimizer.h"
#include "vm/transform.h"
#include "parse/python_adapter.h"
namespace
mindspore
{
namespace
pipeline
{
...
...
@@ -228,6 +229,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
if
(
param_node
->
has_default
())
{
auto
param_value
=
std
::
dynamic_pointer_cast
<
ParamValuePy
>
(
param_node
->
default_param
());
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
parse
::
data_converter
::
PyDataToValue
(
param_value
->
value
()),
true
);
auto
sparse_grad
=
py
::
cast
<
bool
>
(
parse
::
python_adapter
::
GetPyObjAttr
(
param_value
->
value
(),
"sparse_grad"
));
ptr
->
set_sparse_grad
(
sparse_grad
);
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
ptr
);
args_spec
.
push_back
(
ptr
);
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
浏览文件 @
acaa66a7
...
...
@@ -51,6 +51,7 @@ ValuePtr AbstractBase::BuildValue() const {
AbstractBasePtr
AbstractBase
::
Broaden
()
const
{
AbstractBasePtr
clone
=
Clone
();
clone
->
set_value
(
kAnyValue
);
clone
->
set_sparse_grad
(
sparse_grad_
);
return
clone
;
}
...
...
@@ -63,7 +64,8 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL
(
type_
);
MS_EXCEPTION_IF_NULL
(
shape_
);
buffer
<<
type_name
()
<<
"("
<<
"Type: "
<<
type_
->
ToString
()
<<
" Value: "
<<
value
<<
" Shape: "
<<
shape_
->
ToString
()
<<
")"
;
<<
"Type: "
<<
type_
->
ToString
()
<<
" Value: "
<<
value
<<
" Shape: "
<<
shape_
->
ToString
()
<<
" sparse_grad: "
<<
sparse_grad_
<<
")"
;
return
buffer
.
str
();
}
...
...
@@ -72,16 +74,22 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
AbstractBasePtr
AbstractScalar
::
Join
(
const
AbstractBasePtr
&
other
)
{
MS_EXCEPTION_IF_NULL
(
other
);
if
(
*
this
==
*
other
)
{
return
shared_from_base
<
AbstractBase
>
();
auto
ret
=
shared_from_base
<
AbstractBase
>
();
ret
->
set_sparse_grad
(
sparse_grad
());
return
ret
;
}
auto
value_self
=
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
value_self
);
ValuePtr
res_value
=
ValueJoin
(
value_self
,
other
->
GetValueTrack
());
TypePtr
res_type
=
TypeJoin
(
GetTypeTrack
(),
other
->
GetTypeTrack
());
if
(
res_value
==
value_self
)
{
return
shared_from_base
<
AbstractBase
>
();
auto
ret
=
shared_from_base
<
AbstractBase
>
();
ret
->
set_sparse_grad
(
sparse_grad
());
return
ret
;
}
return
std
::
make_shared
<
AbstractScalar
>
(
res_value
,
res_type
);
auto
ret
=
std
::
make_shared
<
AbstractScalar
>
(
res_value
,
res_type
);
ret
->
set_sparse_grad
(
sparse_grad
());
return
ret
;
}
AbstractBasePtr
AbstractType
::
Clone
()
const
{
...
...
@@ -423,7 +431,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
auto
element
=
element_
->
Join
(
other_tensor
->
element_
);
auto
shape
=
ShapeJoin
(
this
->
shape
(),
other_tensor
->
shape
());
return
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
auto
ret
=
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
ret
->
set_sparse_grad
(
sparse_grad
());
return
ret
;
}
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
...
...
@@ -463,6 +473,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr
shp
=
shape
();
clone
->
set_shape
(
shp
->
Clone
());
clone
->
set_value
(
GetValueTrack
());
clone
->
set_sparse_grad
(
sparse_grad
());
return
clone
;
}
...
...
@@ -472,6 +483,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
auto
shp
=
shape
();
broaden
->
set_shape
(
shp
->
Clone
());
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_sparse_grad
(
sparse_grad
());
return
broaden
;
}
...
...
@@ -482,6 +494,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp
->
Broaden
();
broaden
->
set_shape
(
shp
);
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_sparse_grad
(
sparse_grad
());
return
broaden
;
}
...
...
@@ -502,7 +515,8 @@ std::string AbstractTensor::ToString() const {
MS_EXCEPTION_IF_NULL
(
value_track
);
buffer
<<
type_name
()
<<
"("
<<
"shape: "
<<
shape_track
->
ToString
()
<<
", element: "
<<
element_
->
ToString
()
<<
", value_ptr: "
<<
value_track
<<
", value: "
<<
value_track
->
ToString
()
<<
")"
;
<<
", value_ptr: "
<<
value_track
<<
", value: "
<<
value_track
->
ToString
()
<<
" sparse_grad "
<<
sparse_grad
()
<<
")"
;
return
buffer
.
str
();
}
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
浏览文件 @
acaa66a7
...
...
@@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
explicit
AbstractBase
(
const
ValuePtr
&
value
=
nullptr
,
const
TypePtr
&
type
=
kAnyType
,
const
BaseShapePtr
&
shape
=
kNoShape
)
:
value_
(
value
),
type_
(
type
),
shape_
(
shape
)
{}
:
value_
(
value
),
type_
(
type
),
shape_
(
shape
)
,
sparse_grad_
(
false
)
{}
~
AbstractBase
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractBase
,
Base
)
...
...
@@ -53,11 +53,13 @@ class AbstractBase : public Base {
virtual
bool
operator
==
(
const
AbstractBase
&
other
)
const
;
void
set_value
(
const
ValuePtr
&
value
)
{
value_
=
value
;
}
void
set_sparse_grad
(
const
bool
&
sparse_grad
)
{
sparse_grad_
=
sparse_grad
;
}
void
set_type
(
const
TypePtr
&
type
)
{
type_
=
type
;
}
void
set_shape
(
const
BaseShapePtr
&
shape
)
{
shape_
=
shape
;
}
void
set_value_desc
(
const
std
::
string
&
desc
)
{
value_desc_
=
desc
;
}
const
std
::
string
&
value_desc
()
const
{
return
value_desc_
;
}
ValuePtr
GetValueTrack
()
const
{
return
value_
;
}
bool
sparse_grad
()
const
{
return
sparse_grad_
;
}
TypePtr
GetTypeTrack
()
const
{
return
type_
;
}
BaseShapePtr
GetShapeTrack
()
const
{
return
shape_
;
}
...
...
@@ -85,6 +87,7 @@ class AbstractBase : public Base {
TypePtr
type_
;
BaseShapePtr
shape_
;
std
::
string
value_desc_
;
// store initial value description for error report
bool
sparse_grad_
;
};
class
AbstractScalar
:
public
AbstractBase
{
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
acaa66a7
...
...
@@ -851,7 +851,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
}
auto
refkey
=
key_value
->
cast
<
RefKeyPtr
>
();
if
(
refkey
==
nullptr
)
{
return
std
::
make_shared
<
EvalResult
>
(
std
::
make_shared
<
AbstractScalar
>
(
type
),
std
::
make_shared
<
AttrValueMap
>
());
auto
ret
=
std
::
make_shared
<
AbstractScalar
>
(
type
);
auto
ref_value
=
ref_abs
->
ref
();
MS_EXCEPTION_IF_NULL
(
ref_value
);
ret
->
set_sparse_grad
(
ref_value
->
sparse_grad
());
return
std
::
make_shared
<
EvalResult
>
(
ret
,
std
::
make_shared
<
AttrValueMap
>
());
}
std
::
string
name
=
refkey
->
tag
();
...
...
@@ -865,6 +869,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x
=
SensitivityTransform
(
x
);
std
::
shared_ptr
<
SymbolicKeyInstance
>
key
=
std
::
make_shared
<
SymbolicKeyInstance
>
(
node
,
x
);
std
::
shared_ptr
<
AbstractScalar
>
abs_scalar
=
std
::
make_shared
<
AbstractScalar
>
(
key
,
type
);
abs_scalar
->
set_sparse_grad
(
x
->
sparse_grad
());
return
std
::
make_shared
<
EvalResult
>
(
abs_scalar
,
std
::
make_shared
<
AttrValueMap
>
());
}
};
...
...
mindspore/common/parameter.py
浏览文件 @
acaa66a7
...
...
@@ -50,12 +50,14 @@ class Parameter:
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (bool): True if the parameter's gradient is sparse. Default: False.
"""
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
):
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
,
sparse_grad
=
False
):
self
.
set_parameter_data
(
default_input
)
self
.
name
=
name
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
self
.
sparse_grad
=
sparse_grad
self
.
_is_init
=
False
self
.
_sliced
=
False
self
.
clone_info
=
_CloneInfo
()
...
...
@@ -168,6 +170,17 @@ class Parameter:
raise
TypeError
(
"`requires_grad` parameter must be bool type"
)
self
.
_requires_grad
=
value
@
property
def
sparse_grad
(
self
):
"""Return whether the parameter's gradient is sparse."""
return
self
.
_sparse_grad
@
sparse_grad
.
setter
def
sparse_grad
(
self
,
value
=
True
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`sparse_grad` parameter must be bool type"
)
self
.
_sparse_grad
=
value
@
property
def
data
(
self
):
return
self
.
default_input
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
acaa66a7
...
...
@@ -30,6 +30,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum()
transpose
=
P
.
Transpose
()
shape_op
=
P
.
Shape
()
reshape
=
P
.
Reshape
()
size_op
=
P
.
Size
()
invert_permutation
=
P
.
InvertPermutation
()
logical_and
=
P
.
LogicalAnd
()
...
...
@@ -284,6 +285,37 @@ def get_bprop_gather_v2(self):
return
bprop
@
bprop_getters
.
register
(
P
.
SparseGatherV2
)
def
get_bprop_sparse_gather_v2
(
self
):
"""Generate bprop for SparseGatherV2"""
def
bprop
(
x
,
indices
,
axis
,
out
,
dout
):
x_shp
=
shape_op
(
x
)
if
axis
==
0
:
indices_size
=
(
size_op
(
indices
),)
x_tail_shp
=
x_shp
[
1
:]
values_shape
=
indices_size
+
x_tail_shp
values
=
reshape
(
dout
,
values_shape
)
indices
=
reshape
(
indices
,
indices_size
)
return
(
indices
,
values
,
x_shp
),
zeros_like
(
indices
),
zeros_like
(
axis
)
if
F
.
rank
(
dout
)
==
0
:
dout
=
P
.
ExpandDims
()(
dout
,
-
1
)
if
F
.
rank
(
indices
)
==
0
:
indices
=
P
.
ExpandDims
()(
indices
,
-
1
)
out_shp
=
shape_op
(
dout
)
ind_shp
=
shape_op
(
indices
)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1
=
_generate_shape_index
(
out_shp
,
ind_shp
,
axis
)
values_transpose
=
transpose
(
dout
,
perm_1
)
params_grad
=
unsorted_segment_sum
(
values_transpose
,
indices
,
shape_op
(
x
)[
axis
])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
perm_2
=
_generate_inverse_index
(
x_shp
,
axis
)
params_grad
=
transpose
(
params_grad
,
perm_2
)
return
params_grad
,
zeros_like
(
indices
),
zeros_like
(
axis
)
return
bprop
@
bprop_getters
.
register
(
P
.
Range
)
def
get_bprop_range
(
self
):
"""Generate bprop for Range"""
...
...
mindspore/ops/composite/__init__.py
浏览文件 @
acaa66a7
...
...
@@ -20,7 +20,7 @@ Pre-defined combination of operators.
"""
from
.base
import
GradOperation
,
HyperMap
,
MultitypeFuncGraph
,
add_flags
,
\
from
.base
import
GradOperation
,
HyperMap
,
M
ap
,
M
ultitypeFuncGraph
,
add_flags
,
\
grad
,
grad_all
,
grad_all_with_sens
,
grad_by_list
,
grad_by_list_with_sens
,
grad_with_sens
,
\
core
,
env_get
,
tail
,
zip_operation
from
.clip_ops
import
clip_by_value
...
...
mindspore/ops/composite/base.py
浏览文件 @
acaa66a7
...
...
@@ -19,7 +19,7 @@
from
functools
import
partial
from
mindspore
import
context
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
MultitypeFuncGraph_
,
Tail_
,
TensorSlice_
,
\
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
M
ap_
,
M
ultitypeFuncGraph_
,
Tail_
,
TensorSlice_
,
\
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
ZipOperation_
,
ListAppend_
,
TupleGetItemTensor_
from
...common
import
dtype
as
mstype
from
...common.api
import
ms_function
,
_pynative_exec
...
...
@@ -240,6 +240,69 @@ class HyperMap(HyperMap_):
return
func
(
*
args_list
)
return
tuple
(
map
(
hypermap
,
*
args_list
))
class
Map
(
Map_
):
"""
Map will apply the set operation on input sequences.
Which will apply the operations of every elements of the sequence.
Args:
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
the operations should be putted in the first input of the instance.
Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
`(args[0][i], args[1][i])` will be the input of the operation.
If `ops` is not `None`, the first input is the operation, and the other is inputs.
Outputs:
sequence, the output will be same type and same length of sequence from input and the value of each element
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
"""
def
__init__
(
self
,
ops
=
None
):
self
.
ops
=
ops
if
ops
:
Map_
.
__init__
(
self
,
ops
)
else
:
Map_
.
__init__
(
self
)
def
__call__
(
self
,
*
args
):
func
=
args
[
0
]
count
=
0
count_max
=
1
args_list
=
args
[
1
:]
if
self
.
ops
is
not
None
:
func
=
self
.
ops
args_list
=
args
for
item
in
args_list
:
if
isinstance
(
item
,
(
tuple
,
list
)):
count_max
=
len
(
item
)
break
def
get_item
(
x
):
nonlocal
count
if
isinstance
(
x
,
(
tuple
,
list
)):
return
x
[
count
]
return
x
for
i
in
range
(
count_max
):
true_args
=
tuple
(
map
(
get_item
,
args_list
))
func
(
*
true_args
)
count
=
i
+
1
return
True
def
register
(
self
,
*
type_names
):
"""Register a function for the given type string."""
def
deco
(
fn
):
self
.
register_fn
(
type_names
,
fn
)
return
fn
return
deco
class
_ListAppend
(
ListAppend_
):
"""
A metafuncgraph class that append one element to list.
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
acaa66a7
...
...
@@ -21,7 +21,7 @@ A collection of operators to build nerual networks or computing functions.
from
.array_ops
import
(
Argmax
,
Argmin
,
Cast
,
Concat
,
Pack
,
Unpack
,
Diag
,
DiagPart
,
DType
,
ExpandDims
,
Eye
,
Fill
,
GatherNd
,
GatherV2
,
InvertPermutation
,
Fill
,
GatherNd
,
GatherV2
,
SparseGatherV2
,
InvertPermutation
,
IsInstance
,
IsSubClass
,
ArgMaxWithValue
,
OnesLike
,
ZerosLike
,
Rank
,
Reshape
,
ResizeNearestNeighbor
,
ArgMinWithValue
,
Range
,
SameTypeShape
,
ScatterAdd
,
ScatterMax
,
ScatterUpdate
,
...
...
@@ -122,6 +122,7 @@ __all__ = [
'Transpose'
,
'OneHot'
,
'GatherV2'
,
'SparseGatherV2'
,
'Concat'
,
'Pack'
,
'Unpack'
,
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
acaa66a7
...
...
@@ -526,6 +526,29 @@ class GatherV2(PrimitiveWithInfer):
return
out
class
SparseGatherV2
(
GatherV2
):
"""
Returns a slice of input tensor based on the specified indices and axis.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The original Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape()[axis])`.
- **axis** (int) - Specifies the dimension index to gather indices.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = P.GatherV2()(input_params, input_indices, axis)
"""
class
Range
(
PrimitiveWithInfer
):
r
"""
Creates a sequence of numbers.
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
acaa66a7
...
...
@@ -332,6 +332,8 @@ class CheckBprop(PrimitiveWithInfer):
def
infer_shape
(
self
,
xshapes
,
yshapes
):
tips
=
f
'Bprop of
{
self
.
prim_to_check
}
'
validator
.
check_value_type
(
'grads'
,
xshapes
,
(
tuple
,),
tips
)
validator
.
check_value_type
(
'params'
,
yshapes
,
(
tuple
,),
tips
)
if
len
(
xshapes
)
<
len
(
yshapes
):
raise
TypeError
(
f
"
{
tips
}
, the size of output should be
{
len
(
yshapes
)
}
,"
f
" but got
{
len
(
xshapes
)
}
."
)
...
...
@@ -348,6 +350,8 @@ class CheckBprop(PrimitiveWithInfer):
def
infer_dtype
(
self
,
xdtypes
,
ydtypes
):
tips
=
f
'Bprop of
{
self
.
prim_to_check
}
'
validator
.
check_value_type
(
'grads'
,
xdtypes
,
(
tuple
,),
tips
)
validator
.
check_value_type
(
'params'
,
ydtypes
,
(
tuple
,),
tips
)
if
len
(
xdtypes
)
<
len
(
ydtypes
):
raise
TypeError
(
f
"
{
tips
}
, the size of output should be
{
len
(
ydtypes
)
}
,"
f
" but got
{
len
(
xdtypes
)
}
."
)
...
...
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
0 → 100644
浏览文件 @
acaa66a7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
,
context
from
mindspore.common.api
import
_executor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Optimizer
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_update_run_op_for_map
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
op_mul
=
P
.
Mul
()
op_square
=
P
.
Square
()
op_sqrt
=
P
.
Sqrt
()
op_cast
=
P
.
Cast
()
op_reshape
=
P
.
Reshape
()
op_shape
=
P
.
Shape
()
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
update
=
next_m
/
(
op_sqrt
(
next_v
)
+
eps
)
if
decay_flag
:
update
=
update
+
op_mul
(
weight_decay_tensor
,
param_fp32
)
update_with_lr
=
op_mul
(
lr
,
update
)
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
param
,
next_param
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
m
,
next_m
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
v
,
next_v
))
return
next_v
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tuple"
,
"Bool"
)
def
_update_run_op_sparse_for_map
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
return
gradient
[
2
][
2
]
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"beta2"
,
beta2
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"eps"
,
eps
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"weight_dacay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
class
AdamWeightDecaySparse
(
Optimizer
):
"""
Implements Adam algorithm weight decay fix.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`,
and might be in sparse format.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
AdamWeightDecaySparse
,
self
).
__init__
(
learning_rate
,
params
)
if
self
.
is_group
:
raise
RuntimeError
(
f
"The
{
self
.
cls_name
}
optimizer cannot support group setting."
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
self
.
beta1
=
Tensor
(
np
.
array
([
beta1
]).
astype
(
np
.
float32
))
self
.
beta2
=
Tensor
(
np
.
array
([
beta2
]).
astype
(
np
.
float32
))
self
.
eps
=
Tensor
(
np
.
array
([
eps
]).
astype
(
np
.
float32
))
self
.
weight_decay_tensor
=
Tensor
(
np
.
array
([
weight_decay
]).
astype
(
np
.
float32
))
self
.
params
=
self
.
parameters
self
.
moments1
=
self
.
params
.
clone
(
prefix
=
"adam_m"
,
init
=
'zeros'
)
self
.
moments2
=
self
.
params
.
clone
(
prefix
=
"adam_v"
,
init
=
'zeros'
)
self
.
decay_flag
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
params
)
self
.
map
=
C
.
Map
()
def
construct
(
self
,
gradients
):
lr
=
self
.
get_lr
()
updated_velocity
=
self
.
map
(
F
.
partial
(
adam_opt_for_map
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
return
updated_velocity
def
test_AdamWeightDecaySparse
():
""" test_AdamWeightDecaySparse """
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
class
Loss
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Loss
,
self
).
__init__
()
def
construct
(
self
,
base
,
target
):
return
base
class
NetWithSparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
,
sparse_grad
=
True
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
gatherv2
=
P
.
SparseGatherV2
()
self
.
axis
=
0
def
construct
(
self
,
indices
):
return
self
.
gatherv2
(
self
.
w1
,
indices
,
self
.
axis
)
*
self
.
w2
inputs
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
label
=
Tensor
(
np
.
zeros
([
2
,
1
,
2
]).
astype
(
np
.
float32
))
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
loss
=
Loss
()
optimizer
=
AdamWeightDecaySparse
(
net
.
trainable_params
())
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
_executor
.
compile
(
train_network
,
inputs
,
label
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录