Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8d693306
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8d693306
编写于
8月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4126 Add new parameter
Merge pull request !4126 from BowenK/new_parameter
上级
6eb98f28
e7c6b7e6
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
796 addition
and
129 deletion
+796
-129
mindspore/ccsrc/frontend/optimizer/pass_group.cc
mindspore/ccsrc/frontend/optimizer/pass_group.cc
+6
-3
mindspore/ccsrc/frontend/optimizer/pass_group.h
mindspore/ccsrc/frontend/optimizer/pass_group.h
+4
-2
mindspore/ccsrc/frontend/optimizer/pattern.cc
mindspore/ccsrc/frontend/optimizer/pattern.cc
+4
-0
mindspore/ccsrc/frontend/optimizer/pattern.h
mindspore/ccsrc/frontend/optimizer/pattern.h
+67
-10
mindspore/ccsrc/frontend/optimizer/py_pass.cc
mindspore/ccsrc/frontend/optimizer/py_pass.cc
+212
-67
mindspore/ccsrc/frontend/optimizer/py_pass.h
mindspore/ccsrc/frontend/optimizer/py_pass.h
+6
-6
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
+28
-6
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
+11
-2
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+14
-0
mindspore/graph_utils/__init__.py
mindspore/graph_utils/__init__.py
+15
-0
mindspore/graph_utils/graph_pattern.py
mindspore/graph_utils/graph_pattern.py
+86
-18
mindspore/graph_utils/python_pass/__init__.py
mindspore/graph_utils/python_pass/__init__.py
+24
-0
mindspore/graph_utils/python_pass/python_pass_register.py
mindspore/graph_utils/python_pass/python_pass_register.py
+170
-0
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-1
tests/ut/python/optimizer/test_python_pass.py
tests/ut/python/optimizer/test_python_pass.py
+148
-14
未找到文件。
mindspore/ccsrc/frontend/optimizer/pass_group.cc
浏览文件 @
8d693306
...
...
@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "frontend/optimizer/pass_group.h"
#include "frontend/optimizer/py_pass_manager.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -35,14 +36,15 @@ bool PassGroup::DeletePass(const std::string &pass_name) {
return
false
;
}
bool
PassGroup
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PythonPassPtr
>
&
passes
)
const
{
bool
PassGroup
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PythonPassPtr
>
&
passes
,
const
MatchResultPtr
&
res
)
const
{
if
(
func_graph
==
nullptr
)
{
return
false
;
}
bool
changed
=
false
;
for
(
const
auto
&
pass
:
passes
)
{
if
(
pass
!=
nullptr
)
{
if
(
pass
->
Run
(
func_graph
))
{
if
(
pass
->
Run
(
func_graph
,
res
))
{
changed
=
true
;
}
}
...
...
@@ -54,8 +56,9 @@ bool PassGroup::Run(const FuncGraphPtr &func_graph) const {
bool
changed
=
false
;
// run all passes
bool
change
=
true
;
auto
res
=
PyPassManager
::
GetInstance
()
->
GetMatchResult
();
while
(
change
)
{
change
=
Run
(
func_graph
,
passes_
);
change
=
Run
(
func_graph
,
passes_
,
res
);
changed
=
change
||
changed
;
if
(
run_only_once_
)
{
break
;
...
...
mindspore/ccsrc/frontend/optimizer/pass_group.h
浏览文件 @
8d693306
...
...
@@ -41,12 +41,14 @@ class PassGroup {
// @return false, graph not changed
bool
Run
(
const
FuncGraphPtr
&
func_graph
)
const
;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [inout]
func_
graph The graph to be optimized
// @param [in] passes The given graph passes
// @param [inout] res MatchResult used to collect all matched patterns and nodes
// @return true, graph changed
// @return false, graph not changed
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PythonPassPtr
>
&
passes
)
const
;
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PythonPassPtr
>
&
passes
,
const
MatchResultPtr
&
res
)
const
;
std
::
string
name
()
const
{
return
name_
;
}
void
SetRunOnlyOnce
(
bool
run_only_once
)
{
run_only_once_
=
run_only_once
;
}
private:
const
std
::
string
name_
;
...
...
mindspore/ccsrc/frontend/optimizer/pattern.cc
浏览文件 @
8d693306
...
...
@@ -96,6 +96,7 @@ MatchResultPtr IsIn::match(const AnfNodePtr &node) {
for
(
auto
&
iter
:
patterns_
)
{
auto
res
=
iter
->
match
(
node
);
if
(
res
!=
nullptr
)
{
res
->
add_entry
(
shared_from_base
<
IsIn
>
(),
node
);
return
res
;
}
}
...
...
@@ -151,6 +152,9 @@ REGISTER_PYBIND_DEFINE(
(
void
)
py
::
class_
<
AnyPattern
,
std
::
shared_ptr
<
AnyPattern
>
,
Pattern
>
(
*
m
,
"AnyPattern"
).
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
NewTensor
,
std
::
shared_ptr
<
NewTensor
>
,
Pattern
>
(
*
m
,
"NewTensor_"
)
.
def
(
py
::
init
<
tensor
::
TensorPtr
>
());
(
void
)
py
::
class_
<
NewParameter
,
std
::
shared_ptr
<
NewParameter
>
,
Pattern
>
(
*
m
,
"NewParameter_"
)
.
def
(
py
::
init
<
string
,
tensor
::
TensorPtr
,
bool
,
bool
,
bool
>
());
(
void
)
py
::
class_
<
Imm
,
std
::
shared_ptr
<
Imm
>
,
Pattern
>
(
*
m
,
"Imm"
).
def
(
py
::
init
<
int
>
());
}));
}
// namespace python_pass
}
// namespace opt
...
...
mindspore/ccsrc/frontend/optimizer/pattern.h
浏览文件 @
8d693306
...
...
@@ -42,6 +42,10 @@ class CallWith;
using
CallWithPtr
=
std
::
shared_ptr
<
CallWith
>
;
class
NewTensor
;
using
NewTensorPtr
=
std
::
shared_ptr
<
NewTensor
>
;
class
NewParameter
;
using
NewParameterPtr
=
std
::
shared_ptr
<
NewParameter
>
;
class
Imm
;
using
ImmPtr
=
std
::
shared_ptr
<
Imm
>
;
struct
PatternHasher
;
struct
PatternEqual
;
using
PatternNodeMap
=
std
::
unordered_map
<
PatternPtr
,
AnfNodePtr
,
PatternHasher
,
PatternEqual
>
;
...
...
@@ -55,6 +59,7 @@ class Pattern : public Base {
string
unique_name
()
const
{
return
unique_name_
;
}
vector
<
PatternPtr
>
inputs
()
{
return
inputs_
;
}
bool
should_replace
()
{
return
should_replace_
;
}
void
set_should_replace
(
bool
should_replace
)
{
should_replace_
=
should_replace
;
}
virtual
void
reset
()
{}
protected:
...
...
@@ -86,14 +91,14 @@ class IsPrimTypeOf : public Pattern {
~
IsPrimTypeOf
()
=
default
;
IsPrimTypeOf
(
vector
<
PrimitivePyPtr
>
prims
,
string
name
,
bool
should_replace
)
:
primitives_
(
prims
),
name_
(
name
),
matched_prim_
(
nullptr
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"_"
+
name
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
IsPrimTypeOf
_"
+
name
;
should_replace_
=
should_replace
;
if
(
!
should_replace
)
{
matched_prim_
=
prims
[
0
];
}
}
IsPrimTypeOf
(
vector
<
string
>
types
,
string
name
,
bool
should_replace
)
:
types_
(
types
),
name_
(
name
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"_"
+
name
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"
IsPrimTypeOf
_"
+
name
;
// Make primitives_
for
(
auto
&
iter
:
types
)
{
primitives_
.
push_back
(
std
::
make_shared
<
PrimitivePy
>
(
iter
,
py
::
cast
(
nullptr
)));
...
...
@@ -126,19 +131,20 @@ class CallWith : public Pattern {
CallWith
(
PatternPtr
prim_pattern
,
vector
<
PatternPtr
>
inputs
,
bool
should_replace
)
{
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_
=
prim_pattern
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
prim_pattern
->
unique_name
();
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"CallWithPattern_"
+
prim_pattern
->
unique_name
();
inputs_
=
inputs
;
should_replace_
=
should_replace
;
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
should_replace_
=
prim_pattern
->
should_replace
();
}
CallWith
(
PrimitivePyPtr
prim
,
vector
<
PatternPtr
>
inputs
,
bool
should_replace
)
{
prim_
=
prim
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
prim_
->
ToString
();
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"CallWithPrim_"
+
prim_
->
ToString
();
inputs_
=
inputs
;
should_replace_
=
should_replace
;
}
CallWith
(
string
prim_str
,
vector
<
PatternPtr
>
inputs
,
bool
should_replace
)
{
prim_
=
std
::
make_shared
<
PrimitivePy
>
(
prim_str
,
py
::
cast
(
nullptr
));
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
prim_
->
ToString
();
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"CallWithStr_"
+
prim_
->
ToString
();
inputs_
=
inputs
;
should_replace_
=
should_replace
;
}
...
...
@@ -159,7 +165,7 @@ class IsIn : public Pattern {
IsIn
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
IsIn
()
=
default
;
explicit
IsIn
(
vector
<
PatternPtr
>
patterns
)
:
patterns_
(
patterns
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"IsIn"
;
for
(
auto
&
iter
:
patterns
)
{
unique_name_
=
unique_name_
+
"_"
+
iter
->
unique_name
();
}
...
...
@@ -176,9 +182,9 @@ class IsNot : public Pattern {
IsNot
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
IsNot
()
=
default
;
explicit
IsNot
(
vector
<
PatternPtr
>
patterns
)
:
patterns_
(
patterns
)
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"IsNot"
;
for
(
auto
&
iter
:
patterns
)
{
unique_name_
=
"IsNot_"
+
unique_name_
+
"_"
+
iter
->
unique_name
();
unique_name_
=
unique_name_
+
"_"
+
iter
->
unique_name
();
}
}
MS_DECLARE_PARENT
(
IsNot
,
Pattern
);
...
...
@@ -200,7 +206,10 @@ class NewTensor : public Pattern {
public:
NewTensor
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
~
NewTensor
()
=
default
;
explicit
NewTensor
(
tensor
::
TensorPtr
input_tensor
)
:
input_tensor_
(
input_tensor
)
{
should_replace_
=
false
;
}
explicit
NewTensor
(
tensor
::
TensorPtr
input_tensor
)
:
input_tensor_
(
input_tensor
)
{
should_replace_
=
false
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"NewTensor"
;
}
MS_DECLARE_PARENT
(
NewTensor
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Find NewTensor in pattern, NewTensor should only appear in the target.
\n
"
;
...
...
@@ -211,6 +220,54 @@ class NewTensor : public Pattern {
tensor
::
TensorPtr
input_tensor_
;
};
class
NewParameter
:
public
Pattern
{
public:
NewParameter
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
explicit
NewParameter
(
string
para_name
,
tensor
::
TensorPtr
default_tensor
,
bool
requires_grad
,
bool
layerwise_parallel
,
bool
should_replace
)
:
para_name_
(
para_name
),
requires_grad_
(
requires_grad
),
layerwise_parallel_
(
layerwise_parallel
)
{
should_replace_
=
should_replace
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"NewParameter_"
+
para_name
;
// clone input tensor
default_tensor_
=
std
::
make_shared
<
tensor
::
Tensor
>
(
*
default_tensor
.
get
());
built_
=
false
;
}
MS_DECLARE_PARENT
(
NewParameter
,
Pattern
);
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Find NewParameter in pattern, NewParameter should only appear in the target.
\n
"
;
}
string
para_name
()
{
return
para_name_
;
}
tensor
::
TensorPtr
default_tensor
()
{
return
default_tensor_
;
}
bool
requires_grad
()
{
return
requires_grad_
;
}
bool
layerwise_parallel
()
{
return
layerwise_parallel_
;
}
bool
built
()
{
return
built_
;
}
void
set_built
(
bool
built
)
{
built_
=
built
;
}
void
reset
()
override
{
built_
=
false
;
}
private:
string
para_name_
;
bool
requires_grad_
;
bool
layerwise_parallel_
;
bool
built_
;
tensor
::
TensorPtr
default_tensor_
;
};
class
Imm
:
public
Pattern
{
public:
Imm
()
{
unique_name_
=
std
::
to_string
(
g_id_
++
);
}
explicit
Imm
(
int
value
)
:
value_
(
value
)
{
should_replace_
=
false
;
unique_name_
=
std
::
to_string
(
g_id_
++
)
+
"Imm_"
+
std
::
to_string
(
value
);
}
MS_DECLARE_PARENT
(
Imm
,
Pattern
);
// NOTE: Doesn't support Imm in src pattern currently.
MatchResultPtr
match
(
const
AnfNodePtr
&
node
)
override
{
return
nullptr
;
}
int
value
()
{
return
value_
;
}
private:
int
value_
;
};
class
MatchResult
{
public:
MatchResult
()
{}
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.cc
浏览文件 @
8d693306
...
...
@@ -21,13 +21,26 @@
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "pybind_api/ir/primitive_py.h"
#include "ir/scalar.h"
#include "ir/graph_utils.h"
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "utils/info.h"
#include "debug/anf_ir_dump.h"
#include "debug/draw.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
namespace
internal
{
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
);
const
char
PARAMETER_MODULE
[]
=
"mindspore.common.parameter"
;
const
char
PARAMETER_CLASS
[]
=
"Parameter"
;
const
char
SET_PARAM
[]
=
"__setattr__"
;
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
);
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
);
void
ReflectParamBackToPython
(
const
AnfNodePtr
&
param
,
string
param_name
,
tensor
::
TensorPtr
default_input
,
bool
requires_grad
,
bool
layerwise_parallel
);
std
::
string
GetNodeRepr
(
AnfNodePtr
node
)
{
if
(
node
!=
nullptr
)
{
...
...
@@ -42,8 +55,10 @@ std::string GetNodeRepr(AnfNodePtr node) {
repr
+=
")"
;
return
repr
;
}
if
(
node
->
isa
<
ValueNode
>
())
{
return
GetValueNode
(
node
)
->
ToString
();
if
(
node
->
isa
<
Parameter
>
())
{
return
"[Parameter]"
+
node
->
ToString
();
}
else
if
(
node
->
isa
<
ValueNode
>
())
{
return
"[Value]"
+
GetValueNode
(
node
)
->
ToString
();
}
return
node
->
ToString
();
}
...
...
@@ -82,7 +97,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
return
std
::
make_shared
<
ValueNode
>
(
input_tensor
);
}
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
)
{
AnfNodePtr
BuildPrimitiveValueNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
fg
)
{
auto
call_with_pattern
=
pattern
->
cast
<
CallWithPtr
>
();
MS_EXCEPTION_IF_NULL
(
call_with_pattern
);
auto
prim
=
call_with_pattern
->
prim_value
();
...
...
@@ -91,15 +106,70 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
auto
prim_pattern
=
call_with_pattern
->
prim_pattern
();
MS_EXCEPTION_IF_NULL
(
prim_pattern
);
return
ProcessSinglePattern
(
prim_pattern
,
res
);
return
ProcessSinglePattern
(
prim_pattern
,
res
,
fg
);
}
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
)
{
AnfNodePtr
BuildNewParameter
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
auto
new_para_pattern
=
pattern
->
cast
<
NewParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
new_para_pattern
);
if
(
!
new_para_pattern
->
built
())
{
static
int
parameter_id
=
0
;
auto
para_name
=
new_para_pattern
->
para_name
()
+
new_para_pattern
->
unique_name
()
+
std
::
to_string
(
parameter_id
++
);
auto
para_node
=
std
::
make_shared
<
Parameter
>
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
para_node
);
para_node
->
set_name
(
para_name
);
// Set function graph
para_node
->
set_func_graph
(
func_graph
);
// Set Debug Info
auto
debug_info
=
std
::
make_shared
<
NodeDebugInfo
>
(
para_name
);
para_node
->
set_debug_info
(
debug_info
);
// Set abstract
auto
default_value
=
new_para_pattern
->
default_tensor
();
MS_EXCEPTION_IF_NULL
(
default_value
);
para_node
->
set_abstract
(
default_value
->
ToAbstract
()
->
Broaden
());
res
->
add_entry
(
pattern
,
para_node
);
func_graph
->
add_parameter
(
para_node
);
// Reflect back to Cell._params
internal
::
ReflectParamBackToPython
(
para_node
,
para_name
,
default_value
,
new_para_pattern
->
requires_grad
(),
new_para_pattern
->
layerwise_parallel
());
MS_LOG
(
WARNING
)
<<
"Adding parameter: "
+
para_node
->
ToString
()
+
" parameter name:"
+
para_node
->
name
();
new_para_pattern
->
set_built
(
true
);
return
para_node
;
}
else
{
// Built, fetch the node
auto
para_node
=
res
->
get_node
(
pattern
);
MS_EXCEPTION_IF_NULL
(
para_node
);
return
para_node
;
}
}
AnfNodePtr
BuildImmNode
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
)
{
auto
imm_pattern
=
pattern
->
cast
<
ImmPtr
>
();
MS_EXCEPTION_IF_NULL
(
imm_pattern
);
auto
value
=
imm_pattern
->
value
();
auto
scalar_value_ptr
=
std
::
make_shared
<
Int32Imm
>
(
value
);
return
std
::
make_shared
<
ValueNode
>
(
scalar_value_ptr
);
}
AnfNodePtr
ProcessSinglePattern
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
if
(
pattern
->
should_replace
())
{
// Find replacement in the MatchResult
auto
target_node
=
res
->
get_node
(
pattern
);
if
(
target_node
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot find target node in pattern match result, pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
if
(
pattern
->
isa
<
IsPrimTypeOf
>
()
||
pattern
->
isa
<
NewTensor
>
()
||
pattern
->
isa
<
NewParameter
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot find target node, pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
return
nullptr
;
}
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
auto
new_node
=
BuildTarget
(
pattern
,
func_graph
,
res
);
if
(
new_node
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Try to build pattern node but FAILED. pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
}
return
new_node
;
}
if
(
pattern
->
isa
<
NewParameter
>
())
{
return
target_node
;
}
return
target_node
;
}
...
...
@@ -109,7 +179,19 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}
else
if
(
pattern
->
isa
<
NewTensor
>
())
{
return
BuildNewTensor
(
pattern
,
res
);
}
else
if
(
pattern
->
isa
<
CallWith
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
);
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
}
else
if
(
pattern
->
isa
<
NewParameter
>
())
{
return
BuildNewParameter
(
pattern
,
res
,
func_graph
);
}
else
if
(
pattern
->
isa
<
Imm
>
())
{
return
BuildImmNode
(
pattern
,
res
);
}
return
nullptr
;
}
AnfNodePtr
ProcessComplexPatternFirstInput
(
const
PatternPtr
&
pattern
,
const
MatchResultPtr
&
res
,
const
FuncGraphPtr
&
func_graph
)
{
if
(
pattern
->
isa
<
CallWith
>
())
{
return
BuildPrimitiveValueNode
(
pattern
,
res
,
func_graph
);
}
return
nullptr
;
}
...
...
@@ -117,91 +199,154 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
AnfNodePtr
BuildTarget
(
const
PatternPtr
&
pattern
,
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
)
{
auto
target_inputs
=
pattern
->
inputs
();
if
(
target_inputs
.
size
()
==
0
)
{
return
ProcessSinglePattern
(
pattern
,
res
);
auto
new_node
=
ProcessSinglePattern
(
pattern
,
res
,
func_graph
);
if
(
new_node
!=
nullptr
)
{
res
->
add_entry
(
pattern
,
new_node
);
}
return
new_node
;
}
// Build up the AnfNode in a recursive manner
std
::
vector
<
AnfNodePtr
>
new_inputs
;
auto
prim_value_node
=
Process
SinglePattern
(
pattern
,
res
);
auto
prim_value_node
=
Process
ComplexPatternFirstInput
(
pattern
,
res
,
func_graph
);
MS_EXCEPTION_IF_NULL
(
prim_value_node
);
new_inputs
.
push_back
(
prim_value_node
);
for
(
auto
&
iter
:
target_inputs
)
{
if
(
iter
==
pattern
)
{
MS_LOG
(
EXCEPTION
)
<<
"Circle references: Pattern takes itself as input. Got pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
MS_LOG
(
EXCEPTION
)
<<
"Circle references. Got pattern: "
+
pattern
->
unique_name
()
+
"
\n
"
;
}
auto
input_node
=
BuildTarget
(
iter
,
func_graph
,
res
);
if
(
input_node
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to build input node for pattern : "
+
iter
->
unique_name
()
+
"
\n
"
;
}
new_inputs
.
push_back
(
input_node
);
}
auto
new_node
=
func_graph
->
NewCNode
(
new_inputs
);
res
->
add_entry
(
pattern
,
new_node
);
return
new_node
;
}
void
DrawNode
(
string
name
,
AnfNodePtr
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
bool
save_graphs
=
context_ptr
->
save_graphs_flag
();
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
();
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
auto
new_func_graph
=
std
::
make_shared
<
FuncGraph
>
();
new_func_graph
->
set_output
(
node
,
true
);
if
(
save_graphs
)
{
auto
ir_dump_path
=
save_graphs_path
+
"/"
+
name
+
".ir"
;
auto
dot_dump_path
=
save_graphs_path
+
"/"
+
name
+
".dot"
;
DumpIR
(
ir_dump_path
,
new_func_graph
);
draw
::
Draw
(
dot_dump_path
,
new_func_graph
);
}
}
void
ReflectParamBackToPython
(
const
AnfNodePtr
&
param
,
string
param_name
,
tensor
::
TensorPtr
default_input
,
bool
requires_grad
,
bool
layerwise_parallel
)
{
// 1. Get current cell object
auto
ppm
=
opt
::
python_pass
::
PyPassManager
::
GetInstance
();
auto
resource
=
ppm
->
GetResource
();
py
::
object
top_cell
=
resource
->
input
();
if
(
py
::
isinstance
<
py
::
none
>
(
top_cell
))
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to get top cell from resource."
;
}
// 2. New a Parameter object with the above-specified args
py
::
object
parameter_class
=
py
::
module
::
import
(
PARAMETER_MODULE
).
attr
(
PARAMETER_CLASS
);
py
::
object
new_parameter
=
parameter_class
(
default_input
,
param_name
,
requires_grad
,
layerwise_parallel
);
// 3. Add the new python Parameter object to Cell's _params atttributes
top_cell
.
attr
(
SET_PARAM
)(
param_name
,
new_parameter
);
// 4. Set default_param for param_node
ValuePtr
param_value
=
nullptr
;
bool
converted
=
parse
::
ConvertData
(
new_parameter
,
&
param_value
,
false
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to convert new parameter to ValuePtr."
;
}
MS_EXCEPTION_IF_NULL
(
param
);
auto
param_node
=
param
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
param_node
);
param_node
->
set_default_param
(
param_value
);
}
void
Reset
(
PatternPtr
pattern
)
{
if
(
pattern
->
isa
<
IsPrimTypeOf
>
())
{
auto
prim_pattern
=
pattern
->
cast
<
IsPrimTypeOfPtr
>
();
prim_pattern
->
reset
();
return
;
}
else
if
(
pattern
->
isa
<
NewParameter
>
())
{
auto
new_param_pattern
=
pattern
->
cast
<
NewParameterPtr
>
();
new_param_pattern
->
reset
();
return
;
}
else
if
(
pattern
->
isa
<
CallWith
>
())
{
auto
call_with_pattern
=
pattern
->
cast
<
CallWithPtr
>
();
for
(
auto
sub_pattern
:
call_with_pattern
->
inputs
())
{
Reset
(
sub_pattern
);
}
new_inputs
.
push_back
(
BuildTarget
(
iter
,
func_graph
,
res
))
;
return
;
}
return
func_graph
->
NewCNode
(
new_inputs
)
;
return
;
}
}
// namespace internal
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
src_pattern_
);
MS_EXCEPTION_IF_NULL
(
dst_pattern_
);
auto
res
=
src_pattern_
->
match
(
node
);
if
(
res
!=
nullptr
)
{
res
->
dump
();
MS_LOG
(
WARNING
)
<<
"Matched pattern: "
+
src_pattern_
->
unique_name
();
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
)
{
auto
match_res
=
src_pattern_
->
match
(
node
);
if
(
match_res
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Matched pattern: "
+
src_pattern_
->
unique_name
()
+
" node : "
+
internal
::
GetNodeRepr
(
node
);
res
->
merge
(
match_res
);
auto
new_node
=
internal
::
BuildTarget
(
dst_pattern_
,
func_graph
,
res
);
dst_pattern_
->
reset
(
);
MS_LOG
(
DEBU
G
)
<<
"To be replaced node: "
+
internal
::
GetNodeRepr
(
new_node
)
+
"
\n
"
;
internal
::
Reset
(
dst_pattern
()
);
MS_LOG
(
WARNIN
G
)
<<
"To be replaced node: "
+
internal
::
GetNodeRepr
(
new_node
)
+
"
\n
"
;
return
new_node
;
}
src_pattern_
->
reset
(
);
internal
::
Reset
(
src_pattern
()
);
return
nullptr
;
}
bool
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
bool
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
dst_pattern_
);
if
(
src_pattern_
==
nullptr
)
{
// Add NewParameter
auto
new_para_pattern
=
dst_pattern_
->
cast
<
NewParameterPtr
>
();
if
(
new_para_pattern
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Expect NewParameter pattern for target if src pattern is null."
;
}
auto
para_name
=
new_para_pattern
->
para_name
()
+
new_para_pattern
->
unique_name
();
MS_LOG
(
DEBUG
)
<<
"Adding New parameter : "
+
para_name
;
auto
para_node
=
std
::
make_shared
<
Parameter
>
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
para_node
);
para_node
->
set_name
(
para_name
);
// Set function graph
para_node
->
set_func_graph
(
func_graph
);
// Set Debug Info
auto
debug_info
=
std
::
make_shared
<
NodeDebugInfo
>
(
para_name
);
para_node
->
set_debug_info
(
debug_info
);
// Set abstract
auto
default_value
=
new_para_pattern
->
default_tensor
();
MS_EXCEPTION_IF_NULL
(
default_value
);
para_node
->
set_abstract
(
default_value
->
ToAbstract
()
->
Broaden
());
res
->
add_entry
(
dst_pattern_
,
para_node
);
func_graph
->
add_parameter
(
para_node
);
// Reflect back to Cell._params
internal
::
ReflectParamBackToPython
(
para_node
,
para_name
,
default_value
,
new_para_pattern
->
requires_grad
(),
new_para_pattern
->
layerwise_parallel
());
MS_LOG
(
WARNING
)
<<
"Adding parameter: "
+
para_node
->
ToString
()
+
" parameter name:"
+
para_node
->
name
();
return
true
;
}
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
func_graph
);
auto
seen
=
NewSeenGeneration
();
// 1024 is for the initial capacity of deque
std
::
deque
<
AnfNodePtr
>
todo
(
1024
);
todo
.
push_back
(
func_graph
->
output
());
auto
graph_nodes_sorted
=
TopoSort
(
func_graph
->
output
());
bool
changes
=
false
;
auto
&
all_nodes
=
manager
->
all_nodes
();
while
(
!
todo
.
empty
())
{
AnfNodePtr
node
=
todo
.
front
();
todo
.
pop_front
();
// Check whether this node has been matched.
if
(
node
==
nullptr
||
node
->
seen_
==
seen
||
!
internal
::
IsTraversable
(
node
)
||
!
all_nodes
.
contains
(
node
))
{
continue
;
}
node
->
seen_
=
seen
;
// Select nodes that this transform can be applied.
AnfNodePtr
new_node
=
Run
(
func_graph
,
node
);
bool
change
=
(
new_node
!=
nullptr
);
// Traverse once
for
(
auto
&
node
:
graph_nodes_sorted
)
{
AnfNodePtr
new_node
=
Run
(
func_graph
,
node
,
res
);
if
(
new_node
!=
nullptr
&&
new_node
!=
node
)
{
internal
::
DrawNode
(
dst_pattern_
->
unique_name
(),
new_node
);
(
void
)
manager
->
Replace
(
node
,
new_node
);
}
else
if
(
new_node
==
nullptr
)
{
new_node
=
node
;
}
if
(
run_only_once_
)
{
return
change
;
}
// Find success, and add them to todo list
if
(
IsValueNode
<
FuncGraph
>
(
node
))
{
todo
.
push_back
(
GetValueNode
<
FuncGraphPtr
>
(
node
)
->
output
());
}
if
(
node
->
isa
<
CNode
>
())
{
auto
&
inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
(
void
)
std
::
copy
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
todo
));
}
auto
&
node_users
=
manager
->
node_users
();
if
(
change
&&
node_users
.
find
(
node
)
!=
node_users
.
end
())
{
for
(
auto
&
use
:
node_users
[
node
])
{
auto
use_node
=
use
.
first
;
if
(
use_node
==
nullptr
)
{
continue
;
}
todo
.
push_back
(
use_node
);
if
(
use_node
->
seen_
==
seen
)
{
use_node
->
seen_
--
;
}
}
changes
=
true
;
}
}
return
changes
;
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.h
浏览文件 @
8d693306
...
...
@@ -34,20 +34,20 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
class
PythonPass
{
public:
explicit
PythonPass
(
const
std
::
string
&
name
,
const
PatternPtr
&
src
,
const
PatternPtr
&
dst
,
bool
run_only_once
=
false
,
bool
multigraph
=
true
)
:
src_pattern_
(
src
),
dst_pattern_
(
dst
),
name_
(
name
),
run_only_once_
(
run_only_once
),
multigraph_
(
multigraph
)
{}
explicit
PythonPass
(
const
std
::
string
&
name
,
const
PatternPtr
&
src
,
const
PatternPtr
&
dst
,
bool
run_only_once
=
false
)
:
src_pattern_
(
src
),
dst_pattern_
(
dst
),
name_
(
name
),
run_only_once_
(
run_only_once
)
{}
~
PythonPass
()
=
default
;
bool
Run
(
const
FuncGraphPtr
&
func_graph
);
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
MatchResultPtr
&
res
);
std
::
string
name
()
const
{
return
name_
;
}
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
);
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
MatchResultPtr
&
res
);
PatternPtr
src_pattern
()
{
return
src_pattern_
;
}
PatternPtr
dst_pattern
()
{
return
dst_pattern_
;
}
private:
PatternPtr
src_pattern_
;
PatternPtr
dst_pattern_
;
const
std
::
string
name_
;
bool
run_only_once_
;
bool
multigraph_
=
true
;
};
using
PythonPassPtr
=
std
::
shared_ptr
<
PythonPass
>
;
...
...
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
浏览文件 @
8d693306
...
...
@@ -45,14 +45,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
PyPassManager
::
PyPassManager
()
{
phase_to_group_
[
Phase
::
RESOLVE
]
=
std
::
make_shared
<
PassGroup
>
();
phase_to_group_
[
Phase
::
OPT
]
=
std
::
make_shared
<
PassGroup
>
();
res_
=
std
::
make_shared
<
MatchResult
>
();
}
void
PyPassManager
::
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
Phase
phase
,
bool
run_only_once
,
bool
multigraph
)
{
auto
cur_pm
=
GetPassGroup
(
phase
);
MS_EXCEPTION_IF_NULL
(
cur_pm
);
PythonPassPtr
new_pass
=
std
::
make_shared
<
PythonPass
>
(
pass_name
,
pattern
,
target
,
run_only_once
,
multigraph
);
cur_pm
->
AddPass
(
new_pass
);
Phase
phase
,
bool
run_only_once
)
{
auto
cur_pg
=
GetPassGroup
(
phase
);
MS_EXCEPTION_IF_NULL
(
cur_pg
);
cur_pg
->
SetRunOnlyOnce
(
run_only_once
);
MS_EXCEPTION_IF_NULL
(
pattern
);
MS_EXCEPTION_IF_NULL
(
target
);
MS_EXCEPTION_IF_NULL
(
cur_pg
);
PythonPassPtr
new_pass
=
std
::
make_shared
<
PythonPass
>
(
pass_name
,
pattern
,
target
,
run_only_once
);
cur_pg
->
AddPass
(
new_pass
);
}
void
PyPassManager
::
Unregiste
(
const
std
::
string
&
pass_name
,
Phase
phase
)
{
...
...
@@ -63,6 +68,21 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
}
}
void
PyPassManager
::
GenNewParameter
(
const
PatternPtr
&
parameter
)
{
MS_EXCEPTION_IF_NULL
(
parameter
);
// Add new parameter after resolve
// NOTE: Add NewParameter at early stage will cause CSE problems
auto
cur_pg
=
GetPassGroup
(
Phase
::
OPT
);
MS_EXCEPTION_IF_NULL
(
cur_pg
);
cur_pg
->
SetRunOnlyOnce
(
true
);
auto
new_para_pattern
=
parameter
->
cast
<
NewParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
new_para_pattern
);
auto
pass_name
=
new_para_pattern
->
para_name
();
parameter
->
set_should_replace
(
false
);
auto
new_pass
=
std
::
make_shared
<
PythonPass
>
(
pass_name
,
nullptr
,
parameter
,
true
);
cur_pg
->
AddPass
(
new_pass
);
}
void
PyPassManager
::
ClearRes
()
{
MS_LOG
(
INFO
)
<<
"Clear PyPassManager resources!"
;
global_instance
=
nullptr
;
...
...
@@ -75,7 +95,9 @@ REGISTER_PYBIND_DEFINE(
(
void
)
py
::
class_
<
PyPassManager
,
std
::
shared_ptr
<
PyPassManager
>>
(
*
m
,
"PyPassManager_"
)
.
def
(
py
::
init
([]()
{
return
PyPassManager
::
GetInstance
();
}))
.
def
(
"registe"
,
&
PyPassManager
::
Registe
,
"Registe python pass"
)
.
def
(
"unregiste"
,
&
PyPassManager
::
Unregiste
,
"Delete Python Pass"
);
.
def
(
"unregiste"
,
&
PyPassManager
::
Unregiste
,
"Delete Python Pass"
)
.
def
(
"gen_new_parameter"
,
&
PyPassManager
::
GenNewParameter
,
"Generate new parameter"
)
.
def
(
"set_renorm"
,
&
PyPassManager
::
SetRenorm
,
"Set whether or not to do renorm after modified graph"
);
}));
}
// namespace python_pass
}
// namespace opt
...
...
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
浏览文件 @
8d693306
...
...
@@ -27,7 +27,7 @@
#include "ir/graph_utils.h"
#include "utils/ms_utils.h"
#include "pipeline/jit/
parse/resolv
e.h"
#include "pipeline/jit/
resourc
e.h"
#include "frontend/optimizer/pattern.h"
#include "frontend/optimizer/py_pass.h"
#include "frontend/optimizer/pass_group.h"
...
...
@@ -53,12 +53,21 @@ class PyPassManager {
static
PyPassManagerPtr
GetInstance
();
virtual
~
PyPassManager
()
=
default
;
void
Registe
(
const
std
::
string
&
pass_name
,
const
PatternPtr
&
pattern
,
const
PatternPtr
&
target
,
Phase
phase
=
Phase
::
RESOLVE
,
bool
run_only_once
=
false
,
bool
multigraph
=
true
);
Phase
phase
=
Phase
::
RESOLVE
,
bool
run_only_once
=
false
);
void
Unregiste
(
const
std
::
string
&
pass_name
,
Phase
phase
);
void
GenNewParameter
(
const
PatternPtr
&
parameter
);
PassGroupPtr
GetPassGroup
(
Phase
phase
);
void
ClearRes
();
MatchResultPtr
GetMatchResult
()
{
return
res_
;
}
void
SetRenorm
(
bool
should_renorm
)
{
should_renorm_
=
should_renorm
;
}
bool
ShouldRenorm
()
{
return
should_renorm_
;
}
void
SetResource
(
pipeline
::
ResourcePtr
resource
)
{
resource_
=
resource
;
}
pipeline
::
ResourcePtr
GetResource
()
{
return
resource_
;
}
private:
bool
should_renorm_
=
true
;
MatchResultPtr
res_
;
pipeline
::
ResourcePtr
resource_
;
static
std
::
unordered_map
<
Phase
,
PassGroupPtr
>
phase_to_group_
;
};
}
// namespace python_pass
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
8d693306
...
...
@@ -448,8 +448,21 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
MS_EXCEPTION_IF_NULL
(
res
->
manager
());
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
auto
ppm
=
opt
::
python_pass
::
PyPassManager
::
GetInstance
();
ppm
->
SetResource
(
res
);
if
(
!
ppm
->
GetPassGroup
(
phase
)
->
Run
(
res
->
func_graph
()))
{
MS_LOG
(
DEBUG
)
<<
"No match.
\n
"
;
}
else
if
(
phase
==
opt
::
python_pass
::
Phase
::
OPT
&&
opt
::
python_pass
::
PyPassManager
::
GetInstance
()
->
ShouldRenorm
())
{
MS_LOG
(
DEBUG
)
<<
"Entered PyStub Renorm"
;
// Renomalize
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
FuncGraphPtr
func_graph
=
res
->
func_graph
();
abstract
::
AbstractBasePtrList
args_spec
;
auto
parameters
=
func_graph
->
parameters
();
(
void
)
std
::
transform
(
parameters
.
begin
(),
parameters
.
end
(),
std
::
back_inserter
(
args_spec
),
[](
const
AnfNodePtr
&
p
)
->
AbstractBasePtr
{
return
p
->
abstract
();
});
FuncGraphPtr
new_fg
=
Renormalize
(
res
,
func_graph
,
args_spec
);
res
->
set_func_graph
(
new_fg
);
res
->
set_args_spec
(
args_spec
);
}
}
...
...
@@ -477,6 +490,7 @@ static std::vector<ActionItem> CommonPipeline() {
}
// Add resolve-stage python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_resolve"
,
ResolveActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"inference_opt_prepare"
,
InferenceOptPrepareAction
));
// Evaluate type and shape, and specialize
actions
.
emplace_back
(
std
::
make_pair
(
"abstract_specialize"
,
AbstractSpecializeAction
));
...
...
mindspore/graph_utils/__init__.py
0 → 100644
浏览文件 @
8d693306
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""
mindspore/
common
/graph_pattern.py
→
mindspore/
graph_utils
/graph_pattern.py
浏览文件 @
8d693306
...
...
@@ -15,7 +15,8 @@
"""Patterns for describing graphs"""
from
mindspore.ops
import
Primitive
from
mindspore.common.tensor
import
Tensor
from
mindspore._c_expression
import
Pattern
,
IsIn_
,
IsPrimTypeOf_
,
CallWith_
,
IsNot_
,
AnyPattern
,
NewTensor_
from
mindspore._c_expression
import
Pattern
,
IsIn_
,
IsPrimTypeOf_
,
CallWith_
,
IsNot_
,
AnyPattern
,
NewTensor_
,
\
NewParameter_
,
Imm
__all__
=
[
"IsIn"
,
...
...
@@ -24,17 +25,25 @@ __all__ = [
"IsNot"
,
"AnyPattern"
,
"NewTensor"
,
"NewParameter"
,
"Imm"
]
class
IsIn
(
IsIn_
):
"""
r
"""
Express a pattern which allows a list of patterns.
"""
def
__init__
(
self
,
patterns
=
None
,
should_replace
=
True
):
r
"""
Args:
patterns(list/tuple): list of allowed patterns
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False
TypeError: raise type error for invalid inputs.
"""
if
not
should_replace
:
raise
ValueError
(
"IsIn pattern does not have its own should_replace attribute. Set should_replace in
\
...
...
@@ -52,19 +61,28 @@ class IsIn(IsIn_):
class
IsPrimTypeOf
(
IsPrimTypeOf_
):
r
"""
Express a pattern of certain primitive type(s).
NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
NOTE:
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
"""
def
__init__
(
self
,
types
,
name
=
None
,
should_replace
=
True
):
r
"""
Args:
types (str/(list/tuple of Primitives)): Specify allowed types.
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
tuple[:class:`mindspore.ops.Primitive`]):
Specify allowed types.
If it is a string, the form could be
1) a single primitive type, e.g. 'Conv2D'
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional
should_replace
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
"""
if
name
is
not
None
and
not
isinstance
(
name
,
str
):
raise
TypeError
(
f
"Expect string, got :
{
name
}
"
)
...
...
@@ -91,12 +109,21 @@ class CallWith(CallWith_):
r
"""
Express a primitive CNode.
"""
def
__init__
(
self
,
prim_pattern
,
inputs
=
None
,
should_replace
=
Fals
e
):
def
__init__
(
self
,
prim_pattern
,
inputs
=
None
,
should_replace
=
Tru
e
):
r
"""
Args:
prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode.
inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs;
if specified, input patterns should be of right order.
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
:class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
"""
if
not
isinstance
(
prim_pattern
,
(
Pattern
,
str
,
Primitive
)):
raise
TypeError
(
f
"Expect prim_pattern to be Pattern, Primitive or string, got :
{
prim_pattern
}
"
)
...
...
@@ -110,17 +137,23 @@ class CallWith(CallWith_):
raise
TypeError
(
f
"Expect inputs to be a list of Patterns, got :
{
inputs
}
"
)
CallWith_
.
__init__
(
self
,
self
.
prim_pattern
,
self
.
inputs
,
should_replace
)
class
IsNot
(
IsNot_
):
r
"""
Express a pattern which forbids a list of patterns.
NOTE: IsNot pattern should not be the root pattern.
NOTE:
IsNot pattern should not be the root pattern.
"""
def
__init__
(
self
,
patterns
=
None
,
should_replace
=
True
):
r
"""
Args:
patterns(list/tuple): list of forbiden patterns
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False.
TypeError: raise type error for invalid argument.
"""
if
not
should_replace
:
raise
ValueError
(
"IsNot pattern does not have its own should_replace attribute. Set should_replace in
\
...
...
@@ -142,13 +175,48 @@ class NewTensor(NewTensor_):
def
__init__
(
self
,
input_tensor
,
should_replace
=
False
):
r
"""
Args:
input_tensor(
Tensor
): new tensor to be used in the target
input_tensor(
:class:`mindspore.common.tensor.Tensor`
): new tensor to be used in the target
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
Raises:
ValueError: raise if should_replace is True
TypeError: raise type error for invalid argument.
"""
if
should_replace
:
raise
ValueError
(
"NewTensor should only appear in the target, thus should_replace can only
u
be False."
)
raise
ValueError
(
"NewTensor should only appear in the target, thus should_replace can only be False."
)
self
.
input_tensor
=
input_tensor
if
isinstance
(
input_tensor
,
Tensor
):
NewTensor_
.
__init__
(
self
,
input_tensor
)
else
:
raise
TypeError
(
f
"Expect input_tensor to be a Tensor, got :
{
input_tensor
}
"
)
class
NewParameter
(
NewParameter_
):
r
"""
New Parameter to be used in the target.
"""
def
__init__
(
self
,
para_name
,
default_tensor
,
requires_grad
=
False
,
layerwise_parallel
=
False
,
should_replace
=
False
):
r
"""
Args:
para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
parameter everytime a pass target got built. Default: False
Raises:
TypeError: raise type error for invalid argument.
"""
self
.
para_name
=
para_name
self
.
default_tensor
=
default_tensor
self
.
requires_grad
=
requires_grad
self
.
layerwise_parallel
=
layerwise_parallel
self
.
should_replace
=
should_replace
if
isinstance
(
para_name
,
str
)
and
isinstance
(
default_tensor
,
Tensor
)
and
isinstance
(
requires_grad
,
bool
)
and
\
isinstance
(
layerwise_parallel
,
bool
)
and
isinstance
(
should_replace
,
bool
):
NewParameter_
.
__init__
(
self
,
self
.
para_name
,
self
.
default_tensor
,
self
.
requires_grad
,
self
.
layerwise_parallel
,
self
.
should_replace
)
else
:
raise
TypeError
(
f
"Expect para_name(str), default_tensor(Tensor), requires_grad(bool),
\
layerwise_parallel(bool) should_replace(bool), got :
{
para_name
}
,
{
default_tensor
}
,
\
{
requires_grad
}
,
{
layerwise_parallel
}
,
{
should_replace
}
"
)
mindspore/graph_utils/python_pass/__init__.py
0 → 100644
浏览文件 @
8d693306
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""
from
.python_pass_register
import
registe_pass
,
unregiste_pass
,
gen_new_parameter
,
cancel_new_parameter
,
set_renorm
__all__
=
[
"registe_pass"
,
"unregiste_pass"
,
"gen_new_parameter"
,
"cancel_new_parameter"
,
"set_renorm"
]
mindspore/
common
/python_pass_register.py
→
mindspore/
graph_utils/python_pass
/python_pass_register.py
浏览文件 @
8d693306
...
...
@@ -14,10 +14,17 @@
# ============================================================================
"""Python pass register"""
from
inspect
import
isfunction
from
mindspore.common.graph_pattern
import
Pattern
from
mindspore._c_expression
import
PyPassManager_
from
mindspore._c_expression
import
phase
from
mindspore.graph_utils.graph_pattern
import
Pattern
,
NewParameter
from
mindspore._c_expression
import
PyPassManager_
,
phase
__all__
=
[
"registe_pass"
,
"unregiste_pass"
,
"gen_new_parameter"
,
"cancel_new_parameter"
,
"set_renorm"
]
class
PyPassManager
(
PyPassManager_
):
r
"""
Used to registe and unregiste python passes which can be used to alter graphs.
...
...
@@ -30,52 +37,134 @@ class PyPassManager(PyPassManager_):
Raises:
TypeError: If argument has invalid type.
"""
def
__init__
(
self
,
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
,
multi_graph
=
True
):
def
__init__
(
self
,
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
):
if
not
isinstance
(
pipeline_phase
,
phase
):
raise
TypeError
(
f
"Expect
ing
phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
raise
TypeError
(
f
"Expect phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
if
not
isinstance
(
run_only_once
,
bool
):
raise
TypeError
(
f
"Expecting bool, got : (
{
type
(
run_only_once
)
}
)
{
run_only_once
}
"
)
if
not
isinstance
(
multi_graph
,
bool
):
raise
TypeError
(
f
"Expecting bool, got : (
{
type
(
multi_graph
)
}
)
{
multi_graph
}
"
)
raise
TypeError
(
f
"Expect bool, got : (
{
type
(
run_only_once
)
}
)
{
run_only_once
}
"
)
PyPassManager_
.
__init__
(
self
)
self
.
phase_
=
pipeline_phase
self
.
run_only_once_
=
run_only_once
self
.
multi_graph_
=
multi_graph
def
registe
(
self
,
py_pass
):
if
not
isfunction
(
py_pass
):
raise
TypeError
(
f
"Expect
ing
function pass, got : (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
raise
TypeError
(
f
"Expect function pass, got : (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
pattern
,
target
=
py_pass
()
pass_name
=
py_pass
.
__name__
if
not
isinstance
(
pattern
,
Pattern
):
raise
TypeError
(
f
"Expect
ing
pattern of Pattern type, got : (
{
type
(
pattern
)
}
)
{
pattern
}
"
)
raise
TypeError
(
f
"Expect pattern of Pattern type, got : (
{
type
(
pattern
)
}
)
{
pattern
}
"
)
if
not
isinstance
(
target
,
Pattern
):
raise
TypeError
(
f
"Expect
ing
target of Pattern type, got : (
{
type
(
target
)
}
)
{
target
}
"
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
phase_
,
self
.
run_only_once_
,
self
.
multi_graph_
)
raise
TypeError
(
f
"Expect target of Pattern type, got : (
{
type
(
target
)
}
)
{
target
}
"
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
phase_
,
self
.
run_only_once_
)
def
unregiste
(
self
,
py_pass
,
pipeline_phase
=
phase
.
opt
):
if
not
isinstance
(
pipeline_phase
,
phase
):
raise
TypeError
(
f
"Expect
ing
phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
raise
TypeError
(
f
"Expect phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
if
isinstance
(
py_pass
,
str
):
super
().
unregiste
(
py_pass
,
pipeline_phase
)
return
if
isfunction
(
py_pass
):
super
().
unregiste
(
py_pass
.
__name__
,
pipeline_phase
)
return
raise
TypeError
(
f
"Expect
ing
py_pass to be string or function, got (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
raise
TypeError
(
f
"Expect py_pass to be string or function, got (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
def
__call__
(
self
,
py_pass
):
self
.
registe
(
py_pass
)
return
py_pass
def
registe_pass
(
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
,
multi_graph
=
True
):
def
gen_new_parameter
(
self
,
pattern
):
if
not
isinstance
(
pattern
,
NewParameter
):
raise
TypeError
(
f
"Expect pattern to be a NewParameter Pattern, got
{
pattern
}
"
)
super
().
gen_new_parameter
(
pattern
)
def
set_renorm
(
self
,
should_renorm
):
if
not
isinstance
(
should_renorm
,
bool
):
raise
TypeError
(
f
"Expect should_renorm to be a bool, got
{
should_renorm
}
"
)
super
().
set_renorm
(
should_renorm
)
def
registe_pass
(
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
registed. Support phase.resolve and phase.opt. Default: phase.opt.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
This function should be used as a decorator, return the decoratorated pass function.
Examples:
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
>>> @registe_pass()
>>> def toy_pass():
>>> pattern = IsPrimTypeOf("ReLU")
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return
PyPassManager
(
pipeline_phase
,
run_only_once
)
def
unregiste_pass
(
py_pass
,
pipeline_phase
=
phase
.
opt
):
"""
Unregiste python pass.
Args:
py_pass(Union(str, function)): target python pass to unregiste.
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
"""
ppm
=
PyPassManager
()
ppm
.
unregiste
(
py_pass
,
pipeline_phase
)
def
gen_new_parameter
(
pattern
):
"""
Generate specified parameter every time a network gets compiled.
NOTE:
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
gen_new_parameter, every pass match would build a new Parameter.
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
cancel_new_parameter(pattern)
Args:
pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes
after gen_new_parameter.
Raises:
TypeError: If argument has invalid type.
Examples:
>>> from mindspore.graph_utils.graph_pattern import NewParameter
>>> abc = NewParameter("abc")
>>> gen_new_parameter(abc)
"""
ppm
=
PyPassManager
()
ppm
.
gen_new_parameter
(
pattern
)
def
cancel_new_parameter
(
pattern
):
"""
Use with gen_new_parameter to unregiste gen_new_parameter pass.
Args:
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
describes.
Examples:
>>> from mindspore.graph_utils.graph_pattern import NewParameter
>>> abc = NewParameter("abc")
>>> gen_new_parameter(abs)
>>> # some compilations
>>> cancel_new_parameter(abc)
"""
if
not
isinstance
(
pattern
,
NewParameter
):
raise
TypeError
(
f
"Expect pattern to be a NewParameter Pattern, got
{
pattern
}
"
)
ppm
=
PyPassManager
()
ppm
.
unregiste
(
pattern
.
para_name
)
def
set_renorm
(
should_renorm
):
"""
Examples:
>>> @registe_pass()
>>> def toy_pass():
>>> def pattern():
>>> pass
>>> def target():
>>> pass
Set whether or not to do renorm after modified graph in python pass(es).
"""
return
PyPassManager
(
pipeline_phase
,
run_only_once
,
multi_graph
)
ppm
=
PyPassManager
()
ppm
.
set_renorm
(
should_renorm
)
mindspore/ops/primitive.py
浏览文件 @
8d693306
...
...
@@ -152,7 +152,7 @@ class Primitive(Primitive_):
Check if certain inputs should go to the backend. Subclass in need should override this method.
Args:
*
args(Primitive args): Same as arguments of current Primitive.
args(Primitive args): Same as arguments of current Primitive.
Returns:
A tuple consisting of two elements. The first element indicates whether we should filter out current
...
...
tests/ut/python/optimizer/test_python_pass.py
浏览文件 @
8d693306
...
...
@@ -19,10 +19,12 @@ import mindspore.nn as nn
from
mindspore
import
context
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.common.python_pass_register
import
registe_pass
,
PyPassManager
from
mindspore.graph_utils.python_pass
import
registe_pass
,
unregiste_pass
,
set_renorm
,
gen_new_parameter
,
\
cancel_new_parameter
from
mindspore.common.api
import
_generate_pip_args
from
mindspore._c_expression
import
generate_key
,
Executor_
from
mindspore.common.graph_pattern
import
IsIn
,
IsPrimTypeOf
,
CallWith
,
IsNot
,
AnyPattern
,
NewTensor
from
mindspore.graph_utils.graph_pattern
import
IsIn
,
IsPrimTypeOf
,
CallWith
,
IsNot
,
AnyPattern
,
NewTensor
,
\
NewParameter
,
Imm
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -56,12 +58,39 @@ def test_softmax_relu():
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
2
)
ppm
=
PyPassManager
()
ppm
.
unregiste
(
softmax_relu_pass
)
unregiste_pass
(
softmax_relu_pass
)
assert
"ReLU"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
def
test_isin_pattern
():
def
test_softmax_relu_sigmoid
():
"""
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
NOTE:
Sigmoid pattern only exists in the target.
"""
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
def
softmax_relu_pass
():
x
=
AnyPattern
()
softmax_pattern
=
IsPrimTypeOf
(
P
.
Softmax
())
pattern
=
CallWith
(
softmax_pattern
,
inputs
=
[
x
])
sigmoid_pattern
=
IsPrimTypeOf
(
P
.
Sigmoid
(),
should_replace
=
False
)
call_sigmoid
=
CallWith
(
sigmoid_pattern
,
[
x
])
relu_pattern
=
IsPrimTypeOf
(
P
.
ReLU
(),
should_replace
=
False
)
target
=
CallWith
(
relu_pattern
,
inputs
=
[
call_sigmoid
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
3
)
unregiste_pass
(
softmax_relu_pass
)
assert
"ReLU"
in
transformed_repr
assert
"Sigmoid"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
def
test_isin_pattern_0
():
"""
Test IsIn pattern which expresses the IsIn/OneOf semantics.
"""
...
...
@@ -81,16 +110,41 @@ def test_isin_pattern():
target
=
CallWith
(
relu6_pattern
,
inputs
=
[
x
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
2
)
ppm
=
PyPassManager
()
ppm
.
unregiste
(
softmax_relu_pass
)
unregiste_pass
(
softmax_relu_pass
)
assert
"ReLU6"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
def
test_isin_pattern_1
():
"""
Test IsIn. IsIn is used as nested inputs for the target in this case.
"""
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
def
softmax_neg_pass
():
x
=
AnyPattern
()
softmax_pattern
=
IsPrimTypeOf
(
P
.
Softmax
())
call_softmax
=
CallWith
(
softmax_pattern
,
inputs
=
[
x
])
relu_pattern
=
IsPrimTypeOf
(
P
.
ReLU
())
call_relu
=
CallWith
(
relu_pattern
,
inputs
=
[
x
])
pattern
=
IsIn
([
call_softmax
,
call_relu
])
neg_ops
=
IsPrimTypeOf
(
P
.
Neg
(),
should_replace
=
False
)
target
=
CallWith
(
neg_ops
,
inputs
=
[
pattern
])
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
4
)
print
(
transformed_repr
)
unregiste_pass
(
softmax_neg_pass
)
assert
"Neg"
in
transformed_repr
assert
"Softmax"
in
transformed_repr
def
test_isnot_pattern_0
():
"""
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm
(
False
)
class
ConvBN
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ConvBN
,
self
).
__init__
()
...
...
@@ -132,11 +186,11 @@ def test_isnot_pattern_0():
return
pattern
,
target
transformed_repr
=
get_func_graph
(
conv_bn_model
,
inputs
).
get_return
().
expanded_str
(
5
)
ppm
=
PyPassManager
()
ppm
.
unregiste
(
single_bn_pass
)
ppm
.
unregiste
(
bn_pass
)
unregiste_pass
(
single_bn_pass
)
unregiste_pass
(
bn_pass
)
assert
"ReLU6"
not
in
transformed_repr
assert
"Softmax"
in
transformed_repr
set_renorm
(
True
)
def
test_isnot_pattern_1
():
"""
...
...
@@ -160,12 +214,15 @@ def test_isnot_pattern_1():
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
ppm
=
PyPassManager
()
ppm
.
unregiste
(
single_bn_pass
)
unregiste_pass
(
single_bn_pass
)
assert
"ReLU6"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
def
test_newtensor_pattern
():
"""
Test NewTensor pattern in the target
"""
set_renorm
(
False
)
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
...
...
@@ -181,7 +238,84 @@ def test_newtensor_pattern():
target
=
CallWith
(
addn_ops
,
inputs
=
[
x
,
new_weight
],
should_replace
=
False
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
2
)
ppm
=
PyPassManager
()
ppm
.
unregiste
(
softmax_addn_pass
)
unregiste_pass
(
softmax_addn_pass
)
assert
"AddN"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
set_renorm
(
True
)
def
test_newparameter_pattern
():
"""
Test NewParameter pattern in the target
"""
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
def
softmax_addn_pass
():
x
=
AnyPattern
()
softmax
=
P
.
Softmax
()
pattern
=
CallWith
(
softmax
,
inputs
=
[
x
])
default_tensor0
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
default_tensor1
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
new_para_0
=
NewParameter
(
"Merlin"
,
default_tensor0
)
new_para_1
=
NewParameter
(
"Arthur"
,
default_tensor1
)
target_0
=
CallWith
(
P
.
MatMul
(),
inputs
=
[
new_para_0
,
new_para_1
],
should_replace
=
False
)
target
=
CallWith
(
"make_tuple"
,
inputs
=
[
target_0
],
should_replace
=
False
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
unregiste_pass
(
softmax_addn_pass
)
assert
"MatMul"
in
transformed_repr
assert
"make_tuple"
in
transformed_repr
assert
"Softmax"
not
in
transformed_repr
def
test_imm_pattern
():
"""
Test NewParameter pattern in the target
"""
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
@
registe_pass
(
run_only_once
=
True
)
def
softmax_addn_pass
():
x
=
AnyPattern
()
softmax
=
P
.
Softmax
()
pattern
=
CallWith
(
softmax
,
inputs
=
[
x
])
imm
=
Imm
(
0
)
target_0
=
CallWith
(
"make_tuple"
,
inputs
=
[
pattern
],
should_replace
=
False
)
target
=
CallWith
(
"tuple_getitem"
,
inputs
=
[
target_0
,
imm
],
should_replace
=
False
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
unregiste_pass
(
softmax_addn_pass
)
assert
"make_tuple"
in
transformed_repr
assert
"tuple_getitem"
in
transformed_repr
assert
"Softmax"
in
transformed_repr
def
test_gen_new_parameter
():
"""
Test gen_new_parameter
"""
inputs
=
Tensor
(
np
.
ones
([
42
]),
mindspore
.
float16
)
softmax_model
=
nn
.
Softmax
()
default_tensor
=
Tensor
(
np
.
ones
((
4
,
4
)),
mindspore
.
float32
)
new_para
=
NewParameter
(
"Merlin"
,
default_tensor
,
should_replace
=
True
)
gen_new_parameter
(
new_para
)
@
registe_pass
(
run_only_once
=
True
)
def
softmax_make_tuple_pass
():
x
=
AnyPattern
()
softmax
=
P
.
Softmax
()
pattern
=
CallWith
(
softmax
,
inputs
=
[
x
])
target
=
CallWith
(
"make_tuple"
,
inputs
=
[
pattern
,
new_para
],
should_replace
=
False
)
return
pattern
,
target
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
assert
"Merlin"
in
transformed_repr
unregiste_pass
(
softmax_make_tuple_pass
)
cancel_new_parameter
(
new_para
)
transformed_repr
=
get_func_graph
(
softmax_model
,
inputs
).
get_return
().
expanded_str
(
5
)
print
(
transformed_repr
)
assert
"Merlin"
not
in
transformed_repr
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录