Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
45ad430a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45ad430a
编写于
7月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3048 support use valuelist or valuetuple of primitives
Merge pull request !3048 from amongo/SupportPrimitiveList
上级
3bb04abc
ee2039fb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
73 addition
and
11 deletion
+73
-11
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
+15
-11
tests/ut/python/pipeline/parse/test_for_stmt.py
tests/ut/python/pipeline/parse/test_for_stmt.py
+58
-0
未找到文件。
mindspore/ccsrc/pipeline/jit/parse/resolve.cc
浏览文件 @
45ad430a
...
...
@@ -168,15 +168,15 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
return
true
;
}
bool
IsAll
Graph
InValueSequence
(
const
std
::
vector
<
ValuePtr
>
&
value_vec
)
{
bool
IsAll
Func
InValueSequence
(
const
std
::
vector
<
ValuePtr
>
&
value_vec
)
{
for
(
auto
&
elem
:
value_vec
)
{
if
(
elem
->
isa
<
ValueTuple
>
()
||
elem
->
isa
<
ValueList
>
())
{
const
auto
&
vec
=
GetValue
<
std
::
vector
<
ValuePtr
>>
(
elem
);
auto
is_graph
=
IsAll
Graph
InValueSequence
(
vec
);
auto
is_graph
=
IsAll
Func
InValueSequence
(
vec
);
if
(
!
is_graph
)
{
return
false
;
}
}
else
if
(
!
elem
->
isa
<
FuncGraph
>
())
{
}
else
if
(
!
elem
->
isa
<
FuncGraph
>
()
&&
!
elem
->
isa
<
Primitive
>
()
)
{
return
false
;
}
}
...
...
@@ -196,6 +196,8 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
FuncGraphPtr
new_fg
=
elem
->
cast
<
FuncGraphPtr
>
();
manager
->
AddFuncGraph
(
new_fg
);
node
=
NewValueNode
(
new_fg
);
}
else
if
(
elem
->
isa
<
Primitive
>
())
{
node
=
NewValueNode
(
elem
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"TransformToMakeTupleNodes error, expect funcgraph, got "
<<
elem
->
ToString
();
}
...
...
@@ -205,19 +207,21 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
return
cnode
;
}
// transform the ValueTuple or ValueList of graph
node to make tuple of const graph
node
bool
TransformVector
Graph
ValueNode
(
const
FuncGraphManagerPtr
&
manager
,
const
FuncGraphPtr
&
func_graph
,
const
ValueNodePtr
&
value_node
,
AnfNodePtr
*
const
transformed
)
{
// transform the ValueTuple or ValueList of graph
/primitve node to make tuple of const graph/primitve
node
bool
TransformVector
Func
ValueNode
(
const
FuncGraphManagerPtr
&
manager
,
const
FuncGraphPtr
&
func_graph
,
const
ValueNodePtr
&
value_node
,
AnfNodePtr
*
const
transformed
)
{
MS_EXCEPTION_IF_NULL
(
value_node
);
const
auto
&
value_vec
=
GetValue
<
std
::
vector
<
ValuePtr
>>
(
value_node
->
value
());
if
(
!
IsAll
Graph
InValueSequence
(
value_vec
))
{
if
(
!
IsAll
Func
InValueSequence
(
value_vec
))
{
return
false
;
}
// The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
//
(1)
The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node.
// we do this because the graphmanger won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node
// change the vector of graph to be make_tuple of graph value node.
// (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all
// independent nodes.
auto
node_tuple_graphs
=
TransformToMakeTupleNodes
(
manager
,
func_graph
,
value_vec
);
// replace the ret ptr to be make tuple of graph value node
*
transformed
=
node_tuple_graphs
;
...
...
@@ -251,8 +255,8 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
// if the constant node is constant of vector of graph ,add graph to manager
if
(
IsValueNode
<
ValueTuple
>
(
resolved_node
)
||
IsValueNode
<
ValueList
>
(
resolved_node
))
{
(
void
)
TransformVector
Graph
ValueNode
(
manager
,
node
->
func_graph
(),
resolved_node
->
cast
<
ValueNodePtr
>
(),
&
resolved_node
);
(
void
)
TransformVector
Func
ValueNode
(
manager
,
node
->
func_graph
(),
resolved_node
->
cast
<
ValueNodePtr
>
(),
&
resolved_node
);
}
TraceManager
::
EndTrace
();
...
...
tests/ut/python/pipeline/parse/test_for_stmt.py
浏览文件 @
45ad430a
...
...
@@ -17,6 +17,9 @@ from dataclasses import dataclass
import
numpy
as
np
from
mindspore
import
Tensor
,
Model
,
context
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore.nn
import
Cell
from
mindspore.nn
import
ReLU
from
...ut_filter
import
non_graph_engine
...
...
@@ -66,3 +69,58 @@ def function_access_base(number):
def
test_access_0040
():
""" test_access_0040 """
function_access_base
(
2
)
class
OpSeqNet
(
Cell
):
def
__init__
(
self
,
loop_count
=
1
):
super
().
__init__
()
self
.
loop_count
=
loop_count
self
.
op_seq
=
(
P
.
Sqrt
(),
P
.
Reciprocal
(),
P
.
Square
())
def
construct
(
self
,
x
):
t
=
x
for
op
in
self
.
op_seq
:
t
=
op
(
t
)
return
t
def
test_op_seq_test
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
OpSeqNet
()
input_np
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
)
input_me
=
Tensor
(
input_np
)
net
(
input_me
)
_grad_fusion
=
C
.
MultitypeFuncGraph
(
"grad_fushion"
)
@
_grad_fusion
.
register
(
"Tensor"
,
"Function"
)
def
tensor_grad_scale
(
x
,
op
):
return
op
(
x
)
class
AllReduceTest
(
Cell
):
def
__init__
(
self
,
loop_count
=
1
):
super
().
__init__
()
self
.
op_list
=
()
self
.
fushion_flag
=
[
0
,
1
,
1
,
0
,
1
,
0
]
for
i
in
self
.
fushion_flag
:
op
=
P
.
AllReduce
().
add_prim_attr
(
'fusion'
,
i
)
self
.
op_list
=
self
.
op_list
+
(
op
,)
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
x
):
ret
=
()
for
_
in
self
.
fushion_flag
:
ret
=
ret
+
(
x
,)
fushion_res
=
self
.
hyper_map
(
F
.
partial
(
_grad_fusion
),
ret
,
self
.
op_list
)
return
fushion_res
def
test_allreduce_fushio_test
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
AllReduceTest
()
input_np
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
)
input_me
=
Tensor
(
input_np
)
net
(
input_me
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录