Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ffa33520
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ffa33520
编写于
6月 10, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix partial primitive poly node
上级
703c1b26
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
54 addition
and
12 deletion
+54
-12
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
...pore/ccsrc/pipeline/static_analysis/program_specialize.cc
+20
-10
tests/ut/python/ops/test_ops_attr_infer.py
tests/ut/python/ops/test_ops_attr_infer.py
+34
-2
未找到文件。
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
浏览文件 @
ffa33520
...
...
@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
}
auto
real_eval
=
dyn_cast
<
BaseFuncGraphEvaluator
>
(
eval
);
if
(
func
->
context
()
!=
nullptr
)
{
if
(
!
IsVisible
(
func_graph_
,
func
->
context
()
->
func_graph
()))
{
MS_LOG
(
EXCEPTION
)
<<
"Func is not visible NodeInfo: "
<<
trace
::
GetDebugInfo
(
func_graph_
->
debug_info
());
}
}
else
{
if
(
func
->
context
()
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Func context is nullptr NodeInfo: "
<<
trace
::
GetDebugInfo
(
func_graph_
->
debug_info
());
}
AnalysisContextPtr
context
=
real_eval
->
MakeContext
(
engine_
,
argvals
);
...
...
@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
// First element is partial, second is func so arg is start from 2
(
void
)
args
.
insert
(
args
.
begin
(),
inputs
.
begin
()
+
2
,
inputs
.
end
());
func
=
inputs
[
1
];
new_inputs
=
args
;
(
void
)
new_inputs
.
insert
(
new_inputs
.
begin
(),
func
);
}
new_inputs
=
args
;
(
void
)
new_inputs
.
insert
(
new_inputs
.
begin
(),
func
);
AbstractBasePtrList
argvals
;
MS_EXCEPTION_IF_NULL
(
new_inputs
[
0
]);
...
...
@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
<<
new_inputs
[
i
]
->
DebugString
()
<<
", abstract: "
<<
new_inputs
[
i
]
->
abstract
()
->
ToString
();
}
if
(
func
->
isa
<
Parameter
>
()
&&
func
->
func_graph
()
->
has_flag
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
))
{
auto
wrapped_node
=
BuildSpecializedParameterNode
(
new_node
);
new_inputs
[
0
]
=
wrapped_node
;
if
(
!
func
->
isa
<
ValueNode
>
())
{
MS_LOG
(
DEBUG
)
<<
func
->
abstract
()
->
type_name
()
<<
" | "
<<
func
->
abstract
()
->
ToString
();
if
(
func
->
abstract
()
->
isa
<
AbstractFunction
>
()
&&
!
func
->
abstract
()
->
isa
<
AbstractFuncUnion
>
())
{
auto
func_abs
=
func
->
abstract
()
->
cast
<
AbstractFunctionPtr
>
();
EvaluatorPtr
eval
=
engine_
->
GetEvaluatorFor
(
func_abs
);
std
::
pair
<
AbstractBasePtrList
,
AbstractBasePtr
>
result
;
AbstractBasePtrList
empty_args
;
auto
status
=
FindUniqueArgvals
(
func_abs
,
eval
,
empty_args
,
&
result
);
MS_LOG
(
DEBUG
)
<<
"FindUniqueArgvals return status: "
<<
status
;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if
(
status
==
kSpecializeFindUniqueArgvalPoly
||
(
func
->
isa
<
Parameter
>
()
&&
(
func
->
func_graph
()
->
has_flag
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
)
||
func
->
abstract
()
->
isa
<
PartialAbstractClosure
>
())))
{
auto
wrapped_node
=
BuildSpecializedParameterNode
(
new_node
);
new_inputs
[
0
]
=
wrapped_node
;
}
}
}
if
(
CanSpecializeNode
(
func
))
{
...
...
tests/ut/python/ops/test_ops_attr_infer.py
浏览文件 @
ffa33520
...
...
@@ -14,9 +14,12 @@
# ============================================================================
""" test nn ops """
import
numpy
as
np
from
numpy.random
import
normal
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore.ops.composite
import
core
from
mindspore.common.api
import
ms_function
from
mindspore
import
Tensor
from
mindspore.ops
import
functional
as
F
...
...
@@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
net
(
t1
,
t2
)
# test free variable function list as parameter
def
test_remove_and_fv_2
():
@
core
(
loop_can_uroll
=
True
)
def
inner_loop
(
x
,
input_data
,
fv_func_list
):
ret
=
()
for
fv_fn
in
fv_func_list
:
ele
=
fv_fn
(
input_data
)
ret
+=
(
ele
,)
return
ret
@
ms_function
def
out_loop
(
input1
,
input_data
):
ret
=
()
def
fv_func1
(
y
):
return
input1
*
y
def
fv_func2
(
y
):
return
input1
-
y
fv_func_list
=
[
fv_func1
,
fv_func2
]
ele0
=
inner_loop
(
input1
,
input_data
[
0
],
fv_func_list
)
ele1
=
inner_loop
(
input1
,
input_data
[
1
],
fv_func_list
)
ret
=
(
ele0
,
ele1
)
return
ret
input_data
=
(
Tensor
(
normal
(
0
,
0.1
,
(
3
,
3
))),
Tensor
(
normal
(
0
,
0.1
,
(
3
,
1
))))
input1
=
Tensor
(
normal
(
0
,
0.1
,
(
3
,
3
)))
out_loop
(
input1
,
input_data
)
# test cell as high order argument
# The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system
def
Xtest_conv2d_op_with_arg
():
def
test_conv2d_op_with_argi_1
():
class
Conv2dNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Conv2dNet
,
self
).
__init__
()
...
...
@@ -279,7 +311,7 @@ def test_op_with_arg_as_input():
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
def
X
test_partial_as_arg
():
def
test_partial_as_arg
():
class
PartialArgNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
PartialArgNet
,
self
).
__init__
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录