Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
1bdb26f9
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
1bdb26f9
编写于
9月 03, 2020
作者:
B
BowenK
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Warming up python pass by adding inline passes before it
上级
0118930c
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
129 addition
and
79 deletion
+129
-79
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
+9
-1
mindspore/ccsrc/frontend/optimizer/py_pass.cc
mindspore/ccsrc/frontend/optimizer/py_pass.cc
+38
-28
mindspore/ccsrc/frontend/optimizer/py_pass.h
mindspore/ccsrc/frontend/optimizer/py_pass.h
+2
-1
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
+17
-11
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
+3
-3
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+12
-3
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+10
-0
mindspore/ccsrc/pipeline/jit/pass.h
mindspore/ccsrc/pipeline/jit/pass.h
+1
-0
mindspore/graph_utils/python_pass/__init__.py
mindspore/graph_utils/python_pass/__init__.py
+4
-4
mindspore/graph_utils/python_pass/python_pass_register.py
mindspore/graph_utils/python_pass/python_pass_register.py
+13
-8
tests/ut/python/optimizer/test_python_pass.py
tests/ut/python/optimizer/test_python_pass.py
+20
-20
未找到文件。
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
浏览文件 @
1bdb26f9
...
...
@@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
auto
scope
=
std
::
make_shared
<
Scope
>
(
gradients_scope
+
ScopeManager
::
GetInstance
().
GetCurrentScope
()
->
name
()
+
grad_op_child_scope_prefix
+
prim
->
name
());
ScopeGuard
scope_guard
(
scope
);
py
::
function
fn
=
prim
->
is_base
()
?
GetBpropFunction
(
prim
->
name
())
:
prim
->
cast
<
PrimitivePyPtr
>
()
->
GetBpropFunction
();
py
::
function
fn
;
if
(
prim
->
is_base
())
{
fn
=
GetBpropFunction
(
prim
->
name
());
}
else
{
fn
=
prim
->
cast
<
PrimitivePyPtr
>
()
->
GetBpropFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
fn
))
{
fn
=
GetBpropFunction
(
prim
->
name
());
}
}
if
(
!
fn
||
py
::
isinstance
<
py
::
none
>
(
fn
))
{
MS_LOG
(
DEBUG
)
<<
"Fail to find bprop function for "
<<
prim
->
name
()
<<
"."
;
return
nullptr
;
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.cc
浏览文件 @
1bdb26f9
...
...
@@ -35,8 +35,10 @@ namespace internal {
const
char
PARAMETER_MODULE
[]
=
"mindspore.common.parameter"
;
const
char
PARAMETER_CLASS
[]
=
"Parameter"
;
const
char
SET_PARAM
[]
=
"__setattr__"
;
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
);
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
);
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
MatchResultPtr
&
res
);
void
ReflectParamBackToPython
(
const
AnfNodePtr
&
param
,
string
param_name
,
tensor
::
TensorPtr
default_input
,
bool
requires_grad
,
bool
layerwise_parallel
);
...
...
@@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
return
std
::
make_shared
<
ValueNode
>
(
input_tensor
);
}
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
fg
)
{
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
fg
,
const
FuncGraphPtr
&
top_graph
)
{
auto
call_pattern
=
pattern
->
cast
<
CallPtr
>
();
MS_EXCEPTION_IF_NULL
(
call_pattern
);
auto
prim
=
call_pattern
->
prim_value
();
...
...
@@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
auto
prim_pattern
=
call_pattern
->
prim_pattern
();
MS_EXCEPTION_IF_NULL
(
prim_pattern
);
return
ProcessSinglePattern
(
prim_pattern
,
res
,
fg
);
return
ProcessSinglePattern
(
prim_pattern
,
res
,
fg
,
top_graph
);
}
AnfNodePtr
BuildNewParameter
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func
_graph
)
{
AnfNodePtr
BuildNewParameter
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
top
_graph
)
{
auto
new_para_pattern
=
pattern
->
cast
<
NewParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
new_para_pattern
);
if
(
!
new_para_pattern
->
built
())
{
static
int
parameter_id
=
0
;
auto
para_name
=
new_para_pattern
->
para_name
()
+
new_para_pattern
->
unique_name
()
+
std
::
to_string
(
parameter_id
++
);
auto
para_node
=
std
::
make_shared
<
Parameter
>
(
func
_graph
);
auto
para_node
=
std
::
make_shared
<
Parameter
>
(
top
_graph
);
MS_EXCEPTION_IF_NULL
(
para_node
);
para_node
->
set_name
(
para_name
);
// Set function graph
para_node
->
set_func_graph
(
func
_graph
);
para_node
->
set_func_graph
(
top
_graph
);
// Set Debug Info
auto
debug_info
=
std
::
make_shared
<
NodeDebugInfo
>
(
para_name
);
para_node
->
set_debug_info
(
debug_info
);
...
...
@@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
MS_EXCEPTION_IF_NULL
(
default_value
);
para_node
->
set_abstract
(
default_value
->
ToAbstract
()
->
Broaden
());
res
->
add_entry
(
pattern
,
para_node
);
func
_graph
->
add_parameter
(
para_node
);
top
_graph
->
add_parameter
(
para_node
);
// Reflect back to Cell._params
internal
::
ReflectParamBackToPython
(
para_node
,
para_name
,
default_value
,
new_para_pattern
->
requires_grad
(),
new_para_pattern
->
layerwise_parallel
());
...
...
@@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
return
std
::
make_shared
<
ValueNode
>
(
scalar_value_ptr
);
}
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
)
{
auto
target_node
=
res
->
get_node
(
pattern
);
if
(
target_node
!=
nullptr
)
{
// If pattern is NewParameter, check whether it shouldn't last and is not built
...
...
@@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}
else
if
(
pattern
->
isa
<
NewTensor
>
())
{
return
BuildNewTensor
(
pattern
,
res
);
}
else
if
(
pattern
->
isa
<
Call
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
,
top_graph
);
}
else
if
(
pattern
->
isa
<
NewParameter
>
())
{
return
BuildNewParameter
(
pattern
,
res
,
func_graph
);
// Add new parameter to top graph instead of current graph
return
BuildNewParameter
(
pattern
,
res
,
top_graph
);
}
else
if
(
pattern
->
isa
<
Imm
>
())
{
return
BuildImmNode
(
pattern
,
res
);
}
else
{
...
...
@@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}
AnfNodePtr
ProcessComplexPatternFirstInput
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
)
{
if
(
pattern
->
isa
<
Call
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
,
top_graph
);
}
return
nullptr
;
}
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
)
{
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
MatchResultPtr
&
res
)
{
auto
target_inputs
=
pattern
->
inputs
();
if
(
target_inputs
.
size
()
==
0
)
{
auto
new_node
=
ProcessSinglePattern
(
pattern
,
res
,
func_graph
);
auto
new_node
=
ProcessSinglePattern
(
pattern
,
res
,
func_graph
,
top_graph
);
if
(
new_node
!=
nullptr
)
{
res
->
add_entry
(
pattern
,
new_node
);
}
...
...
@@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
}
// Build up the AnfNode in a recursive manner
std
::
vector
<
AnfNodePtr
>
new_inputs
;
auto
prim_value_node
=
ProcessComplexPatternFirstInput
(
pattern
,
res
,
func_graph
);
auto
prim_value_node
=
ProcessComplexPatternFirstInput
(
pattern
,
res
,
func_graph
,
top_graph
);
MS_EXCEPTION_IF_NULL
(
prim_value_node
);
new_inputs
.
push_back
(
prim_value_node
);
for
(
auto
&
iter
:
target_inputs
)
{
if
(
iter
==
pattern
)
{
MS_LOG
(
EXCEPTION
)
<<
"Circle references. Got pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
}
auto
input_node
=
BuildTarget
(
iter
,
func_graph
,
res
);
auto
input_node
=
BuildTarget
(
iter
,
func_graph
,
top_graph
,
res
);
if
(
input_node
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to build input node for pattern : "
+
iter
->
unique_name
()
+
"
\n
"
;
}
...
...
@@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) {
}
// namespace internal
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
)
{
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
)
{
auto
match_res
=
src_pattern_
->
match
(
node
);
if
(
match_res
!=
nullptr
)
{
res
->
merge
(
match_res
);
auto
new_node
=
internal
::
BuildTarget
(
dst_pattern_
,
func_graph
,
res
);
auto
new_node
=
internal
::
BuildTarget
(
dst_pattern_
,
func_graph
,
top_graph
,
res
);
internal
::
Reset
(
dst_pattern
());
return
new_node
;
}
...
...
@@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
}
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
func_graph
);
auto
graph_nodes_sorted
=
TopoSort
(
func_graph
->
output
());
auto
func_graphs
=
manager
->
func_graphs
();
bool
changes
=
false
;
// Traverse once
for
(
auto
&
node
:
graph_nodes_sorted
)
{
AnfNodePtr
new_node
=
Run
(
func_graph
,
node
,
res
);
if
(
new_node
!=
nullptr
&&
new_node
!=
node
)
{
(
void
)
manager
->
Replace
(
node
,
new_node
);
changes
=
true
;
for
(
auto
&
fg
:
func_graphs
)
{
manager
->
AddFuncGraph
(
fg
);
auto
graph_nodes_sorted
=
TopoSort
(
fg
->
output
());
// Traverse once
for
(
auto
&
node
:
graph_nodes_sorted
)
{
AnfNodePtr
new_node
=
Run
(
fg
,
func_graph
,
node
,
res
);
if
(
new_node
!=
nullptr
&&
new_node
!=
node
)
{
MS_LOG
(
WARNING
)
<<
"Matched"
;
(
void
)
manager
->
Replace
(
node
,
new_node
);
changes
=
true
;
}
}
}
return
changes
;
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.h
浏览文件 @
1bdb26f9
...
...
@@ -39,7 +39,8 @@ class PythonPass {
~
PythonPass
()
=
default
;
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
);
std
::
string
name
()
const
{
return
name_
;
}
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
);
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
);
PatternPtr
src_pattern
()
{
return
src_pattern_
;
}
PatternPtr
dst_pattern
()
{
return
dst_pattern_
;
}
...
...
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
浏览文件 @
1bdb26f9
...
...
@@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
}
PyPassManager
::
PyPassManager
()
{
phase_to_group_
[
Phase
::
RESOLVE
]
=
std
::
make_shared
<
PassGroup
>
(
);
phase_to_group_
[
Phase
::
OPT
]
=
std
::
make_shared
<
PassGroup
>
();
phase_to_group_
[
Phase
::
PREAD
]
=
std
::
make_shared
<
PassGroup
>
(
"Pre_AD_PassGroup"
);
phase_to_group_
[
Phase
::
OPT
]
=
std
::
make_shared
<
PassGroup
>
(
"After_OPT_PassGroup"
);
res_
=
std
::
make_shared
<
MatchResult
>
();
}
void
PyPassManager
::
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
bool
run_only_once
)
{
// NOTE: remove phase option to avoid unnecessary confusion.
auto
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
bool
requires_grad
,
bool
run_only_once
)
{
PassGroupPtr
cur_pg
;
if
(
requires_grad
)
{
cur_pg
=
GetPassGroup
(
Phase
::
PREAD
);
}
else
{
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
}
MS_EXCEPTION_IF_NULL
(
cur_pg
);
cur_pg
->
SetRunOnlyOnce
(
run_only_once
);
MS_EXCEPTION_IF_NULL
(
pattern
);
...
...
@@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
}
void
PyPassManager
::
Unregiste
(
const
std
::
string
&
pass_name
)
{
// NOTE: remove phase option to avoid unnecessary confusion.
auto
cur_pm
=
GetPassGroup
(
Phase
::
OPT
);
MS_EXCEPTION_IF_NULL
(
cur_pm
);
if
(
!
cur_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"No such pass : "
+
pass_name
+
"
\n
"
;
auto
opt_pm
=
GetPassGroup
(
Phase
::
OPT
);
if
(
!
opt_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"Opt has no such pass : "
+
pass_name
+
"
\n
"
;
}
auto
pre_ad_pm
=
GetPassGroup
(
Phase
::
PREAD
);
if
(
!
pre_ad_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"Pre_AD has no such pass : "
+
pass_name
+
"
\n
"
;
}
}
...
...
@@ -92,7 +98,7 @@ void PyPassManager::ClearRes() {
REGISTER_PYBIND_DEFINE
(
PyPassManager_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
Phase
>
(
*
m
,
"phase"
,
py
::
arithmetic
()).
value
(
"
resolve"
,
Phase
::
RESOLVE
).
value
(
"opt"
,
Phase
::
OPT
);
(
void
)
py
::
enum_
<
Phase
>
(
*
m
,
"phase"
,
py
::
arithmetic
()).
value
(
"
pre_ad"
,
Phase
::
PREAD
).
value
(
"opt"
,
Phase
::
OPT
);
(
void
)
py
::
class_
<
PyPassManager
,
std
::
shared_ptr
<
PyPassManager
>>
(
*
m
,
"PyPassManager_"
)
.
def
(
py
::
init
([]()
{
return
PyPassManager
::
GetInstance
();
}))
.
def
(
"registe"
,
&
PyPassManager
::
Registe
,
"Registe python pass"
)
...
...
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
浏览文件 @
1bdb26f9
...
...
@@ -38,7 +38,7 @@ namespace python_pass {
class
PyPassManager
;
using
PyPassManagerPtr
=
std
::
shared_ptr
<
PyPassManager
>
;
enum
Phase
{
RESOLVE
,
OPT
};
enum
Phase
{
PREAD
,
OPT
};
class
PyPassManager
{
protected:
...
...
@@ -52,8 +52,8 @@ class PyPassManager {
// Access the only global instance
static
PyPassManagerPtr
GetInstance
();
virtual
~
PyPassManager
()
=
default
;
void
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
bool
run_only_once
=
false
);
void
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
bool
requires_grad
,
bool
run_only_once
);
void
Unregiste
(
const
std
::
string
&
pass_name
);
void
GenNewParameter
(
const
PatternPtr
&
parameter
);
PassGroupPtr
GetPassGroup
(
Phase
phase
);
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
1bdb26f9
...
...
@@ -288,6 +288,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return
true
;
}
bool
OptInlineAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kInlinePasses
);
}
bool
GeOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kGePasses
);
}
bool
VmOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kVmPasses
);
}
...
...
@@ -460,7 +462,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
return
ppm
->
GetPassGroup
(
phase
)
->
Run
(
res
->
func_graph
());
}
bool
ResolveActionPyStub
(
const
ResourcePtr
&
res
)
{
return
true
||
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
RESOLVE
);
}
bool
PreAdActionPyStub
(
const
ResourcePtr
&
res
)
{
if
(
!
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
PREAD
))
{
MS_LOG
(
DEBUG
)
<<
"No Match."
;
}
return
true
;
}
bool
OptActionVmPyStub
(
const
ResourcePtr
&
res
)
{
if
(
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
OPT
))
{
...
...
@@ -516,12 +523,14 @@ static std::vector<ActionItem> CommonPipeline() {
if
(
!
multi_graphs
)
{
actions
.
emplace_back
(
std
::
make_pair
(
"combine_like_graphs"
,
CombineLikeGraphs
));
}
// Add resolve-stage python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_resolve"
,
ResolveActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"inference_opt_prepare"
,
InferenceOptPrepareAction
));
// Evaluate type and shape, and specialize
actions
.
emplace_back
(
std
::
make_pair
(
"abstract_specialize"
,
AbstractSpecializeAction
));
// Do data structure simplifications and inline
actions
.
emplace_back
(
std
::
make_pair
(
"inline"
,
OptInlineAction
));
// Add pre-ad, post-inline python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_pre_ad"
,
PreAdActionPyStub
));
return
actions
;
}
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
1bdb26f9
...
...
@@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return
map_a
;
}
OptPassGroupMap
GetA1A2
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
auto
opt_a
=
GetOptPassesA
(
irpass
);
OptPassGroupMap
a1_a2
({
opt_a
[
0
],
opt_a
[
1
]});
return
a1_a2
;
}
OptPassGroupMap
GetOptPassesAfterCconv
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
c_1
=
opt
::
OptPassConfig
({
// Safe inlining,
...
...
@@ -270,6 +276,7 @@ static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts =
void
InitOpt
(
const
ResourcePtr
&
res
)
{
if
(
g_pass_opts
.
size
()
==
0
)
{
opt
::
irpass
::
OptimizeIRPassLib
irpass
;
g_pass_opts
[
"a1a2"
]
=
Optimizer
::
MakeOptimizer
(
"a1a2"
,
res
,
GetA1A2
(
irpass
));
g_pass_opts
[
"opt_a"
]
=
Optimizer
::
MakeOptimizer
(
"opt_a"
,
res
,
GetOptPassesA
(
irpass
));
g_pass_opts
[
"opt_b"
]
=
Optimizer
::
MakeOptimizer
(
"opt_b"
,
res
,
GetOptPassesB
(
irpass
),
false
,
true
);
g_pass_opts
[
"opt_after_cconv"
]
=
...
...
@@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
return
true
;
}
bool
OptPassA1A2
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"a1a2"
);
}
bool
OptPassAGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_a"
);
}
bool
OptPassBGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_b"
);
}
bool
OptPassAfterCconvGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_after_cconv"
);
}
...
...
@@ -440,5 +448,7 @@ std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{
"cconv"
,
CconvPass
},
{
"transform_top"
,
TransformTopGraphPass
},
{
"transform_graph"
,
OptPassTransformGraphGroup
}};
std
::
vector
<
PassItem
>
kInlinePasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"a1a2"
,
OptPassA1A2
}};
}
// namespace pipeline
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/pass.h
浏览文件 @
1bdb26f9
...
...
@@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
extern
std
::
vector
<
PassItem
>
kGePasses
;
extern
std
::
vector
<
PassItem
>
kVmPasses
;
extern
std
::
vector
<
PassItem
>
kInlinePasses
;
extern
std
::
vector
<
PassItem
>
kPynativePasses
;
bool
CconvPass
(
const
ResourcePtr
&
res
);
...
...
mindspore/graph_utils/python_pass/__init__.py
浏览文件 @
1bdb26f9
...
...
@@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from
.python_pass_register
import
registe_pass
,
unregiste_pass
,
gen_new_parameter
,
cancel_new_parameter
,
set_renorm
,
\
set_reopt
from
.python_pass_register
import
registe_pass
,
unregiste_pass
,
gen_new_parameter
,
cancel_new_parameter
,
_
set_renorm
,
\
_
set_reopt
__all__
=
[
"registe_pass"
,
"unregiste_pass"
,
"gen_new_parameter"
,
"cancel_new_parameter"
,
"set_renorm"
,
"set_reopt"
"
_
set_renorm"
,
"
_
set_reopt"
]
mindspore/graph_utils/python_pass/python_pass_register.py
浏览文件 @
1bdb26f9
...
...
@@ -23,22 +23,26 @@ __all__ = [
"unregiste_pass"
,
"gen_new_parameter"
,
"cancel_new_parameter"
,
"set_renorm"
,
"set_reopt"
"
_
set_renorm"
,
"
_
set_reopt"
]
class
PyPassManager
(
PyPassManager_
):
r
"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
Raises:
TypeError: If argument has invalid type.
"""
def
__init__
(
self
,
run_only_once
=
False
):
def
__init__
(
self
,
requires_grad
=
True
,
run_only_once
=
False
):
if
not
isinstance
(
requires_grad
,
bool
):
raise
TypeError
(
f
"Expect bool, got : (
{
type
(
requires_grad
)
}
)
{
requires_grad
}
"
)
if
not
isinstance
(
run_only_once
,
bool
):
raise
TypeError
(
f
"Expect bool, got : (
{
type
(
run_only_once
)
}
)
{
run_only_once
}
"
)
self
.
requires_grad
=
requires_grad
self
.
run_only_once_
=
run_only_once
PyPassManager_
.
__init__
(
self
)
...
...
@@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_):
raise
TypeError
(
f
"Expect pattern of Pattern type, got : (
{
type
(
pattern
)
}
)
{
pattern
}
"
)
if
not
isinstance
(
target
,
Pattern
):
raise
TypeError
(
f
"Expect target of Pattern type, got : (
{
type
(
target
)
}
)
{
target
}
"
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
run_only_once_
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
r
equires_grad
,
self
.
r
un_only_once_
)
def
unregiste
(
self
,
py_pass
):
if
isinstance
(
py_pass
,
str
):
...
...
@@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_):
raise
TypeError
(
f
"Expect do_reopt to be a bool, got
{
do_reopt
}
"
)
super
().
set_reopt
(
do_reopt
)
def
registe_pass
(
run_only_once
=
False
):
def
registe_pass
(
r
equires_grad
=
True
,
r
un_only_once
=
False
):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
...
...
@@ -99,7 +104,7 @@ def registe_pass(run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return
PyPassManager
(
run_only_once
)
return
PyPassManager
(
r
equires_grad
,
r
un_only_once
)
def
unregiste_pass
(
py_pass
):
"""
...
...
@@ -157,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm
=
PyPassManager
()
ppm
.
unregiste
(
pattern
.
para_name
)
def
set_renorm
(
should_renorm
):
def
_
set_renorm
(
should_renorm
):
"""
Set whether or not to do renormalization after modified graph in python pass(es).
...
...
@@ -171,7 +176,7 @@ def set_renorm(should_renorm):
ppm
=
PyPassManager
()
ppm
.
set_renorm
(
should_renorm
)
def
set_reopt
(
do_reopt
):
def
_
set_reopt
(
do_reopt
):
"""
Set whether or not to do optimization after modified graph in python pass(es).
...
...
tests/ut/python/optimizer/test_python_pass.py
浏览文件 @
1bdb26f9
...
...
@@ -19,8 +19,8 @@ import mindspore.nn as nn
from
mindspore
import
context
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.graph_utils.python_pass
import
registe_pass
,
unregiste_pass
,
set_renorm
,
gen_new_parameter
,
\
cancel_new_parameter
,
set_reopt
from
mindspore.graph_utils.python_pass
import
registe_pass
,
unregiste_pass
,
_
set_renorm
,
gen_new_parameter
,
\
cancel_new_parameter
,
_
set_reopt
from
mindspore.common.api
import
_generate_pip_args
from
mindspore._c_expression
import
generate_key
,
Executor_
from
mindspore.graph_utils.graph_pattern
import
OneOf
,
Prim
,
Call
,
NoneOf
,
Any
,
NewTensor
,
NewParameter
,
Imm
...
...
@@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm
(
False
)
set_reopt
(
False
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
class
ConvBN
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ConvBN
,
self
).
__init__
()
...
...
@@ -176,7 +176,7 @@ def test_isnot_pattern_0():
inputs
=
Tensor
(
np
.
random
.
normal
(
0
,
1
,
(
10
,
32
,
32
,
32
)),
mindspore
.
float32
)
conv_bn_model
=
ConvBN
()
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
single_bn_pass
():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
...
...
@@ -188,7 +188,7 @@ def test_isnot_pattern_0():
target
=
Call
(
P
.
ReLU6
(),
[
pattern_0
])
return
pattern
,
target
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
bn_pass
():
"""
Sub a BN to Softmax.
...
...
@@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass
(
bn_pass
)
assert
"ReLU6"
not
in
transformed_repr
assert
"Softmax"
in
transformed_repr
set_renorm
(
True
)
_
set_renorm
(
True
)
def
test_isnot_pattern_1
():
"""
...
...
@@ -234,12 +234,12 @@ def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
set_renorm
(
False
)
set_reopt
(
False
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_addn_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
[
x
])
...
...
@@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass
(
softmax_addn_pass
)
assert
"AddN"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
set_renorm
(
True
)
_
set_renorm
(
True
)
def
test_newparameter_pattern
():
"""
...
...
@@ -261,9 +261,9 @@ def test_newparameter_pattern():
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
set_renorm
(
False
)
set_reopt
(
False
)
@
registe_pass
(
run_only_once
=
True
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_addn_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
[
x
])
...
...
@@ -288,9 +288,9 @@ def test_imm_target():
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
set_renorm
(
False
)
set_reopt
(
False
)
@
registe_pass
(
run_only_once
=
True
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
[
x
])
...
...
@@ -313,10 +313,10 @@ def test_gen_new_parameter():
default_tensor
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
new_para
=
NewParameter
(
"Merlin"
,
default_tensor
)
set_renorm
(
False
)
set_reopt
(
False
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
gen_new_parameter
(
new_para
)
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_make_tuple_pass
():
x
=
Any
()
softmax
=
P
.
Softmax
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录