Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b26f6b6b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b26f6b6b
编写于
6月 22, 2020
作者:
B
BowenK
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add python pass support
上级
1ea38eb6
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
683 addition
and
5 deletion
+683
-5
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+0
-4
mindspore/ccsrc/optimizer/pass_group.cc
mindspore/ccsrc/optimizer/pass_group.cc
+69
-0
mindspore/ccsrc/optimizer/pass_group.h
mindspore/ccsrc/optimizer/pass_group.h
+61
-0
mindspore/ccsrc/optimizer/py_pass.cc
mindspore/ccsrc/optimizer/py_pass.cc
+236
-0
mindspore/ccsrc/optimizer/py_pass.h
mindspore/ccsrc/optimizer/py_pass.h
+56
-0
mindspore/ccsrc/optimizer/py_pass_manager.cc
mindspore/ccsrc/optimizer/py_pass_manager.cc
+84
-0
mindspore/ccsrc/optimizer/py_pass_manager.h
mindspore/ccsrc/optimizer/py_pass_manager.h
+66
-0
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+27
-0
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+2
-0
mindspore/common/python_pass_register.py
mindspore/common/python_pass_register.py
+80
-0
mindspore/nn/layer/basic.py
mindspore/nn/layer/basic.py
+2
-1
未找到文件。
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
b26f6b6b
...
...
@@ -346,10 +346,6 @@ class TensorAddByZero : public AnfVisitor {
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
IsPrimitive
(
node
,
prim
::
kPrimZerosLike
))
{
is_zero_
=
true
;
return
;
}
if
(
node
->
isa
<
ValueNode
>
()
&&
CheckTensorConstant
(
0
).
IsTensorScalarConstant
(
GetValueNode
(
node
)))
{
is_zero_
=
true
;
return
;
...
...
mindspore/ccsrc/optimizer/pass_group.cc
0 → 100644
浏览文件 @
b26f6b6b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/pass_group.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
void
PassGroup
::
AddPass
(
const
PythonPassPtr
&
pass
)
{
if
(
pass
!=
nullptr
)
{
passes_
.
push_back
(
pass
);
}
}
bool
PassGroup
::
DeletePass
(
const
std
::
string
&
pass_name
)
{
for
(
auto
iter
=
passes_
.
begin
();
iter
!=
passes_
.
end
();
iter
++
)
{
if
((
*
iter
)
->
name
()
==
pass_name
)
{
*
iter
=
nullptr
;
passes_
.
erase
(
iter
);
return
true
;
}
}
return
false
;
}
bool
PassGroup
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PythonPassPtr
>
&
passes
)
const
{
if
(
func_graph
==
nullptr
)
{
return
false
;
}
bool
changed
=
false
;
for
(
const
auto
&
pass
:
passes
)
{
if
(
pass
!=
nullptr
)
{
if
(
pass
->
Run
(
func_graph
))
{
changed
=
true
;
}
}
}
return
changed
;
}
bool
PassGroup
::
Run
(
const
FuncGraphPtr
&
func_graph
)
const
{
bool
changed
=
false
;
// run all passes
bool
change
=
true
;
while
(
change
)
{
change
=
Run
(
func_graph
,
passes_
);
changed
=
change
||
changed
;
if
(
run_only_once_
)
{
break
;
}
}
return
changed
;
}
}
// namespace python_pass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/optimizer/pass_group.h
0 → 100644
浏览文件 @
b26f6b6b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include "optimizer/py_pass.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
class
PassGroup
{
public:
explicit
PassGroup
(
const
std
::
string
&
name
=
"pass_group"
,
bool
run_only_once
=
false
)
:
name_
(
name
),
passes_
{},
run_only_once_
(
run_only_once
)
{}
virtual
~
PassGroup
()
=
default
;
// Add graph pass, the pass object will be freed when pass manager freed.
void
AddPass
(
const
PythonPassPtr
&
pass
);
// Delete graph pass before the pass manager is freed.
bool
DeletePass
(
const
std
::
string
&
pass_name
);
// Run passes added in pass manager on the input graph
// @param [inout] graph The graph to be optimized
// @return true, graph changed
// @return false, graph not changed
bool
Run
(
const
FuncGraphPtr
&
func_graph
)
const
;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [in] passes The given graph passes
// @return true, graph changed
// @return false, graph not changed
bool
Run
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
PythonPassPtr
>
&
passes
)
const
;
std
::
string
name
()
const
{
return
name_
;
}
private:
const
std
::
string
name_
;
std
::
vector
<
PythonPassPtr
>
passes_
;
bool
run_only_once_
;
};
using
PassGroupPtr
=
std
::
shared_ptr
<
PassGroup
>
;
}
// namespace python_pass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
mindspore/ccsrc/optimizer/py_pass.cc
0 → 100644
浏览文件 @
b26f6b6b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/py_pass.h"
#include <unordered_set>
#include <deque>
#include <algorithm>
#include <utility>
#include <vector>
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/resource.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
namespace
internal
{
std
::
string
GetNodeRepr
(
AnfNodePtr
node
)
{
if
(
node
!=
nullptr
)
{
if
(
node
->
isa
<
CNode
>
())
{
std
::
string
repr
=
"("
;
auto
const
&
inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
for
(
auto
&
input
:
inputs
)
{
repr
+=
" "
;
repr
+=
GetNodeRepr
(
input
);
repr
+=
" "
;
}
repr
+=
")"
;
return
repr
;
}
if
(
node
->
isa
<
ValueNode
>
())
{
return
GetValueNode
(
node
)
->
ToString
();
}
return
node
->
ToString
();
}
return
""
;
}
void
ResolveFuncGraph_
(
const
FuncGraphPtr
&
fg
)
{
auto
manager
=
Manage
(
fg
,
false
);
parse
::
python_adapter
::
set_use_signature_in_resolve
(
false
);
parse
::
ResolveAll
(
manager
);
}
bool
Match
(
const
AnfNodePtr
&
pattern
,
const
AnfNodePtr
&
node
,
const
NodeEquivPtr
&
equiv_ptr
)
{
if
(
node
==
nullptr
)
{
return
false
;
}
MS_EXCEPTION_IF_NULL
(
pattern
);
if
(
pattern
->
isa
<
ValueNode
>
())
{
if
(
!
node
->
isa
<
ValueNode
>
())
{
return
false
;
}
if
(
GetNodeRepr
(
pattern
)
==
GetNodeRepr
(
node
))
{
// add to equiv_ptr
equiv_ptr
->
insert
(
std
::
make_pair
(
GetValueNode
(
pattern
)
->
ToString
(),
node
));
return
true
;
}
return
false
;
}
else
if
(
pattern
->
isa
<
Parameter
>
())
{
MS_LOG
(
DEBUG
)
<<
pattern
->
ToString
()
+
"
\n
"
;
// add to equiv_ptr
equiv_ptr
->
insert
(
std
::
make_pair
(
pattern
->
ToString
(),
node
));
return
true
;
}
else
if
(
pattern
->
isa
<
CNode
>
())
{
// match every single sub ANode
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
auto
pattern_inputs
=
pattern
->
cast
<
CNodePtr
>
()
->
inputs
();
auto
node_inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
pattern_inputs
.
size
()
!=
node_inputs
.
size
())
{
return
false
;
}
for
(
auto
p_item
=
pattern_inputs
.
begin
(),
node_item
=
node_inputs
.
begin
();
p_item
!=
pattern_inputs
.
end
();
p_item
++
,
node_item
++
)
{
auto
res
=
Match
(
*
p_item
,
*
node_item
,
equiv_ptr
);
if
(
!
res
)
{
return
false
;
}
}
return
true
;
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected condition, ("
+
pattern
->
ToString
()
+
" , "
+
node
->
ToString
()
+
")
\n
"
;
}
AnfNodePtr
BuildTarget
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
cur_raw_dst_node_
,
const
NodeEquivPtr
&
equiv_ptr
)
{
if
(
cur_raw_dst_node_
->
isa
<
Parameter
>
())
{
auto
sub_pair
=
equiv_ptr
->
find
(
cur_raw_dst_node_
->
ToString
());
if
(
sub_pair
!=
equiv_ptr
->
end
())
{
return
sub_pair
->
second
;
}
MS_LOG
(
EXCEPTION
)
<<
"cur_raw_dst_node_ : "
+
internal
::
GetNodeRepr
(
cur_raw_dst_node_
)
+
"
\n
"
;
}
else
if
(
cur_raw_dst_node_
->
isa
<
ValueNode
>
())
{
// check primitive ValueNode
auto
sub_pair
=
equiv_ptr
->
find
(
cur_raw_dst_node_
->
cast
<
ValueNodePtr
>
()
->
value
()
->
ToString
());
if
(
sub_pair
!=
equiv_ptr
->
end
())
{
return
sub_pair
->
second
;
}
return
cur_raw_dst_node_
;
}
else
if
(
cur_raw_dst_node_
->
isa
<
CNode
>
())
{
std
::
vector
<
AnfNodePtr
>
new_inputs
;
auto
inputs
=
cur_raw_dst_node_
->
cast
<
CNodePtr
>
()
->
inputs
();
for
(
auto
sub_node
=
inputs
.
begin
();
sub_node
!=
inputs
.
end
();
sub_node
++
)
{
auto
subed
=
internal
::
BuildTarget
(
func_graph
,
*
sub_node
,
equiv_ptr
);
new_inputs
.
push_back
(
subed
);
}
return
func_graph
->
NewCNode
(
new_inputs
);
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected node type, got : "
+
internal
::
GetNodeRepr
(
cur_raw_dst_node_
);
}
bool
isTraversable
(
const
AnfNodePtr
&
node
)
{
if
(
node
==
nullptr
)
{
return
false
;
}
if
(
node
->
isa
<
CNode
>
()
||
node
->
isa
<
Parameter
>
())
{
return
true
;
}
if
(
IsValueNode
<
FuncGraph
>
(
node
)
||
IsValueNode
<
RefKey
>
(
node
))
{
return
true
;
}
return
false
;
}
}
// namespace internal
void
PythonPass
::
Build
(
const
py
::
function
&
src
,
const
py
::
function
&
dst
)
{
// 1. get FuncGraph from py::function
auto
src_fg_
=
parse
::
ParsePythonCode
(
src
);
auto
dst_fg_
=
parse
::
ParsePythonCode
(
dst
);
if
(
src_fg_
==
nullptr
||
dst_fg_
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to parse python code.
\n
"
;
}
// 2. Resolve
internal
::
ResolveFuncGraph_
(
src_fg_
);
internal
::
ResolveFuncGraph_
(
dst_fg_
);
// 3. from FuncGraphPtr to ValueNode
src_node_
=
src_fg_
->
output
();
dst_node_
=
dst_fg_
->
output
();
}
PythonPass
::
PythonPass
(
const
std
::
string
&
name
,
const
py
::
function
&
src
,
const
py
::
function
&
dst
,
bool
run_only_once
,
bool
multigraph
)
:
name_
(
name
),
run_only_once_
(
run_only_once
),
multigraph_
(
multigraph
)
{
Build
(
src
,
dst
);
}
AnfNodePtr
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
auto
equiv_ptr
=
std
::
make_shared
<
NodeEquiv
>
();
bool
is_a_match
=
internal
::
Match
(
src_node_
,
node
,
equiv_ptr
);
if
(
is_a_match
)
{
auto
new_node
=
internal
::
BuildTarget
(
func_graph
,
dst_node_
,
equiv_ptr
);
MS_LOG
(
DEBUG
)
<<
"To be replaced node: "
+
internal
::
GetNodeRepr
(
new_node
)
+
"
\n
"
;
return
new_node
;
}
return
nullptr
;
}
bool
PythonPass
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
func_graph
);
auto
seen
=
NewSeenGeneration
();
// 1024 is for the initial capacity of deque
std
::
deque
<
AnfNodePtr
>
todo
(
1024
);
todo
.
push_back
(
func_graph
->
output
());
bool
changes
=
false
;
auto
&
all_nodes
=
manager
->
all_nodes
();
while
(
!
todo
.
empty
())
{
AnfNodePtr
node
=
todo
.
front
();
todo
.
pop_front
();
// check whether this node has been matched.
if
(
node
==
nullptr
||
node
->
seen_
==
seen
||
!
internal
::
isTraversable
(
node
)
||
!
all_nodes
.
contains
(
node
))
{
continue
;
}
node
->
seen_
=
seen
;
// select nodes that this transform can be applied.
AnfNodePtr
new_node
=
Run
(
func_graph
,
node
);
bool
change
=
(
new_node
!=
nullptr
);
if
(
new_node
!=
nullptr
&&
new_node
!=
node
)
{
(
void
)
manager
->
Replace
(
node
,
new_node
);
}
else
if
(
new_node
==
nullptr
)
{
new_node
=
node
;
}
if
(
run_only_once_
)
{
return
change
;
}
// find success, and add them to todo list
if
(
IsValueNode
<
FuncGraph
>
(
node
))
{
todo
.
push_back
(
GetValueNode
<
FuncGraphPtr
>
(
node
)
->
output
());
}
if
(
node
->
isa
<
CNode
>
())
{
auto
&
inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
(
void
)
std
::
copy
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
todo
));
}
auto
&
node_users
=
manager
->
node_users
();
if
(
change
&&
node_users
.
find
(
node
)
!=
node_users
.
end
())
{
for
(
auto
&
use
:
node_users
[
node
])
{
auto
use_node
=
use
.
first
;
if
(
use_node
==
nullptr
)
{
continue
;
}
todo
.
push_back
(
use_node
);
if
(
use_node
->
seen_
==
seen
)
{
use_node
->
seen_
--
;
}
}
}
}
return
changes
;
}
}
// namespace python_pass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/optimizer/py_pass.h
0 → 100644
浏览文件 @
b26f6b6b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
#define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
#include <string>
#include <memory>
#include <unordered_map>
#include "ir/anf.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
class
PythonPass
;
using
PythonPassPtr
=
std
::
shared_ptr
<
PythonPass
>
;
using
NodeEquiv
=
std
::
unordered_map
<
std
::
string
,
AnfNodePtr
>
;
using
NodeEquivPtr
=
std
::
shared_ptr
<
NodeEquiv
>
;
class
PythonPass
{
public:
explicit
PythonPass
(
const
std
::
string
&
name
,
const
py
::
function
&
src
,
const
py
::
function
&
dst
,
bool
run_only_once
=
false
,
bool
multigraph
=
true
);
~
PythonPass
()
=
default
;
bool
Run
(
const
FuncGraphPtr
&
func_graph
);
std
::
string
name
()
const
{
return
name_
;
}
AnfNodePtr
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
);
private:
void
Build
(
const
py
::
function
&
src
,
const
py
::
function
&
dst
);
AnfNodePtr
src_node_
=
nullptr
;
AnfNodePtr
dst_node_
=
nullptr
;
const
std
::
string
name_
;
bool
run_only_once_
;
bool
multigraph_
=
true
;
};
using
PythonPassPtr
=
std
::
shared_ptr
<
PythonPass
>
;
}
// namespace python_pass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
mindspore/ccsrc/optimizer/py_pass_manager.cc
0 → 100644
浏览文件 @
b26f6b6b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/py_pass_manager.h"
#include <functional>
#include <algorithm>
#include <utility>
#include <initializer_list>
#include "ir/manager.h"
#include "optimizer/pass_group.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
PyPassManagerPtr
PyPassManager
::
global_instance
=
nullptr
;
std
::
unordered_map
<
Phase
,
PassGroupPtr
>
PyPassManager
::
phase_to_group_
;
PassGroupPtr
PyPassManager
::
GetPassGroup
(
Phase
phase
)
{
auto
pm
=
phase_to_group_
.
find
(
phase
);
if
(
pm
==
phase_to_group_
.
end
())
{
return
nullptr
;
}
return
pm
->
second
;
}
PyPassManagerPtr
PyPassManager
::
GetInstance
()
{
if
(
global_instance
==
nullptr
)
{
global_instance
=
std
::
shared_ptr
<
PyPassManager
>
(
new
(
std
::
nothrow
)
PyPassManager
());
}
return
global_instance
;
}
PyPassManager
::
PyPassManager
()
{
phase_to_group_
[
Phase
::
RESOLVE
]
=
std
::
make_shared
<
PassGroup
>
();
phase_to_group_
[
Phase
::
OPT
]
=
std
::
make_shared
<
PassGroup
>
();
}
void
PyPassManager
::
Registe
(
const
std
::
string
&
pass_name
,
const
py
::
function
&
pattern
,
const
py
::
function
&
target
,
Phase
phase
,
bool
run_only_once
,
bool
multigraph
)
{
auto
cur_pm
=
GetPassGroup
(
phase
);
MS_EXCEPTION_IF_NULL
(
cur_pm
);
PythonPassPtr
new_pass
=
std
::
make_shared
<
PythonPass
>
(
pass_name
,
pattern
,
target
,
run_only_once
,
multigraph
);
cur_pm
->
AddPass
(
new_pass
);
}
void
PyPassManager
::
Unregiste
(
const
std
::
string
&
pass_name
,
Phase
phase
)
{
auto
cur_pm
=
GetPassGroup
(
phase
);
MS_EXCEPTION_IF_NULL
(
cur_pm
);
if
(
!
cur_pm
->
DeletePass
(
pass_name
))
{
MS_LOG
(
WARNING
)
<<
"No such pass : "
+
pass_name
+
"
\n
"
;
}
}
void
PyPassManager
::
ClearRes
()
{
MS_LOG
(
INFO
)
<<
"Clear PyPassManager resources!"
;
global_instance
=
nullptr
;
phase_to_group_
.
clear
();
}
REGISTER_PYBIND_DEFINE
(
PyPassManager_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
Phase
>
(
*
m
,
"phase"
,
py
::
arithmetic
()).
value
(
"resolve"
,
Phase
::
RESOLVE
).
value
(
"opt"
,
Phase
::
OPT
);
(
void
)
py
::
class_
<
PyPassManager
,
std
::
shared_ptr
<
PyPassManager
>>
(
*
m
,
"PyPassManager_"
)
.
def
(
py
::
init
([]()
{
return
PyPassManager
::
GetInstance
();
}))
.
def
(
"registe"
,
&
PyPassManager
::
Registe
,
"Registe python pass"
)
.
def
(
"unregiste"
,
&
PyPassManager
::
Unregiste
,
"Delete Python Pass"
);
}));
}
// namespace python_pass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/optimizer/py_pass_manager.h
0 → 100644
浏览文件 @
b26f6b6b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "utils/graph_utils.h"
#include "common/utils.h"
#include "pipeline/parse/resolve.h"
#include "optimizer/py_pass.h"
#include "optimizer/pass_group.h"
namespace
mindspore
{
namespace
opt
{
namespace
python_pass
{
class
PyPassManager
;
using
PyPassManagerPtr
=
std
::
shared_ptr
<
PyPassManager
>
;
enum
Phase
{
RESOLVE
,
OPT
};
class
PyPassManager
{
protected:
PyPassManager
();
static
PyPassManagerPtr
global_instance
;
public:
// Singletons should not be cloneable and assignable
PyPassManager
(
const
PyPassManager
&
other
)
=
delete
;
void
operator
=
(
const
PyPassManager
&
)
=
delete
;
// Access the only global instance
static
PyPassManagerPtr
GetInstance
();
virtual
~
PyPassManager
()
=
default
;
void
Registe
(
const
std
::
string
&
pass_name
,
const
py
::
function
&
pattern
,
const
py
::
function
&
target
,
Phase
phase
=
Phase
::
RESOLVE
,
bool
run_only_once
=
false
,
bool
multigraph
=
true
);
void
Unregiste
(
const
std
::
string
&
pass_name
,
Phase
phase
);
PassGroupPtr
GetPassGroup
(
Phase
phase
);
void
ClearRes
();
private:
static
std
::
unordered_map
<
Phase
,
PassGroupPtr
>
phase_to_group_
;
};
}
// namespace python_pass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
b26f6b6b
...
...
@@ -39,6 +39,7 @@
#include "optimizer/optimizer.h"
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "optimizer/py_pass_manager.h"
namespace
mindspore
{
namespace
pipeline
{
...
...
@@ -420,6 +421,25 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
bool
ValidateAction
(
const
ResourcePtr
&
res
)
{
return
ValidatePass
(
res
);
}
void
ActionPyStub
(
const
ResourcePtr
&
res
,
opt
::
python_pass
::
Phase
phase
)
{
MS_EXCEPTION_IF_NULL
(
res
->
manager
());
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
auto
ppm
=
opt
::
python_pass
::
PyPassManager
::
GetInstance
();
if
(
!
ppm
->
GetPassGroup
(
phase
)
->
Run
(
res
->
func_graph
()))
{
MS_LOG
(
DEBUG
)
<<
"No match.
\n
"
;
}
}
bool
ResolveActionPyStub
(
const
ResourcePtr
&
res
)
{
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
RESOLVE
);
return
true
;
}
bool
OptActionPyStub
(
const
ResourcePtr
&
res
)
{
ActionPyStub
(
res
,
opt
::
python_pass
::
Phase
::
RESOLVE
);
return
true
;
}
static
std
::
vector
<
ActionItem
>
CommonPipeline
()
{
std
::
vector
<
ActionItem
>
actions
;
...
...
@@ -432,6 +452,8 @@ static std::vector<ActionItem> CommonPipeline() {
if
(
!
multi_graphs
)
{
actions
.
emplace_back
(
std
::
make_pair
(
"combine_like_graphs"
,
CombineLikeGraphs
));
}
// Add resolve-stage python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_resolve"
,
ResolveActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"inference_opt_prepare"
,
InferenceOptPrepareAction
));
// Evaluate type and shape, and specialize
actions
.
emplace_back
(
std
::
make_pair
(
"abstract_specialize"
,
AbstractSpecializeAction
));
...
...
@@ -443,6 +465,8 @@ std::vector<ActionItem> GePipeline() {
auto
actions
=
CommonPipeline
();
// optimize
actions
.
emplace_back
(
std
::
make_pair
(
"optimize"
,
GeOptimizeAction
));
// Add opt-stage python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_opt"
,
OptActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"remove_value_node_duplications"
,
RemoveValueNodeDuplicationsAction
));
actions
.
emplace_back
(
std
::
make_pair
(
"validate"
,
ValidateAction
));
return
actions
;
...
...
@@ -454,6 +478,9 @@ std::vector<ActionItem> VmPipeline() {
// optimize
actions
.
emplace_back
(
std
::
make_pair
(
"optimize"
,
VmOptimizeAction
));
// Add opt-stage python pass stub
actions
.
emplace_back
(
std
::
make_pair
(
"py_opt"
,
OptActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"validate"
,
ValidateAction
));
// compile the ANF graph
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
b26f6b6b
...
...
@@ -39,6 +39,7 @@
#include "device/kernel_runtime_manager.h"
#include "debug/trace.h"
#include "pynative/pynative_execute.h"
#include "optimizer/py_pass_manager.h"
#if (ENABLE_GE || ENABLE_D)
#include "pipeline/pipeline_ge.h"
...
...
@@ -964,6 +965,7 @@ void ClearResAtexit() {
pipeline
::
ExecutorPy
::
ClearRes
();
pipeline
::
ReclaimOptimizer
();
pynative
::
PynativeExecutor
::
GetInstance
()
->
ClearRes
();
opt
::
python_pass
::
PyPassManager
::
GetInstance
()
->
ClearRes
();
#ifdef ENABLE_GE
transform
::
DfGraphManager
::
GetInstance
().
ClearGraph
();
transform
::
DfGraphConvertor
::
get_adpt_map
().
clear
();
...
...
mindspore/common/python_pass_register.py
0 → 100644
浏览文件 @
b26f6b6b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Python pass register"""
from
inspect
import
isfunction
from
mindspore._c_expression
import
PyPassManager_
from
mindspore._c_expression
import
phase
class
PyPassManager
(
PyPassManager_
):
r
"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises:
TypeError: If argument has invalid type.
"""
def
__init__
(
self
,
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
,
multi_graph
=
True
):
if
not
isinstance
(
pipeline_phase
,
phase
):
raise
TypeError
(
f
"Expecting phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
if
not
isinstance
(
run_only_once
,
bool
):
raise
TypeError
(
f
"Expecting bool, got : (
{
type
(
run_only_once
)
}
)
{
run_only_once
}
"
)
if
not
isinstance
(
multi_graph
,
bool
):
raise
TypeError
(
f
"Expecting bool, got : (
{
type
(
multi_graph
)
}
)
{
multi_graph
}
"
)
PyPassManager_
.
__init__
(
self
)
self
.
phase_
=
pipeline_phase
self
.
run_only_once_
=
run_only_once
self
.
multi_graph_
=
multi_graph
def
registe
(
self
,
py_pass
):
if
not
isfunction
(
py_pass
):
raise
TypeError
(
f
"Expecting function pass, got : (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
pattern
,
target
=
py_pass
()
pass_name
=
py_pass
.
__name__
if
not
isfunction
(
pattern
):
raise
TypeError
(
f
"Expecting function pattern, got : (
{
type
(
pattern
)
}
)
{
pattern
}
"
)
if
not
isfunction
(
target
):
raise
TypeError
(
f
"Expecting function target, got : (
{
type
(
target
)
}
)
{
target
}
"
)
super
().
registe
(
pass_name
,
pattern
,
target
,
self
.
phase_
,
self
.
run_only_once_
,
self
.
multi_graph_
)
def
unregiste
(
self
,
py_pass
,
pipeline_phase
=
phase
.
opt
):
if
not
isinstance
(
pipeline_phase
,
phase
):
raise
TypeError
(
f
"Expecting phase, got : (
{
type
(
pipeline_phase
)
}
)
{
pipeline_phase
}
"
)
if
isinstance
(
py_pass
,
str
):
super
().
unregiste
(
py_pass
,
pipeline_phase
)
return
if
isfunction
(
py_pass
):
super
().
unregiste
(
py_pass
.
__name__
,
pipeline_phase
)
return
raise
TypeError
(
f
"Expecting py_pass to be string or function, got (
{
type
(
py_pass
)
}
)
{
py_pass
}
"
)
def
__call__
(
self
,
py_pass
):
self
.
registe
(
py_pass
)
return
py_pass
def
registe_pass
(
pipeline_phase
=
phase
.
opt
,
run_only_once
=
False
,
multi_graph
=
True
):
"""
Examples:
>>> @registe_pass()
>>> def toy_pass():
>>> def pattern():
>>> pass
>>> def target():
>>> pass
"""
return
PyPassManager
(
pipeline_phase
,
run_only_once
,
multi_graph
)
mindspore/nn/layer/basic.py
浏览文件 @
b26f6b6b
...
...
@@ -170,7 +170,8 @@ class Dense(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
activation (str): activate function applied to the output of the fully connected layer, eg. 'relu'.
Default: None.
Raises:
ValueError: If weight_init or bias_init shape is incorrect.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录