Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ffa33520
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ffa33520
编写于
6月 10, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix partial primitive poly node
上级
703c1b26
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
54 addition
and
12 deletion
+54
-12
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
...pore/ccsrc/pipeline/static_analysis/program_specialize.cc
+20
-10
tests/ut/python/ops/test_ops_attr_infer.py
tests/ut/python/ops/test_ops_attr_infer.py
+34
-2
未找到文件。
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
浏览文件 @
ffa33520
...
@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
...
@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
}
}
auto
real_eval
=
dyn_cast
<
BaseFuncGraphEvaluator
>
(
eval
);
auto
real_eval
=
dyn_cast
<
BaseFuncGraphEvaluator
>
(
eval
);
if
(
func
->
context
()
!=
nullptr
)
{
if
(
func
->
context
()
==
nullptr
)
{
if
(
!
IsVisible
(
func_graph_
,
func
->
context
()
->
func_graph
()))
{
MS_LOG
(
EXCEPTION
)
<<
"Func is not visible NodeInfo: "
<<
trace
::
GetDebugInfo
(
func_graph_
->
debug_info
());
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Func context is nullptr NodeInfo: "
<<
trace
::
GetDebugInfo
(
func_graph_
->
debug_info
());
MS_LOG
(
EXCEPTION
)
<<
"Func context is nullptr NodeInfo: "
<<
trace
::
GetDebugInfo
(
func_graph_
->
debug_info
());
}
}
AnalysisContextPtr
context
=
real_eval
->
MakeContext
(
engine_
,
argvals
);
AnalysisContextPtr
context
=
real_eval
->
MakeContext
(
engine_
,
argvals
);
...
@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
...
@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
// First element is partial, second is func so arg is start from 2
// First element is partial, second is func so arg is start from 2
(
void
)
args
.
insert
(
args
.
begin
(),
inputs
.
begin
()
+
2
,
inputs
.
end
());
(
void
)
args
.
insert
(
args
.
begin
(),
inputs
.
begin
()
+
2
,
inputs
.
end
());
func
=
inputs
[
1
];
func
=
inputs
[
1
];
new_inputs
=
args
;
(
void
)
new_inputs
.
insert
(
new_inputs
.
begin
(),
func
);
}
}
new_inputs
=
args
;
(
void
)
new_inputs
.
insert
(
new_inputs
.
begin
(),
func
);
AbstractBasePtrList
argvals
;
AbstractBasePtrList
argvals
;
MS_EXCEPTION_IF_NULL
(
new_inputs
[
0
]);
MS_EXCEPTION_IF_NULL
(
new_inputs
[
0
]);
...
@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
...
@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
<<
new_inputs
[
i
]
->
DebugString
()
<<
", abstract: "
<<
new_inputs
[
i
]
->
abstract
()
->
ToString
();
<<
new_inputs
[
i
]
->
DebugString
()
<<
", abstract: "
<<
new_inputs
[
i
]
->
abstract
()
->
ToString
();
}
}
if
(
func
->
isa
<
Parameter
>
()
&&
func
->
func_graph
()
->
has_flag
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
))
{
if
(
!
func
->
isa
<
ValueNode
>
())
{
auto
wrapped_node
=
BuildSpecializedParameterNode
(
new_node
);
MS_LOG
(
DEBUG
)
<<
func
->
abstract
()
->
type_name
()
<<
" | "
<<
func
->
abstract
()
->
ToString
();
new_inputs
[
0
]
=
wrapped_node
;
if
(
func
->
abstract
()
->
isa
<
AbstractFunction
>
()
&&
!
func
->
abstract
()
->
isa
<
AbstractFuncUnion
>
())
{
auto
func_abs
=
func
->
abstract
()
->
cast
<
AbstractFunctionPtr
>
();
EvaluatorPtr
eval
=
engine_
->
GetEvaluatorFor
(
func_abs
);
std
::
pair
<
AbstractBasePtrList
,
AbstractBasePtr
>
result
;
AbstractBasePtrList
empty_args
;
auto
status
=
FindUniqueArgvals
(
func_abs
,
eval
,
empty_args
,
&
result
);
MS_LOG
(
DEBUG
)
<<
"FindUniqueArgvals return status: "
<<
status
;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if
(
status
==
kSpecializeFindUniqueArgvalPoly
||
(
func
->
isa
<
Parameter
>
()
&&
(
func
->
func_graph
()
->
has_flag
(
FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER
)
||
func
->
abstract
()
->
isa
<
PartialAbstractClosure
>
())))
{
auto
wrapped_node
=
BuildSpecializedParameterNode
(
new_node
);
new_inputs
[
0
]
=
wrapped_node
;
}
}
}
}
if
(
CanSpecializeNode
(
func
))
{
if
(
CanSpecializeNode
(
func
))
{
...
...
tests/ut/python/ops/test_ops_attr_infer.py
浏览文件 @
ffa33520
...
@@ -14,9 +14,12 @@
...
@@ -14,9 +14,12 @@
# ============================================================================
# ============================================================================
""" test nn ops """
""" test nn ops """
import
numpy
as
np
import
numpy
as
np
from
numpy.random
import
normal
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore.ops.composite
import
core
from
mindspore.common.api
import
ms_function
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
functional
as
F
...
@@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
...
@@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
net
(
t1
,
t2
)
net
(
t1
,
t2
)
# test free variable function list as parameter
def
test_remove_and_fv_2
():
@
core
(
loop_can_uroll
=
True
)
def
inner_loop
(
x
,
input_data
,
fv_func_list
):
ret
=
()
for
fv_fn
in
fv_func_list
:
ele
=
fv_fn
(
input_data
)
ret
+=
(
ele
,)
return
ret
@
ms_function
def
out_loop
(
input1
,
input_data
):
ret
=
()
def
fv_func1
(
y
):
return
input1
*
y
def
fv_func2
(
y
):
return
input1
-
y
fv_func_list
=
[
fv_func1
,
fv_func2
]
ele0
=
inner_loop
(
input1
,
input_data
[
0
],
fv_func_list
)
ele1
=
inner_loop
(
input1
,
input_data
[
1
],
fv_func_list
)
ret
=
(
ele0
,
ele1
)
return
ret
input_data
=
(
Tensor
(
normal
(
0
,
0.1
,
(
3
,
3
))),
Tensor
(
normal
(
0
,
0.1
,
(
3
,
1
))))
input1
=
Tensor
(
normal
(
0
,
0.1
,
(
3
,
3
)))
out_loop
(
input1
,
input_data
)
# test cell as high order argument
# test cell as high order argument
# The graph with free variables used as argument is not supported yet
# The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system
# because of the limit of inference specialize system
def
Xtest_conv2d_op_with_arg
():
def
test_conv2d_op_with_argi_1
():
class
Conv2dNet
(
nn
.
Cell
):
class
Conv2dNet
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Conv2dNet
,
self
).
__init__
()
super
(
Conv2dNet
,
self
).
__init__
()
...
@@ -279,7 +311,7 @@ def test_op_with_arg_as_input():
...
@@ -279,7 +311,7 @@ def test_op_with_arg_as_input():
# The partial application used as argument is not supported yet
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
# because of the limit of inference specialize system
def
X
test_partial_as_arg
():
def
test_partial_as_arg
():
class
PartialArgNet
(
nn
.
Cell
):
class
PartialArgNet
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
PartialArgNet
,
self
).
__init__
()
super
(
PartialArgNet
,
self
).
__init__
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录