Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e6f82af8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6f82af8
编写于
9月 03, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cell class to c++
上级
879a5191
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
450 addition
and
42 deletion
+450
-42
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+5
-2
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
...ore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
+0
-1
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+18
-5
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
+5
-5
mindspore/ccsrc/pipeline/jit/parse/parse.cc
mindspore/ccsrc/pipeline/jit/parse/parse.cc
+107
-20
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
+3
-0
mindspore/ccsrc/pybind_api/ir/cell_py.cc
mindspore/ccsrc/pybind_api/ir/cell_py.cc
+50
-0
mindspore/ccsrc/pybind_api/ir/cell_py.h
mindspore/ccsrc/pybind_api/ir/cell_py.h
+44
-0
mindspore/core/ir/cell.cc
mindspore/core/ir/cell.cc
+94
-0
mindspore/core/ir/cell.h
mindspore/core/ir/cell.h
+69
-0
mindspore/core/ir/func_graph_extends.cc
mindspore/core/ir/func_graph_extends.cc
+2
-1
mindspore/core/utils/label.cc
mindspore/core/utils/label.cc
+2
-1
mindspore/nn/cell.py
mindspore/nn/cell.py
+47
-3
mindspore/nn/layer/container.py
mindspore/nn/layer/container.py
+2
-2
tests/ut/python/parallel/test_get_parameter_layout.py
tests/ut/python/parallel/test_get_parameter_layout.py
+1
-1
tests/ut/python/parallel/test_split_grad_sens.py
tests/ut/python/parallel/test_split_grad_sens.py
+1
-1
未找到文件。
mindspore/_extends/parse/parser.py
浏览文件 @
e6f82af8
...
@@ -194,9 +194,12 @@ def get_object_key(obj):
...
@@ -194,9 +194,12 @@ def get_object_key(obj):
obj_key
=
"%s_ID"
%
(
str
(
obj
.
__class__
.
__name__
)
+
str
(
obj
.
__name__
)
+
obj
.
cell_init_args
)
obj_key
=
"%s_ID"
%
(
str
(
obj
.
__class__
.
__name__
)
+
str
(
obj
.
__name__
)
+
obj
.
cell_init_args
)
obj_id
=
"%s_ID%d"
%
(
str
(
obj
.
__class__
.
__name__
)
+
str
(
obj
.
__name__
),
id
(
obj
))
obj_id
=
"%s_ID%d"
%
(
str
(
obj
.
__class__
.
__name__
)
+
str
(
obj
.
__name__
),
id
(
obj
))
else
:
else
:
# `<class 'xxxxxxx'>`
# -> `xxxxxxx`
tag
=
str
(
obj
.
__class__
)[
8
:
-
2
]
if
hasattr
(
obj
,
"cell_init_args"
):
if
hasattr
(
obj
,
"cell_init_args"
):
obj_key
=
"%s_ID"
%
(
str
(
obj
.
__class__
.
__name__
)
+
obj
.
cell_init_args
)
obj_key
=
"%s_ID"
%
(
tag
+
obj
.
cell_init_args
)
obj_id
=
"%s_ID%d"
%
(
str
(
obj
.
__class__
.
__name__
)
,
id
(
obj
))
obj_id
=
"%s_ID%d"
%
(
tag
,
id
(
obj
))
logger
.
debug
(
"obj_key %s obj_id = %s"
,
obj_key
,
obj_id
)
logger
.
debug
(
"obj_key %s obj_id = %s"
,
obj_key
,
obj_id
)
# method has same id of different instance
# method has same id of different instance
...
...
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
浏览文件 @
e6f82af8
...
@@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor {
...
@@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor {
}
}
}
}
// (void)mng->Replace(new_fg_parameters[param_i], new_param);
new_parameters
.
push_back
(
new_param
);
new_parameters
.
push_back
(
new_param
);
curr_input_idx
++
;
curr_input_idx
++
;
}
}
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
e6f82af8
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "ir/func_graph_cloner.h"
#include "ir/func_graph_cloner.h"
#include "ir/param_info.h"
#include "ir/param_info.h"
#include "ir/cell.h"
#include "frontend/parallel/costmodel_context.h"
#include "frontend/parallel/costmodel_context.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/context.h"
#include "pipeline/jit/pass.h"
#include "pipeline/jit/pass.h"
...
@@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) {
...
@@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) {
parse
::
python_adapter
::
set_python_env_flag
(
true
);
parse
::
python_adapter
::
set_python_env_flag
(
true
);
parse
::
python_adapter
::
SetPythonPath
(
dir
);
parse
::
python_adapter
::
SetPythonPath
(
dir
);
FuncGraphPtr
fg
=
parse
::
ConvertToFuncGraph
(
input
);
ValuePtr
converted_ret
=
nullptr
;
if
(
fg
==
nullptr
)
{
bool
converted
=
parse
::
ConvertData
(
input
,
&
converted_ret
,
true
);
MS_LOG
(
EXCEPTION
)
<<
"Parse error."
;
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attribute convert error with type:"
<<
std
::
string
(
py
::
str
(
input
));
}
}
res
->
set_func_graph
(
fg
);
FuncGraphPtr
top_graph
=
nullptr
;
if
(
py
::
isinstance
<
Cell
>
(
input
))
{
top_graph
=
parse
::
MakeTopGraph
(
input
,
converted_ret
);
}
else
if
(
converted_ret
->
isa
<
FuncGraph
>
())
{
top_graph
=
converted_ret
->
cast
<
FuncGraphPtr
>
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Object to parse "
<<
std
::
string
(
py
::
str
(
input
))
<<
" is not function or cell."
;
}
parse
::
Parser
::
UpdateTopFuncGraph
(
top_graph
);
res
->
set_func_graph
(
top_graph
);
FuncGraphManagerPtr
manager
=
res
->
manager
();
FuncGraphManagerPtr
manager
=
res
->
manager
();
if
(
manager
==
nullptr
)
{
if
(
manager
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Manager is nullptr."
;
MS_LOG
(
EXCEPTION
)
<<
"Manager is nullptr."
;
}
}
manager
->
AddFuncGraph
(
fg
);
manager
->
AddFuncGraph
(
top_graph
);
return
true
;
return
true
;
}
}
...
...
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
浏览文件 @
e6f82af8
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "frontend/operator/ops.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/composite/composite.h"
#include "ir/func_graph_cloner.h"
#include "ir/func_graph_cloner.h"
#include "ir/cell.h"
#include "utils/symbolic.h"
#include "utils/symbolic.h"
#include "utils/ms_context.h"
#include "utils/ms_context.h"
...
@@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
...
@@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
return
true
;
return
true
;
}
}
bool
ConvertCellObjToFuncGraph
(
py
::
object
obj
,
ValuePtr
*
const
data
)
{
bool
ConvertCellObjToFuncGraph
(
const
CellPtr
&
cell
,
ValuePtr
*
const
data
)
{
auto
obj
=
py
::
cast
(
cell
);
FuncGraphPtr
func_graph
=
ConvertToFuncGraph
(
obj
);
FuncGraphPtr
func_graph
=
ConvertToFuncGraph
(
obj
);
if
(
func_graph
==
nullptr
)
{
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Parse resolve function error."
;
MS_LOG
(
ERROR
)
<<
"Parse resolve function error."
;
...
@@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
...
@@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
if
(
obj_type
==
RESOLVE_TYPE_CLASS_INSTANCE
)
{
if
(
obj_type
==
RESOLVE_TYPE_CLASS_INSTANCE
)
{
// Create the namespace for common class instance
// Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct'
// When the obj is Cell, default parse the 'construct'
if
(
data_converter
::
IsCellInstance
(
obj
))
{
return
ConvertCellObjToFuncGraph
(
obj
,
data
);
}
py
::
module
mod
=
python_adapter
::
GetPyModule
(
PYTHON_MOD_PARSE_MODULE
);
py
::
module
mod
=
python_adapter
::
GetPyModule
(
PYTHON_MOD_PARSE_MODULE
);
py
::
object
namespace_var
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
,
obj
);
py
::
object
namespace_var
=
python_adapter
::
CallPyModFn
(
mod
,
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
,
obj
);
*
data
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
namespace_var
);
*
data
=
std
::
make_shared
<
NameSpace
>
(
RESOLVE_NAMESPACE_NAME_CLASS_MEMBER
,
namespace_var
);
...
@@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
...
@@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
ret
=
ConvertTuple
(
obj
,
&
converted
,
use_signature
);
ret
=
ConvertTuple
(
obj
,
&
converted
,
use_signature
);
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_CELL_AS_LIST
))
{
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_CELL_AS_LIST
))
{
ret
=
ConvertCellList
(
obj
,
&
converted
,
use_signature
);
ret
=
ConvertCellList
(
obj
,
&
converted
,
use_signature
);
}
else
if
(
py
::
isinstance
<
Cell
>
(
obj
))
{
return
ConvertCellObjToFuncGraph
(
obj
.
cast
<
CellPtr
>
(),
data
);
}
else
if
(
py
::
isinstance
<
py
::
list
>
(
obj
))
{
}
else
if
(
py
::
isinstance
<
py
::
list
>
(
obj
))
{
ret
=
ConvertList
(
obj
,
&
converted
,
use_signature
);
ret
=
ConvertList
(
obj
,
&
converted
,
use_signature
);
}
else
if
(
py
::
isinstance
<
py
::
module
>
(
obj
))
{
}
else
if
(
py
::
isinstance
<
py
::
module
>
(
obj
))
{
...
...
mindspore/ccsrc/pipeline/jit/parse/parse.cc
浏览文件 @
e6f82af8
...
@@ -140,34 +140,80 @@ void Parser::CleanParserResource() {
...
@@ -140,34 +140,80 @@ void Parser::CleanParserResource() {
ScopeManager
::
GetInstance
().
ClearScope
();
ScopeManager
::
GetInstance
().
ClearScope
();
}
}
FuncGraphPtr
Parser
::
ParseFuncGraph
()
{
AnfNodePtr
AppendParameterObj
(
const
FuncGraphPtr
&
func_graph
,
const
py
::
object
&
obj
)
{
// get ast FunctionDef node
MS_EXCEPTION_IF_NULL
(
func_graph
);
py
::
object
node
=
ast_
->
GetAstNode
();
auto
value
=
py
::
cast
<
tensor
::
MetaTensorPtr
>
(
obj
);
FunctionBlockPtr
pFnBlock
=
ParseFunction
(
node
);
// parameter object should not be none
if
(
errcode
()
!=
PARSE_SUCCESS
)
{
if
(
value
==
nullptr
||
!
value
->
is_parameter
())
{
MS_LOG
(
ERROR
)
<<
"Parse function error, code is "
<<
errcode
();
MS_LOG
(
EXCEPTION
)
<<
"Parameter error: because obj is not Parameter object."
;
return
nullptr
;
}
// get the parameter name from parameter object
auto
param_name
=
value
->
param_info
()
->
name
();
auto
top_graph
=
func_graph
;
// if the parameter node has been created , return it
AnfNodePtr
para_node
=
nullptr
;
for
(
auto
param
:
top_graph
->
parameters
())
{
auto
param_node
=
dyn_cast
<
Parameter
>
(
param
);
if
(
param_node
!=
nullptr
&&
param_node
->
name
()
==
param_name
)
{
para_node
=
param
;
break
;
}
}
}
if
(
para_node
==
nullptr
)
{
auto
node
=
top_graph
->
AddWeightParameter
(
param_name
);
RemoveUnnecessaryPhis
();
node
->
set_default_param
(
value
);
// set_abstract for parameter
auto
abs
=
value
->
ToAbstract
();
// boarden value
abs
=
abs
->
Broaden
();
node
->
set_abstract
(
abs
);
para_node
=
node
;
}
return
para_node
;
}
MS_EXCEPTION_IF_NULL
(
pFnBlock
);
void
UpdataParam
(
const
FuncGraphPtr
&
top_graph
,
const
py
::
object
&
cell
)
{
auto
params
=
py
::
list
(
cell
.
attr
(
"get_parameters"
)()).
cast
<
std
::
vector
<
py
::
object
>>
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
i
++
)
{
(
void
)
AppendParameterObj
(
top_graph
,
params
[
i
]);
}
}
void
CheckFuncReturn
(
const
FuncGraphPtr
&
fn
,
const
std
::
shared_ptr
<
ParseAst
>
&
ast
)
{
// check whether the functions refered by this function and itself are missing 'return' statement
// check whether the functions refered by this function and itself are missing 'return' statement
auto
mng
=
Manage
(
pFnBlock
->
func_graph
()
,
false
);
auto
mng
=
Manage
(
fn
,
false
);
for
(
auto
func_graph
:
mng
->
func_graphs
())
{
for
(
auto
func_graph
:
mng
->
func_graphs
())
{
if
(
func_graph
->
get_return
()
!=
nullptr
)
{
if
(
func_graph
->
get_return
()
!=
nullptr
)
{
continue
;
continue
;
}
}
py
::
list
ret
=
ast_
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
py
::
object
node
=
ast
->
GetAstNode
();
py
::
list
ret
=
ast
->
CallParserObjMethod
(
PYTHON_PARSE_GET_LOCATION
,
node
);
py
::
str
desc
=
py
::
str
desc
=
python_adapter
::
CallPyModFn
(
ast
_
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast_
->
function
(),
ret
[
0
],
ret
[
1
]);
python_adapter
::
CallPyModFn
(
ast
->
module
(),
PYTHON_MOD_GET_OBJECT_DESCRIPTION
,
ast
->
function
(),
ret
[
0
],
ret
[
1
]);
MS_EXCEPTION
(
TypeError
)
<<
"Missing return statement in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
MS_EXCEPTION
(
TypeError
)
<<
"Missing return statement in "
<<
desc
.
cast
<
std
::
string
>
()
<<
"."
;
}
}
// clear manager info after checking missing return
// clear manager info after checking missing return
for
(
auto
fg
:
mng
->
func_graphs
())
{
for
(
auto
fg
:
mng
->
func_graphs
())
{
fg
->
ClearAllManagerInfo
();
fg
->
ClearAllManagerInfo
();
}
}
}
FuncGraphPtr
Parser
::
ParseFuncGraph
()
{
// get ast FunctionDef node
py
::
object
node
=
ast_
->
GetAstNode
();
FunctionBlockPtr
pFnBlock
=
ParseFunction
(
node
);
if
(
errcode
()
!=
PARSE_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Parse function error, code is "
<<
errcode
();
return
nullptr
;
}
RemoveUnnecessaryPhis
();
MS_EXCEPTION_IF_NULL
(
pFnBlock
);
CheckFuncReturn
(
pFnBlock
->
func_graph
(),
ast_
);
return
pFnBlock
->
func_graph
();
return
pFnBlock
->
func_graph
();
}
}
...
@@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
...
@@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
return
GenerateAnfNodeForCall
(
block
,
call_function_anf_node
,
packed_arguments
,
group_arguments
,
need_unpack
);
return
GenerateAnfNodeForCall
(
block
,
call_function_anf_node
,
packed_arguments
,
group_arguments
,
need_unpack
);
}
}
AnfNodePtr
Parser
::
GenerateAnfNodeForCall
(
const
FunctionBlockPtr
&
block
,
const
AnfNodePtr
&
call_function_anf_node
,
CNodePtr
MakeUnpackCall
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
call_function_anf_node
,
const
std
::
vector
<
AnfNodePtr
>
&
packed_arguments
,
const
std
::
vector
<
AnfNodePtr
>
&
packed_arguments
)
{
const
std
::
vector
<
AnfNodePtr
>
&
group_arguments
,
bool
need_unpack
)
const
{
// if there is keyword arguments or starred, using an unpack_call op to unpack the argument
if
(
need_unpack
)
{
std
::
vector
<
AnfNodePtr
>
unpack_call_nodes
;
std
::
vector
<
AnfNodePtr
>
unpack_call_nodes
;
auto
unpack_call_op
=
NewValueNode
(
std
::
make_shared
<
prim
::
UnpackCall
>
(
NAMED_METAGRAPH_UNPACKCALL
));
auto
unpack_call_op
=
NewValueNode
(
std
::
make_shared
<
prim
::
UnpackCall
>
(
NAMED_METAGRAPH_UNPACKCALL
));
unpack_call_nodes
.
push_back
(
unpack_call_op
);
unpack_call_nodes
.
push_back
(
unpack_call_op
);
unpack_call_nodes
.
push_back
(
call_function_anf_node
);
unpack_call_nodes
.
push_back
(
call_function_anf_node
);
(
void
)
std
::
transform
(
packed_arguments
.
begin
(),
packed_arguments
.
end
(),
std
::
back_inserter
(
unpack_call_nodes
),
(
void
)
std
::
transform
(
packed_arguments
.
begin
(),
packed_arguments
.
end
(),
std
::
back_inserter
(
unpack_call_nodes
),
[](
AnfNodePtr
node
)
->
AnfNodePtr
{
return
node
;
});
[](
AnfNodePtr
node
)
->
AnfNodePtr
{
return
node
;
});
CNodePtr
unpack_call
=
block
->
func_graph
()
->
NewCNode
(
unpack_call_nodes
);
CNodePtr
unpack_call
=
func_graph
->
NewCNode
(
unpack_call_nodes
);
return
unpack_call
;
return
unpack_call
;
}
AnfNodePtr
Parser
::
GenerateAnfNodeForCall
(
const
FunctionBlockPtr
&
block
,
const
AnfNodePtr
&
call_function_anf_node
,
const
std
::
vector
<
AnfNodePtr
>
&
packed_arguments
,
const
std
::
vector
<
AnfNodePtr
>
&
group_arguments
,
bool
need_unpack
)
const
{
// if there is keyword arguments or starred, using an unpack_call op to unpack the argument
if
(
need_unpack
)
{
return
MakeUnpackCall
(
block
->
func_graph
(),
call_function_anf_node
,
packed_arguments
);
}
}
// else there is no keyword arguments and starred, parsed as normal arguments without unpack
// else there is no keyword arguments and starred, parsed as normal arguments without unpack
std
::
vector
<
AnfNodePtr
>
func_call_nodes
;
std
::
vector
<
AnfNodePtr
>
func_call_nodes
;
...
@@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
...
@@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
return
true
;
return
true
;
}
}
FuncGraphPtr
MakeTopGraph
(
const
py
::
object
&
cell
,
const
ValuePtr
&
cell_ptr
)
{
auto
func_graph
=
std
::
make_shared
<
FuncGraph
>
();
func_graph
->
debug_info
()
->
set_name
(
"top"
);
// def top(*arg, *kwargs):
auto
param_vargs
=
func_graph
->
add_parameter
();
auto
args_name
=
"args"
;
param_vargs
->
set_name
(
args_name
);
param_vargs
->
debug_info
()
->
set_name
(
args_name
);
auto
param_vkwargs
=
func_graph
->
add_parameter
();
args_name
=
"kwargs"
;
param_vkwargs
->
set_name
(
args_name
);
param_vkwargs
->
debug_info
()
->
set_name
(
args_name
);
func_graph
->
set_has_vararg
(
true
);
func_graph
->
set_has_kwarg
(
true
);
func_graph
->
set_kwonlyargs_count
(
0
);
// cell_obj
parse
::
UpdateFuncGraphFlags
(
cell
,
func_graph
);
// top graph's construct flag
if
(
py
::
hasattr
(
cell
,
"construct"
))
{
parse
::
UpdateFuncGraphFlags
(
cell
.
attr
(
"construct"
),
func_graph
);
}
UpdataParam
(
func_graph
,
cell
);
// ret = cell_obj(*arg, *kwargs)
auto
call_fn
=
MakeUnpackCall
(
func_graph
,
NewValueNode
(
cell_ptr
),
{
param_vargs
,
param_vkwargs
});
// return ret
func_graph
->
set_output
(
call_fn
);
MS_LOG
(
DEBUG
)
<<
"add Flag for "
<<
std
::
string
(
py
::
str
(
cell
));
return
func_graph
;
}
}
// namespace parse
}
// namespace parse
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
浏览文件 @
e6f82af8
...
@@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
...
@@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
// Parse the python object to graph
// Parse the python object to graph
FuncGraphPtr
ParsePythonCode
(
const
py
::
object
&
obj
,
FuncGraphPtr
ParsePythonCode
(
const
py
::
object
&
obj
,
const
std
::
string
&
python_mod_get_parse_method
=
PYTHON_MOD_GET_PARSE_METHOD
);
const
std
::
string
&
python_mod_get_parse_method
=
PYTHON_MOD_GET_PARSE_METHOD
);
// add wrap for cell top graph.
FuncGraphPtr
MakeTopGraph
(
const
py
::
object
&
cell
,
const
ValuePtr
&
cell_ptr
);
}
// namespace parse
}
// namespace parse
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/pybind_api/ir/cell_py.cc
0 → 100644
浏览文件 @
e6f82af8
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind_api/ir/cell_py.h"
#include <string>
#include "pybind_api/api_register.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/parse/python_adapter.h"
namespace
mindspore
{
void
CellPy
::
AddAttr
(
CellPtr
cell
,
const
std
::
string
&
name
,
const
py
::
object
&
obj
)
{
std
::
string
attr_name
=
name
;
ValuePtr
converted_ret
=
nullptr
;
if
(
py
::
isinstance
<
py
::
module
>
(
obj
))
{
MS_LOG
(
EXCEPTION
)
<<
"Cell set_attr failed, attr should not be py::module"
;
}
bool
converted
=
parse
::
ConvertData
(
obj
,
&
converted_ret
,
true
);
if
(
!
converted
)
{
MS_LOG
(
DEBUG
)
<<
"Attribute convert error with type: "
<<
std
::
string
(
py
::
str
(
obj
));
}
else
{
MS_LOG
(
DEBUG
)
<<
cell
->
ToString
()
<<
" add attr "
<<
attr_name
<<
converted_ret
->
ToString
();
cell
->
AddAttr
(
attr_name
,
converted_ret
);
}
}
// Define python 'Cell' class.
REGISTER_PYBIND_DEFINE
(
Cell
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Cell
,
std
::
shared_ptr
<
Cell
>>
(
*
m
,
"Cell_"
)
.
def
(
py
::
init
<
std
::
string
&>
())
.
def
(
"__str__"
,
&
Cell
::
ToString
)
.
def
(
"_add_attr"
,
&
CellPy
::
AddAttr
,
"Add Cell attr."
)
.
def
(
"_del_attr"
,
&
Cell
::
DelAttr
,
"Delete Cell attr."
)
.
def
(
"construct"
,
[]()
{
MS_LOG
(
EXCEPTION
)
<<
"we should define `construct` for all `cell`."
;
},
"construct"
);
}));
}
// namespace mindspore
mindspore/ccsrc/pybind_api/ir/cell_py.h
0 → 100644
浏览文件 @
e6f82af8
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_
#define MINDSPORE_CCSRC_UTILS_CELL_PY_H_
#include <memory>
#include <string>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "ir/cell.h"
namespace
py
=
pybind11
;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace
mindspore
{
// Cell python wrapper and adapter class.
class
CellPy
{
public:
static
void
AddAttr
(
CellPtr
cell
,
const
std
::
string
&
name
,
const
py
::
object
&
obj
);
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_
mindspore/core/ir/cell.cc
0 → 100644
浏览文件 @
e6f82af8
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/cell.h"
#include <utility>
#include <map>
#include <algorithm>
#include "abstract/abstract_value.h"
namespace
mindspore
{
using
mindspore
::
abstract
::
AbstractFunction
;
abstract
::
AbstractBasePtr
Cell
::
ToAbstract
()
{
/*
std::vector<abstract::AbstractAttribute> abs_attrs;
std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs),
[](std::pair<std::string, ValuePtr> attr) -> abstract::AbstractAttribute {
return std::make_pair(attr.first, attr.second->ToAbstract());
});
auto abs = std::make_shared<abstract::AbstractCell>(shared_from_base<Named>(), abs_attrs);
abs->set_value(shared_from_base<Value>());
return abs;
*/
return
nullptr
;
}
bool
Cell
::
operator
==
(
const
Value
&
other
)
const
{
if
(
other
.
isa
<
Cell
>
())
{
auto
other_prim
=
static_cast
<
const
Cell
&>
(
other
);
return
*
this
==
other_prim
;
}
else
{
return
false
;
}
}
bool
Cell
::
operator
==
(
const
Cell
&
other
)
const
{
if
(
name
()
!=
other
.
name
())
{
return
false
;
}
if
(
attrs_
.
size
()
!=
other
.
attrs_
.
size
())
{
return
false
;
}
auto
all
=
std
::
all_of
(
attrs_
.
begin
(),
attrs_
.
end
(),
[
&
other
](
const
std
::
pair
<
std
::
string
,
ValuePtr
>
&
item
)
->
bool
{
if
(
item
.
second
==
nullptr
)
{
return
false
;
}
auto
iter
=
other
.
attrs_
.
find
(
item
.
first
);
if
(
iter
==
other
.
attrs_
.
end
())
{
return
false
;
}
return
*
item
.
second
==
*
iter
->
second
;
});
return
all
;
}
std
::
string
Cell
::
GetAttrString
()
const
{
std
::
ostringstream
buffer
;
bool
begin
=
true
;
buffer
<<
"{"
<<
std
::
endl
;
for
(
auto
&
attr
:
attrs_
)
{
if
(
!
begin
)
{
buffer
<<
", "
<<
std
::
endl
;
}
else
{
begin
=
false
;
}
buffer
<<
attr
.
first
<<
":"
<<
attr
.
second
->
ToString
();
}
buffer
<<
"}"
;
return
buffer
.
str
();
}
std
::
string
Cell
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
"Cell "
<<
name
();
return
buffer
.
str
();
}
void
Cell
::
DelAttr
(
const
std
::
string
&
name
)
{
attrs_
.
erase
(
name
);
}
}
// namespace mindspore
mindspore/core/ir/cell.h
0 → 100644
浏览文件 @
e6f82af8
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_CELL_H_
#define MINDSPORE_CCSRC_IR_CELL_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include "abstract/abstract_value.h"
#include "utils/misc.h"
namespace
mindspore
{
using
abstract
::
AbstractBasePtr
;
using
abstract
::
AbstractBasePtrList
;
// value for Cell
class
Cell
:
public
Named
{
public:
explicit
Cell
(
const
std
::
string
&
name
)
:
Named
(
name
)
{}
MS_DECLARE_PARENT
(
Cell
,
Named
);
abstract
::
AbstractBasePtr
ToAbstract
()
override
;
std
::
string
ToString
()
const
override
;
std
::
string
GetAttrString
()
const
;
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
void
set_attrs
(
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs_input
)
{
attrs_
=
attrs_input
;
}
void
AddAttr
(
const
std
::
string
&
name
,
const
ValuePtr
&
attr
)
{
attrs_
[
name
]
=
attr
;
}
void
DelAttr
(
const
std
::
string
&
name
);
ValuePtr
GetAttr
(
const
std
::
string
&
attr_name
)
const
{
auto
iter
=
attrs_
.
find
(
attr_name
);
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
bool
HasAttr
(
const
std
::
string
&
attr_name
)
const
{
auto
iter
=
attrs_
.
find
(
attr_name
);
return
!
(
iter
==
attrs_
.
cend
());
}
bool
operator
==
(
const
Value
&
other
)
const
override
;
bool
operator
==
(
const
Cell
&
other
)
const
;
~
Cell
()
override
=
default
;
const
bool
parse_info_
=
true
;
private:
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attrs_
;
};
using
CellPtr
=
std
::
shared_ptr
<
Cell
>
;
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_CELL_H_
mindspore/core/ir/func_graph_extends.cc
浏览文件 @
e6f82af8
...
@@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
...
@@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
MS_LOG
(
EXCEPTION
)
<<
"Function:"
<<
this
->
ToString
()
<<
", variable_args_count "
<<
variable_args_count
MS_LOG
(
EXCEPTION
)
<<
"Function:"
<<
this
->
ToString
()
<<
", variable_args_count "
<<
variable_args_count
<<
" were given."
;
<<
" were given."
;
}
}
auto
varg_name
=
specialized_graph
->
GetVariableArgName
();
// for python variable argument input , there is no upper limit
// for python variable argument input , there is no upper limit
for
(
int
i
=
0
;
i
<
variable_args_count
;
++
i
)
{
for
(
int
i
=
0
;
i
<
variable_args_count
;
++
i
)
{
ParameterPtr
p
=
std
::
make_shared
<
Parameter
>
(
specialized_graph
);
ParameterPtr
p
=
std
::
make_shared
<
Parameter
>
(
specialized_graph
);
std
::
string
param_name
=
specialized_graph
->
GetVariableArgName
()
+
std
::
to_string
(
i
);
std
::
string
param_name
=
varg_name
+
std
::
to_string
(
i
);
p
->
set_name
(
param_name
);
p
->
set_name
(
param_name
);
MS_EXCEPTION_IF_NULL
(
p
->
debug_info
());
MS_EXCEPTION_IF_NULL
(
p
->
debug_info
());
p
->
debug_info
()
->
set_name
(
param_name
);
p
->
debug_info
()
->
set_name
(
param_name
);
...
...
mindspore/core/utils/label.cc
浏览文件 @
e6f82af8
...
@@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe
...
@@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe
while
(
temp_info
!=
nullptr
)
{
while
(
temp_info
!=
nullptr
)
{
if
(
temp_info
->
trace_info
()
!=
nullptr
)
{
if
(
temp_info
->
trace_info
()
!=
nullptr
)
{
if
(
temp_info
->
trace_info
()
->
isa
<
TraceResolve
>
()
||
temp_info
->
trace_info
()
->
isa
<
TraceExpandJ
>
()
||
if
(
temp_info
->
trace_info
()
->
isa
<
TraceResolve
>
()
||
temp_info
->
trace_info
()
->
isa
<
TraceExpandJ
>
()
||
temp_info
->
trace_info
()
->
isa
<
TraceGenMetaFuncGraph
>
())
{
temp_info
->
trace_info
()
->
isa
<
TraceGenMetaFuncGraph
>
()
||
temp_info
->
trace_info
()
->
isa
<
TraceGenerateVarArg
>
()
||
temp_info
->
trace_info
()
->
isa
<
TraceGenerateKwArg
>
())
{
break
;
break
;
}
}
trace_name
.
trace_labels
.
push_back
(
GetTraceName
(
temp_info
->
trace_info
(),
trace_label
));
trace_name
.
trace_labels
.
push_back
(
GetTraceName
(
temp_info
->
trace_info
(),
trace_label
));
...
...
mindspore/nn/cell.py
浏览文件 @
e6f82af8
...
@@ -24,14 +24,14 @@ from ..common import dtype as mstype
...
@@ -24,14 +24,14 @@ from ..common import dtype as mstype
from
..common.api
import
_executor
,
_pynative_exec
from
..common.api
import
_executor
,
_pynative_exec
from
.._checkparam
import
_check_str_by_regular
from
.._checkparam
import
_check_str_by_regular
from
..common.parameter
import
Parameter
,
ParameterTuple
from
..common.parameter
import
Parameter
,
ParameterTuple
from
.._c_expression
import
init_backend
from
.._c_expression
import
init_backend
,
Cell_
from
..ops.primitive
import
Primitive
from
..ops.primitive
import
Primitive
from
..ops.operations
import
HookBackward
from
..ops.operations
import
HookBackward
from
..ops.functional
import
cast
from
..ops.functional
import
cast
from
..parallel._tensor
import
_load_tensor_by_layout
from
..parallel._tensor
import
_load_tensor_by_layout
from
..common.tensor
import
Tensor
from
..common.tensor
import
Tensor
class
Cell
:
class
Cell
(
Cell_
)
:
"""
"""
Base class for all neural networks.
Base class for all neural networks.
...
@@ -58,14 +58,21 @@ class Cell:
...
@@ -58,14 +58,21 @@ class Cell:
>>> def construct(self, x):
>>> def construct(self, x):
>>> return self.relu(x)
>>> return self.relu(x)
"""
"""
IGNORE_LIST
=
[
'_scope'
,
'_cell_init_args'
,
'_auto_prefix'
,
'_cells'
,
'_params'
,
'_construct_inputs_names'
,
'_construct_inputs_num'
,
'_create_time'
,
'_mindspore_flags'
,
'_parallel_inputs_run'
,
'_parameter_layout_dict'
,
'_already_run'
,
'_params_list'
,
'_phase'
,
'_auto_parallel_mode'
,
'_backward_hook'
,
'_bprop_debug'
,
'_is_run'
,
'_param_prefix'
,
'_attr_synced'
,
'enable_hook'
,
'pynative'
,
'requires_grad'
,
'_auto_parallel_compile_and_run'
,
'cell_type'
]
def
__init__
(
self
,
auto_prefix
=
True
,
flags
=
None
):
def
__init__
(
self
,
auto_prefix
=
True
,
flags
=
None
):
Cell_
.
__init__
(
self
,
self
.
_cell_tag
)
self
.
_params
=
OrderedDict
()
self
.
_params
=
OrderedDict
()
self
.
_cells
=
OrderedDict
()
self
.
_cells
=
OrderedDict
()
self
.
_params_list
=
OrderedDict
()
self
.
_params_list
=
OrderedDict
()
self
.
training
=
False
self
.
training
=
False
self
.
requires_grad
=
False
self
.
requires_grad
=
False
self
.
pynative
=
False
self
.
pynative
=
False
self
.
_attr_synced
=
False
self
.
_param_prefix
=
''
self
.
_param_prefix
=
''
self
.
_auto_prefix
=
auto_prefix
self
.
_auto_prefix
=
auto_prefix
self
.
_scope
=
None
self
.
_scope
=
None
...
@@ -92,6 +99,12 @@ class Cell:
...
@@ -92,6 +99,12 @@ class Cell:
def
already_run
(
self
):
def
already_run
(
self
):
return
self
.
_already_run
return
self
.
_already_run
@
property
def
_cell_tag
(
self
):
# `<class 'xxxxxxx'>`
# -> `xxxxxxx`
return
str
(
self
.
__class__
)[
8
:
-
2
]
@
already_run
.
setter
@
already_run
.
setter
def
already_run
(
self
,
value
):
def
already_run
(
self
,
value
):
self
.
_already_run
=
value
self
.
_already_run
=
value
...
@@ -222,6 +235,7 @@ class Cell:
...
@@ -222,6 +235,7 @@ class Cell:
del
self
.
_cells
[
name
]
del
self
.
_cells
[
name
]
else
:
else
:
object
.
__delattr__
(
self
,
name
)
object
.
__delattr__
(
self
,
name
)
self
.
_attr_synced
=
False
def
cast_inputs
(
self
,
inputs
,
dst_type
):
def
cast_inputs
(
self
,
inputs
,
dst_type
):
res
=
list
()
res
=
list
()
...
@@ -277,6 +291,34 @@ class Cell:
...
@@ -277,6 +291,34 @@ class Cell:
self
.
_already_run
=
True
self
.
_already_run
=
True
return
output
return
output
def
_add_attr
(
self
,
name
,
value
):
if
name
and
name
[:
2
]
!=
'__'
and
name
not
in
Cell
.
IGNORE_LIST
:
super
(
Cell
,
self
).
_add_attr
(
name
,
value
)
def
_sync_attr_for_compile
(
self
):
"""Sync the attr to c++ object."""
if
self
.
_attr_synced
:
return
cells
=
self
.
__dict__
.
get
(
'_cells'
)
for
key
in
cells
:
cell
=
cells
[
key
]
cell
.
_sync_attr_for_compile
()
self
.
_add_attr
(
key
,
cell
)
params
=
self
.
__dict__
.
get
(
'_params'
)
for
key
in
params
:
if
'.'
in
key
:
continue
param
=
params
[
key
]
self
.
_add_attr
(
key
,
param
)
params_list
=
self
.
__dict__
.
get
(
'_params_list'
)
for
key
in
params_list
:
params_list_item
=
params_list
[
key
]
self
.
_add_attr
(
key
,
params_list_item
)
for
key
in
self
.
__dict__
:
value
=
self
.
__dict__
[
key
]
self
.
_add_attr
(
key
,
value
)
self
.
_attr_synced
=
True
def
__setattr__
(
self
,
name
,
value
):
def
__setattr__
(
self
,
name
,
value
):
cells
=
self
.
__dict__
.
get
(
'_cells'
)
cells
=
self
.
__dict__
.
get
(
'_cells'
)
params
=
self
.
__dict__
.
get
(
'_params'
)
params
=
self
.
__dict__
.
get
(
'_params'
)
...
@@ -329,6 +371,8 @@ class Cell:
...
@@ -329,6 +371,8 @@ class Cell:
if
isinstance
(
value
,
Primitive
):
if
isinstance
(
value
,
Primitive
):
value
.
set_prim_instance_name
(
name
)
value
.
set_prim_instance_name
(
name
)
object
.
__setattr__
(
self
,
name
,
value
)
object
.
__setattr__
(
self
,
name
,
value
)
if
name
not
in
Cell
.
IGNORE_LIST
:
self
.
_attr_synced
=
False
def
extend_repr
(
self
):
def
extend_repr
(
self
):
"""
"""
...
@@ -451,7 +495,7 @@ class Cell:
...
@@ -451,7 +495,7 @@ class Cell:
Object, the result of executing.
Object, the result of executing.
"""
"""
self
.
_auto_parallel_compile_and_run
=
True
self
.
_auto_parallel_compile_and_run
=
True
_executor
.
compile
(
self
,
*
inputs
,
phase
=
self
.
phase
,
auto_parallel_mode
=
self
.
_auto_parallel_mode
)
self
.
compile
(
*
inputs
)
if
self
.
_auto_parallel_mode
:
if
self
.
_auto_parallel_mode
:
if
inputs
and
isinstance
(
inputs
[
0
],
Tensor
)
and
inputs
[
0
].
virtual_flag
:
if
inputs
and
isinstance
(
inputs
[
0
],
Tensor
)
and
inputs
[
0
].
virtual_flag
:
...
...
mindspore/nn/layer/container.py
浏览文件 @
e6f82af8
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# ============================================================================
# ============================================================================
"""container"""
"""container"""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
from
..cell
import
Cell
from
..cell
import
Cell
__all__
=
[
'SequentialCell'
,
'CellList'
]
__all__
=
[
'SequentialCell'
,
'CellList'
]
...
@@ -34,7 +34,7 @@ def _valid_cell(cell):
...
@@ -34,7 +34,7 @@ def _valid_cell(cell):
raise
TypeError
(
'Cell {} is not subclass of Cell'
.
format
(
cell
))
raise
TypeError
(
'Cell {} is not subclass of Cell'
.
format
(
cell
))
class
_CellListBase
(
metaclass
=
ABCMeta
):
class
_CellListBase
():
"""
"""
An interface for base the cell as list.
An interface for base the cell as list.
...
...
tests/ut/python/parallel/test_get_parameter_layout.py
浏览文件 @
e6f82af8
...
@@ -51,7 +51,7 @@ def test_get_parameter_layout():
...
@@ -51,7 +51,7 @@ def test_get_parameter_layout():
exe
.
compile
(
net
,
x
,
phase
=
'train'
,
auto_parallel_mode
=
True
)
exe
.
compile
(
net
,
x
,
phase
=
'train'
,
auto_parallel_mode
=
True
)
x_layout
=
[[
2
,
4
],
[
1
,
-
1
],
[
16
,
32
],
[
0
],
[
1
]]
# device_arrangement = [2, 4], tensor_map = [1, -1]
x_layout
=
[[
2
,
4
],
[
1
,
-
1
],
[
16
,
32
],
[
0
],
[
1
]]
# device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout
=
[[
2
,
4
],
[
0
,
-
1
],
[
16
,
32
],
[
0
],
[
1
]]
# device_arrangement = [2, 4], tensor_map = [0, -1]
weight_layout
=
[[
2
,
4
],
[
0
,
-
1
],
[
16
,
32
],
[
0
],
[
1
]]
# device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict
=
{
'
x
'
:
x_layout
,
'w1'
:
weight_layout
}
expect_dict
=
{
'
args0
'
:
x_layout
,
'w1'
:
weight_layout
}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert
net
.
parameter_layout_dict
==
expect_dict
assert
net
.
parameter_layout_dict
==
expect_dict
...
...
tests/ut/python/parallel/test_split_grad_sens.py
浏览文件 @
e6f82af8
...
@@ -125,7 +125,7 @@ def test_grad_sens_parameter_type():
...
@@ -125,7 +125,7 @@ def test_grad_sens_parameter_type():
y_layout
=
[[
8
,
8
],
[
-
1
,
0
],
[
32
,
8
],
[
0
],
[
1
]]
y_layout
=
[[
8
,
8
],
[
-
1
,
0
],
[
32
,
8
],
[
0
],
[
1
]]
b_layout
=
[[
8
,
8
],
[
0
,
-
1
],
[
8
,
64
],
[
0
],
[
1
]]
b_layout
=
[[
8
,
8
],
[
0
,
-
1
],
[
8
,
64
],
[
0
],
[
1
]]
sens_layout
=
[[
8
,
8
],
[
1
,
-
1
],
[
16
,
64
],
[
0
],
[
1
]]
sens_layout
=
[[
8
,
8
],
[
1
,
-
1
],
[
16
,
64
],
[
0
],
[
1
]]
expect_dict
=
{
'
x'
:
x_layout
,
'y'
:
y_layout
,
'b'
:
b_layout
,
'sens
'
:
sens_layout
}
expect_dict
=
{
'
args0'
:
x_layout
,
'args1'
:
y_layout
,
'args2'
:
b_layout
,
'args3
'
:
sens_layout
}
assert
net
.
parameter_layout_dict
==
expect_dict
assert
net
.
parameter_layout_dict
==
expect_dict
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录