Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0e898137
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0e898137
编写于
5月 09, 2020
作者:
H
huangdongrun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add resolve
transform valuetuple to maketuple of graphs add testcase
上级
0eb32593
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
122 addition
and
39 deletion
+122
-39
mindspore/ccsrc/pipeline/parse/resolve.cc
mindspore/ccsrc/pipeline/parse/resolve.cc
+48
-39
tests/ut/python/ops/test_tuple.py
tests/ut/python/ops/test_tuple.py
+74
-0
未找到文件。
mindspore/ccsrc/pipeline/parse/resolve.cc
浏览文件 @
0e898137
...
...
@@ -170,51 +170,59 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
return
true
;
}
bool
IsAllGraphInValueSequence
(
const
std
::
vector
<
ValuePtr
>
&
value_vec
)
{
for
(
auto
&
elem
:
value_vec
)
{
if
(
elem
->
isa
<
ValueTuple
>
()
||
elem
->
isa
<
ValueList
>
())
{
const
auto
&
vec
=
GetValue
<
std
::
vector
<
ValuePtr
>>
(
elem
);
auto
is_graph
=
IsAllGraphInValueSequence
(
vec
);
if
(
!
is_graph
)
{
return
false
;
}
}
else
if
(
!
elem
->
isa
<
FuncGraph
>
())
{
return
false
;
}
}
return
true
;
}
AnfNodePtr
TransformToMakeTupleNodes
(
const
FuncGraphManagerPtr
&
manager
,
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
ValuePtr
>
&
value_vec
)
{
std
::
vector
<
AnfNodePtr
>
nodes
;
nodes
.
emplace_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
auto
&
elem
:
value_vec
)
{
AnfNodePtr
node
=
nullptr
;
if
(
elem
->
isa
<
ValueTuple
>
()
||
elem
->
isa
<
ValueList
>
())
{
const
auto
&
vec
=
GetValue
<
std
::
vector
<
ValuePtr
>>
(
elem
);
node
=
TransformToMakeTupleNodes
(
manager
,
func_graph
,
vec
);
}
else
if
(
elem
->
isa
<
FuncGraph
>
())
{
FuncGraphPtr
new_fg
=
elem
->
cast
<
FuncGraphPtr
>
();
manager
->
AddFuncGraph
(
new_fg
);
node
=
NewValueNode
(
new_fg
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"TransformToMakeTupleNodes error, expect funcgraph, got "
<<
elem
->
ToString
();
}
nodes
.
emplace_back
(
node
);
}
auto
cnode
=
func_graph
->
NewCNode
(
nodes
);
return
cnode
;
}
// transform the ValueTuple or ValueList of graph node to make tuple of const graph node
bool
TransformVectorGraphValueNode
(
const
FuncGraphManagerPtr
&
manager
,
const
AnfNodePtr
&
node
,
bool
TransformVectorGraphValueNode
(
const
FuncGraphManagerPtr
&
manager
,
const
FuncGraphPtr
&
func_graph
,
const
ValueNodePtr
&
value_node
,
AnfNodePtr
*
const
transformed
)
{
MS_EXCEPTION_IF_NULL
(
value_node
);
const
auto
&
value_vec
=
GetValue
<
std
::
vector
<
ValuePtr
>>
(
value_node
->
value
());
bool
has_graph_in_list
=
false
;
for
(
auto
&
elemv
:
value_vec
)
{
MS_EXCEPTION_IF_NULL
(
elemv
);
if
(
elemv
->
isa
<
FuncGraph
>
())
{
FuncGraphPtr
new_fg
=
elemv
->
cast
<
FuncGraphPtr
>
();
manager
->
AddFuncGraph
(
new_fg
);
has_graph_in_list
=
true
;
continue
;
}
if
(
has_graph_in_list
)
{
MS_LOG
(
EXCEPTION
)
<<
"List has graph in it, but not all is graph"
;
}
if
(
!
IsAllGraphInValueSequence
(
value_vec
))
{
return
false
;
}
// The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node.
if
(
has_graph_in_list
)
{
// change the vector of graph to be make_list of graph value node
std
::
vector
<
AnfNodePtr
>
list_vec
;
auto
make_list_op
=
NewValueNode
(
prim
::
kPrimMakeTuple
);
list_vec
.
emplace_back
(
make_list_op
);
(
void
)
std
::
transform
(
std
::
begin
(
value_vec
),
std
::
end
(
value_vec
),
std
::
back_inserter
(
list_vec
),
[](
const
ValuePtr
&
value
)
{
return
NewValueNode
(
value
);
});
FuncGraphPtr
cnode_graph
=
nullptr
;
auto
users
=
manager
->
node_users
()[
node
];
for
(
auto
&
use
:
users
)
{
auto
use_node
=
use
.
first
;
MS_EXCEPTION_IF_NULL
(
use_node
);
if
(
use_node
->
isa
<
CNode
>
())
{
cnode_graph
=
use_node
->
func_graph
();
}
}
if
(
cnode_graph
)
{
CNodePtr
list_app
=
cnode_graph
->
NewCNode
(
list_vec
);
// replace the ret ptr to be make_list of graph value node
*
transformed
=
list_app
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Can not find apply for node use when replacing node of vector of graph"
;
}
}
// we do this because the graphmanger won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node
auto
node_tuple_graphs
=
TransformToMakeTupleNodes
(
manager
,
func_graph
,
value_vec
);
// replace the ret ptr to be make tuple of graph value node
*
transformed
=
node_tuple_graphs
;
return
true
;
}
...
...
@@ -245,7 +253,8 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
// if the constant node is constant of vector of graph ,add graph to manager
if
(
IsValueNode
<
ValueTuple
>
(
resolved_node
)
||
IsValueNode
<
ValueList
>
(
resolved_node
))
{
(
void
)
TransformVectorGraphValueNode
(
manager
,
node
,
resolved_node
->
cast
<
ValueNodePtr
>
(),
&
resolved_node
);
(
void
)
TransformVectorGraphValueNode
(
manager
,
node
->
func_graph
(),
resolved_node
->
cast
<
ValueNodePtr
>
(),
&
resolved_node
);
}
TraceManager
::
EndTrace
();
...
...
tests/ut/python/ops/test_tuple.py
0 → 100644
浏览文件 @
0e898137
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
mindspore.context
as
context
import
functools
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore
import
context
from
..ut_filter
import
non_graph_engine
from
....mindspore_test_framework.mindspore_test
import
mindspore_test
from
....mindspore_test_framework.pipeline.forward.compile_forward
\
import
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
class
TupleGraphNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
TupleGraphNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
1
,
3
,
pad_mode
=
'same'
)
self
.
conv2
=
nn
.
Conv2d
(
3
,
1
,
7
,
pad_mode
=
'same'
)
self
.
conv3
=
nn
.
Conv2d
(
3
,
3
,
3
,
pad_mode
=
'same'
)
self
.
layers
=
(
self
.
conv1
,
self
.
conv2
,
self
.
conv3
)
def
construct
(
self
,
x
):
return
self
.
layers
[
0
](
x
)
class
NestTupleGraphNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NestTupleGraphNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
1
,
3
,
pad_mode
=
'same'
)
self
.
conv2
=
nn
.
Conv2d
(
3
,
1
,
7
,
pad_mode
=
'same'
)
self
.
conv3
=
nn
.
Conv2d
(
3
,
3
,
3
,
pad_mode
=
'same'
)
self
.
layers
=
((
self
.
conv1
,
self
.
conv2
),
(
self
.
conv2
,
self
.
conv1
,
self
.
conv3
))
def
construct
(
self
,
x
):
return
self
.
layers
[
0
][
1
](
x
)
test_case_ops
=
[
(
'TupleGraph'
,
{
'block'
:
TupleGraphNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
((
3
,
3
,
24
,
24
)),
mstype
.
float32
)]}),
(
'NestTupleGraph'
,
{
'block'
:
NestTupleGraphNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
((
3
,
3
,
24
,
24
)),
mstype
.
float32
)]}),
]
test_case_lists
=
[
test_case_ops
]
test_exec_case
=
functools
.
reduce
(
lambda
x
,
y
:
x
+
y
,
test_case_lists
)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
@
non_graph_engine
@
mindspore_test
(
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
)
def
test_exec
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
return
test_exec_case
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录