Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
44e74ad5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
44e74ad5
编写于
7月 02, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Apply indexed_slices
上级
e03bd975
变更
35
隐藏空白更改
内联
并排
Showing
35 changed file
with
198 addition
and
356 deletion
+198
-356
mindspore/ccsrc/ir/func_graph.cc
mindspore/ccsrc/ir/func_graph.cc
+2
-1
mindspore/ccsrc/ir/func_graph.h
mindspore/ccsrc/ir/func_graph.h
+4
-0
mindspore/ccsrc/ir/func_graph_cloner.cc
mindspore/ccsrc/ir/func_graph_cloner.cc
+2
-0
mindspore/ccsrc/operator/composite/multitype_funcgraph.cc
mindspore/ccsrc/operator/composite/multitype_funcgraph.cc
+39
-26
mindspore/ccsrc/operator/prim_others.cc
mindspore/ccsrc/operator/prim_others.cc
+5
-112
mindspore/ccsrc/optimizer/irpass/inline.h
mindspore/ccsrc/optimizer/irpass/inline.h
+2
-2
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+0
-1
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+0
-2
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+2
-2
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+7
-9
mindspore/ccsrc/pipeline/resource.cc
mindspore/ccsrc/pipeline/resource.cc
+26
-28
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
+7
-30
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
+1
-9
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
+5
-1
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
+8
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+21
-38
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+0
-1
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
...pore/ccsrc/pipeline/static_analysis/program_specialize.cc
+12
-3
mindspore/ccsrc/pipeline/static_analysis/program_specialize.h
...spore/ccsrc/pipeline/static_analysis/program_specialize.h
+3
-2
mindspore/ccsrc/utils/context/ms_context.cc
mindspore/ccsrc/utils/context/ms_context.cc
+1
-1
mindspore/ccsrc/utils/context/ms_context.h
mindspore/ccsrc/utils/context/ms_context.h
+3
-3
mindspore/common/parameter.py
mindspore/common/parameter.py
+1
-28
mindspore/context.py
mindspore/context.py
+8
-9
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+2
-2
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+2
-2
mindspore/nn/optim/lazyadam.py
mindspore/nn/optim/lazyadam.py
+2
-2
mindspore/nn/optim/proximal_ada_grad.py
mindspore/nn/optim/proximal_ada_grad.py
+2
-2
mindspore/ops/functional.py
mindspore/ops/functional.py
+0
-1
tests/ut/python/ir/test_indexed_slices.py
tests/ut/python/ir/test_indexed_slices.py
+16
-25
tests/ut/python/nn/optim/test_adam.py
tests/ut/python/nn/optim/test_adam.py
+3
-3
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
+2
-1
tests/ut/python/nn/optim/test_ftrl.py
tests/ut/python/nn/optim/test_ftrl.py
+3
-3
tests/ut/python/nn/optim/test_lazyadam.py
tests/ut/python/nn/optim/test_lazyadam.py
+3
-3
tests/ut/python/nn/optim/test_proximal_ada_grad.py
tests/ut/python/nn/optim/test_proximal_ada_grad.py
+3
-3
tests/ut/python/pipeline/infer/test_hypermap_specialize.py
tests/ut/python/pipeline/infer/test_hypermap_specialize.py
+1
-1
未找到文件。
mindspore/ccsrc/ir/func_graph.cc
浏览文件 @
44e74ad5
...
@@ -45,7 +45,8 @@ FuncGraph::FuncGraph()
...
@@ -45,7 +45,8 @@ FuncGraph::FuncGraph()
hyper_param_count_
(
0
),
hyper_param_count_
(
0
),
is_generated_
(
false
),
is_generated_
(
false
),
return_
(
nullptr
),
return_
(
nullptr
),
manager_
(
std
::
weak_ptr
<
FuncGraphManager
>
())
{
manager_
(
std
::
weak_ptr
<
FuncGraphManager
>
()),
stub_
(
false
)
{
debug_info_
=
std
::
make_shared
<
GraphDebugInfo
>
();
debug_info_
=
std
::
make_shared
<
GraphDebugInfo
>
();
}
}
...
...
mindspore/ccsrc/ir/func_graph.h
浏览文件 @
44e74ad5
...
@@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase {
...
@@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase {
void
SetEffectDepends
(
const
std
::
vector
<
AnfNodePtr
>
&
depend_inputs
);
void
SetEffectDepends
(
const
std
::
vector
<
AnfNodePtr
>
&
depend_inputs
);
bool
HasEffect
(
const
CNodePtr
&
cnode
);
bool
HasEffect
(
const
CNodePtr
&
cnode
);
bool
stub
()
const
{
return
stub_
;
}
void
set_stub
(
bool
stub
)
{
stub_
=
stub
;
}
private:
private:
// graph is manipulated by manager and others
// graph is manipulated by manager and others
friend
FuncGraphManager
;
friend
FuncGraphManager
;
...
@@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase {
...
@@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase {
// CNode order which relates to origin code order
// CNode order which relates to origin code order
std
::
list
<
CNodePtr
>
order_
;
std
::
list
<
CNodePtr
>
order_
;
bool
stub_
;
};
};
inline
CNodePtr
NewCNode
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs
,
const
FuncGraphPtr
&
fg
)
{
inline
CNodePtr
NewCNode
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs
,
const
FuncGraphPtr
&
fg
)
{
...
...
mindspore/ccsrc/ir/func_graph_cloner.cc
浏览文件 @
44e74ad5
...
@@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
...
@@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(
*
target_func_graph
)
->
set_kwonlyargs_count
(
func_graph
->
kwonlyargs_count
());
(
*
target_func_graph
)
->
set_kwonlyargs_count
(
func_graph
->
kwonlyargs_count
());
(
*
target_func_graph
)
->
set_hyper_param_count
(
func_graph
->
hyper_param_count
());
(
*
target_func_graph
)
->
set_hyper_param_count
(
func_graph
->
hyper_param_count
());
(
*
target_func_graph
)
->
set_is_generate
(
func_graph
->
is_generated
());
(
*
target_func_graph
)
->
set_is_generate
(
func_graph
->
is_generated
());
(
*
target_func_graph
)
->
set_stub
(
func_graph
->
stub
());
TraceManager
::
EndTrace
();
TraceManager
::
EndTrace
();
}
}
...
@@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
...
@@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph
->
set_kwonlyargs_count
(
func_graph
->
kwonlyargs_count
());
new_func_graph
->
set_kwonlyargs_count
(
func_graph
->
kwonlyargs_count
());
new_func_graph
->
set_hyper_param_count
(
func_graph
->
hyper_param_count
());
new_func_graph
->
set_hyper_param_count
(
func_graph
->
hyper_param_count
());
new_func_graph
->
set_is_generate
(
func_graph
->
is_generated
());
new_func_graph
->
set_is_generate
(
func_graph
->
is_generated
());
new_func_graph
->
set_stub
(
func_graph
->
stub
());
for
(
auto
&
item
:
func_graph
->
parameter_default_value
())
{
for
(
auto
&
item
:
func_graph
->
parameter_default_value
())
{
new_func_graph
->
set_param_default_value
(
item
.
first
,
cloner
[
item
.
second
]);
new_func_graph
->
set_param_default_value
(
item
.
first
,
cloner
[
item
.
second
]);
}
}
...
...
mindspore/ccsrc/operator/composite/multitype_funcgraph.cc
浏览文件 @
44e74ad5
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "operator/cc_implementations.h"
#include "optimizer/opt.h"
#include "optimizer/opt.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h"
#include "utils/symbolic.h"
#include "pybind_api/api_register.h"
#include "pybind_api/api_register.h"
#include "./common.h"
#include "./common.h"
...
@@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
...
@@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
}
}
return
item
.
second
;
return
item
.
second
;
}
}
// Try best match
return
py
::
none
();
py
::
function
py_fn_subclass
;
}
size_t
subclass_match_cnt
=
0
;
for
(
auto
&
item
:
fn_cache_py_
)
{
FuncGraphPtr
GenerateStubFunc
(
const
TypePtrList
&
types
)
{
TypePtrList
sign
=
item
.
first
;
auto
context
=
MsContext
::
GetInstance
();
if
(
sign
.
size
()
!=
types
.
size
())
{
MS_EXCEPTION_IF_NULL
(
context
);
continue
;
bool
enable_sparse
=
context
->
enable_sparse
();
if
(
!
enable_sparse
)
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
parameters
;
ParameterPtr
undetermined_param
=
nullptr
;
auto
stub
=
std
::
make_shared
<
FuncGraph
>
();
for
(
size_t
i
=
0
;
i
<
types
.
size
();
++
i
)
{
auto
param
=
stub
->
add_parameter
();
parameters
.
push_back
(
param
);
if
(
types
[
i
]
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
undetermined_param
=
param
;
}
}
auto
match
=
true
;
}
for
(
size_t
i
=
0
;
i
<
sign
.
size
();
++
i
)
{
if
(
undetermined_param
!=
nullptr
)
{
if
(
!
IsIdentidityOrSubclass
(
UnwrapRef
(
types
[
i
]),
sign
[
i
])
&&
std
::
vector
<
AnfNodePtr
>
inputs
{
NewValueNode
(
prim
::
kPrimMakeTuple
)};
!
IsParentOrChildrenType
(
UnwrapRef
(
types
[
i
]),
sign
[
i
]))
{
for
(
size_t
i
=
0
;
i
<
types
.
size
();
++
i
)
{
match
=
false
;
if
(
types
[
i
]
->
type_id
()
==
kObjectTypeFunction
)
{
break
;
std
::
vector
<
AnfNodePtr
>
call_prim
{
parameters
[
i
],
undetermined_param
};
inputs
.
push_back
(
stub
->
NewCNode
(
call_prim
));
}
else
{
inputs
.
push_back
(
parameters
[
i
]);
}
}
}
}
if
(
!
match
)
{
auto
stub_output
=
stub
->
NewCNode
(
inputs
);
continue
;
stub
->
set_output
(
stub_output
);
}
stub
->
set_stub
(
true
);
py_fn_subclass
=
item
.
second
;
return
stub
;
subclass_match_cnt
++
;
}
if
(
subclass_match_cnt
>
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"There are more than one prototypes for overload function match by subclass"
;
}
if
(
subclass_match_cnt
==
1
)
{
MS_LOG
(
DEBUG
)
<<
"Found one subclass match"
;
return
py_fn_subclass
;
}
}
return
py
::
none
()
;
return
nullptr
;
}
}
FuncGraphPtr
MultitypeFuncGraph
::
GenerateFromTypes
(
const
TypePtrList
&
types
)
{
FuncGraphPtr
MultitypeFuncGraph
::
GenerateFromTypes
(
const
TypePtrList
&
types
)
{
...
@@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
...
@@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
MS_LOG
(
DEBUG
)
<<
"Find overload function "
<<
buffer
.
str
()
<<
", function: "
<<
func_graph
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"Find overload function "
<<
buffer
.
str
()
<<
", function: "
<<
func_graph
->
ToString
();
return
func_graph
;
return
func_graph
;
}
}
auto
stub
=
GenerateStubFunc
(
types
);
if
(
stub
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"GenerateStubFunc "
<<
buffer
.
str
()
<<
", function: "
<<
stub
->
ToString
();
return
stub
;
}
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
oss
<<
"There are "
<<
fn_cache_py_
.
size
()
<<
" prototypes for overload function `"
<<
name_
oss
<<
"There are "
<<
fn_cache_py_
.
size
()
<<
" prototypes for overload function `"
<<
name_
<<
"`, corresponding location info:
\n
"
;
<<
"`, corresponding location info:
\n
"
;
...
...
mindspore/ccsrc/operator/prim_others.cc
浏览文件 @
44e74ad5
...
@@ -23,8 +23,8 @@
...
@@ -23,8 +23,8 @@
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
abstract
{
namespace
abstract
{
...
@@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
...
@@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return
AbstractFunction
::
MakeAbstractFunction
(
jv
);
return
AbstractFunction
::
MakeAbstractFunction
(
jv
);
}
}
class
UndeterminedShapeType
{
public:
explicit
UndeterminedShapeType
(
const
std
::
string
&
env_str
)
{
// param_name indices_shape indices_type values_shape values_type dense_shape
// export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1
// 2:Float32:3 1 2"
std
::
vector
<
string
>
fields
;
string
tmp
;
std
::
stringstream
input
(
env_str
);
while
(
std
::
getline
(
input
,
tmp
,
':'
))
{
fields
.
push_back
(
tmp
);
}
if
(
fields
.
size
()
!=
fields_num
)
{
MS_LOG
(
EXCEPTION
)
<<
"Expect "
<<
fields_num
<<
" fields, but got "
<<
fields
.
size
();
}
param_name_
=
fields
[
0
];
indices_shape_
=
GetShape
(
fields
[
1
]);
indices_type_
=
StringToType
(
fields
[
2
]);
values_shape_
=
GetShape
(
fields
[
3
]);
values_type_
=
StringToType
(
fields
[
4
]);
auto
dense_shape_vec
=
GetShape
(
fields
[
5
]);
AbstractBasePtrList
dense_shape_list
;
(
void
)
std
::
transform
(
dense_shape_vec
.
begin
(),
dense_shape_vec
.
end
(),
std
::
back_inserter
(
dense_shape_list
),
[](
const
auto
&
elem
)
{
return
FromValue
(
elem
,
false
);
});
dense_shape_
=
dense_shape_list
;
}
~
UndeterminedShapeType
()
=
default
;
const
std
::
string
&
param_name
()
{
return
param_name_
;
}
const
std
::
vector
<
int
>
&
indices_shape
()
{
return
indices_shape_
;
}
const
TypePtr
&
indices_type
()
{
return
indices_type_
;
}
const
std
::
vector
<
int
>
&
values_shape
()
{
return
values_shape_
;
}
const
TypePtr
&
values_type
()
{
return
values_type_
;
}
const
AbstractBasePtrList
&
dense_shape
()
{
return
dense_shape_
;
}
private:
std
::
string
param_name_
;
std
::
vector
<
int
>
indices_shape_
;
TypePtr
indices_type_
;
std
::
vector
<
int
>
values_shape_
;
TypePtr
values_type_
;
AbstractBasePtrList
dense_shape_
;
static
const
size_t
fields_num
;
std
::
vector
<
int
>
GetShape
(
const
std
::
string
&
shape_str
);
};
std
::
vector
<
int
>
UndeterminedShapeType
::
GetShape
(
const
std
::
string
&
shape_str
)
{
std
::
vector
<
int
>
ret
;
std
::
istringstream
iss
(
shape_str
);
int
elem
;
while
(
iss
.
good
())
{
iss
>>
elem
;
ret
.
emplace_back
(
elem
);
}
return
ret
;
}
const
size_t
UndeterminedShapeType
::
fields_num
=
6
;
std
::
unordered_map
<
std
::
string
,
UndeterminedShapeType
>
g_undetermined_configs
;
void
InitUndeterminedFromEnv
(
const
std
::
string
&
sparse_shape_types
)
{
std
::
string
tmp
;
std
::
stringstream
input
(
sparse_shape_types
);
g_undetermined_configs
.
clear
();
while
(
std
::
getline
(
input
,
tmp
,
';'
))
{
auto
config
=
UndeterminedShapeType
(
tmp
);
g_undetermined_configs
.
insert
(
std
::
make_pair
(
config
.
param_name
(),
config
));
MS_LOG
(
DEBUG
)
<<
"Undetermined config from env: "
<<
tmp
;
}
}
AbstractBasePtr
InferImplEnvGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplEnvGetItem
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
AbstractBasePtrList
&
args_spec_list
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
MS_EXCEPTION_IF_NULL
(
primitive
);
...
@@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
...
@@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
MS_LOG
(
EXCEPTION
)
<<
"EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: "
<<
key
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: "
<<
key
->
ToString
();
}
}
if
(
!
key
->
sparse_grad
().
empty
())
{
// Will be fixed once undetermined type ready
if
(
g_undetermined_configs
.
empty
())
{
auto
sparse_shape_types
=
common
::
GetEnv
(
"UNDETERMINED_SPARSE_SHAPE_TYPES"
);
MS_LOG
(
INFO
)
<<
"Undetermind sparse shape:"
<<
sparse_shape_types
;
if
(
sparse_shape_types
.
empty
())
{
sparse_shape_types
=
"sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"
;
}
InitUndeterminedFromEnv
(
sparse_shape_types
);
}
auto
shape_types
=
g_undetermined_configs
.
find
(
key
->
sparse_grad
());
if
(
shape_types
==
g_undetermined_configs
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Param "
<<
key
->
ToString
()
<<
" has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES"
;
}
MS_LOG
(
DEBUG
)
<<
"EnvGetItem is sparse_grad "
<<
key
->
ToString
();
AbstractBasePtrList
sparse_list
;
// indices
auto
indices_ele
=
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
shape_types
->
second
.
indices_type
());
auto
indices
=
std
::
make_shared
<
AbstractTensor
>
(
indices_ele
,
std
::
make_shared
<
Shape
>
(
shape_types
->
second
.
indices_shape
()));
sparse_list
.
emplace_back
(
indices
);
// values
auto
dout_ele
=
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
shape_types
->
second
.
values_type
());
auto
dout
=
std
::
make_shared
<
AbstractTensor
>
(
dout_ele
,
std
::
make_shared
<
Shape
>
(
shape_types
->
second
.
values_shape
()));
sparse_list
.
emplace_back
(
dout
);
// dense_shape
sparse_list
.
emplace_back
(
std
::
make_shared
<
AbstractTuple
>
(
shape_types
->
second
.
dense_shape
()));
return
std
::
make_shared
<
AbstractTuple
>
(
sparse_list
);
}
auto
context
=
MsContext
::
GetInstance
();
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
_flag
=
context
->
enable_sparse_flag
();
bool
enable_sparse
=
context
->
enable_sparse
();
if
(
enable_sparse
_flag
&&
key
->
has_indexed_slices_grad
()
&&
dflt
->
isa
<
AbstractTensor
>
())
{
if
(
enable_sparse
&&
dflt
->
isa
<
AbstractTensor
>
())
{
auto
dflt_tensor
=
dflt
->
cast
<
AbstractTensorPtr
>
();
auto
dflt_tensor
=
dflt
->
cast
<
AbstractTensorPtr
>
();
return
std
::
make_shared
<
AbstractUndetermined
>
(
dflt_tensor
->
element
()
->
Clone
(),
dflt_tensor
->
shape
()
->
Clone
());
return
std
::
make_shared
<
AbstractUndetermined
>
(
dflt_tensor
->
element
()
->
Clone
(),
dflt_tensor
->
shape
()
->
Clone
());
}
}
if
(
!
key
->
GetValueTrack
()
->
isa
<
SymbolicKeyInstance
>
())
{
if
(
!
key
->
GetValueTrack
()
->
isa
<
SymbolicKeyInstance
>
())
{
return
dflt
;
return
dflt
;
}
}
...
@@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
...
@@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
if
(
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
if
(
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
MS_LOG
(
EXCEPTION
)
<<
"First input of make_ref should be a RefKey but a "
<<
type
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"First input of make_ref should be a RefKey but a "
<<
type
->
ToString
();
}
}
auto
ret
=
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
return
std
::
make_shared
<
AbstractRef
>
(
args_spec_list
[
0
],
args_spec_list
[
1
],
args_spec_list
[
2
]);
ret
->
set_sparse_grad
(
args_spec_list
[
2
]
->
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
args_spec_list
[
2
]
->
has_indexed_slices_grad
());
return
ret
;
}
}
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
AbstractBasePtr
InferImplGetRefKey
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
...
...
mindspore/ccsrc/optimizer/irpass/inline.h
浏览文件 @
44e74ad5
...
@@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
...
@@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
}
}
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
node
);
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
node
);
if
(
fg
->
has_flag
(
FUNC_GRAPH_FLAG_DEFER_INLINE
))
{
if
(
fg
->
has_flag
(
FUNC_GRAPH_FLAG_DEFER_INLINE
)
||
fg
->
stub
()
)
{
return
nullptr
;
return
nullptr
;
}
}
...
@@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor {
...
@@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor {
// G
// G
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
inputs
[
0
]);
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
inputs
[
0
]);
if
(
fg
->
has_flag
(
FUNC_GRAPH_FLAG_DEFER_INLINE
))
{
if
(
fg
->
has_flag
(
FUNC_GRAPH_FLAG_DEFER_INLINE
)
||
fg
->
stub
()
)
{
return
nullptr
;
return
nullptr
;
}
}
// Do not inline GraphKernel to Cell.
// Do not inline GraphKernel to Cell.
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
44e74ad5
...
@@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
...
@@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
std
::
string
env
=
common
::
GetEnv
(
"SLICE_ENV"
);
std
::
string
env
=
common
::
GetEnv
(
"SLICE_ENV"
);
if
(
!
env
.
empty
())
{
if
(
!
env
.
empty
())
{
MS_LOG
(
INFO
)
<<
"Slice tensors shape will be configured from env:"
<<
env
;
MS_LOG
(
INFO
)
<<
"Slice tensors shape will be configured from env:"
<<
env
;
abstract
::
InitUndeterminedFromEnv
(
env
);
}
}
}
}
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
44e74ad5
...
@@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
...
@@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
ValuePtr
value
=
param_value
->
value
();
ValuePtr
value
=
param_value
->
value
();
constexpr
bool
broaden
=
true
;
constexpr
bool
broaden
=
true
;
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
value
,
broaden
);
AbstractBasePtr
ptr
=
abstract
::
FromValue
(
value
,
broaden
);
ptr
->
set_sparse_grad
(
param_value
->
sparse_grad
());
ptr
->
set_has_indexed_slices_grad
(
param_value
->
has_indexed_slices_grad
());
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
ptr
);
parallel
::
ParallelParameterContextRestoreInNoTraining
(
func_graph
,
param_node
,
ptr
);
args_spec
.
push_back
(
ptr
);
args_spec
.
push_back
(
ptr
);
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
44e74ad5
...
@@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) {
...
@@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"set_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
set_enable_graph_kernel
,
.
def
(
"set_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
set_enable_graph_kernel
,
"Set the GraphKernel switch to on or off."
)
"Set the GraphKernel switch to on or off."
)
.
def
(
"get_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
enable_graph_kernel
,
"Get the value of GraphKernel switch."
)
.
def
(
"get_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
enable_graph_kernel
,
"Get the value of GraphKernel switch."
)
.
def
(
"get_enable_sparse
_flag"
,
&
mindspore
::
MsContext
::
enable_sparse_flag
,
"Get whether to enable sparse
."
)
.
def
(
"get_enable_sparse
"
,
&
mindspore
::
MsContext
::
enable_sparse
,
"Get whether to enable sparsity
."
)
.
def
(
"set_enable_sparse
_flag"
,
&
mindspore
::
MsContext
::
set_enable_sparse_flag
,
"Set whether to enable sparse
."
);
.
def
(
"set_enable_sparse
"
,
&
mindspore
::
MsContext
::
set_enable_sparse
,
"Set whether to enable sparsity
."
);
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
44e74ad5
...
@@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
...
@@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
return
true
;
return
true
;
}
}
std
::
vector
<
PassItem
>
kVmPasses
=
{{
"
simplify_data_structures"
,
SimplifyDataStructuresPass
},
std
::
vector
<
PassItem
>
kVmPasses
=
{{
"
opt_a"
,
OptPassAGroup
},
{
"
opt_a"
,
OptPassAGroup
},
{
"
simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_b"
,
OptPassBGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
},
{
"cconv"
,
CconvPass
},
{
"opt_graph_kernel_a"
,
OptPassGraphKernelGroupA
},
{
"opt_graph_kernel_a"
,
OptPassGraphKernelGroupA
},
{
"opt_graph_kernel_b"
,
OptPassGraphKernelGroupB
},
{
"opt_graph_kernel_b"
,
OptPassGraphKernelGroupB
},
{
"add_control_depend"
,
AddControlDependPass
}};
{
"add_control_depend"
,
AddControlDependPass
}};
std
::
vector
<
PassItem
>
kGePasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
std
::
vector
<
PassItem
>
kGePasses
=
{
{
"opt_a"
,
OptPassAGroup
},
{
"opt_a"
,
OptPassAGroup
},
{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_b"
,
OptPassBGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"add_control_depend"
,
AddControlDependPass
},
{
"add_control_depend"
,
AddControlDependPass
},
{
"opt_control"
,
ControlGroup
},
{
"opt_prepare"
,
PrepareGroup
},
{
"opt_control"
,
ControlGroup
},
{
"cconv"
,
CconvPass
}};
{
"opt_prepare"
,
PrepareGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
}};
}
// namespace pipeline
}
// namespace pipeline
...
...
mindspore/ccsrc/pipeline/resource.cc
浏览文件 @
44e74ad5
...
@@ -146,37 +146,35 @@ MethodMap &GetMethodMap() {
...
@@ -146,37 +146,35 @@ MethodMap &GetMethodMap() {
}},
}},
{
kObjectTypeTensorType
,
{
kObjectTypeTensorType
,
{
{
{
"__add__"
,
std
::
string
(
"add"
)},
// C.add
{
"__add__"
,
std
::
string
(
"add"
)},
// C.add
{
"__sub__"
,
std
::
string
(
"sub"
)},
// C.sub
{
"__sub__"
,
std
::
string
(
"sub"
)},
// C.sub
{
"__mul__"
,
std
::
string
(
"mul"
)},
// C.mul
{
"__mul__"
,
std
::
string
(
"mul"
)},
// C.mul
{
"__truediv__"
,
std
::
string
(
"truediv"
)},
// C.truediv
{
"__truediv__"
,
std
::
string
(
"truediv"
)},
// C.truediv
{
"__floordiv__"
,
std
::
string
(
"floordiv"
)},
// C.floordiv
{
"__floordiv__"
,
std
::
string
(
"floordiv"
)},
// C.floordiv
{
"__mod__"
,
std
::
string
(
"mod"
)},
// C.mod
{
"__mod__"
,
std
::
string
(
"mod"
)},
// C.mod
{
"__pow__"
,
std
::
string
(
"pow_"
)},
// C.pow
{
"__pow__"
,
std
::
string
(
"pow_"
)},
// C.pow
{
"__floor__"
,
std
::
string
(
"array_floor"
)},
// C.array_floor
{
"__floor__"
,
std
::
string
(
"array_floor"
)},
// C.array_floor
{
"__trunc__"
,
std
::
string
(
"array_trunc"
)},
// C.array_trunc
{
"__trunc__"
,
std
::
string
(
"array_trunc"
)},
// C.array_trunc
{
"__pos__"
,
std
::
string
(
"array_uadd"
)},
// C.array_uadd
{
"__pos__"
,
std
::
string
(
"array_uadd"
)},
// C.array_uadd
{
"__neg__"
,
std
::
string
(
"array_usub"
)},
// C.array_usub
{
"__neg__"
,
std
::
string
(
"array_usub"
)},
// C.array_usub
{
"__eq__"
,
std
::
string
(
"eq"
)},
// C.eq
{
"__eq__"
,
std
::
string
(
"eq"
)},
// C.eq
{
"__ne__"
,
std
::
string
(
"ne"
)},
// C.ne
{
"__ne__"
,
std
::
string
(
"ne"
)},
// C.ne
{
"__lt__"
,
std
::
string
(
"lt"
)},
// C.lt
{
"__lt__"
,
std
::
string
(
"lt"
)},
// C.lt
{
"__gt__"
,
std
::
string
(
"gt"
)},
// C.gt
{
"__gt__"
,
std
::
string
(
"gt"
)},
// C.gt
{
"__le__"
,
std
::
string
(
"le"
)},
// C.le
{
"__le__"
,
std
::
string
(
"le"
)},
// C.le
{
"__ge__"
,
std
::
string
(
"ge"
)},
// C.ge
{
"__ge__"
,
std
::
string
(
"ge"
)},
// C.ge
{
"__matmul__"
,
prim
::
kPrimDot
},
// P.dot,
{
"__matmul__"
,
prim
::
kPrimDot
},
// P.dot,
{
"__len__"
,
prim
::
kPrimArrayLen
},
// P.array_len,
{
"__len__"
,
prim
::
kPrimArrayLen
},
// P.array_len,
{
"__getitem__"
,
prim
::
kPrimArrayGetItem
},
// P.array_getitem,
{
"__getitem__"
,
prim
::
kPrimArrayGetItem
},
// P.array_getitem,
{
"__setitem__"
,
prim
::
kPrimArraySetItem
},
// P.array_setitem,
{
"__setitem__"
,
prim
::
kPrimArraySetItem
},
// P.array_setitem,
{
"__ms_iter__"
,
std
::
string
(
"array_iter"
)},
// C.array_iter
{
"__ms_iter__"
,
std
::
string
(
"array_iter"
)},
// C.array_iter
{
"__ms_to_array__"
,
prim
::
kPrimIdentity
},
// P.identity,
{
"__ms_to_array__"
,
prim
::
kPrimIdentity
},
// P.identity,
{
"item"
,
prim
::
kPrimArrayToScalar
},
// P.array_to_scalar,
{
"item"
,
prim
::
kPrimArrayToScalar
},
// P.array_to_scalar,
{
"transpose"
,
std
::
string
(
"transpose"
)},
// P.transpose
{
"transpose"
,
std
::
string
(
"transpose"
)},
// P.transpose
{
"__bool__"
,
std
::
string
(
"tensor_bool"
)},
// C.tensor_bool
{
"__bool__"
,
std
::
string
(
"tensor_bool"
)},
// C.tensor_bool
{
"is_indexed_slices"
,
prim
::
kPrimIsIndexedSlices
},
// F.is_indexed_slices
}},
}},
{
kObjectTypeIndexedSlicesType
,
{
kObjectTypeIndexedSlicesType
,
{
{
{
"is_indexed_slices"
,
prim
::
kPrimIsIndexedSlices
},
// F.is_indexed_slices
{
"values"
,
prim
::
kPrimIndexedSlicesGetValues
},
// F.indexed_slices_get_values
{
"values"
,
prim
::
kPrimIndexedSlicesGetValues
},
// F.indexed_slices_get_values
{
"indices"
,
prim
::
kPrimIndexedSlicesGetIndices
},
// F.indexed_slices_get_indices
{
"indices"
,
prim
::
kPrimIndexedSlicesGetIndices
},
// F.indexed_slices_get_indices
{
"dense_shape"
,
prim
::
kPrimIndexedSlicesGetDenseShape
},
// F.indexed_slices_get_dense_shape
{
"dense_shape"
,
prim
::
kPrimIndexedSlicesGetDenseShape
},
// F.indexed_slices_get_dense_shape
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
浏览文件 @
44e74ad5
...
@@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const {
...
@@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const {
AbstractBasePtr
AbstractBase
::
Broaden
()
const
{
AbstractBasePtr
AbstractBase
::
Broaden
()
const
{
AbstractBasePtr
clone
=
Clone
();
AbstractBasePtr
clone
=
Clone
();
clone
->
set_value
(
kAnyValue
);
clone
->
set_value
(
kAnyValue
);
clone
->
set_sparse_grad
(
sparse_grad_
);
return
clone
;
return
clone
;
}
}
...
@@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const {
...
@@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL
(
type_
);
MS_EXCEPTION_IF_NULL
(
type_
);
MS_EXCEPTION_IF_NULL
(
shape_
);
MS_EXCEPTION_IF_NULL
(
shape_
);
buffer
<<
type_name
()
<<
"("
buffer
<<
type_name
()
<<
"("
<<
"Type: "
<<
type_
->
ToString
()
<<
" Value: "
<<
value
<<
" Shape: "
<<
shape_
->
ToString
()
<<
"Type: "
<<
type_
->
ToString
()
<<
" Value: "
<<
value
<<
" Shape: "
<<
shape_
->
ToString
()
<<
")"
;
<<
" sparse_grad: "
<<
sparse_grad_
<<
" has_indexed_slices_grad: "
<<
has_indexed_slices_grad_
<<
")"
;
return
buffer
.
str
();
return
buffer
.
str
();
}
}
...
@@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
...
@@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
AbstractBasePtr
AbstractScalar
::
Join
(
const
AbstractBasePtr
&
other
)
{
AbstractBasePtr
AbstractScalar
::
Join
(
const
AbstractBasePtr
&
other
)
{
MS_EXCEPTION_IF_NULL
(
other
);
MS_EXCEPTION_IF_NULL
(
other
);
if
(
*
this
==
*
other
)
{
if
(
*
this
==
*
other
)
{
auto
ret
=
shared_from_base
<
AbstractBase
>
();
return
shared_from_base
<
AbstractBase
>
();
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
}
auto
value_self
=
GetValueTrack
();
auto
value_self
=
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
value_self
);
MS_EXCEPTION_IF_NULL
(
value_self
);
ValuePtr
res_value
=
ValueJoin
(
value_self
,
other
->
GetValueTrack
());
ValuePtr
res_value
=
ValueJoin
(
value_self
,
other
->
GetValueTrack
());
TypePtr
res_type
=
TypeJoin
(
GetTypeTrack
(),
other
->
GetTypeTrack
());
TypePtr
res_type
=
TypeJoin
(
GetTypeTrack
(),
other
->
GetTypeTrack
());
if
(
res_value
==
value_self
)
{
if
(
res_value
==
value_self
)
{
auto
ret
=
shared_from_base
<
AbstractBase
>
();
return
shared_from_base
<
AbstractBase
>
();
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
}
auto
ret
=
std
::
make_shared
<
AbstractScalar
>
(
res_value
,
res_type
);
return
std
::
make_shared
<
AbstractScalar
>
(
res_value
,
res_type
);
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
}
AbstractBasePtr
AbstractType
::
Clone
()
const
{
AbstractBasePtr
AbstractType
::
Clone
()
const
{
...
@@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
...
@@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"Join failed as type mismatch, this: "
<<
ToString
()
<<
", other: "
<<
other
->
ToString
();
}
}
if
(
*
this
==
*
other
)
{
if
(
*
this
==
*
other
)
{
if
(
sparse_grad
()
==
other
->
sparse_grad
())
{
return
shared_from_base
<
AbstractBase
>
();
return
shared_from_base
<
AbstractBase
>
();
}
}
}
auto
element
=
element_
->
Join
(
other_tensor
->
element_
);
auto
element
=
element_
->
Join
(
other_tensor
->
element_
);
auto
shape
=
ShapeJoin
(
this
->
shape
(),
other_tensor
->
shape
());
auto
shape
=
ShapeJoin
(
this
->
shape
(),
other_tensor
->
shape
());
auto
ret
=
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
return
std
::
make_shared
<
AbstractTensor
>
(
element
,
shape
);
ret
->
set_sparse_grad
(
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
ret
;
}
}
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
bool
AbstractTensor
::
operator
==
(
const
AbstractTensor
&
other
)
const
{
...
@@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const {
...
@@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr
shp
=
shape
();
ShapePtr
shp
=
shape
();
clone
->
set_shape
(
shp
->
Clone
());
clone
->
set_shape
(
shp
->
Clone
());
clone
->
set_value
(
GetValueTrack
());
clone
->
set_value
(
GetValueTrack
());
clone
->
set_sparse_grad
(
sparse_grad
());
clone
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
clone
;
return
clone
;
}
}
...
@@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const {
...
@@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const {
auto
shp
=
shape
();
auto
shp
=
shape
();
broaden
->
set_shape
(
shp
->
Clone
());
broaden
->
set_shape
(
shp
->
Clone
());
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_sparse_grad
(
sparse_grad
());
broaden
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
broaden
;
return
broaden
;
}
}
...
@@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
...
@@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp
->
Broaden
();
shp
->
Broaden
();
broaden
->
set_shape
(
shp
);
broaden
->
set_shape
(
shp
);
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_value
(
kAnyValue
);
broaden
->
set_sparse_grad
(
sparse_grad
());
broaden
->
set_has_indexed_slices_grad
(
has_indexed_slices_grad
());
return
broaden
;
return
broaden
;
}
}
...
@@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const {
...
@@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const {
MS_EXCEPTION_IF_NULL
(
value_track
);
MS_EXCEPTION_IF_NULL
(
value_track
);
buffer
<<
type_name
()
<<
"("
buffer
<<
type_name
()
<<
"("
<<
"shape: "
<<
shape_track
->
ToString
()
<<
", element: "
<<
element_
->
ToString
()
<<
"shape: "
<<
shape_track
->
ToString
()
<<
", element: "
<<
element_
->
ToString
()
<<
", value_ptr: "
<<
value_track
<<
", value: "
<<
value_track
->
ToString
()
<<
" sparse_grad "
<<
sparse_grad
()
<<
", value_ptr: "
<<
value_track
<<
", value: "
<<
value_track
->
ToString
()
<<
")"
;
<<
" has_indexed_slices_grad "
<<
has_indexed_slices_grad
()
<<
")"
;
return
buffer
.
str
();
return
buffer
.
str
();
}
}
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
浏览文件 @
44e74ad5
...
@@ -44,7 +44,7 @@ class AbstractBase : public Base {
...
@@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
public:
explicit
AbstractBase
(
const
ValuePtr
&
value
=
nullptr
,
const
TypePtr
&
type
=
kAnyType
,
explicit
AbstractBase
(
const
ValuePtr
&
value
=
nullptr
,
const
TypePtr
&
type
=
kAnyType
,
const
BaseShapePtr
&
shape
=
kNoShape
)
const
BaseShapePtr
&
shape
=
kNoShape
)
:
value_
(
value
),
type_
(
type
),
shape_
(
shape
)
,
sparse_grad_
(
""
),
has_indexed_slices_grad_
(
false
)
{}
:
value_
(
value
),
type_
(
type
),
shape_
(
shape
)
{}
~
AbstractBase
()
override
=
default
;
~
AbstractBase
()
override
=
default
;
MS_DECLARE_PARENT
(
AbstractBase
,
Base
)
MS_DECLARE_PARENT
(
AbstractBase
,
Base
)
...
@@ -53,17 +53,11 @@ class AbstractBase : public Base {
...
@@ -53,17 +53,11 @@ class AbstractBase : public Base {
virtual
bool
operator
==
(
const
AbstractBase
&
other
)
const
;
virtual
bool
operator
==
(
const
AbstractBase
&
other
)
const
;
void
set_value
(
const
ValuePtr
&
value
)
{
value_
=
value
;
}
void
set_value
(
const
ValuePtr
&
value
)
{
value_
=
value
;
}
void
set_sparse_grad
(
const
std
::
string
&
sparse_grad
)
{
sparse_grad_
=
sparse_grad
;
}
void
set_has_indexed_slices_grad
(
const
bool
&
has_indexed_slices_grad
)
{
has_indexed_slices_grad_
=
has_indexed_slices_grad
;
}
void
set_type
(
const
TypePtr
&
type
)
{
type_
=
type
;
}
void
set_type
(
const
TypePtr
&
type
)
{
type_
=
type
;
}
void
set_shape
(
const
BaseShapePtr
&
shape
)
{
shape_
=
shape
;
}
void
set_shape
(
const
BaseShapePtr
&
shape
)
{
shape_
=
shape
;
}
void
set_value_desc
(
const
std
::
string
&
desc
)
{
value_desc_
=
desc
;
}
void
set_value_desc
(
const
std
::
string
&
desc
)
{
value_desc_
=
desc
;
}
const
std
::
string
&
value_desc
()
const
{
return
value_desc_
;
}
const
std
::
string
&
value_desc
()
const
{
return
value_desc_
;
}
ValuePtr
GetValueTrack
()
const
{
return
value_
;
}
ValuePtr
GetValueTrack
()
const
{
return
value_
;
}
const
std
::
string
&
sparse_grad
()
const
{
return
sparse_grad_
;
}
const
bool
&
has_indexed_slices_grad
()
const
{
return
has_indexed_slices_grad_
;
}
TypePtr
GetTypeTrack
()
const
{
return
type_
;
}
TypePtr
GetTypeTrack
()
const
{
return
type_
;
}
BaseShapePtr
GetShapeTrack
()
const
{
return
shape_
;
}
BaseShapePtr
GetShapeTrack
()
const
{
return
shape_
;
}
...
@@ -91,8 +85,6 @@ class AbstractBase : public Base {
...
@@ -91,8 +85,6 @@ class AbstractBase : public Base {
TypePtr
type_
;
TypePtr
type_
;
BaseShapePtr
shape_
;
BaseShapePtr
shape_
;
std
::
string
value_desc_
;
// store initial value description for error report
std
::
string
value_desc_
;
// store initial value description for error report
std
::
string
sparse_grad_
;
bool
has_indexed_slices_grad_
;
};
};
class
AbstractScalar
:
public
AbstractBase
{
class
AbstractScalar
:
public
AbstractBase
{
...
...
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
浏览文件 @
44e74ad5
...
@@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
...
@@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
}
MS_EXCEPTION_IF_NULL
(
ret_base
);
MS_EXCEPTION_IF_NULL
(
ret_base
);
MS_LOG
(
DEBUG
)
<<
"BaseFuncGraph "
<<
fg
->
ToString
()
<<
" eval end, evaluated abstract: "
<<
ret_base
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"BaseFuncGraph "
<<
fg
->
ToString
()
<<
" eval end, evaluated abstract: "
<<
ret_base
->
ToString
()
<<
", is stub: "
<<
fg
->
stub
();
if
(
fg
->
stub
())
{
return
std
::
make_shared
<
EvalResult
>
(
std
::
make_shared
<
AbstractUndetermined
>
(),
nullptr
);
}
return
std
::
make_shared
<
EvalResult
>
(
ret_base
,
nullptr
);
return
std
::
make_shared
<
EvalResult
>
(
ret_base
,
nullptr
);
}
}
...
...
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
浏览文件 @
44e74ad5
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <vector>
#include <vector>
#include "pipeline/static_analysis/static_analysis.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "utils/context/ms_context.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
abstract
{
namespace
abstract
{
...
@@ -59,6 +60,13 @@ class Evaluator : public Base {
...
@@ -59,6 +60,13 @@ class Evaluator : public Base {
}
}
virtual
EvalResultPtr
AbstractEval
(
const
AbstractBasePtrList
&
args_spec_list
)
{
virtual
EvalResultPtr
AbstractEval
(
const
AbstractBasePtrList
&
args_spec_list
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
();
if
(
!
enable_sparse
)
{
return
nullptr
;
}
auto
is_abstract
=
std
::
any_of
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
auto
&
arg
)
{
auto
is_abstract
=
std
::
any_of
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
auto
&
arg
)
{
if
(
arg
->
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
if
(
arg
->
BuildType
()
->
type_id
()
==
kObjectTypeUndeterminedType
)
{
return
true
;
return
true
;
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
44e74ad5
...
@@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
...
@@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using
mindspore
::
parse
::
PyObjectWrapper
;
using
mindspore
::
parse
::
PyObjectWrapper
;
EvalResultPtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
EvalResultPtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
auto
context
=
MsContext
::
GetInstance
();
if
(
prim_
!=
prim
::
kPrimMakeTuple
&&
prim_
!=
prim
::
kPrimSwitch
)
{
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
if
(
enable_sparse_flag
&&
prim_
!=
prim
::
kPrimMakeTuple
&&
prim_
!=
prim
::
kPrimSwitch
)
{
auto
ret_abstract
=
AbstractEval
(
args
);
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"StandardPrimEvaluator eval Undetermined"
;
MS_LOG
(
DEBUG
)
<<
"StandardPrimEvaluator eval Undetermined"
;
...
@@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
...
@@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
EvalResultPtr
DoSignatureEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
EvalResultPtr
DoSignatureEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
()
->
abstract
();
});
auto
ret_abstract
=
AbstractEval
(
args_spec_list
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"StandardPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
if
(
out_conf
->
node
()
==
nullptr
||
!
out_conf
->
node
()
->
isa
<
CNode
>
())
{
if
(
out_conf
->
node
()
==
nullptr
||
!
out_conf
->
node
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Node of out_conf should be CNode"
;
MS_LOG
(
EXCEPTION
)
<<
"Node of out_conf should be CNode"
;
}
}
...
@@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
...
@@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
}
AnfNodePtrList
args_inputs
{
out_node_inputs
.
begin
()
+
1
,
out_node_inputs
.
end
()};
AnfNodePtrList
args_inputs
{
out_node_inputs
.
begin
()
+
1
,
out_node_inputs
.
end
()};
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
()
->
abstract
();
});
ScopePtr
scope
=
kDefaultScope
;
ScopePtr
scope
=
kDefaultScope
;
if
(
out_conf
!=
nullptr
)
{
if
(
out_conf
!=
nullptr
)
{
scope
=
out_conf
->
node
()
->
scope
();
scope
=
out_conf
->
node
()
->
scope
();
...
@@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
...
@@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
// end anonymous namespace
}
// end anonymous namespace
EvalResultPtr
PythonPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
EvalResultPtr
PythonPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
auto
context
=
MsContext
::
GetInstance
();
auto
ret_abstract
=
AbstractEval
(
args
);
MS_EXCEPTION_IF_NULL
(
context
);
if
(
ret_abstract
!=
nullptr
)
{
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
MS_LOG
(
DEBUG
)
<<
"PythonPrimEvaluator eval Undetermined"
;
if
(
enable_sparse_flag
)
{
return
ret_abstract
;
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"PythonPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
}
MS_LOG
(
DEBUG
)
<<
"Eval for:"
<<
prim_py_
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"Eval for:"
<<
prim_py_
->
ToString
();
...
@@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
...
@@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
}
}
EvalResultPtr
UniformPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
EvalResultPtr
UniformPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
auto
context
=
MsContext
::
GetInstance
();
auto
ret_abstract
=
AbstractEval
(
args
);
MS_EXCEPTION_IF_NULL
(
context
);
if
(
ret_abstract
!=
nullptr
)
{
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
MS_LOG
(
DEBUG
)
<<
"UniformPrimEvaluator eval Undetermined"
;
if
(
enable_sparse_flag
)
{
return
ret_abstract
;
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"UniformPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
}
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if
(
nargs_
!=
args
.
size
())
{
if
(
nargs_
!=
args
.
size
())
{
...
@@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
...
@@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
auto
ret
=
std
::
make_shared
<
AbstractScalar
>
(
type
);
auto
ret
=
std
::
make_shared
<
AbstractScalar
>
(
type
);
auto
ref_value
=
ref_abs
->
ref
();
auto
ref_value
=
ref_abs
->
ref
();
MS_EXCEPTION_IF_NULL
(
ref_value
);
MS_EXCEPTION_IF_NULL
(
ref_value
);
ret
->
set_sparse_grad
(
ref_value
->
sparse_grad
());
ret
->
set_has_indexed_slices_grad
(
ref_value
->
has_indexed_slices_grad
());
return
std
::
make_shared
<
EvalResult
>
(
ret
,
std
::
make_shared
<
AttrValueMap
>
());
return
std
::
make_shared
<
EvalResult
>
(
ret
,
std
::
make_shared
<
AttrValueMap
>
());
}
}
...
@@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
...
@@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x
=
SensitivityTransform
(
x
);
x
=
SensitivityTransform
(
x
);
std
::
shared_ptr
<
SymbolicKeyInstance
>
key
=
std
::
make_shared
<
SymbolicKeyInstance
>
(
node
,
x
);
std
::
shared_ptr
<
SymbolicKeyInstance
>
key
=
std
::
make_shared
<
SymbolicKeyInstance
>
(
node
,
x
);
std
::
shared_ptr
<
AbstractScalar
>
abs_scalar
=
std
::
make_shared
<
AbstractScalar
>
(
key
,
type
);
std
::
shared_ptr
<
AbstractScalar
>
abs_scalar
=
std
::
make_shared
<
AbstractScalar
>
(
key
,
type
);
abs_scalar
->
set_sparse_grad
(
x
->
sparse_grad
());
abs_scalar
->
set_has_indexed_slices_grad
(
x
->
has_indexed_slices_grad
());
return
std
::
make_shared
<
EvalResult
>
(
abs_scalar
,
std
::
make_shared
<
AttrValueMap
>
());
return
std
::
make_shared
<
EvalResult
>
(
abs_scalar
,
std
::
make_shared
<
AttrValueMap
>
());
}
}
};
};
...
@@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
...
@@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT
(
GetAttrEvaluator
,
TransitionPrimEvaluator
);
MS_DECLARE_PARENT
(
GetAttrEvaluator
,
TransitionPrimEvaluator
);
EvalResultPtr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
EvalResultPtr
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
const
ConfigPtr
&
in_conf0
,
const
AnfNodeConfigPtr
&
out_conf
)
override
{
auto
context
=
MsContext
::
GetInstance
();
auto
ret_abstract
=
AbstractEval
(
args_spec_list
);
MS_EXCEPTION_IF_NULL
(
context
);
if
(
ret_abstract
!=
nullptr
)
{
bool
enable_sparse_flag
=
context
->
enable_sparse_flag
();
MS_LOG
(
DEBUG
)
<<
"GetAttrEvaluator eval Undetermined"
;
if
(
enable_sparse_flag
)
{
return
ret_abstract
;
auto
ret_abstract
=
AbstractEval
(
args_spec_list
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"GetAttrEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
}
// Inputs: data, item
// Inputs: data, item
if
(
args_spec_list
.
size
()
!=
2
)
{
if
(
args_spec_list
.
size
()
!=
2
)
{
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
44e74ad5
...
@@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
...
@@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr
InferImplDebug
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplDebug
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
void
InitUndeterminedFromEnv
(
const
std
::
string
&
sparse_shape_types
);
AbstractBasePtr
InferImplMakeIndexedSlices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplMakeIndexedSlices
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
...
...
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
浏览文件 @
44e74ad5
...
@@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
...
@@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
AbstractFunctionPtr
func
=
real_a
->
GetUnique
();
AbstractFunctionPtr
func
=
real_a
->
GetUnique
();
SpecializeStatusCode
errcode
;
SpecializeStatusCode
errcode
;
ScopeGuard
scope_guard
(
node
->
scope
());
ScopeGuard
scope_guard
(
node
->
scope
());
AnfNodePtr
repl
=
BuildSpecializedNodeInner
(
abs
,
func
,
argvals
,
&
errcode
);
AnfNodePtr
repl
=
BuildSpecializedNodeInner
(
node
,
abs
,
func
,
argvals
,
&
errcode
);
if
(
repl
==
nullptr
)
{
if
(
repl
==
nullptr
)
{
if
(
errcode
==
kSpecializeFindUniqueArgvalDead
)
{
if
(
errcode
==
kSpecializeFindUniqueArgvalDead
)
{
const
auto
error_dead_node
=
std
::
make_shared
<
AbstractError
>
(
kDeadNode
,
node
);
const
auto
error_dead_node
=
std
::
make_shared
<
AbstractError
>
(
kDeadNode
,
node
);
...
@@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
...
@@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
return
repl
;
return
repl
;
}
}
AnfNodePtr
FuncGraphSpecializer
::
BuildSpecializedNodeInner
(
const
AbstractBasePtr
&
abs
,
const
AbstractFunctionPtr
&
func
,
AnfNodePtr
FuncGraphSpecializer
::
BuildSpecializedNodeInner
(
const
AnfNodePtr
&
node
,
const
AbstractBasePtr
&
abs
,
const
AbstractFunctionPtr
&
func
,
const
AbstractBasePtrList
&
args
,
const
AbstractBasePtrList
&
args
,
SpecializeStatusCode
*
errcode
)
{
SpecializeStatusCode
*
errcode
)
{
MS_EXCEPTION_IF_NULL
(
abs
);
MS_EXCEPTION_IF_NULL
(
abs
);
...
@@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
...
@@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
AnalysisContextPtr
context
=
real_eval
->
MakeContext
(
engine_
,
argvals
);
AnalysisContextPtr
context
=
real_eval
->
MakeContext
(
engine_
,
argvals
);
MS_LOG
(
DEBUG
)
<<
"Specialize function graph: "
<<
context
->
func_graph
()
->
ToString
()
<<
", args: "
<<
argvals
.
size
()
MS_LOG
(
DEBUG
)
<<
"Specialize function graph: "
<<
context
->
func_graph
()
->
ToString
()
<<
", args: "
<<
argvals
.
size
()
<<
", graph: "
<<
context
->
func_graph
()
->
get_return
()
->
DebugString
();
<<
", graph: "
<<
context
->
func_graph
()
->
get_return
()
->
DebugString
();
if
(
context
->
func_graph
()
->
stub
())
{
MS_LOG
(
DEBUG
)
<<
"Specialize stub function graph, return the original node: "
<<
context
->
func_graph
()
->
ToString
()
<<
", args: "
<<
argvals
.
size
()
<<
", graph: "
<<
context
->
func_graph
()
->
get_return
()
->
DebugString
()
<<
", "
<<
node
->
ToString
();
return
node
;
}
FuncGraphPtr
v
=
specializer_
->
SpecializeFuncGraph
(
context
->
func_graph
(),
context
);
FuncGraphPtr
v
=
specializer_
->
SpecializeFuncGraph
(
context
->
func_graph
(),
context
);
v
->
set_flag
(
kFuncGraphFlagUndetermined
,
false
);
return
BuildValueNode
(
v
,
abs
);
return
BuildValueNode
(
v
,
abs
);
}
}
...
@@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
...
@@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
*
result
=
std
::
make_pair
(
choices
->
begin
()
->
first
,
choices
->
begin
()
->
second
->
abstract
());
*
result
=
std
::
make_pair
(
choices
->
begin
()
->
first
,
choices
->
begin
()
->
second
->
abstract
());
return
kSpecializeSuccess
;
return
kSpecializeSuccess
;
}
else
if
(
choices
->
empty
())
{
}
else
if
(
choices
->
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Find DEAD code, it may be optimized in later phase."
;
MS_LOG
(
DEBUG
)
<<
"Find DEAD code, it may be optimized in later phase "
<<
func
->
ToString
()
<<
" | "
<<
func
->
type_name
();
return
kSpecializeFindUniqueArgvalDead
;
return
kSpecializeFindUniqueArgvalDead
;
}
else
{
}
else
{
if
(
IsPolyFunc
(
func
,
argvals
))
{
if
(
IsPolyFunc
(
func
,
argvals
))
{
...
...
mindspore/ccsrc/pipeline/static_analysis/program_specialize.h
浏览文件 @
44e74ad5
...
@@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
...
@@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// Build a specialized node from given argvals;
// Build a specialized node from given argvals;
AnfNodePtr
BuildSpecializedNode
(
const
AnfNodePtr
&
node
,
const
AbstractBasePtr
&
abs
,
AnfNodePtr
BuildSpecializedNode
(
const
AnfNodePtr
&
node
,
const
AbstractBasePtr
&
abs
,
const
AbstractBasePtrList
&
argvals
);
const
AbstractBasePtrList
&
argvals
);
AnfNodePtr
BuildSpecializedNodeInner
(
const
AbstractBasePtr
&
abs
,
const
AbstractFunctionPtr
&
func
,
AnfNodePtr
BuildSpecializedNodeInner
(
const
AnfNodePtr
&
node
,
const
AbstractBasePtr
&
abs
,
const
AbstractBasePtrList
&
args
,
SpecializeStatusCode
*
errcode
);
const
AbstractFunctionPtr
&
func
,
const
AbstractBasePtrList
&
args
,
SpecializeStatusCode
*
errcode
);
// Find the unique argument values which can be used to specialize a primitive or graph function.
// Find the unique argument values which can be used to specialize a primitive or graph function.
SpecializeStatusCode
FindUniqueArgvals
(
const
AbstractFunctionPtr
&
fn
,
const
EvaluatorPtr
&
eval
,
SpecializeStatusCode
FindUniqueArgvals
(
const
AbstractFunctionPtr
&
fn
,
const
EvaluatorPtr
&
eval
,
...
...
mindspore/ccsrc/utils/context/ms_context.cc
浏览文件 @
44e74ad5
...
@@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
...
@@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
max_device_memory_
=
kDefaultMaxDeviceMemory
;
max_device_memory_
=
kDefaultMaxDeviceMemory
;
print_file_path_
=
""
;
print_file_path_
=
""
;
enable_graph_kernel_
=
false
;
enable_graph_kernel_
=
false
;
enable_sparse_
flag_
=
false
;
enable_sparse_
=
false
;
}
}
std
::
shared_ptr
<
MsContext
>
MsContext
::
GetInstance
()
{
std
::
shared_ptr
<
MsContext
>
MsContext
::
GetInstance
()
{
...
...
mindspore/ccsrc/utils/context/ms_context.h
浏览文件 @
44e74ad5
...
@@ -161,8 +161,8 @@ class MsContext {
...
@@ -161,8 +161,8 @@ class MsContext {
void
set_enable_graph_kernel
(
bool
enable_graph_kernel
)
{
enable_graph_kernel_
=
enable_graph_kernel
;
}
void
set_enable_graph_kernel
(
bool
enable_graph_kernel
)
{
enable_graph_kernel_
=
enable_graph_kernel
;
}
bool
enable_graph_kernel
()
const
{
return
enable_graph_kernel_
;
}
bool
enable_graph_kernel
()
const
{
return
enable_graph_kernel_
;
}
bool
enable_sparse
_flag
()
const
{
return
enable_sparse_flag
_
;
}
bool
enable_sparse
()
const
{
return
enable_sparse
_
;
}
void
set_enable_sparse
_flag
(
bool
enable_sparse_flag
)
{
enable_sparse_flag_
=
enable_sparse_flag
;
}
void
set_enable_sparse
(
bool
enable_sparse
)
{
enable_sparse_
=
enable_sparse
;
}
private:
private:
MsContext
(
const
std
::
string
&
backend_policy
,
const
std
::
string
&
target
);
MsContext
(
const
std
::
string
&
backend_policy
,
const
std
::
string
&
target
);
...
@@ -207,7 +207,7 @@ class MsContext {
...
@@ -207,7 +207,7 @@ class MsContext {
float
max_device_memory_
;
float
max_device_memory_
;
std
::
string
print_file_path_
;
std
::
string
print_file_path_
;
bool
enable_graph_kernel_
;
bool
enable_graph_kernel_
;
bool
enable_sparse_
flag_
;
bool
enable_sparse_
;
};
};
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/common/parameter.py
浏览文件 @
44e74ad5
...
@@ -51,18 +51,13 @@ class Parameter:
...
@@ -51,18 +51,13 @@ class Parameter:
requires_grad (bool): True if the parameter requires gradient. Default: True.
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False.
broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
"""
"""
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
,
def
__init__
(
self
,
default_input
,
name
,
requires_grad
=
True
,
layerwise_parallel
=
False
):
sparse_grad
=
""
,
has_indexed_slices_grad
=
False
):
self
.
_value
=
ParamValue
()
self
.
_value
=
ParamValue
()
self
.
set_parameter_data
(
default_input
)
self
.
set_parameter_data
(
default_input
)
self
.
name
=
name
self
.
name
=
name
self
.
requires_grad
=
requires_grad
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
self
.
layerwise_parallel
=
layerwise_parallel
self
.
sparse_grad
=
sparse_grad
self
.
has_indexed_slices_grad
=
has_indexed_slices_grad
self
.
_is_init
=
False
self
.
_is_init
=
False
self
.
_sliced
=
False
self
.
_sliced
=
False
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
...
@@ -177,28 +172,6 @@ class Parameter:
...
@@ -177,28 +172,6 @@ class Parameter:
raise
TypeError
(
"`requires_grad` parameter must be bool type"
)
raise
TypeError
(
"`requires_grad` parameter must be bool type"
)
self
.
_value
.
requires_grad
=
value
self
.
_value
.
requires_grad
=
value
@
property
def
sparse_grad
(
self
):
"""Return whether the parameter's gradient is sparse."""
return
self
.
_value
.
sparse_grad
@
sparse_grad
.
setter
def
sparse_grad
(
self
,
value
=
""
):
if
not
isinstance
(
value
,
str
):
raise
TypeError
(
"`sparse_grad` parameter must be str type"
)
self
.
_value
.
sparse_grad
=
value
@
property
def
has_indexed_slices_grad
(
self
):
"""Return whether the parameter's gradient is indexed_slices."""
return
self
.
_value
.
has_indexed_slices_grad
@
has_indexed_slices_grad
.
setter
def
has_indexed_slices_grad
(
self
,
value
=
False
):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"`has_indexed_slices_grad` parameter must be bool type"
)
self
.
_value
.
has_indexed_slices_grad
=
value
@
property
@
property
def
data
(
self
):
def
data
(
self
):
return
self
.
default_input
return
self
.
default_input
...
...
mindspore/context.py
浏览文件 @
44e74ad5
...
@@ -367,14 +367,6 @@ class _Context:
...
@@ -367,14 +367,6 @@ class _Context:
def
check_bprop
(
self
,
check_bprop_flag
):
def
check_bprop
(
self
,
check_bprop_flag
):
self
.
_context_handle
.
set_check_bprop_flag
(
check_bprop_flag
)
self
.
_context_handle
.
set_check_bprop_flag
(
check_bprop_flag
)
@
property
def
enable_sparse
(
self
):
return
self
.
_context_handle
.
get_enable_sparse_flag
()
@
enable_sparse
.
setter
def
enable_sparse
(
self
,
enable_sparse_flag
):
self
.
_context_handle
.
set_enable_sparse_flag
(
enable_sparse_flag
)
@
property
@
property
def
max_device_memory
(
self
):
def
max_device_memory
(
self
):
return
self
.
_context_handle
.
get_max_device_memory
()
return
self
.
_context_handle
.
get_max_device_memory
()
...
@@ -408,6 +400,13 @@ class _Context:
...
@@ -408,6 +400,13 @@ class _Context:
full_file_name
=
print_file_path
full_file_name
=
print_file_path
self
.
_context_handle
.
set_print_file_path
(
full_file_name
)
self
.
_context_handle
.
set_print_file_path
(
full_file_name
)
@
property
def
enable_sparse
(
self
):
return
self
.
_context_handle
.
get_enable_sparse
()
@
enable_sparse
.
setter
def
enable_sparse
(
self
,
enable_sparse
):
self
.
_context_handle
.
set_enable_sparse
(
enable_sparse
)
def
check_input_format
(
x
):
def
check_input_format
(
x
):
import
re
import
re
...
@@ -601,7 +600,7 @@ def set_context(**kwargs):
...
@@ -601,7 +600,7 @@ def set_context(**kwargs):
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file.
suffix to the file.
enable_sparse (bool): Whether to enable spars
e
feature. Default: False.
enable_sparse (bool): Whether to enable spars
ity
feature. Default: False.
Raises:
Raises:
ValueError: If input key is not an attribute in context.
ValueError: If input key is not an attribute in context.
...
...
mindspore/nn/optim/adam.py
浏览文件 @
44e74ad5
...
@@ -162,8 +162,8 @@ class Adam(Optimizer):
...
@@ -162,8 +162,8 @@ class Adam(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
and the
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
.
`sparse_grad` of `Parameter` being set.
The sparse feature is under continuous development. The sparse
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
behavior is currently performed on the CPU.
Args:
Args:
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
44e74ad5
...
@@ -72,8 +72,8 @@ class FTRL(Optimizer):
...
@@ -72,8 +72,8 @@ class FTRL(Optimizer):
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note:
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
and the
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
.
`sparse_grad` of `Parameter` being set.
The sparse feature is under continuous development. The sparse
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
behavior is currently performed on the CPU.
Args:
Args:
...
...
mindspore/nn/optim/lazyadam.py
浏览文件 @
44e74ad5
...
@@ -91,8 +91,8 @@ class LazyAdam(Optimizer):
...
@@ -91,8 +91,8 @@ class LazyAdam(Optimizer):
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
and the
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
.
`sparse_grad` of `Parameter` being set.
The sparse behavior, to be notice, is not equivalent to the
The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU.
continuous development. The sparse behavior is currently performed on the CPU.
...
...
mindspore/nn/optim/proximal_ada_grad.py
浏览文件 @
44e74ad5
...
@@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer):
...
@@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer):
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
Note:
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
and the
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network
.
`sparse_grad` of `Parameter` being set as True.
The sparse feature is under continuous development. The sparse
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
behavior is currently performed on the CPU.
Args:
Args:
...
...
mindspore/ops/functional.py
浏览文件 @
44e74ad5
...
@@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices')
...
@@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices')
indexed_slices_get_values
=
Primitive
(
'IndexedSlicesGetValues'
)
indexed_slices_get_values
=
Primitive
(
'IndexedSlicesGetValues'
)
indexed_slices_get_indices
=
Primitive
(
'IndexedSlicesGetIndices'
)
indexed_slices_get_indices
=
Primitive
(
'IndexedSlicesGetIndices'
)
indexed_slices_get_dense_shape
=
Primitive
(
'IndexedSlicesGetDenseShape'
)
indexed_slices_get_dense_shape
=
Primitive
(
'IndexedSlicesGetDenseShape'
)
is_indexed_slices
=
Primitive
(
'IsIndexedSlices'
)
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
tensor_operator_registry
.
register
(
'__add__'
,
tensor_add
)
...
...
tests/ut/python/ir/test_indexed_slices.py
浏览文件 @
44e74ad5
...
@@ -36,6 +36,8 @@ from mindspore._checkparam import Rel
...
@@ -36,6 +36,8 @@ from mindspore._checkparam import Rel
from
mindspore.nn
import
Optimizer
from
mindspore.nn
import
Optimizer
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_sparse
=
True
)
reduce_sum
=
P
.
ReduceSum
()
reduce_sum
=
P
.
ReduceSum
()
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
transpose
=
P
.
Transpose
()
transpose
=
P
.
Transpose
()
...
@@ -44,7 +46,6 @@ reshape = P.Reshape()
...
@@ -44,7 +46,6 @@ reshape = P.Reshape()
size_op
=
P
.
Size
()
size_op
=
P
.
Size
()
invert_permutation
=
P
.
InvertPermutation
()
invert_permutation
=
P
.
InvertPermutation
()
logical_and
=
P
.
LogicalAnd
()
logical_and
=
P
.
LogicalAnd
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_sparse
=
True
)
@
constexpr
@
constexpr
def
_generate_shape_index
(
out_shape
,
indices_shape
,
axis
):
def
_generate_shape_index
(
out_shape
,
indices_shape
,
axis
):
...
@@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self):
...
@@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self):
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Undetermined"
,
"Bool"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"IndexedSlices"
,
"Bool"
)
def
_update_run_op_for_map
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
def
_update_run_op_for_map_indexed_slices
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
if
gradient
.
is_indexed_slices
():
m
,
v
,
gradient
,
decay_flag
):
return
gradient
.
values
()
return
gradient
.
values
()
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_update_run_op_for_map_tensor
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
op_mul
=
P
.
Mul
()
op_mul
=
P
.
Mul
()
op_square
=
P
.
Square
()
op_square
=
P
.
Square
()
op_sqrt
=
P
.
Sqrt
()
op_sqrt
=
P
.
Sqrt
()
...
@@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices():
...
@@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices():
self
.
dense_shape
=
(
3
,
4
)
self
.
dense_shape
=
(
3
,
4
)
def
construct
(
self
,
indices
,
values
):
def
construct
(
self
,
indices
,
values
):
ret
=
(
IndexedSlices
(
indices
,
values
,
self
.
dense_shape
),)
ret
=
(
IndexedSlices
(
indices
,
values
,
self
.
dense_shape
),)
return
ret
[
0
]
.
is_indexed_slices
()
return
ret
[
0
]
indices
=
Tensor
([[
0
,
0
],
[
1
,
2
]])
indices
=
Tensor
([[
0
,
0
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
MakeIndexedSlices
()(
indices
,
values
)
MakeIndexedSlices
()(
indices
,
values
)
...
@@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
...
@@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
self
.
network
=
network
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
def
construct
(
self
,
x
,
y
):
grad
=
grad_all
(
self
.
network
)(
x
,
y
)
grad
=
grad_all
(
self
.
network
)(
x
,
y
)
return
grad
,
grad
[
0
]
.
is_indexed_slices
(),
grad
[
1
].
is_indexed_slices
()
return
grad
,
grad
[
0
]
,
grad
[
1
]
class
SparseGatherV2
(
nn
.
Cell
):
class
SparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
SparseGatherV2
,
self
).
__init__
()
super
(
SparseGatherV2
,
self
).
__init__
()
...
@@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
...
@@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
weights
=
self
.
weights
weights
=
self
.
weights
grad
=
grad_by_list
(
self
.
network
,
weights
)(
x
)
grad
=
grad_by_list
(
self
.
network
,
weights
)(
x
)
x
=
grad
[
0
]
x
=
grad
[
0
]
return
x
.
is_indexed_slices
()
,
x
.
values
(),
x
.
indices
(),
x
.
dense_shape
()
return
x
,
x
.
values
(),
x
.
indices
(),
x
.
dense_shape
()
class
SparseGatherV2
(
nn
.
Cell
):
class
SparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
SparseGatherV2
,
self
).
__init__
()
super
(
SparseGatherV2
,
self
).
__init__
()
self
.
sparse_gatherv2
=
MySparseGatherV2
()
self
.
sparse_gatherv2
=
MySparseGatherV2
()
self
.
axis
=
0
self
.
axis
=
0
self
.
params
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
int32
)),
self
.
params
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
int32
)),
name
=
"params"
)
name
=
"params"
,
has_indexed_slices_grad
=
True
)
def
construct
(
self
,
indices
):
def
construct
(
self
,
indices
):
return
self
.
sparse_gatherv2
(
self
.
params
,
indices
,
self
.
axis
)
return
self
.
sparse_gatherv2
(
self
.
params
,
indices
,
self
.
axis
)
indices
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
indices
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
...
@@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
...
@@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
network
(
indices
)
network
(
indices
)
def
test_indexed_slices_is_indexed_slices
():
class
MakeIndexedSlices
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MakeIndexedSlices
,
self
).
__init__
()
self
.
dense_shape
=
(
3
,
4
)
def
construct
(
self
,
indices
,
values
):
indexed_slices
=
IndexedSlices
(
indices
,
values
,
self
.
dense_shape
)
ret
=
indexed_slices
.
is_indexed_slices
()
return
ret
indices
=
Tensor
([[
0
,
0
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
MakeIndexedSlices
()(
indices
,
values
)
def
test_indexed_slices_env_get
():
def
test_indexed_slices_env_get
():
class
Loss
(
nn
.
Cell
):
class
Loss
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -271,7 +262,7 @@ def test_indexed_slices_env_get():
...
@@ -271,7 +262,7 @@ def test_indexed_slices_env_get():
class
NetWithSparseGatherV2
(
nn
.
Cell
):
class
NetWithSparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
,
has_indexed_slices_grad
=
True
)
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
gatherv2
=
MySparseGatherV2
()
self
.
gatherv2
=
MySparseGatherV2
()
self
.
axis
=
0
self
.
axis
=
0
...
...
tests/ut/python/nn/optim/test_adam.py
浏览文件 @
44e74ad5
...
@@ -17,12 +17,13 @@ import numpy as np
...
@@ -17,12 +17,13 @@ import numpy as np
import
pytest
import
pytest
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
from
mindspore
import
Tensor
,
Parameter
,
context
from
mindspore.common.api
import
_executor
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Adam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
from
mindspore.nn.optim
import
Adam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
enable_sparse
=
True
)
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
""" Net definition """
""" Net definition """
...
@@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell):
...
@@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
""" NetWithSparseGatherV2 definition """
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight1"
)
name
=
"weight1"
,
sparse_grad
=
"sparse_key_w1"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
((
np
.
float32
))),
name
=
"weight2"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
((
np
.
float32
))),
name
=
"weight2"
)
self
.
axis
=
0
self
.
axis
=
0
self
.
gather
=
P
.
SparseGatherV2
()
self
.
gather
=
P
.
SparseGatherV2
()
...
...
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
浏览文件 @
44e74ad5
...
@@ -27,6 +27,7 @@ from mindspore.ops import functional as F
...
@@ -27,6 +27,7 @@ from mindspore.ops import functional as F
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Rel
context
.
set_context
(
enable_sparse
=
True
)
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
...
@@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse():
...
@@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse():
class
NetWithSparseGatherV2
(
nn
.
Cell
):
class
NetWithSparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
,
sparse_grad
=
"sparse_key_w1"
)
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
gatherv2
=
P
.
SparseGatherV2
()
self
.
gatherv2
=
P
.
SparseGatherV2
()
self
.
axis
=
0
self
.
axis
=
0
...
...
tests/ut/python/nn/optim/test_ftrl.py
浏览文件 @
44e74ad5
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
import
numpy
as
np
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
from
mindspore
import
Tensor
,
Parameter
,
context
from
mindspore.common.api
import
_executor
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
FTRL
from
mindspore.nn.optim
import
FTRL
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
enable_sparse
=
True
)
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell):
...
@@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
""" NetWithSparseGatherV2 definition """
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight1"
)
name
=
"weight1"
,
sparse_grad
=
"sparse_key_w1"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
((
np
.
float32
))),
name
=
"weight2"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
((
np
.
float32
))),
name
=
"weight2"
)
self
.
axis
=
0
self
.
axis
=
0
self
.
gather
=
P
.
SparseGatherV2
()
self
.
gather
=
P
.
SparseGatherV2
()
...
...
tests/ut/python/nn/optim/test_lazyadam.py
浏览文件 @
44e74ad5
...
@@ -17,12 +17,13 @@ import numpy as np
...
@@ -17,12 +17,13 @@ import numpy as np
import
pytest
import
pytest
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
from
mindspore
import
Tensor
,
Parameter
,
context
from
mindspore.common.api
import
_executor
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
LazyAdam
from
mindspore.nn.optim
import
LazyAdam
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
enable_sparse
=
True
)
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
""" Net definition """
""" Net definition """
...
@@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell):
...
@@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
""" NetWithSparseGatherV2 definition """
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight1"
)
name
=
"weight1"
,
sparse_grad
=
"sparse_key_w1"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
((
np
.
float32
))),
name
=
"weight2"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
((
np
.
float32
))),
name
=
"weight2"
)
self
.
axis
=
0
self
.
axis
=
0
self
.
gather
=
P
.
SparseGatherV2
()
self
.
gather
=
P
.
SparseGatherV2
()
...
...
tests/ut/python/nn/optim/test_proximal_ada_grad.py
浏览文件 @
44e74ad5
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
import
numpy
as
np
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
from
mindspore
import
Tensor
,
Parameter
,
context
from
mindspore.common.api
import
_executor
from
mindspore.common.api
import
_executor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
ProximalAdagrad
from
mindspore.nn.optim
import
ProximalAdagrad
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
enable_sparse
=
True
)
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell):
...
@@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
""" NetWithSparseGatherV2 definition """
def
__init__
(
self
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight1"
,
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight1"
)
sparse_grad
=
"sparse_key_w1"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight2"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"weight2"
)
self
.
axis
=
0
self
.
axis
=
0
self
.
gather
=
P
.
SparseGatherV2
()
self
.
gather
=
P
.
SparseGatherV2
()
...
...
tests/ut/python/pipeline/infer/test_hypermap_specialize.py
浏览文件 @
44e74ad5
...
@@ -53,4 +53,4 @@ def test_hypermap_specialize_param():
...
@@ -53,4 +53,4 @@ def test_hypermap_specialize_param():
expected_ret
=
(
Tensor
(
np
.
full
(
1
,
5
).
astype
(
np
.
int32
)),
Tensor
(
np
.
full
(
2
,
5
).
astype
(
np
.
int32
)))
expected_ret
=
(
Tensor
(
np
.
full
(
1
,
5
).
astype
(
np
.
int32
)),
Tensor
(
np
.
full
(
2
,
5
).
astype
(
np
.
int32
)))
ret
=
hypermap_specialize_param
()
ret
=
hypermap_specialize_param
()
assert
ret
==
(
expected_ret
,
expected_ret
)
assert
ret
==
(
expected_ret
,
list
(
expected_ret
)
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录