Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
93926230
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93926230
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!5181 Python pass pattern renaming and interface tweaking
Merge pull request !5181 from BowenK/new_parameter
上级
83b9d1c5
641d12d6
master
无相关合并请求
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
236 addition
and
287 deletion
+236
-287
mindspore/ccsrc/frontend/optimizer/pattern.cc
mindspore/ccsrc/frontend/optimizer/pattern.cc
+44
-29
mindspore/ccsrc/frontend/optimizer/pattern.h
mindspore/ccsrc/frontend/optimizer/pattern.h
+52
-68
mindspore/ccsrc/frontend/optimizer/py_pass.cc
mindspore/ccsrc/frontend/optimizer/py_pass.cc
+20
-30
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
+7
-6
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
+4
-3
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+1
-0
mindspore/graph_utils/graph_pattern.py
mindspore/graph_utils/graph_pattern.py
+31
-56
mindspore/graph_utils/python_pass/python_pass_register.py
mindspore/graph_utils/python_pass/python_pass_register.py
+19
-23
tests/ut/python/optimizer/test_python_pass.py
tests/ut/python/optimizer/test_python_pass.py
+58
-72
未找到文件。
mindspore/ccsrc/frontend/optimizer/pattern.cc
浏览文件 @
93926230
...
...
@@ -21,25 +21,23 @@ namespace opt {
namespace
python_pass
{
int
Pattern
::
g_id_
=
0
;
MatchResultPtr
IsPrimTypeOf
::
match
(
const
AnfNodePtr
&
node
)
{
MatchResultPtr
Prim
::
match
(
const
AnfNodePtr
&
node
)
{
if
(
!
IsValueNode
<
Primitive
>
(
node
))
{
return
nullptr
;
}
MatchResultPtr
res
=
std
::
make_shared
<
MatchResult
>
();
if
(
IsValueNode
<
Primitive
>
(
node
))
{
// iterate over all primitives
for
(
auto
&
iter
:
primitives_
)
{
if
(
IsPrimitive
(
node
,
iter
)
||
iter
->
name
()
==
"*"
)
{
matched_prim_
=
iter
;
res
->
add_entry
(
shared_from_base
<
IsPrimTypeOf
>
(),
node
);
return
res
;
}
// iterate over all primitives
for
(
auto
&
iter
:
primitives_
)
{
if
(
IsPrimitive
(
node
,
iter
)
||
iter
->
name
()
==
"*"
)
{
matched_prim_
=
iter
;
res
->
add_entry
(
shared_from_base
<
Prim
>
(),
node
);
return
res
;
}
}
return
nullptr
;
}
MatchResultPtr
Call
With
::
match
(
const
AnfNodePtr
&
node
)
{
MatchResultPtr
Call
::
match
(
const
AnfNodePtr
&
node
)
{
if
(
!
IsPrimitiveCNode
(
node
))
{
return
nullptr
;
}
...
...
@@ -71,7 +69,7 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
}
// If inputs is not specified, add node without looking into its inputs
if
(
p_inputs_size
==
0
)
{
res
->
add_entry
(
shared_from_base
<
Call
With
>
(),
cnode
->
input
(
0
));
res
->
add_entry
(
shared_from_base
<
Call
>
(),
cnode
->
input
(
0
));
return
res
;
}
bool
failed
=
false
;
...
...
@@ -86,24 +84,24 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
res
->
merge
(
input_match_result
);
}
if
(
!
failed
)
{
res
->
add_entry
(
shared_from_base
<
Call
With
>
(),
cnode
->
input
(
0
));
res
->
add_entry
(
shared_from_base
<
Call
>
(),
cnode
->
input
(
0
));
return
res
;
}
return
nullptr
;
}
MatchResultPtr
IsIn
::
match
(
const
AnfNodePtr
&
node
)
{
MatchResultPtr
OneOf
::
match
(
const
AnfNodePtr
&
node
)
{
for
(
auto
&
iter
:
patterns_
)
{
auto
res
=
iter
->
match
(
node
);
if
(
res
!=
nullptr
)
{
res
->
add_entry
(
shared_from_base
<
IsIn
>
(),
node
);
res
->
add_entry
(
shared_from_base
<
OneOf
>
(),
node
);
return
res
;
}
}
return
nullptr
;
}
MatchResultPtr
IsNot
::
match
(
const
AnfNodePtr
&
node
)
{
MatchResultPtr
NoneOf
::
match
(
const
AnfNodePtr
&
node
)
{
for
(
auto
&
iter
:
patterns_
)
{
auto
res
=
iter
->
match
(
node
);
if
(
res
!=
nullptr
)
{
...
...
@@ -111,16 +109,33 @@ MatchResultPtr IsNot::match(const AnfNodePtr &node) {
}
}
auto
res
=
std
::
make_shared
<
MatchResult
>
();
res
->
add_entry
(
shared_from_base
<
IsNot
>
(),
node
);
res
->
add_entry
(
shared_from_base
<
NoneOf
>
(),
node
);
return
res
;
}
MatchResultPtr
Any
Pattern
::
match
(
const
AnfNodePtr
&
node
)
{
MatchResultPtr
Any
::
match
(
const
AnfNodePtr
&
node
)
{
MatchResultPtr
res
=
std
::
make_shared
<
MatchResult
>
();
res
->
add_entry
(
shared_from_base
<
Any
Pattern
>
(),
node
);
res
->
add_entry
(
shared_from_base
<
Any
>
(),
node
);
return
res
;
}
MatchResultPtr
Imm
::
match
(
const
AnfNodePtr
&
node
)
{
if
(
!
IsValueNode
<
Int32Imm
>
(
node
))
{
return
nullptr
;
}
// Check value
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
value_ptr
=
value_node
->
value
()
->
cast
<
Int32ImmPtr
>
();
MS_EXCEPTION_IF_NULL
(
value_ptr
);
if
((
int32_t
)
value_ptr
->
value
()
==
value_
)
{
MatchResultPtr
res
=
std
::
make_shared
<
MatchResult
>
();
res
->
add_entry
(
shared_from_base
<
Imm
>
(),
node
);
return
res
;
}
return
nullptr
;
}
AnfNodePtr
MatchResult
::
get_node
(
const
PatternPtr
&
pattern
)
{
auto
entry
=
match_result_
.
find
(
pattern
);
if
(
entry
==
match_result_
.
end
())
{
...
...
@@ -140,20 +155,20 @@ void MatchResult::merge(const MatchResultPtr &other_result) {
REGISTER_PYBIND_DEFINE
(
Pattern
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Pattern
,
std
::
shared_ptr
<
Pattern
>>
(
*
m
,
"Pattern"
).
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
IsIn
,
std
::
shared_ptr
<
IsIn
>
,
Pattern
>
(
*
m
,
"IsIn
_"
).
def
(
py
::
init
<
vector
<
PatternPtr
>>
());
(
void
)
py
::
class_
<
IsPrimTypeOf
,
std
::
shared_ptr
<
IsPrimTypeOf
>
,
Pattern
>
(
*
m
,
"IsPrimTypeOf
_"
,
py
::
dynamic_attr
())
.
def
(
py
::
init
<
vector
<
PrimitivePyPtr
>
,
string
,
bool
>
())
.
def
(
py
::
init
<
vector
<
string
>
,
string
,
bool
>
());
(
void
)
py
::
class_
<
Call
With
,
std
::
shared_ptr
<
CallWith
>
,
Pattern
>
(
*
m
,
"CallWith
_"
)
.
def
(
py
::
init
<
PatternPtr
,
vector
<
PatternPtr
>
,
bool
>
())
.
def
(
py
::
init
<
PrimitivePyPtr
,
vector
<
PatternPtr
>
,
bool
>
())
.
def
(
py
::
init
<
string
,
vector
<
PatternPtr
>
,
bool
>
());
(
void
)
py
::
class_
<
IsNot
,
std
::
shared_ptr
<
IsNot
>
,
Pattern
>
(
*
m
,
"IsNot
_"
).
def
(
py
::
init
<
vector
<
PatternPtr
>>
());
(
void
)
py
::
class_
<
Any
Pattern
,
std
::
shared_ptr
<
AnyPattern
>
,
Pattern
>
(
*
m
,
"AnyPattern
"
).
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
OneOf
,
std
::
shared_ptr
<
OneOf
>
,
Pattern
>
(
*
m
,
"OneOf
_"
).
def
(
py
::
init
<
vector
<
PatternPtr
>>
());
(
void
)
py
::
class_
<
Prim
,
std
::
shared_ptr
<
Prim
>
,
Pattern
>
(
*
m
,
"Prim
_"
,
py
::
dynamic_attr
())
.
def
(
py
::
init
<
vector
<
PrimitivePyPtr
>
,
string
>
())
.
def
(
py
::
init
<
vector
<
string
>
,
string
>
());
(
void
)
py
::
class_
<
Call
,
std
::
shared_ptr
<
Call
>
,
Pattern
>
(
*
m
,
"Call
_"
)
.
def
(
py
::
init
<
PatternPtr
,
vector
<
PatternPtr
>>
())
.
def
(
py
::
init
<
PrimitivePyPtr
,
vector
<
PatternPtr
>>
())
.
def
(
py
::
init
<
string
,
vector
<
PatternPtr
>>
());
(
void
)
py
::
class_
<
NoneOf
,
std
::
shared_ptr
<
NoneOf
>
,
Pattern
>
(
*
m
,
"NoneOf
_"
).
def
(
py
::
init
<
vector
<
PatternPtr
>>
());
(
void
)
py
::
class_
<
Any
,
std
::
shared_ptr
<
Any
>
,
Pattern
>
(
*
m
,
"Any
"
).
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
NewTensor
,
std
::
shared_ptr
<
NewTensor
>
,
Pattern
>
(
*
m
,
"NewTensor_"
)
.
def
(
py
::
init
<
tensor
::
TensorPtr
>
());
(
void
)
py
::
class_
<
NewParameter
,
std
::
shared_ptr
<
NewParameter
>
,
Pattern
>
(
*
m
,
"NewParameter_"
)
.
def
(
py
::
init
<
string
,
tensor
::
TensorPtr
,
bool
,
bool
,
bool
>
());
.
def
(
py
::
init
<
string
,
tensor
::
TensorPtr
,
bool
,
bool
>
());
(
void
)
py
::
class_
<
Imm
,
std
::
shared_ptr
<
Imm
>
,
Pattern
>
(
*
m
,
"Imm"
).
def
(
py
::
init
<
int
>
());
}));
}
// namespace python_pass
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/frontend/optimizer/pattern.h
浏览文件 @
93926230
...
...
@@ -36,10 +36,10 @@ class MatchResult;
using
MatchResultPtr
=
std
::
shared_ptr
<
MatchResult
>
;
class
Pattern
;
using
PatternPtr
=
std
::
shared_ptr
<
Pattern
>
;
class
IsPrimTypeOf
;
using
IsPrimTypeOfPtr
=
std
::
shared_ptr
<
IsPrimTypeOf
>
;
class
Call
With
;
using
Call
WithPtr
=
std
::
shared_ptr
<
CallWith
>
;
class
Prim
;
using
PrimPtr
=
std
::
shared_ptr
<
Prim
>
;
class
Call
;
using
Call
Ptr
=
std
::
shared_ptr
<
Call
>
;
class
NewTensor
;
using
NewTensorPtr
=
std
::
shared_ptr
<
NewTensor
>
;
class
NewParameter
;
...
...
@@ -58,8 +58,6 @@ class Pattern : public Base {
virtual
bool
operator
==
(
const
Pattern
&
other
)
const
{
return
unique_name_
==
other
.
unique_name_
;
}
string
unique_name
()
const
{
return
unique_name_
;
}
vector
<
PatternPtr
>
inputs
()
{
return
inputs_
;
}
bool
should_replace
()
{
return
should_replace_
;
}
void
set_should_replace
(
bool
should_replace
)
{
should_replace_
=
should_replace
;
}
virtual
void
reset
()
{}
protected:
...
...
@@ -67,7 +65,6 @@ class Pattern : public Base {
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
string
unique_name_
;
vector
<
PatternPtr
>
inputs_
;
bool
should_replace_
=
true
;
};
struct
PatternEqual
{
...
...
@@ -85,70 +82,61 @@ struct PatternHasher {
}
};
class
IsPrimTypeOf
:
public
Pattern
{
class
Prim
:
public
Pattern
{
public:
IsPrimTypeOf
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
IsPrimTypeOf
()
=
default
;
IsPrimTypeOf
(
vector
<
PrimitivePyPtr
>
prims
,
string
name
,
bool
should_replace
)
:
primitives_
(
prims
),
name_
(
name
),
matched_prim_
(
nullptr
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"IsPrimTypeOf_"
+
name
;
should_replace_
=
should_replace
;
if
(
!
should_replace
)
{
matched_prim_
=
prims
[
0
];
}
Prim
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
Prim
()
=
default
;
Prim
(
vector
<
PrimitivePyPtr
>
prims
,
string
name
)
:
primitives_
(
prims
),
name_
(
name
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Prim_"
+
name
;
// Default using the first prim to build target
matched_prim_
=
primitives_
[
0
];
}
IsPrimTypeOf
(
vector
<
string
>
types
,
string
name
,
bool
should_replac
e
)
:
types_
(
types
),
name_
(
name
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
IsPrimTypeOf
_"
+
name
;
Prim
(
vector
<
string
>
types
,
string
nam
e
)
:
types_
(
types
),
name_
(
name
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
Prim
_"
+
name
;
// Make primitives_
for
(
auto
&
iter
:
types
)
{
primitives_
.
push_back
(
std
::
make_shared
<
PrimitivePy
>
(
iter
,
py
::
cast
(
nullptr
)));
}
should_replace_
=
should_replace
;
if
(
!
should_replace
)
{
matched_prim_
=
primitives_
[
0
];
}
// Default using the first prim to build target
matched_prim_
=
primitives_
[
0
];
}
MS_DECLARE_PARENT
(
IsPrimTypeOf
,
Pattern
);
MS_DECLARE_PARENT
(
Prim
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
;
PrimitivePyPtr
matched_primitive
()
{
return
matched_prim_
;
}
void
reset
()
override
{
if
(
should_replace_
)
{
matched_prim_
=
nullptr
;
}
// Init before reset
MS_EXCEPTION_IF_NULL
(
matched_prim_
)
;
matched_prim_
=
primitives_
[
0
];
}
private:
vector
<
string
>
types_
;
vector
<
PrimitivePyPtr
>
primitives_
;
string
name_
;
PrimitivePyPtr
matched_prim_
;
PrimitivePyPtr
matched_prim_
{
nullptr
}
;
};
class
Call
With
:
public
Pattern
{
class
Call
:
public
Pattern
{
public:
Call
With
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
Call
With
()
=
default
;
Call
With
(
PatternPtr
prim_pattern
,
vector
<
PatternPtr
>
inputs
,
bool
should_replace
)
{
Call
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
Call
()
=
default
;
Call
(
PatternPtr
prim_pattern
,
vector
<
PatternPtr
>
inputs
)
{
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_
=
prim_pattern
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Call
WithPattern
_"
+
prim_pattern
->
unique_name
();
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Call_"
+
prim_pattern
->
unique_name
();
inputs_
=
inputs
;
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
should_replace_
=
prim_pattern
->
should_replace
();
}
Call
With
(
PrimitivePyPtr
prim
,
vector
<
PatternPtr
>
inputs
,
bool
should_replace
)
{
Call
(
PrimitivePyPtr
prim
,
vector
<
PatternPtr
>
inputs
)
{
prim_
=
prim
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Call
WithPrim
_"
+
prim_
->
ToString
();
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Call_"
+
prim_
->
ToString
();
inputs_
=
inputs
;
should_replace_
=
should_replace
;
}
Call
With
(
string
prim_str
,
vector
<
PatternPtr
>
inputs
,
bool
should_replace
)
{
Call
(
string
prim_str
,
vector
<
PatternPtr
>
inputs
)
{
prim_
=
std
::
make_shared
<
PrimitivePy
>
(
prim_str
,
py
::
cast
(
nullptr
));
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Call
With
Str_"
+
prim_
->
ToString
();
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"CallStr_"
+
prim_
->
ToString
();
inputs_
=
inputs
;
should_replace_
=
should_replace
;
}
MS_DECLARE_PARENT
(
Call
With
,
Pattern
);
MS_DECLARE_PARENT
(
Call
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
;
PrimitivePtr
prim_value
()
{
return
prim_
;
}
PatternPtr
prim_pattern
()
{
return
prim_pattern_
;
}
...
...
@@ -160,45 +148,45 @@ class CallWith : public Pattern {
string
name_
;
};
class
IsIn
:
public
Pattern
{
class
OneOf
:
public
Pattern
{
public:
IsIn
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
IsIn
()
=
default
;
explicit
IsIn
(
vector
<
PatternPtr
>
patterns
)
:
patterns_
(
patterns
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
IsIn
"
;
OneOf
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
OneOf
()
=
default
;
explicit
OneOf
(
vector
<
PatternPtr
>
patterns
)
:
patterns_
(
patterns
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
OneOf
"
;
for
(
auto
&
iter
:
patterns
)
{
unique_name_
=
unique_name_
+
"_"
+
iter
->
unique_name
();
}
}
MS_DECLARE_PARENT
(
IsIn
,
Pattern
);
MS_DECLARE_PARENT
(
OneOf
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
;
private:
vector
<
PatternPtr
>
patterns_
;
};
class
IsNot
:
public
Pattern
{
class
NoneOf
:
public
Pattern
{
public:
IsNot
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
IsNot
()
=
default
;
explicit
IsNot
(
vector
<
PatternPtr
>
patterns
)
:
patterns_
(
patterns
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
IsNot
"
;
NoneOf
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
NoneOf
()
=
default
;
explicit
NoneOf
(
vector
<
PatternPtr
>
patterns
)
:
patterns_
(
patterns
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
NoneOf
"
;
for
(
auto
&
iter
:
patterns
)
{
unique_name_
=
unique_name_
+
"_"
+
iter
->
unique_name
();
}
}
MS_DECLARE_PARENT
(
IsNot
,
Pattern
);
MS_DECLARE_PARENT
(
NoneOf
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
;
private:
vector
<
PatternPtr
>
patterns_
;
};
class
Any
Pattern
:
public
Pattern
{
class
Any
:
public
Pattern
{
public:
Any
Pattern
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"_AnyPattern
"
;
}
~
Any
Pattern
()
=
default
;
MS_DECLARE_PARENT
(
Any
Pattern
,
Pattern
);
Any
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"_Any
"
;
}
~
Any
()
=
default
;
MS_DECLARE_PARENT
(
Any
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
;
};
...
...
@@ -207,7 +195,6 @@ class NewTensor : public Pattern {
NewTensor
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
NewTensor
()
=
default
;
explicit
NewTensor
(
tensor
::
TensorPtr
input_tensor
)
:
input_tensor_
(
input_tensor
)
{
should_replace_
=
false
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"NewTensor"
;
}
MS_DECLARE_PARENT
(
NewTensor
,
Pattern
);
...
...
@@ -223,10 +210,8 @@ class NewTensor : public Pattern {
class
NewParameter
:
public
Pattern
{
public:
NewParameter
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
explicit
NewParameter
(
string
para_name
,
tensor
::
TensorPtr
default_tensor
,
bool
requires_grad
,
bool
layerwise_parallel
,
bool
should_replace
)
explicit
NewParameter
(
string
para_name
,
tensor
::
TensorPtr
default_tensor
,
bool
requires_grad
,
bool
layerwise_parallel
)
:
para_name_
(
para_name
),
requires_grad_
(
requires_grad
),
layerwise_parallel_
(
layerwise_parallel
)
{
should_replace_
=
should_replace
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"NewParameter_"
+
para_name
;
// clone input tensor
default_tensor_
=
std
::
make_shared
<
tensor
::
Tensor
>
(
*
default_tensor
.
get
());
...
...
@@ -243,11 +228,14 @@ class NewParameter : public Pattern {
bool
built
()
{
return
built_
;
}
void
set_built
(
bool
built
)
{
built_
=
built
;
}
void
reset
()
override
{
built_
=
false
;
}
bool
should_last
()
{
return
last_across_passes_
;
}
void
set_last
(
bool
last
)
{
last_across_passes_
=
last
;
}
private:
string
para_name_
;
bool
requires_grad_
;
bool
layerwise_parallel_
;
bool
last_across_passes_
{
false
};
bool
built_
;
tensor
::
TensorPtr
default_tensor_
;
};
...
...
@@ -255,13 +243,9 @@ class NewParameter : public Pattern {
class
Imm
:
public
Pattern
{
public:
Imm
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
explicit
Imm
(
int
value
)
:
value_
(
value
)
{
should_replace_
=
false
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Imm_"
+
std
::
to_string
(
value
);
}
explicit
Imm
(
int
value
)
:
value_
(
value
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Imm_"
+
std
::
to_string
(
value
);
}
MS_DECLARE_PARENT
(
Imm
,
Pattern
);
// NOTE: Doesn't support Imm in src pattern currently.
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
{
return
nullptr
;
}
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
;
int
value
()
{
return
value_
;
}
private:
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/frontend/optimizer/py_pass.cc
浏览文件 @
93926230
...
...
@@ -80,7 +80,7 @@ bool IsTraversable(const AnfNodePtr &node) {
AnfNodePtr
BuildPrimitive
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
)
{
// Build up AnfNode from primitive
auto
prim_pattern
=
pattern
->
cast
<
IsPrimTypeOf
Ptr
>
();
auto
prim_pattern
=
pattern
->
cast
<
Prim
Ptr
>
();
MS_EXCEPTION_IF_NULL
(
prim_pattern
);
PrimitivePyPtr
prim
=
prim_pattern
->
matched_primitive
();
MS_EXCEPTION_IF_NULL
(
prim
);
...
...
@@ -98,13 +98,13 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
}
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
fg
)
{
auto
call_
with_pattern
=
pattern
->
cast
<
CallWith
Ptr
>
();
MS_EXCEPTION_IF_NULL
(
call_
with_
pattern
);
auto
prim
=
call_
with_
pattern
->
prim_value
();
auto
call_
pattern
=
pattern
->
cast
<
Call
Ptr
>
();
MS_EXCEPTION_IF_NULL
(
call_pattern
);
auto
prim
=
call_pattern
->
prim_value
();
if
(
prim
!=
nullptr
)
{
return
std
::
make_shared
<
ValueNode
>
(
prim
);
}
auto
prim_pattern
=
call_
with_
pattern
->
prim_pattern
();
auto
prim_pattern
=
call_pattern
->
prim_pattern
();
MS_EXCEPTION_IF_NULL
(
prim_pattern
);
return
ProcessSinglePattern
(
prim_pattern
,
res
,
fg
);
}
...
...
@@ -152,45 +152,35 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
}
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
if
(
pattern
->
should_replace
())
{
// Find replacement in the MatchResult
auto
target_node
=
res
->
get_node
(
pattern
);
if
(
target_node
==
nullptr
)
{
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
if
(
pattern
->
isa
<
IsPrimTypeOf
>
()
||
pattern
->
isa
<
NewTensor
>
()
||
pattern
->
isa
<
NewParameter
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot find target node, pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
return
nullptr
;
}
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
auto
new_node
=
BuildTarget
(
pattern
,
func_graph
,
res
);
if
(
new_node
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Try to build pattern node but FAILED. pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
}
return
new_node
;
}
if
(
pattern
->
isa
<
NewParameter
>
())
{
auto
target_node
=
res
->
get_node
(
pattern
);
if
(
target_node
!=
nullptr
)
{
// If pattern is NewParameter, check whether it shouldn't last and is not built
auto
new_para
=
pattern
->
cast
<
NewParameterPtr
>
();
if
(
new_para
==
nullptr
||
new_para
->
should_last
()
||
new_para
->
built
())
{
return
target_node
;
}
return
target_node
;
}
// Build up new node from pattern
if
(
pattern
->
isa
<
IsPrimTypeOf
>
())
{
if
(
pattern
->
isa
<
Prim
>
())
{
return
BuildPrimitive
(
pattern
,
res
);
}
else
if
(
pattern
->
isa
<
NewTensor
>
())
{
return
BuildNewTensor
(
pattern
,
res
);
}
else
if
(
pattern
->
isa
<
Call
With
>
())
{
}
else
if
(
pattern
->
isa
<
Call
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
}
else
if
(
pattern
->
isa
<
NewParameter
>
())
{
return
BuildNewParameter
(
pattern
,
res
,
func_graph
);
}
else
if
(
pattern
->
isa
<
Imm
>
())
{
return
BuildImmNode
(
pattern
,
res
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot find or build target node, pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
return
nullptr
;
}
return
nullptr
;
}
AnfNodePtr
ProcessComplexPatternFirstInput
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
if
(
pattern
->
isa
<
Call
With
>
())
{
if
(
pattern
->
isa
<
Call
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
}
return
nullptr
;
...
...
@@ -269,16 +259,16 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor
}
void
Reset
(
PatternPtr
pattern
)
{
if
(
pattern
->
isa
<
IsPrimTypeOf
>
())
{
auto
prim_pattern
=
pattern
->
cast
<
IsPrimTypeOf
Ptr
>
();
if
(
pattern
->
isa
<
Prim
>
())
{
auto
prim_pattern
=
pattern
->
cast
<
Prim
Ptr
>
();
prim_pattern
->
reset
();
return
;
}
else
if
(
pattern
->
isa
<
NewParameter
>
())
{
auto
new_param_pattern
=
pattern
->
cast
<
NewParameterPtr
>
();
new_param_pattern
->
reset
();
return
;
}
else
if
(
pattern
->
isa
<
Call
With
>
())
{
auto
call_with_pattern
=
pattern
->
cast
<
Call
With
Ptr
>
();
}
else
if
(
pattern
->
isa
<
Call
>
())
{
auto
call_with_pattern
=
pattern
->
cast
<
CallPtr
>
();
for
(
auto
sub_pattern
:
call_with_pattern
->
inputs
())
{
Reset
(
sub_pattern
);
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
浏览文件 @
93926230
...
...
@@ -49,8 +49,9 @@ PyPassManager::PyPassManager() {
}
void
PyPassManager
::
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
Phase
phase
,
bool
run_only_once
)
{
auto
cur_pg
=
GetPassGroup
(
phase
);
bool
run_only_once
)
{
// NOTE: remove phase option to avoid unnecessary confusion.
auto
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
MS_EXCEPTION_IF_NULL
(
cur_pg
);
cur_pg
->
SetRunOnlyOnce
(
run_only_once
);
MS_EXCEPTION_IF_NULL
(
pattern
);
...
...
@@ -60,8 +61,9 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
cur_pg
->
AddPass
(
new_pass
);
}
void
PyPassManager
::
Unregiste
(
const
std
::
string
&
pass_name
,
Phase
phase
)
{
auto
cur_pm
=
GetPassGroup
(
phase
);
void
PyPassManager
::
Unregiste
(
const
std
::
string
&
pass_name
)
{
// NOTE: remove phase option to avoid unnecessary confusion.
auto
cur_pm
=
GetPassGroup
(
Phase
::
OPT
);
MS_EXCEPTION_IF_NULL
(
cur_pm
);
if
(
!
cur_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"No such pass : "
+
pass_name
+
"
\n
"
;
...
...
@@ -70,7 +72,6 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
void
PyPassManager
::
GenNewParameter
(
const
PatternPtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
// Add new parameter after resolve
// NOTE: Add NewParameter at early stage will cause CSE problems
auto
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
MS_EXCEPTION_IF_NULL
(
cur_pg
);
...
...
@@ -78,7 +79,7 @@ void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
auto
new_para_pattern
=
parameter
->
cast
<
NewParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
new_para_pattern
);
auto
pass_name
=
new_para_pattern
->
para_name
();
parameter
->
set_should_replace
(
fals
e
);
new_para_pattern
->
set_last
(
tru
e
);
auto
new_pass
=
std
::
make_shared
<
PythonPass
>
(
pass_name
,
nullptr
,
parameter
,
true
);
cur_pg
->
AddPass
(
new_pass
);
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
浏览文件 @
93926230
...
...
@@ -53,16 +53,17 @@ class PyPassManager {
static
PyPassManagerPtr
GetInstance
();
virtual
~
PyPassManager
()
=
default
;
void
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
Phase
phase
=
Phase
::
RESOLVE
,
bool
run_only_once
=
false
);
void
Unregiste
(
const
std
::
string
&
pass_name
,
Phase
phase
);
bool
run_only_once
=
false
);
void
Unregiste
(
const
std
::
string
&
pass_name
);
void
GenNewParameter
(
const
PatternPtr
&
parameter
);
PassGroupPtr
GetPassGroup
(
Phase
phase
);
void
ClearRes
();
MatchResultPtr
GetMatchResult
()
{
return
res_
;
}
void
SetRenorm
(
bool
should_renorm
)
{
should_renorm_
=
should_renorm
;
}
bool
ShouldRenorm
()
{
return
should_renorm_
;
}
void
SetResource
(
pipeline
::
ResourcePtr
resource
)
{
resource_
=
resource
;
}
pipeline
::
ResourcePtr
GetResource
()
{
return
resource_
;
}
void
ClearRes
();
void
ClearPipelineRes
()
{
resource_
=
nullptr
;
}
private:
bool
should_renorm_
=
true
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
93926230
...
...
@@ -477,6 +477,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
// save the run graph func to MsPipeLine
SaveCompiledGraph
(
phase_s
);
opt
::
python_pass
::
PyPassManager
::
GetInstance
()
->
ClearPipelineRes
();
resource
->
Clean
();
// Reclaim all resource used by optimizer;
ReclaimOptimizer
();
...
...
This diff is collapsed.
Click to expand it.
mindspore/graph_utils/graph_pattern.py
浏览文件 @
93926230
...
...
@@ -15,50 +15,43 @@
"""Patterns for describing graphs"""
from
mindspore.ops
import
Primitive
from
mindspore.common.tensor
import
Tensor
from
mindspore._c_expression
import
Pattern
,
IsIn_
,
IsPrimTypeOf_
,
CallWith_
,
IsNot_
,
AnyPattern
,
NewTensor_
,
\
NewParameter_
,
Imm
from
mindspore._c_expression
import
Pattern
,
OneOf_
,
Prim_
,
Call_
,
NoneOf_
,
Any
,
NewTensor_
,
NewParameter_
,
Imm
__all__
=
[
"
IsIn
"
,
"
IsPrimTypeOf
"
,
"Call
With
"
,
"
IsNot
"
,
"Any
Pattern
"
,
"
OneOf
"
,
"
Prim
"
,
"Call"
,
"
NoneOf
"
,
"Any"
,
"NewTensor"
,
"NewParameter"
,
"Imm"
]
class
IsIn
(
IsIn
_
):
class
OneOf
(
OneOf
_
):
r
"""
Express a pattern which allows a list of patterns.
"""
def
__init__
(
self
,
patterns
=
None
,
should_replace
=
True
):
def
__init__
(
self
,
patterns
=
None
):
r
"""
Args:
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False
TypeError: raise type error for invalid inputs.
"""
if
not
should_replace
:
raise
ValueError
(
"IsIn pattern does not have its own should_replace attribute. Set should_replace in
\
its sub-pattern instead."
)
self
.
patterns
=
patterns
if
patterns
is
None
:
IsIn_
.
__init__
(
self
,
())
elif
isinstance
(
patterns
,
Pattern
):
IsIn_
.
__init__
(
self
,
[
patterns
])
if
isinstance
(
patterns
,
Pattern
):
OneOf_
.
__init__
(
self
,
[
patterns
])
elif
isinstance
(
patterns
,
(
tuple
,
list
))
and
all
(
isinstance
(
pattern
,
Pattern
)
for
pattern
in
patterns
):
IsIn
_
.
__init__
(
self
,
patterns
)
OneOf
_
.
__init__
(
self
,
patterns
)
else
:
raise
TypeError
(
f
"Expect patterns to be a list of Patterns/Pattern, got :
{
patterns
}
"
)
class
IsPrimTypeOf
(
IsPrimTypeOf
_
):
class
Prim
(
Prim
_
):
r
"""
Express a pattern of certain primitive type(s).
...
...
@@ -66,7 +59,7 @@ class IsPrimTypeOf(IsPrimTypeOf_):
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
"""
def
__init__
(
self
,
types
,
name
=
None
,
should_replace
=
True
):
def
__init__
(
self
,
types
,
name
=
None
):
r
"""
Args:
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
...
...
@@ -77,9 +70,6 @@ class IsPrimTypeOf(IsPrimTypeOf_):
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
...
...
@@ -103,13 +93,13 @@ class IsPrimTypeOf(IsPrimTypeOf_):
self
.
types
=
types
else
:
raise
TypeError
(
f
"Expecting a primitive type string or a list of Primitives, got :
{
types
}
"
)
IsPrimTypeOf_
.
__init__
(
self
,
self
.
types
,
self
.
name
,
should_replac
e
)
Prim_
.
__init__
(
self
,
self
.
types
,
self
.
nam
e
)
class
Call
With
(
CallWith
_
):
class
Call
(
Call
_
):
r
"""
Express a primitive CNode.
"""
def
__init__
(
self
,
prim_pattern
,
inputs
=
None
,
should_replace
=
True
):
def
__init__
(
self
,
prim_pattern
,
inputs
=
None
):
r
"""
Args:
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
...
...
@@ -118,9 +108,6 @@ class CallWith(CallWith_):
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
...
...
@@ -135,36 +122,31 @@ class CallWith(CallWith_):
self
.
inputs
=
inputs
else
:
raise
TypeError
(
f
"Expect inputs to be a list of Patterns, got :
{
inputs
}
"
)
Call
With_
.
__init__
(
self
,
self
.
prim_pattern
,
self
.
inputs
,
should_replace
)
Call
_
.
__init__
(
self
,
self
.
prim_pattern
,
self
.
inputs
)
class
IsNot
(
IsNot
_
):
class
NoneOf
(
NoneOf
_
):
r
"""
Express a pattern which forbids a list of patterns.
NOTE:
IsNot
pattern should not be the root pattern.
NoneOf
pattern should not be the root pattern.
"""
def
__init__
(
self
,
patterns
=
None
,
should_replace
=
True
):
def
__init__
(
self
,
patterns
=
None
):
r
"""
Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False.
TypeError: raise type error for invalid argument.
"""
if
not
should_replace
:
raise
ValueError
(
"IsNot pattern does not have its own should_replace attribute. Set should_replace in
\
its sub-pattern instead."
)
self
.
patterns
=
patterns
if
patterns
is
None
:
IsNot
_
.
__init__
(
self
,
())
NoneOf
_
.
__init__
(
self
,
())
elif
isinstance
(
patterns
,
Pattern
):
IsNot
_
.
__init__
(
self
,
[
patterns
])
NoneOf
_
.
__init__
(
self
,
[
patterns
])
elif
isinstance
(
patterns
,
(
tuple
,
list
))
and
all
(
isinstance
(
pattern
,
Pattern
)
for
pattern
in
patterns
):
IsNot
_
.
__init__
(
self
,
patterns
)
NoneOf
_
.
__init__
(
self
,
patterns
)
else
:
raise
TypeError
(
f
"Expect list of Patterns/Pattern, got :
{
patterns
}
"
)
...
...
@@ -172,18 +154,14 @@ class NewTensor(NewTensor_):
r
"""
New Tensor to be used in the target.
"""
def
__init__
(
self
,
input_tensor
,
should_replace
=
False
):
def
__init__
(
self
,
input_tensor
):
r
"""
Args:
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
Raises:
ValueError: raise if should_replace is True
TypeError: raise type error for invalid argument.
"""
if
should_replace
:
raise
ValueError
(
"NewTensor should only appear in the target, thus should_replace can only be False."
)
self
.
input_tensor
=
input_tensor
if
isinstance
(
input_tensor
,
Tensor
):
NewTensor_
.
__init__
(
self
,
input_tensor
)
...
...
@@ -194,15 +172,13 @@ class NewParameter(NewParameter_):
r
"""
New Parameter to be used in the target.
"""
def
__init__
(
self
,
para_name
,
default_tensor
,
requires_grad
=
False
,
layerwise_parallel
=
False
,
should_replace
=
False
):
def
__init__
(
self
,
para_name
,
default_tensor
,
requires_grad
=
False
,
layerwise_parallel
=
False
):
r
"""
Args:
para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
parameter everytime a pass target got built. Default: False
Raises:
TypeError: raise type error for invalid argument.
...
...
@@ -211,12 +187,11 @@ class NewParameter(NewParameter_):
self
.
default_tensor
=
default_tensor
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
self
.
should_replace
=
should_replace
if
isinstance
(
para_name
,
str
)
and
isinstance
(
default_tensor
,
Tensor
)
and
isinstance
(
requires_grad
,
bool
)
and
\
isinstance
(
layerwise_parallel
,
bool
)
and
isinstance
(
should_replace
,
bool
)
:
isinstance
(
layerwise_parallel
,
bool
):
NewParameter_
.
__init__
(
self
,
self
.
para_name
,
self
.
default_tensor
,
self
.
requires_grad
,
self
.
layerwise_parallel
,
self
.
should_replace
)
self
.
layerwise_parallel
)
else
:
raise
TypeError
(
f
"Expect para_name(str), default_tensor(Tensor), requires_grad(bool),
\
layerwise_parallel(bool)
should_replace(bool)
, got :
{
para_name
}
,
{
default_tensor
}
,
\
{
requires_grad
}
,
{
layerwise_parallel
}
,
{
should_replace
}
"
)
layerwise_parallel(bool), got :
{
para_name
}
,
{
default_tensor
}
,
\
{
requires_grad
}
,
{
layerwise_parallel
}
"
)
This diff is collapsed.
Click to expand it.
mindspore/graph_utils/python_pass/python_pass_register.py
浏览文件 @
93926230
...
...
@@ -15,7 +15,7 @@
"""Python pass register"""
from
inspect
import
isfunction
from
mindspore.graph_utils.graph_pattern
import
Pattern
,
NewParameter
from
mindspore._c_expression
import
PyPassManager_
,
phase
from
mindspore._c_expression
import
PyPassManager_
__all__
=
[
...
...
@@ -30,21 +30,16 @@ class PyPassManager(PyPassManager_):
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises:
TypeError: If argument has invalid type.
"""
def
__init__
(
self
,
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
):
if
not
isinstance
(
pipeline_phase
,
phase
):
raise
TypeError
(
f
"Expect phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
def
__init__
(
self
,
run_only_once
=
False
):
if
not
isinstance
(
run_only_once
,
bool
):
raise
TypeError
(
f
"Expect bool, got : (
{
type
(
run_only_once
)
}
)
{
run_only_once
}
"
)
PyPassManager_
.
__init__
(
self
)
self
.
phase_
=
pipeline_phase
self
.
run_only_once_
=
run_only_once
PyPassManager_
.
__init__
(
self
)
def
registe
(
self
,
py_pass
):
if
not
isfunction
(
py_pass
):
...
...
@@ -55,16 +50,14 @@ class PyPassManager(PyPassManager_):
raise
TypeError
(
f
"Expect pattern of Pattern type, got : (
{
type
(
pattern
)
}
)
{
pattern
}
"
)
if
not
isinstance
(
target
,
Pattern
):
raise
TypeError
(
f
"Expect target of Pattern type, got : (
{
type
(
target
)
}
)
{
target
}
"
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
phase_
,
self
.
run_only_once_
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
run_only_once_
)
def
unregiste
(
self
,
py_pass
,
pipeline_phase
=
phase
.
opt
):
if
not
isinstance
(
pipeline_phase
,
phase
):
raise
TypeError
(
f
"Expect phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
def
unregiste
(
self
,
py_pass
):
if
isinstance
(
py_pass
,
str
):
super
().
unregiste
(
py_pass
,
pipeline_phase
)
super
().
unregiste
(
py_pass
)
return
if
isfunction
(
py_pass
):
super
().
unregiste
(
py_pass
.
__name__
,
pipeline_phase
)
super
().
unregiste
(
py_pass
.
__name__
)
return
raise
TypeError
(
f
"Expect py_pass to be string or function, got (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
...
...
@@ -82,13 +75,11 @@ class PyPassManager(PyPassManager_):
raise
TypeError
(
f
"Expect should_renorm to be a bool, got
{
should_renorm
}
"
)
super
().
set_renorm
(
should_renorm
)
def
registe_pass
(
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
):
def
registe_pass
(
run_only_once
=
False
):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
registed. Support phase.resolve and phase.opt. Default: phase.opt.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
...
...
@@ -102,19 +93,17 @@ def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return
PyPassManager
(
pipeline_phase
,
run_only_once
)
return
PyPassManager
(
run_only_once
)
def
unregiste_pass
(
py_pass
,
pipeline_phase
=
phase
.
opt
):
def
unregiste_pass
(
py_pass
):
"""
Unregiste python pass.
Args:
py_pass(Union(str, function)): target python pass to unregiste.
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
"""
ppm
=
PyPassManager
()
ppm
.
unregiste
(
py_pass
,
pipeline_phase
)
ppm
.
unregiste
(
py_pass
)
def
gen_new_parameter
(
pattern
):
"""
...
...
@@ -164,7 +153,14 @@ def cancel_new_parameter(pattern):
def
set_renorm
(
should_renorm
):
"""
Set whether or not to do renorm after modified graph in python pass(es).
Set whether or not to do renormalization after modified graph in python pass(es).
Args:
should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es).
NOTE:
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
renormalization may BREAK the network.
"""
ppm
=
PyPassManager
()
ppm
.
set_renorm
(
should_renorm
)
This diff is collapsed.
Click to expand it.
tests/ut/python/optimizer/test_python_pass.py
浏览文件 @
93926230
...
...
@@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_
cancel_new_parameter
from
mindspore.common.api
import
_generate_pip_args
from
mindspore._c_expression
import
generate_key
,
Executor_
from
mindspore.graph_utils.graph_pattern
import
IsIn
,
IsPrimTypeOf
,
CallWith
,
IsNot
,
AnyPattern
,
NewTensor
,
\
NewParameter
,
Imm
from
mindspore.graph_utils.graph_pattern
import
OneOf
,
Prim
,
Call
,
NoneOf
,
Any
,
NewTensor
,
NewParameter
,
Imm
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -50,11 +49,9 @@ def test_softmax_relu():
@
registe_pass
(
run_only_once
=
True
)
def
softmax_relu_pass
():
x
=
AnyPattern
()
softmax_pattern
=
IsPrimTypeOf
(
P
.
Softmax
())
pattern
=
CallWith
(
softmax_pattern
,
inputs
=
[
x
])
relu_pattern
=
IsPrimTypeOf
(
P
.
ReLU
(),
should_replace
=
False
)
target
=
CallWith
(
relu_pattern
,
inputs
=
[
x
])
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
inputs
=
[
x
])
target
=
Call
(
P
.
ReLU
(),
inputs
=
[
x
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
2
)
...
...
@@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid():
@
registe_pass
(
run_only_once
=
True
)
def
softmax_relu_pass
():
x
=
Any
Pattern
()
softmax_pattern
=
IsPrimTypeOf
(
P
.
Softmax
())
pattern
=
Call
With
(
softmax_pattern
,
inputs
=
[
x
])
sigmoid_pattern
=
IsPrimTypeOf
(
P
.
Sigmoid
(),
should_replace
=
False
)
call_sigmoid
=
Call
With
(
sigmoid_pattern
,
[
x
])
relu_pattern
=
IsPrimTypeOf
(
P
.
ReLU
(),
should_replace
=
False
)
target
=
Call
With
(
relu_pattern
,
inputs
=
[
call_sigmoid
])
x
=
Any
()
softmax_pattern
=
Prim
(
P
.
Softmax
())
pattern
=
Call
(
softmax_pattern
,
inputs
=
[
x
])
sigmoid_pattern
=
Prim
(
P
.
Sigmoid
()
)
call_sigmoid
=
Call
(
sigmoid_pattern
,
[
x
])
relu_pattern
=
Prim
(
P
.
ReLU
()
)
target
=
Call
(
relu_pattern
,
inputs
=
[
call_sigmoid
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
3
)
...
...
@@ -99,15 +96,15 @@ def test_isin_pattern_0():
@
registe_pass
(
run_only_once
=
True
)
def
softmax_relu_pass
():
x
=
Any
Pattern
()
softmax_pattern
=
IsPrimTypeOf
(
P
.
Softmax
())
call_softmax
=
Call
With
(
softmax_pattern
,
inputs
=
[
x
])
relu_pattern
=
IsPrimTypeOf
(
P
.
ReLU
())
call_relu
=
Call
With
(
relu_pattern
,
inputs
=
[
x
])
pattern
=
IsIn
([
call_softmax
,
call_relu
])
relu6_pattern
=
IsPrimTypeOf
(
P
.
ReLU6
(),
should_replace
=
False
)
target
=
Call
With
(
relu6_pattern
,
inputs
=
[
x
])
x
=
Any
()
softmax_pattern
=
Prim
(
P
.
Softmax
())
call_softmax
=
Call
(
softmax_pattern
,
inputs
=
[
x
])
relu_pattern
=
Prim
(
P
.
ReLU
())
call_relu
=
Call
(
relu_pattern
,
inputs
=
[
x
])
pattern
=
OneOf
([
call_softmax
,
call_relu
])
relu6_pattern
=
Prim
(
P
.
ReLU6
()
)
target
=
Call
(
relu6_pattern
,
inputs
=
[
x
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
2
)
unregiste_pass
(
softmax_relu_pass
)
...
...
@@ -123,18 +120,17 @@ def test_isin_pattern_1():
@
registe_pass
(
run_only_once
=
True
)
def
softmax_neg_pass
():
x
=
Any
Pattern
()
softmax_pattern
=
IsPrimTypeOf
(
P
.
Softmax
())
call_softmax
=
Call
With
(
softmax_pattern
,
inputs
=
[
x
])
relu_pattern
=
IsPrimTypeOf
(
P
.
ReLU
())
call_relu
=
Call
With
(
relu_pattern
,
inputs
=
[
x
])
pattern
=
IsIn
([
call_softmax
,
call_relu
])
neg_ops
=
IsPrimTypeOf
(
P
.
Neg
(),
should_replace
=
False
)
target
=
Call
With
(
neg_ops
,
inputs
=
[
pattern
])
x
=
Any
()
softmax_pattern
=
Prim
(
P
.
Softmax
())
call_softmax
=
Call
(
softmax_pattern
,
inputs
=
[
x
])
relu_pattern
=
Prim
(
P
.
ReLU
())
call_relu
=
Call
(
relu_pattern
,
inputs
=
[
x
])
pattern
=
OneOf
([
call_softmax
,
call_relu
])
neg_ops
=
Prim
(
P
.
Neg
()
)
target
=
Call
(
neg_ops
,
inputs
=
[
pattern
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
4
)
print
(
transformed_repr
)
unregiste_pass
(
softmax_neg_pass
)
assert
"Neg"
in
transformed_repr
assert
"Softmax"
in
transformed_repr
...
...
@@ -167,11 +163,11 @@ def test_isnot_pattern_0():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
"""
conv2d_prim
=
IsPrimTypeOf
(
"Conv2D"
)
conv2d
=
Call
With
(
conv2d_prim
)
pattern_0
=
IsNot
(
conv2d
)
pattern
=
Call
With
(
P
.
BatchNorm
(),
inputs
=
[
pattern_0
])
target
=
Call
With
(
P
.
ReLU6
(),
inputs
=
[
pattern_0
])
conv2d_prim
=
Prim
(
"Conv2D"
)
conv2d
=
Call
(
conv2d_prim
)
pattern_0
=
NoneOf
(
conv2d
)
pattern
=
Call
(
P
.
BatchNorm
(),
inputs
=
[
pattern_0
])
target
=
Call
(
P
.
ReLU6
(),
inputs
=
[
pattern_0
])
return
pattern
,
target
@
registe_pass
(
run_only_once
=
True
)
...
...
@@ -179,10 +175,8 @@ def test_isnot_pattern_0():
"""
Sub a BN to Softmax.
"""
bn
=
P
.
BatchNorm
()
pattern
=
CallWith
(
bn
)
softmax
=
P
.
Softmax
()
target
=
CallWith
(
softmax
,
should_replace
=
False
)
pattern
=
Call
(
P
.
BatchNorm
())
target
=
Call
(
P
.
Softmax
())
return
pattern
,
target
transformed_repr
=
get_func_graph
(
conv_bn_model
,
inputs
).
get_return
().
expanded_str
(
5
)
...
...
@@ -205,12 +199,12 @@ def test_isnot_pattern_1():
"""
Sub a BN which does NOT take MatMul as inputs to ReLU6.
"""
matmul
=
IsPrimTypeOf
(
"MatMul"
)
pattern_0
=
IsNot
(
matmul
)
matmul
=
Prim
(
"MatMul"
)
pattern_0
=
NoneOf
(
matmul
)
softmax
=
P
.
Softmax
()
pattern
=
Call
With
(
softmax
,
inputs
=
[
pattern_0
])
pattern
=
Call
(
softmax
,
inputs
=
[
pattern_0
])
relu6
=
P
.
ReLU6
()
target
=
Call
With
(
relu6
,
inputs
=
[
pattern_0
],
should_replace
=
False
)
target
=
Call
(
relu6
,
inputs
=
[
pattern_0
]
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
...
...
@@ -228,14 +222,12 @@ def test_newtensor_pattern():
@
registe_pass
(
run_only_once
=
True
)
def
softmax_addn_pass
():
x
=
AnyPattern
()
softmax
=
P
.
Softmax
()
pattern
=
CallWith
(
softmax
,
inputs
=
[
x
])
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
inputs
=
[
x
])
weight_tensor
=
Tensor
(
np
.
zeros
([
42
]),
mindspore
.
float16
)
new_weight
=
NewTensor
(
weight_tensor
)
addn_ops
=
P
.
AddN
()
target
=
CallWith
(
addn_ops
,
inputs
=
[
x
,
new_weight
],
should_replace
=
False
)
target
=
Call
(
P
.
AddN
(),
inputs
=
[
x
,
new_weight
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
2
)
unregiste_pass
(
softmax_addn_pass
)
...
...
@@ -252,25 +244,23 @@ def test_newparameter_pattern():
@
registe_pass
(
run_only_once
=
True
)
def
softmax_addn_pass
():
x
=
AnyPattern
()
softmax
=
P
.
Softmax
()
pattern
=
CallWith
(
softmax
,
inputs
=
[
x
])
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
inputs
=
[
x
])
default_tensor0
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
default_tensor1
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
new_para_0
=
NewParameter
(
"Merlin"
,
default_tensor0
)
new_para_1
=
NewParameter
(
"Arthur"
,
default_tensor1
)
target_0
=
Call
With
(
P
.
MatMul
(),
inputs
=
[
new_para_0
,
new_para_1
],
should_replace
=
False
)
target
=
Call
With
(
"make_tuple"
,
inputs
=
[
target_0
],
should_replace
=
False
)
target_0
=
Call
(
P
.
MatMul
(),
inputs
=
[
new_para_0
,
new_para_1
]
)
target
=
Call
(
"make_tuple"
,
inputs
=
[
target_0
]
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
unregiste_pass
(
softmax_addn_pass
)
assert
"MatMul"
in
transformed_repr
assert
"make_tuple"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
def
test_imm_
pattern
():
def
test_imm_
target
():
"""
Test NewParameter pattern in the target
"""
...
...
@@ -278,17 +268,15 @@ def test_imm_pattern():
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
def
softmax_addn_pass
():
x
=
AnyPattern
()
softmax
=
P
.
Softmax
()
pattern
=
CallWith
(
softmax
,
inputs
=
[
x
])
def
softmax_pass
():
x
=
Any
()
pattern
=
Call
(
P
.
Softmax
(),
inputs
=
[
x
])
imm
=
Imm
(
0
)
target_0
=
Call
With
(
"make_tuple"
,
inputs
=
[
pattern
],
should_replace
=
False
)
target
=
Call
With
(
"tuple_getitem"
,
inputs
=
[
target_0
,
imm
],
should_replace
=
False
)
target_0
=
Call
(
"make_tuple"
,
inputs
=
[
pattern
]
)
target
=
Call
(
"tuple_getitem"
,
inputs
=
[
target_0
,
imm
]
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
unregiste_pass
(
softmax_addn_pass
)
unregiste_pass
(
softmax_pass
)
assert
"make_tuple"
in
transformed_repr
assert
"tuple_getitem"
in
transformed_repr
assert
"Softmax"
in
transformed_repr
...
...
@@ -301,21 +289,19 @@ def test_gen_new_parameter():
softmax_model
=
nn
.
Softmax
()
default_tensor
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
new_para
=
NewParameter
(
"Merlin"
,
default_tensor
,
should_replace
=
True
)
new_para
=
NewParameter
(
"Merlin"
,
default_tensor
)
gen_new_parameter
(
new_para
)
@
registe_pass
(
run_only_once
=
True
)
def
softmax_make_tuple_pass
():
x
=
Any
Pattern
()
x
=
Any
()
softmax
=
P
.
Softmax
()
pattern
=
Call
With
(
softmax
,
inputs
=
[
x
])
pattern
=
Call
(
softmax
,
inputs
=
[
x
])
target
=
Call
With
(
"make_tuple"
,
inputs
=
[
pattern
,
new_para
],
should_replace
=
False
)
target
=
Call
(
"make_tuple"
,
inputs
=
[
pattern
,
new_para
]
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
assert
"Merlin"
in
transformed_repr
unregiste_pass
(
softmax_make_tuple_pass
)
cancel_new_parameter
(
new_para
)
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
assert
"Merlin"
not
in
transformed_repr
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部