Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
88f5cbe5
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
88f5cbe5
编写于
9月 09, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 09, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5692 Add requires_grad option for python pass
Merge pull request !5692 from BowenK/pre_ad
上级
2e6a5a90
1bdb26f9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
129 addition
and
79 deletion
+129
-79
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
+9
-1
mindspore/ccsrc/frontend/optimizer/py_pass.cc
mindspore/ccsrc/frontend/optimizer/py_pass.cc
+38
-28
mindspore/ccsrc/frontend/optimizer/py_pass.h
mindspore/ccsrc/frontend/optimizer/py_pass.h
+2
-1
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
+17
-11
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
+3
-3
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+12
-3
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+10
-0
mindspore/ccsrc/pipeline/jit/pass.h
mindspore/ccsrc/pipeline/jit/pass.h
+1
-0
mindspore/graph_utils/python_pass/__init__.py
mindspore/graph_utils/python_pass/__init__.py
+4
-4
mindspore/graph_utils/python_pass/python_pass_register.py
mindspore/graph_utils/python_pass/python_pass_register.py
+13
-8
tests/ut/python/optimizer/test_python_pass.py
tests/ut/python/optimizer/test_python_pass.py
+20
-20
未找到文件。
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
浏览文件 @
88f5cbe5
...
...
@@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
auto
scope
=
std
::
make_shared
<
Scope
>
(
gradients_scope
+
ScopeManager
::
GetInstance
().
GetCurrentScope
()
->
name
()
+
grad_op_child_scope_prefix
+
prim
->
name
());
ScopeGuard
scope_guard
(
scope
);
py
::
function
fn
=
prim
->
is_base
()
?
GetBpropFunction
(
prim
->
name
())
:
prim
->
cast
<
PrimitivePyPtr
>
()
->
GetBpropFunction
();
py
::
function
fn
;
if
(
prim
->
is_base
())
{
fn
=
GetBpropFunction
(
prim
->
name
());
}
else
{
fn
=
prim
->
cast
<
PrimitivePyPtr
>
()
->
GetBpropFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
fn
))
{
fn
=
GetBpropFunction
(
prim
->
name
());
}
}
if
(
!
fn
||
py
::
isinstance
<
py
::
none
>
(
fn
))
{
MS_LOG
(
DEBUG
)
<<
"Fail to find bprop function for "
<<
prim
->
name
()
<<
"."
;
return
nullptr
;
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.cc
浏览文件 @
88f5cbe5
...
...
@@ -35,8 +35,10 @@ namespace internal {
const
char
PARAMETER_MODULE
[]
=
"mindspore.common.parameter"
;
const
char
PARAMETER_CLASS
[]
=
"Parameter"
;
const
char
SET_PARAM
[]
=
"__setattr__"
;
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
);
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
);
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
MatchResultPtr
&
res
);
void
ReflectParamBackToPython
(
const
AnfNodePtr
&
param
,
string
param_name
,
tensor
::
TensorPtr
default_input
,
bool
requires_grad
,
bool
layerwise_parallel
);
...
...
@@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
return
std
::
make_shared
<
ValueNode
>
(
input_tensor
);
}
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
fg
)
{
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
fg
,
const
FuncGraphPtr
&
top_graph
)
{
auto
call_pattern
=
pattern
->
cast
<
CallPtr
>
();
MS_EXCEPTION_IF_NULL
(
call_pattern
);
auto
prim
=
call_pattern
->
prim_value
();
...
...
@@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
auto
prim_pattern
=
call_pattern
->
prim_pattern
();
MS_EXCEPTION_IF_NULL
(
prim_pattern
);
return
ProcessSinglePattern
(
prim_pattern
,
res
,
fg
);
return
ProcessSinglePattern
(
prim_pattern
,
res
,
fg
,
top_graph
);
}
AnfNodePtr
BuildNewParameter
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func
_graph
)
{
AnfNodePtr
BuildNewParameter
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
top
_graph
)
{
auto
new_para_pattern
=
pattern
->
cast
<
NewParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
new_para_pattern
);
if
(
!
new_para_pattern
->
built
())
{
static
int
parameter_id
=
0
;
auto
para_name
=
new_para_pattern
->
para_name
()
+
new_para_pattern
->
unique_name
()
+
std
::
to_string
(
parameter_id
++
);
auto
para_node
=
std
::
make_shared
<
Parameter
>
(
func
_graph
);
auto
para_node
=
std
::
make_shared
<
Parameter
>
(
top
_graph
);
MS_EXCEPTION_IF_NULL
(
para_node
);
para_node
->
set_name
(
para_name
);
// Set function graph
para_node
->
set_func_graph
(
func
_graph
);
para_node
->
set_func_graph
(
top
_graph
);
// Set Debug Info
auto
debug_info
=
std
::
make_shared
<
NodeDebugInfo
>
(
para_name
);
para_node
->
set_debug_info
(
debug_info
);
...
...
@@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
MS_EXCEPTION_IF_NULL
(
default_value
);
para_node
->
set_abstract
(
default_value
->
ToAbstract
()
->
Broaden
());
res
->
add_entry
(
pattern
,
para_node
);
func
_graph
->
add_parameter
(
para_node
);
top
_graph
->
add_parameter
(
para_node
);
// Reflect back to Cell._params
internal
::
ReflectParamBackToPython
(
para_node
,
para_name
,
default_value
,
new_para_pattern
->
requires_grad
(),
new_para_pattern
->
layerwise_parallel
());
...
...
@@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
return
std
::
make_shared
<
ValueNode
>
(
scalar_value_ptr
);
}
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
)
{
auto
target_node
=
res
->
get_node
(
pattern
);
if
(
target_node
!=
nullptr
)
{
// If pattern is NewParameter, check whether it shouldn't last and is not built
...
...
@@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}
else
if
(
pattern
->
isa
<
NewTensor
>
())
{
return
BuildNewTensor
(
pattern
,
res
);
}
else
if
(
pattern
->
isa
<
Call
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
,
top_graph
);
}
else
if
(
pattern
->
isa
<
NewParameter
>
())
{
return
BuildNewParameter
(
pattern
,
res
,
func_graph
);
// Add new parameter to top graph instead of current graph
return
BuildNewParameter
(
pattern
,
res
,
top_graph
);
}
else
if
(
pattern
->
isa
<
Imm
>
())
{
return
BuildImmNode
(
pattern
,
res
);
}
else
{
...
...
@@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}
AnfNodePtr
ProcessComplexPatternFirstInput
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
)
{
if
(
pattern
->
isa
<
Call
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
,
top_graph
);
}
return
nullptr
;
}
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
)
{
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
MatchResultPtr
&
res
)
{
auto
target_inputs
=
pattern
->
inputs
();
if
(
target_inputs
.
size
()
==
0
)
{
auto
new_node
=
ProcessSinglePattern
(
pattern
,
res
,
func_graph
);
auto
new_node
=
ProcessSinglePattern
(
pattern
,
res
,
func_graph
,
top_graph
);
if
(
new_node
!=
nullptr
)
{
res
->
add_entry
(
pattern
,
new_node
);
}
...
...
@@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
}
// Build up the AnfNode in a recursive manner
std
::
vector
<
AnfNodePtr
>
new_inputs
;
auto
prim_value_node
=
ProcessComplexPatternFirstInput
(
pattern
,
res
,
func_graph
);
auto
prim_value_node
=
ProcessComplexPatternFirstInput
(
pattern
,
res
,
func_graph
,
top_graph
);
MS_EXCEPTION_IF_NULL
(
prim_value_node
);
new_inputs
.
push_back
(
prim_value_node
);
for
(
auto
&
iter
:
target_inputs
)
{
if
(
iter
==
pattern
)
{
MS_LOG
(
EXCEPTION
)
<<
"Circle references. Got pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
}
auto
input_node
=
BuildTarget
(
iter
,
func_graph
,
res
);
auto
input_node
=
BuildTarget
(
iter
,
func_graph
,
top_graph
,
res
);
if
(
input_node
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to build input node for pattern : "
+
iter
->
unique_name
()
+
"
\n
"
;
}
...
...
@@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) {
}
// namespace internal
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
)
{
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
)
{
auto
match_res
=
src_pattern_
->
match
(
node
);
if
(
match_res
!=
nullptr
)
{
res
->
merge
(
match_res
);
auto
new_node
=
internal
::
BuildTarget
(
dst_pattern_
,
func_graph
,
res
);
auto
new_node
=
internal
::
BuildTarget
(
dst_pattern_
,
func_graph
,
top_graph
,
res
);
internal
::
Reset
(
dst_pattern
());
return
new_node
;
}
...
...
@@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
}
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
func_graph
);
auto
graph_nodes_sorted
=
TopoSort
(
func_graph
->
output
());
auto
func_graphs
=
manager
->
func_graphs
();
bool
changes
=
false
;
// Traverse once
for
(
auto
&
node
:
graph_nodes_sorted
)
{
AnfNodePtr
new_node
=
Run
(
func_graph
,
node
,
res
);
if
(
new_node
!=
nullptr
&&
new_node
!=
node
)
{
(
void
)
manager
->
Replace
(
node
,
new_node
);
changes
=
true
;
for
(
auto
&
fg
:
func_graphs
)
{
manager
->
AddFuncGraph
(
fg
);
auto
graph_nodes_sorted
=
TopoSort
(
fg
->
output
());
// Traverse once
for
(
auto
&
node
:
graph_nodes_sorted
)
{
AnfNodePtr
new_node
=
Run
(
fg
,
func_graph
,
node
,
res
);
if
(
new_node
!=
nullptr
&&
new_node
!=
node
)
{
MS_LOG
(
WARNING
)
<<
"Matched"
;
(
void
)
manager
->
Replace
(
node
,
new_node
);
changes
=
true
;
}
}
}
return
changes
;
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.h
浏览文件 @
88f5cbe5
...
...
@@ -39,7 +39,8 @@ class PythonPass {
~
PythonPass
()
=
default
;
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
);
std
::
string
name
()
const
{
return
name_
;
}
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
);
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
FuncGraphPtr
&
top_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
);
PatternPtr
src_pattern
()
{
return
src_pattern_
;
}
PatternPtr
dst_pattern
()
{
return
dst_pattern_
;
}
...
...
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
浏览文件 @
88f5cbe5
...
...
@@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
}
PyPassManager
::
PyPassManager
()
{
phase_to_group_
[
Phase
::
RESOLVE
]
=
std
::
make_shared
<
PassGroup
>
(
);
phase_to_group_
[
Phase
::
OPT
]
=
std
::
make_shared
<
PassGroup
>
();
phase_to_group_
[
Phase
::
PREAD
]
=
std
::
make_shared
<
PassGroup
>
(
"Pre_AD_PassGroup"
);
phase_to_group_
[
Phase
::
OPT
]
=
std
::
make_shared
<
PassGroup
>
(
"After_OPT_PassGroup"
);
res_
=
std
::
make_shared
<
MatchResult
>
();
}
void
PyPassManager
::
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
bool
run_only_once
)
{
// NOTE: remove phase option to avoid unnecessary confusion.
auto
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
bool
requires_grad
,
bool
run_only_once
)
{
PassGroupPtr
cur_pg
;
if
(
requires_grad
)
{
cur_pg
=
GetPassGroup
(
Phase
::
PREAD
);
}
else
{
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
}
MS_EXCEPTION_IF_NULL
(
cur_pg
);
cur_pg
->
SetRunOnlyOnce
(
run_only_once
);
MS_EXCEPTION_IF_NULL
(
pattern
);
...
...
@@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
}
void
PyPassManager
::
Unregiste
(
const
std
::
string
&
pass_name
)
{
// NOTE: remove phase option to avoid unnecessary confusion.
auto
cur_pm
=
GetPassGroup
(
Phase
::
OPT
);
MS_EXCEPTION_IF_NULL
(
cur_pm
);
if
(
!
cur_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"No such pass : "
+
pass_name
+
"
\n
"
;
auto
opt_pm
=
GetPassGroup
(
Phase
::
OPT
);
if
(
!
opt_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"Opt has no such pass : "
+
pass_name
+
"
\n
"
;
}
auto
pre_ad_pm
=
GetPassGroup
(
Phase
::
PREAD
);
if
(
!
pre_ad_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"Pre_AD has no such pass : "
+
pass_name
+
"
\n
"
;
}
}
...
...
@@ -92,7 +98,7 @@ void PyPassManager::ClearRes() {
REGISTER_PYBIND_DEFINE
(
PyPassManager_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
Phase
>
(
*
m
,
"phase"
,
py
::
arithmetic
()).
value
(
"
resolve"
,
Phase
::
RESOLVE
).
value
(
"opt"
,
Phase
::
OPT
);
(
void
)
py
::
enum_
<
Phase
>
(
*
m
,
"phase"
,
py
::
arithmetic
()).
value
(
"
pre_ad"
,
Phase
::
PREAD
).
value
(
"opt"
,
Phase
::
OPT
);
(
void
)
py
::
class_
<
PyPassManager
,
std
::
shared_ptr
<
PyPassManager
>>
(
*
m
,
"PyPassManager_"
)
.
def
(
py
::
init
([]()
{
return
PyPassManager
::
GetInstance
();
}))
.
def
(
"registe"
,
&
PyPassManager
::
Registe
,
"Registe python pass"
)
...
...
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
浏览文件 @
88f5cbe5
...
...
@@ -38,7 +38,7 @@ namespace python_pass {
class
PyPassManager
;
using
PyPassManagerPtr
=
std
::
shared_ptr
<
PyPassManager
>
;
enum
Phase
{
RESOLVE
,
OPT
};
enum
Phase
{
PREAD
,
OPT
};
class
PyPassManager
{
protected:
...
...
@@ -52,8 +52,8 @@ class PyPassManager {
// Access the only global instance
static
PyPassManagerPtr
GetInstance
();
virtual
~
PyPassManager
()
=
default
;
void
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
bool
run_only_once
=
false
);
void
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
bool
requires_grad
,
bool
run_only_once
);
void
Unregiste
(
const
std
::
string
&
pass_name
);
void
GenNewParameter
(
const
PatternPtr
&
parameter
);
PassGroupPtr
GetPassGroup
(
Phase
phase
);
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
88f5cbe5
...
...
@@ -301,6 +301,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return
true
;
}
bool
OptInlineAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kInlinePasses
);
}
bool
GeOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kGePasses
);
}
bool
VmOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kVmPasses
);
}
...
...
@@ -473,7 +475,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
return
ppm
->
GetPassGroup
(
phase
)
->
Run
(
res
->
func_graph
());
}
bool
ResolveActionPyStub
(
const
ResourcePtr
&
res
)
{
return
true
||
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
RESOLVE
);
}
bool
PreAdActionPyStub
(
const
ResourcePtr
&
res
)
{
if
(
!
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
PREAD
))
{
MS_LOG
(
DEBUG
)
<<
"No Match."
;
}
return
true
;
}
bool
OptActionVmPyStub
(
const
ResourcePtr
&
res
)
{
if
(
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
OPT
))
{
...
...
@@ -529,12 +536,14 @@ static std::vector<ActionItem> CommonPipeline() {
if
(
!
multi_graphs
)
{
actions
.
emplace_back
(
std
::
make_pair
(
"combine_like_graphs"
,
CombineLikeGraphs
));
}
// Add resolve-stage python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_resolve"
,
ResolveActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"inference_opt_prepare"
,
InferenceOptPrepareAction
));
// Evaluate type and shape, and specialize
actions
.
emplace_back
(
std
::
make_pair
(
"abstract_specialize"
,
AbstractSpecializeAction
));
// Do data structure simplifications and inline
actions
.
emplace_back
(
std
::
make_pair
(
"inline"
,
OptInlineAction
));
// Add pre-ad, post-inline python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_pre_ad"
,
PreAdActionPyStub
));
return
actions
;
}
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
88f5cbe5
...
...
@@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return
map_a
;
}
OptPassGroupMap
GetA1A2
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
auto
opt_a
=
GetOptPassesA
(
irpass
);
OptPassGroupMap
a1_a2
({
opt_a
[
0
],
opt_a
[
1
]});
return
a1_a2
;
}
OptPassGroupMap
GetOptPassesAfterCconv
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
c_1
=
opt
::
OptPassConfig
({
// Safe inlining,
...
...
@@ -270,6 +276,7 @@ static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts =
void
InitOpt
(
const
ResourcePtr
&
res
)
{
if
(
g_pass_opts
.
size
()
==
0
)
{
opt
::
irpass
::
OptimizeIRPassLib
irpass
;
g_pass_opts
[
"a1a2"
]
=
Optimizer
::
MakeOptimizer
(
"a1a2"
,
res
,
GetA1A2
(
irpass
));
g_pass_opts
[
"opt_a"
]
=
Optimizer
::
MakeOptimizer
(
"opt_a"
,
res
,
GetOptPassesA
(
irpass
));
g_pass_opts
[
"opt_b"
]
=
Optimizer
::
MakeOptimizer
(
"opt_b"
,
res
,
GetOptPassesB
(
irpass
),
false
,
true
);
g_pass_opts
[
"opt_after_cconv"
]
=
...
...
@@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
return
true
;
}
bool
OptPassA1A2
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"a1a2"
);
}
bool
OptPassAGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_a"
);
}
bool
OptPassBGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_b"
);
}
bool
OptPassAfterCconvGroup
(
const
ResourcePtr
&
res
)
{
return
OptPassGroup
(
res
,
"opt_after_cconv"
);
}
...
...
@@ -440,5 +448,7 @@ std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{
"cconv"
,
CconvPass
},
{
"transform_top"
,
TransformTopGraphPass
},
{
"transform_graph"
,
OptPassTransformGraphGroup
}};
std
::
vector
<
PassItem
>
kInlinePasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"a1a2"
,
OptPassA1A2
}};
}
// namespace pipeline
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/pass.h
浏览文件 @
88f5cbe5
...
...
@@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
extern
std
::
vector
<
PassItem
>
kGePasses
;
extern
std
::
vector
<
PassItem
>
kVmPasses
;
extern
std
::
vector
<
PassItem
>
kInlinePasses
;
extern
std
::
vector
<
PassItem
>
kPynativePasses
;
bool
CconvPass
(
const
ResourcePtr
&
res
);
...
...
mindspore/graph_utils/python_pass/__init__.py
浏览文件 @
88f5cbe5
...
...
@@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from
.python_pass_register
import
registe_pass
,
unregiste_pass
,
gen_new_parameter
,
cancel_new_parameter
,
set_renorm
,
\
set_reopt
from
.python_pass_register
import
registe_pass
,
unregiste_pass
,
gen_new_parameter
,
cancel_new_parameter
,
_
set_renorm
,
\
_
set_reopt
__all__
=
[
"registe_pass"
,
"unregiste_pass"
,
"gen_new_parameter"
,
"cancel_new_parameter"
,
"set_renorm"
,
"set_reopt"
"
_
set_renorm"
,
"
_
set_reopt"
]
mindspore/graph_utils/python_pass/python_pass_register.py
浏览文件 @
88f5cbe5
...
...
@@ -23,22 +23,26 @@ __all__ = [
"unregiste_pass"
,
"gen_new_parameter"
,
"cancel_new_parameter"
,
"set_renorm"
,
"set_reopt"
"
_
set_renorm"
,
"
_
set_reopt"
]
class
PyPassManager
(
PyPassManager_
):
r
"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
Raises:
TypeError: If argument has invalid type.
"""
def
__init__
(
self
,
run_only_once
=
False
):
def
__init__
(
self
,
requires_grad
=
True
,
run_only_once
=
False
):
if
not
isinstance
(
requires_grad
,
bool
):
raise
TypeError
(
f
"Expect bool, got : (
{
type
(
requires_grad
)
}
)
{
requires_grad
}
"
)
if
not
isinstance
(
run_only_once
,
bool
):
raise
TypeError
(
f
"Expect bool, got : (
{
type
(
run_only_once
)
}
)
{
run_only_once
}
"
)
self
.
requires_grad
=
requires_grad
self
.
run_only_once_
=
run_only_once
PyPassManager_
.
__init__
(
self
)
...
...
@@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_):
raise
TypeError
(
f
"Expect pattern of Pattern type, got : (
{
type
(
pattern
)
}
)
{
pattern
}
"
)
if
not
isinstance
(
target
,
Pattern
):
raise
TypeError
(
f
"Expect target of Pattern type, got : (
{
type
(
target
)
}
)
{
target
}
"
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
run_only_once_
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
r
equires_grad
,
self
.
r
un_only_once_
)
def
unregiste
(
self
,
py_pass
):
if
isinstance
(
py_pass
,
str
):
...
...
@@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_):
raise
TypeError
(
f
"Expect do_reopt to be a bool, got
{
do_reopt
}
"
)
super
().
set_reopt
(
do_reopt
)
def
registe_pass
(
run_only_once
=
False
):
def
registe_pass
(
r
equires_grad
=
True
,
r
un_only_once
=
False
):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
...
...
@@ -99,7 +104,7 @@ def registe_pass(run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return
PyPassManager
(
run_only_once
)
return
PyPassManager
(
r
equires_grad
,
r
un_only_once
)
def
unregiste_pass
(
py_pass
):
"""
...
...
@@ -157,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm
=
PyPassManager
()
ppm
.
unregiste
(
pattern
.
para_name
)
def
set_renorm
(
should_renorm
):
def
_
set_renorm
(
should_renorm
):
"""
Set whether or not to do renormalization after modified graph in python pass(es).
...
...
@@ -171,7 +176,7 @@ def set_renorm(should_renorm):
ppm
=
PyPassManager
()
ppm
.
set_renorm
(
should_renorm
)
def
set_reopt
(
do_reopt
):
def
_
set_reopt
(
do_reopt
):
"""
Set whether or not to do optimization after modified graph in python pass(es).
...
...
tests/ut/python/optimizer/test_python_pass.py
浏览文件 @
88f5cbe5
...
...
@@ -19,8 +19,8 @@ import mindspore.nn as nn
from
mindspore
import
context
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.graph_utils.python_pass
import
registe_pass
,
unregiste_pass
,
set_renorm
,
gen_new_parameter
,
\
cancel_new_parameter
,
set_reopt
from
mindspore.graph_utils.python_pass
import
registe_pass
,
unregiste_pass
,
_
set_renorm
,
gen_new_parameter
,
\
cancel_new_parameter
,
_
set_reopt
from
mindspore.common.api
import
_generate_pip_args
from
mindspore._c_expression
import
generate_key
,
Executor_
from
mindspore.graph_utils.graph_pattern
import
OneOf
,
Prim
,
Call
,
NoneOf
,
Any
,
NewTensor
,
NewParameter
,
Imm
...
...
@@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm
(
False
)
set_reopt
(
False
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
class
ConvBN
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ConvBN
,
self
).
__init__
()
...
...
@@ -176,7 +176,7 @@ def test_isnot_pattern_0():
inputs
=
Tensor
(
np
.
random
.
normal
(
0
,
1
,
(
10
,
32
,
32
,
32
)),
mindspore
.
float32
)
conv_bn_model
=
ConvBN
()
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
single_bn_pass
():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
...
...
@@ -188,7 +188,7 @@ def test_isnot_pattern_0():
target
=
Call
(
P
.
ReLU6
(),
[
pattern_0
])
return
pattern
,
target
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
bn_pass
():
"""
Sub a BN to Softmax.
...
...
@@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass
(
bn_pass
)
assert
"ReLU6"
not
in
transformed_repr
assert
"Softmax"
in
transformed_repr
set_renorm
(
True
)
_
set_renorm
(
True
)
def
test_isnot_pattern_1
():
"""
...
...
@@ -234,12 +234,12 @@ def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
set_renorm
(
False
)
set_reopt
(
False
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_addn_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
[
x
])
...
...
@@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass
(
softmax_addn_pass
)
assert
"AddN"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
set_renorm
(
True
)
_
set_renorm
(
True
)
def
test_newparameter_pattern
():
"""
...
...
@@ -261,9 +261,9 @@ def test_newparameter_pattern():
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
set_renorm
(
False
)
set_reopt
(
False
)
@
registe_pass
(
run_only_once
=
True
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_addn_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
[
x
])
...
...
@@ -288,9 +288,9 @@ def test_imm_target():
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
set_renorm
(
False
)
set_reopt
(
False
)
@
registe_pass
(
run_only_once
=
True
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
[
x
])
...
...
@@ -313,10 +313,10 @@ def test_gen_new_parameter():
default_tensor
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
new_para
=
NewParameter
(
"Merlin"
,
default_tensor
)
set_renorm
(
False
)
set_reopt
(
False
)
_
set_renorm
(
False
)
_
set_reopt
(
False
)
gen_new_parameter
(
new_para
)
@
registe_pass
(
run_only_once
=
True
)
@
registe_pass
(
r
equires_grad
=
False
,
r
un_only_once
=
True
)
def
softmax_make_tuple_pass
():
x
=
Any
()
softmax
=
P
.
Softmax
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录