Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8003a89a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8003a89a
编写于
5月 12, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
!766 bugfix(SA): Add the support of nested loop.
Merge pull request !766 from gongchen/nest_loop
上级
fd72534a
425a2076
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
190 addition
and
25 deletion
+190
-25
mindspore/ccsrc/optimizer/optimizer.h
mindspore/ccsrc/optimizer/optimizer.h
+11
-11
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+12
-1
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
+24
-0
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
+6
-0
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
+67
-10
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
+2
-0
tests/ut/python/pynative_mode/test_framstruct.py
tests/ut/python/pynative_mode/test_framstruct.py
+52
-1
tests/ut/python/pynative_mode/test_multigraph_sink.py
tests/ut/python/pynative_mode/test_multigraph_sink.py
+16
-2
未找到文件。
mindspore/ccsrc/optimizer/optimizer.h
浏览文件 @
8003a89a
...
@@ -27,14 +27,13 @@
...
@@ -27,14 +27,13 @@
#include <utility>
#include <utility>
#include <initializer_list>
#include <initializer_list>
#ifdef DEBUG
#include "debug/draw.h"
#include "debug/draw.h"
#include "debug/anf_ir_dump.h"
#include "debug/anf_ir_dump.h"
#
endif
#
include "debug/trace.h"
#include "optimizer/opt.h"
#include "optimizer/opt.h"
#include "pipeline/resource.h"
#include "pipeline/resource.h"
#include "pipeline/action.h"
#include "pipeline/action.h"
#include "
debug/trace
.h"
#include "
utils/context/ms_context
.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
@@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
...
@@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
FuncGraphPtr
step
(
FuncGraphPtr
func_graph
,
bool
use_profile
=
true
)
{
FuncGraphPtr
step
(
FuncGraphPtr
func_graph
,
bool
use_profile
=
true
)
{
// Optimizer step counter;
// Optimizer step counter;
int
counter
=
1
;
int
counter
=
-
1
;
bool
changes
=
true
;
bool
changes
=
true
;
while
(
changes
)
{
while
(
changes
)
{
...
@@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
...
@@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
}
};
};
use_profile
?
(
WITH
(
MsProfile
::
GetProfile
()
->
Step
(
pass_names_
[
i
]))
opt_func
)
:
opt_func
();
use_profile
?
(
WITH
(
MsProfile
::
GetProfile
()
->
Step
(
pass_names_
[
i
]))
opt_func
)
:
opt_func
();
#ifdef DEBUG
if
(
IS_OUTPUT_ON
(
mindspore
::
DEBUG
)
&&
MsContext
::
GetInstance
()
->
save_graphs_flag
())
{
MS_LOG
(
DEBUG
)
<<
name_
<<
" round "
<<
counter
<<
" OptPass "
<<
pass_names_
[
i
]
<<
" end."
;
MS_LOG
(
DEBUG
)
<<
name_
<<
" round "
<<
counter
<<
" OptPass "
<<
pass_names_
[
i
]
<<
" end."
;
auto
fg_name
=
name_
+
"_r"
+
std
::
to_string
(
counter
)
+
"_"
+
std
::
to_string
(
i
)
+
"_"
+
pass_names_
[
i
];
auto
fg_name
=
func_graph
->
DumpFuncGraph
(
fg_name
);
"opt_substep_"
+
name_
+
"_r"
+
std
::
to_string
(
counter
)
+
"_"
+
std
::
to_string
(
i
)
+
"_"
+
pass_names_
[
i
];
DumpIR
(
fg_name
+
".ir"
,
func_graph
);
func_graph
->
DumpFuncGraph
(
fg_name
);
MS_LOG
(
DEBUG
)
<<
"Dump "
<<
pass_names_
[
i
]
<<
" func graph."
;
DumpIR
(
fg_name
+
".ir"
,
func_graph
);
#endif
MS_LOG
(
DEBUG
)
<<
"Dump "
<<
pass_names_
[
i
]
<<
" func graph."
;
}
}
}
};
};
use_profile
?
(
WITH
(
MsProfile
::
GetProfile
()
->
Lap
(
counter
++
))
run_runc
)
:
run_runc
();
use_profile
?
(
WITH
(
MsProfile
::
GetProfile
()
->
Lap
(
counter
++
))
run_runc
)
:
run_runc
();
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
8003a89a
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include "pipeline/static_analysis/static_analysis.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "pipeline/static_analysis/program_specialize.h"
#include "pipeline/static_analysis/program_specialize.h"
#include "pipeline/resource.h"
#include "pipeline/resource.h"
#include "utils/context/ms_context.h"
#include "pipeline/remove_value_node_dup.h"
#include "pipeline/remove_value_node_dup.h"
#include "optimizer/optimizer.h"
#include "optimizer/optimizer.h"
#include "vm/transform.h"
#include "vm/transform.h"
...
@@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
...
@@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
}
bool
OptimizeAction
(
const
ResourcePtr
&
res
,
const
std
::
vector
<
PassItem
>
&
passes
)
{
bool
OptimizeAction
(
const
ResourcePtr
&
res
,
const
std
::
vector
<
PassItem
>
&
passes
)
{
size_t
counter
=
0
;
for
(
auto
&
pass
:
passes
)
{
for
(
auto
&
pass
:
passes
)
{
WITH
(
MsProfile
::
GetProfile
()
->
Step
(
pass
.
first
))[
&
pass
,
&
res
]()
{
WITH
(
MsProfile
::
GetProfile
()
->
Step
(
pass
.
first
))[
&
pass
,
&
res
,
&
counter
]()
{
MS_LOG
(
DEBUG
)
<<
"Pass "
<<
pass
.
first
<<
" start ..."
;
MS_LOG
(
DEBUG
)
<<
"Pass "
<<
pass
.
first
<<
" start ..."
;
auto
result
=
pass
.
second
(
res
);
auto
result
=
pass
.
second
(
res
);
if
(
!
result
)
{
if
(
!
result
)
{
MS_LOG
(
EXCEPTION
)
<<
"Pass running to end, failed in pass:"
<<
pass
.
first
;
MS_LOG
(
EXCEPTION
)
<<
"Pass running to end, failed in pass:"
<<
pass
.
first
;
}
}
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
()
&&
res
->
func_graph
()
!=
nullptr
)
{
auto
fg_name
=
"opt_pass_"
+
std
::
to_string
(
counter
)
+
"_"
+
pass
.
first
;
auto
func_graph
=
res
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
func_graph
->
DumpFuncGraph
(
fg_name
);
DumpIR
(
fg_name
+
".ir"
,
func_graph
);
MS_LOG
(
DEBUG
)
<<
"Dump "
<<
fg_name
<<
" func graph."
;
}
counter
++
;
MS_LOG
(
DEBUG
)
<<
"Pass "
<<
pass
.
first
<<
" end."
;
MS_LOG
(
DEBUG
)
<<
"Pass "
<<
pass
.
first
<<
" end."
;
};
};
}
}
...
...
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
浏览文件 @
8003a89a
...
@@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &
...
@@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &
AnalysisContextPtr
BaseFuncGraphEvaluator
::
MakeContext
(
const
AnalysisEnginePtr
&
engine
,
AnalysisContextPtr
BaseFuncGraphEvaluator
::
MakeContext
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
AbstractBasePtrList
&
args_spec_list
)
{
AbstractBasePtrList
normalized_args_spec_list
=
NormalizeArgs
(
args_spec_list
);
AbstractBasePtrList
normalized_args_spec_list
=
NormalizeArgs
(
args_spec_list
);
normalized_args_spec_list
=
BroadenUndeterminedArgs
(
normalized_args_spec_list
);
FuncGraphPtr
fg
=
GetFuncGraph
(
engine
,
normalized_args_spec_list
);
FuncGraphPtr
fg
=
GetFuncGraph
(
engine
,
normalized_args_spec_list
);
MS_EXCEPTION_IF_NULL
(
parent_context_
);
MS_EXCEPTION_IF_NULL
(
parent_context_
);
AnalysisContextPtr
context
=
parent_context_
->
NewFuncGraphContext
(
fg
,
normalized_args_spec_list
);
AnalysisContextPtr
context
=
parent_context_
->
NewFuncGraphContext
(
fg
,
normalized_args_spec_list
);
...
@@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
...
@@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
<<
", broaded: "
<<
mindspore
::
ToString
(
broaded_list
);
<<
", broaded: "
<<
mindspore
::
ToString
(
broaded_list
);
return
broaded_list
;
return
broaded_list
;
}
}
return
args_spec_list
;
}
AbstractBasePtrList
FuncGraphEvaluator
::
BroadenUndeterminedArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
{
MS_EXCEPTION_IF_NULL
(
func_graph_
);
if
(
func_graph_
->
has_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
))
{
return
args_spec_list
;
}
if
(
func_graph_
->
has_flag
(
kFuncGraphFlagUndetermined
))
{
if
(
func_graph_
->
has_flag
(
kFuncGraphFlagUndetermined
))
{
if
(
parent_context_
)
{
if
(
parent_context_
)
{
MS_LOG
(
DEBUG
)
<<
"Undeterminate FuncGraphEvaluator "
<<
ToString
()
MS_LOG
(
DEBUG
)
<<
"Undeterminate FuncGraphEvaluator "
<<
ToString
()
...
@@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
...
@@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
return
joined_args_spec_list
;
return
joined_args_spec_list
;
}
}
}
}
if
(
trace_
.
size
()
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"Current eval args: "
<<
::
mindspore
::
ToString
(
args_spec_list
);
MS_LOG
(
DEBUG
)
<<
"Last eval args: "
<<
::
mindspore
::
ToString
(
trace_
.
back
());
// Join the last eval arguments and current arguments to check if there are loop variant.
auto
joined_args_spec_list
=
AbstractJoin
(
args_spec_list
,
trace_
.
back
());
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if
(
!
(
joined_args_spec_list
==
args_spec_list
))
{
trace_
.
push_back
(
joined_args_spec_list
);
func_graph_
->
set_flags
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
}
MS_LOG
(
DEBUG
)
<<
"Joined eval args: "
<<
::
mindspore
::
ToString
(
joined_args_spec_list
);
return
joined_args_spec_list
;
}
else
{
trace_
.
push_back
(
args_spec_list
);
}
}
}
return
args_spec_list
;
return
args_spec_list
;
}
}
...
@@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
...
@@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
return
conf
->
GetEvaluatedValue
();
return
conf
->
GetEvaluatedValue
();
});
});
args_spec_list
=
NormalizeArgs
(
args_spec_list
);
args_spec_list
=
NormalizeArgs
(
args_spec_list
);
args_spec_list
=
BroadenUndeterminedArgs
(
args_spec_list
);
trace
::
TraceGraphInferEnter
(
shared_from_base
<
Evaluator
>
(),
out_conf
);
trace
::
TraceGraphInferEnter
(
shared_from_base
<
Evaluator
>
(),
out_conf
);
InferEntryLogging
(
shared_from_base
<
Evaluator
>
(),
args_spec_list
,
out_conf
);
InferEntryLogging
(
shared_from_base
<
Evaluator
>
(),
args_spec_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
cache_
);
MS_EXCEPTION_IF_NULL
(
cache_
);
...
...
mindspore/ccsrc/pipeline/static_analysis/evaluator.h
浏览文件 @
8003a89a
...
@@ -47,6 +47,10 @@ class Evaluator : public Base {
...
@@ -47,6 +47,10 @@ class Evaluator : public Base {
virtual
AbstractBasePtrList
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
return
args_spec_list
;
}
virtual
AbstractBasePtrList
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
{
return
args_spec_list
;
}
virtual
AbstractBasePtrList
BroadenUndeterminedArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
{
return
args_spec_list
;
}
std
::
string
ToString
()
const
override
{
return
identifier_
;
}
std
::
string
ToString
()
const
override
{
return
identifier_
;
}
virtual
AnfNodePtr
bound_node
()
const
{
return
bound_node_
.
lock
();
}
virtual
AnfNodePtr
bound_node
()
const
{
return
bound_node_
.
lock
();
}
...
@@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
...
@@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
FuncGraphPtr
func_graph
()
{
return
func_graph_
;
}
FuncGraphPtr
func_graph
()
{
return
func_graph_
;
}
AbstractBasePtrList
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
override
;
AbstractBasePtrList
NormalizeArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
const
override
;
AbstractBasePtrList
BroadenUndeterminedArgs
(
const
AbstractBasePtrList
&
args_spec_list
)
override
;
std
::
string
ToString
()
const
override
{
return
identifier_
+
"_"
+
func_graph_
->
ToString
();
}
std
::
string
ToString
()
const
override
{
return
identifier_
+
"_"
+
func_graph_
->
ToString
();
}
private:
private:
FuncGraphPtr
func_graph_
;
FuncGraphPtr
func_graph_
;
std
::
unordered_map
<
AbstractBasePtrList
,
FuncGraphPtr
,
AbstractBasePtrListHasher
,
AbstractBasePtrListEqual
>
std
::
unordered_map
<
AbstractBasePtrList
,
FuncGraphPtr
,
AbstractBasePtrListHasher
,
AbstractBasePtrListEqual
>
func_graph_cache_
;
func_graph_cache_
;
std
::
vector
<
AbstractBasePtrList
>
trace_
;
};
};
using
FuncGraphEvaluatorPtr
=
std
::
shared_ptr
<
FuncGraphEvaluator
>
;
using
FuncGraphEvaluatorPtr
=
std
::
shared_ptr
<
FuncGraphEvaluator
>
;
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
浏览文件 @
8003a89a
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "pipeline/static_analysis/static_analysis.h"
#include "pipeline/static_analysis/static_analysis.h"
#include <algorithm>
#include <algorithm>
#include <set>
#include "pipeline/static_analysis/utils.h"
#include "pipeline/static_analysis/utils.h"
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/prim.h"
...
@@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC
...
@@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC
for
(
std
::
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
for
(
std
::
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
const
AnfNodePtr
&
node
=
inputs
[
i
];
const
AnfNodePtr
&
node
=
inputs
[
i
];
args_conf_list
.
push_back
(
MakeConfig
(
node
,
context
));
args_conf_list
.
push_back
(
MakeConfig
(
node
,
context
));
MS_LOG
(
DEBUG
)
<<
"Current CNode args_conf_list["
<<
i
<<
"] node: "
<<
node
->
DebugString
();
}
}
std
::
vector
<
EvaluatorPtr
>
infs
;
std
::
vector
<
EvaluatorPtr
>
infs
;
...
@@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
...
@@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
const
AnfNodeConfigPtr
&
out_conf
,
const
AnfNodeConfigPtr
&
out_conf
,
const
ConfigPtrList
&
args_conf_list
)
{
const
ConfigPtrList
&
args_conf_list
)
{
AbstractBasePtrList
out_specs
;
AbstractBasePtrList
out_specs
;
if
(
!
multi_poss_
.
count
(
evaluators
[
0
]))
{
multi_poss_
[
evaluators
[
0
]]
=
evaluators
[
1
];
multi_poss_
[
evaluators
[
1
]]
=
evaluators
[
0
];
}
AbstractBasePtrList
args_spec_list
;
AbstractBasePtrList
args_spec_list
;
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
[](
const
ConfigPtr
&
conf
)
->
AbstractBasePtr
{
...
@@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
...
@@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
for
(
auto
eval
:
evaluators
)
{
for
(
auto
eval
:
evaluators
)
{
auto
fg_eval
=
eval
->
cast
<
FuncGraphEvaluatorPtr
>
();
auto
fg_eval
=
eval
->
cast
<
FuncGraphEvaluatorPtr
>
();
if
(
fg_eval
)
{
if
(
fg_eval
)
{
auto
undetermined_fgs
=
fg_eval
->
func_graph
()
->
recursive_graphs
();
auto
fg
=
fg_eval
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
fg
);
auto
undetermined_fgs
=
fg
->
recursive_graphs
();
if
(
undetermined_fgs
)
{
if
(
undetermined_fgs
)
{
for
(
auto
undetermined_fg
:
*
undetermined_fgs
)
{
auto
fg_parent
=
fg
->
parent
();
MS_LOG
(
DEBUG
)
<<
"Set graph undetermined: "
<<
undetermined_fg
->
ToString
();
MS_EXCEPTION_IF_NULL
(
fg_parent
);
// As the current evaluator has multiple possibles, all the func_graphs which
fg_parent
->
set_flags
(
kFuncGraphFlagUndetermined
,
true
);
// are recursive with the current func_graph are undetermined in control flow.
MS_LOG
(
DEBUG
)
<<
"Set graph undetermined: "
<<
fg_parent
->
ToString
();
undetermined_fg
->
set_flags
(
kFuncGraphFlagUndetermined
,
true
);
}
}
}
}
}
auto
current_inf
=
std
::
make_pair
(
eval
,
args_spec_list
);
auto
current_inf
=
std
::
make_pair
(
eval
,
args_spec_list
);
MS_LOG
(
DEBUG
)
<<
"Check Evaluator "
<<
eval
->
ToString
();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring.
// If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring.
auto
it
=
std
::
find
(
eval_trace_
.
begin
(),
eval_trace_
.
end
(),
current_inf
);
auto
it
=
std
::
find
(
eval_trace_
.
rbegin
(),
eval_trace_
.
r
end
(),
current_inf
);
if
(
it
==
eval_trace_
.
end
())
{
if
(
it
==
eval_trace_
.
r
end
())
{
eval_trace_
.
push_back
(
current_inf
);
eval_trace_
.
push_back
(
current_inf
);
MS_LOG
(
DEBUG
)
<<
"Trace Evaluator "
<<
eval
->
ToString
()
<<
" ptr: "
<<
eval
.
get
();
MS_EXCEPTION_IF_NULL
(
eval
);
MS_EXCEPTION_IF_NULL
(
eval
);
auto
out_spec
=
eval
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
auto
out_spec
=
eval
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
out_spec
);
MS_EXCEPTION_IF_NULL
(
out_spec
);
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
eval
->
ToString
()
<<
" return out_spec: "
<<
out_spec
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
eval
->
ToString
()
<<
" return out_spec: "
<<
out_spec
->
ToString
();
out_specs
.
push_back
(
out_spec
);
out_specs
.
push_back
(
out_spec
);
MS_LOG
(
DEBUG
)
<<
"Pop Evaluator "
<<
eval
->
ToString
();
eval_trace_
.
pop_back
();
eval_trace_
.
pop_back
();
if
(
eval_trace_
.
empty
())
{
multi_poss_
.
clear
();
}
}
else
if
(
it
!=
eval_trace_
.
rbegin
())
{
// Find latest entry function to handle nested recursion.
EvaluatorPtr
latest_entry
=
eval
;
auto
latest_entry_iter
=
eval_trace_
.
rbegin
();
for
(
auto
r_it
=
eval_trace_
.
rbegin
();
*
r_it
!=
*
it
;)
{
auto
it_temp
=
std
::
find
(
evaluators
.
begin
(),
evaluators
.
end
(),
r_it
->
first
);
if
(
it_temp
!=
evaluators
.
end
())
{
latest_entry
=
*
it_temp
;
latest_entry_iter
=
r_it
;
break
;
}
latest_entry_iter
=
++
r_it
;
}
if
(
latest_entry
!=
eval
)
{
MS_LOG
(
DEBUG
)
<<
"Continue Evaluator "
<<
eval
->
ToString
();
continue
;
}
bool
has_undetermined
=
false
;
// Check whether sub loop has untraced undetermined evaluator.
std
::
set
<
std
::
pair
<
EvaluatorPtr
,
AbstractBasePtrList
>>
undetermined_evals
;
for
(
auto
r_it
=
eval_trace_
.
rbegin
();
r_it
!=
latest_entry_iter
;
r_it
++
)
{
undetermined_evals
.
insert
(
*
r_it
);
}
MS_LOG
(
DEBUG
)
<<
"undetermined_evals size(): "
<<
undetermined_evals
.
size
();
for
(
auto
u_eval
:
undetermined_evals
)
{
MS_LOG
(
DEBUG
)
<<
u_eval
.
first
->
ToString
()
<<
" check undetermined."
;
if
(
!
undetermined_evals
.
count
(
std
::
make_pair
(
multi_poss_
[
u_eval
.
first
],
args_spec_list
)))
{
MS_LOG
(
DEBUG
)
<<
u_eval
.
first
->
ToString
()
<<
" has undetermined."
;
has_undetermined
=
true
;
break
;
}
}
if
(
has_undetermined
==
false
)
{
MS_LOG
(
DEBUG
)
<<
eval
->
ToString
()
<<
" has no undetermined."
;
continue
;
}
// Try to travel the latest undetermined.
if
(
latest_entry
!=
eval_trace_
.
rbegin
()
->
first
)
{
MS_LOG
(
DEBUG
)
<<
"Direct Run Evaluator "
<<
eval
->
ToString
();
auto
out_spec
=
latest_entry
->
Run
(
shared_from_this
(),
args_conf_list
,
out_conf
);
MS_EXCEPTION_IF_NULL
(
out_spec
);
MS_LOG
(
DEBUG
)
<<
"Evaluator "
<<
latest_entry
->
ToString
()
<<
" return out_spec: "
<<
out_spec
->
ToString
();
return
out_spec
;
}
}
}
}
}
if
(
out_specs
.
size
()
==
0
)
{
if
(
out_specs
.
size
()
==
0
)
{
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.h
浏览文件 @
8003a89a
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include <utility>
#include <utility>
#include <map>
#ifdef DEBUG
#ifdef DEBUG
#include <stack>
#include <stack>
...
@@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
...
@@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnfNodeConfigMap
anfnode_config_map_
;
AnfNodeConfigMap
anfnode_config_map_
;
// Use a list to trace multiple evaluators.
// Use a list to trace multiple evaluators.
std
::
list
<
std
::
pair
<
EvaluatorPtr
,
AbstractBasePtrList
>>
eval_trace_
;
std
::
list
<
std
::
pair
<
EvaluatorPtr
,
AbstractBasePtrList
>>
eval_trace_
;
std
::
map
<
EvaluatorPtr
,
EvaluatorPtr
>
multi_poss_
;
AnalysisContextPtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnalysisContextPtr
&
context
,
AnalysisContextPtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnalysisContextPtr
&
context
,
const
ConfigPtrList
&
args_conf_list
);
const
ConfigPtrList
&
args_conf_list
);
...
...
tests/ut/python/pynative_mode/test_framstruct.py
浏览文件 @
8003a89a
...
@@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
...
@@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
def
setup_module
(
module
):
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
@
ms_function
@
ms_function
def
while_upper_bound
(
upper
):
def
while_upper_bound
(
upper
):
rval
=
2
rval
=
2
...
@@ -392,6 +391,58 @@ def test_grad_factorial():
...
@@ -392,6 +391,58 @@ def test_grad_factorial():
res
=
C
.
grad
(
factorial
)(
3
)
res
=
C
.
grad
(
factorial
)(
3
)
assert
res
==
11
assert
res
==
11
@
ms_function
def
factorial2
(
n
):
""" factorial """
if
n
!=
0
:
return
n
*
factorial2
(
n
-
1
)
elif
n
==
1
:
return
1
*
factorial2
(
n
-
1
)
else
:
return
1
def
test_factorial2
():
res
=
factorial2
(
3
)
assert
res
==
6
@
ms_function
def
foo
(
n
):
if
n
<=
1
:
if
n
==
1
:
return
foo
(
n
-
1
)
else
:
return
1
else
:
return
foo
(
n
-
1
)
def
test_foo
():
res
=
foo
(
5
)
assert
res
==
1
@
ms_function
def
double_nested_loop
(
x
):
i
=
0
s
=
0
while
(
i
<
x
):
j
=
0
i
=
i
+
1
while
(
j
<
3
):
j
=
j
+
1
s
=
s
+
j
return
s
def
test_nested_loop
():
res
=
double_nested_loop
(
3
)
assert
res
==
18
@
ms_function
def
double_nested_loop2
(
x
):
s
=
0
for
i
in
range
(
x
):
for
j
in
range
(
3
):
s
=
s
+
j
return
s
def
test_nested_loop2
():
res
=
double_nested_loop
(
1
)
assert
res
==
6
def
_for
(
x
):
def
_for
(
x
):
""" _for """
""" _for """
ret
=
x
*
x
ret
=
x
*
x
...
...
tests/ut/python/pynative_mode/test_multigraph_sink.py
浏览文件 @
8003a89a
...
@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
...
@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
def
setup_module
(
module
):
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
save_graphs
=
Tru
e
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
save_graphs
=
Fals
e
,
device_target
=
"Ascend"
)
context
.
set_context
(
enable_task_sink
=
True
,
device_id
=
0
)
context
.
set_context
(
enable_task_sink
=
True
,
device_id
=
0
)
...
@@ -86,7 +86,17 @@ def while_by_while(x, y, z):
...
@@ -86,7 +86,17 @@ def while_by_while(x, y, z):
x
=
x
+
1
x
=
x
+
1
x
=
x
+
1
x
=
x
+
1
return
x
return
x
@
ms_function
def
while_in_while
(
x
,
y
,
z
):
out
=
c4
while
x
<
y
:
z
=
c4
+
c4
while
z
<
y
:
z
=
z
+
1
out
=
out
+
z
x
=
x
+
1
out
=
out
+
x
return
out
def
test_simple_if
():
def
test_simple_if
():
output
=
simple_if
(
c1
,
c2
,
c3
)
output
=
simple_if
(
c1
,
c2
,
c3
)
...
@@ -117,3 +127,7 @@ def test_while_by_while():
...
@@ -117,3 +127,7 @@ def test_while_by_while():
expect
=
Tensor
([
28
],
mstype
.
int32
)
expect
=
Tensor
([
28
],
mstype
.
int32
)
assert
output
==
expect
assert
output
==
expect
def
test_while_in_while
():
output
=
while_in_while
(
c1
,
c2
,
c3
)
expect
=
Tensor
([
1274
],
mstype
.
int32
)
assert
output
==
expect
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录