Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
69d1b4c0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
69d1b4c0
编写于
6月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2468 Update Ref Eliminate and TileEliminate to Pattern Matcher
Merge pull request !2468 from Giancarlo/optimizer_update
上级
32f3a657
3277ca56
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
46 addition
and
82 deletion
+46
-82
mindspore/ccsrc/ir/pattern_matcher.h
mindspore/ccsrc/ir/pattern_matcher.h
+4
-0
mindspore/ccsrc/optimizer/irpass/branch_culling.h
mindspore/ccsrc/optimizer/irpass/branch_culling.h
+9
-10
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
+33
-72
未找到文件。
mindspore/ccsrc/ir/pattern_matcher.h
浏览文件 @
69d1b4c0
...
...
@@ -39,6 +39,10 @@ namespace mindspore {
template
<
typename
T
>
class
PBase
{
public:
bool
CheckFunc
(
const
opt
::
PredicateFuncType
&
func
,
const
AnfNodePtr
&
node
)
{
return
func
(
get_object
().
GetNode
(
node
));
}
const
T
&
get_object
()
const
{
return
*
static_cast
<
const
T
*>
(
this
);
}
template
<
typename
TN
>
...
...
mindspore/ccsrc/optimizer/irpass/branch_culling.h
浏览文件 @
69d1b4c0
...
...
@@ -45,7 +45,7 @@ class SwitchSimplify : public OptimizerCaller {
};
MATCH_REPLACE_LAMBDA_IF
(
node
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
SwitchSimplLambda
,
IsValueNode
<
BoolImm
>
(
cond
.
GetNode
(
node
)
));
cond
.
CheckFunc
(
IsValueNode
<
BoolImm
>
,
node
));
return
nullptr
;
}
...
...
@@ -61,7 +61,7 @@ class FloatTupleGetItemSwitch : public OptimizerCaller {
PPrimitive
(
prim
::
kPrimTupleGetItem
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
x
),
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
PPrimitive
(
prim
::
kPrimTupleGetItem
,
true_br
,
x
),
PPrimitive
(
prim
::
kPrimTupleGetItem
,
false_br
,
x
)),
IsVNode
(
x
.
GetNode
(
node
)
));
x
.
CheckFunc
(
IsVNode
,
node
));
return
nullptr
;
}
};
...
...
@@ -72,11 +72,10 @@ class FloatEnvGetItemSwitch : public OptimizerCaller {
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
PatternNode
<
AnfNodePtr
>
cond
,
true_br
,
false_br
,
x
,
x2
;
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimEnvGetItem
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
x
,
x2
),
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
PPrimitive
(
prim
::
kPrimEnvGetItem
,
true_br
,
x
,
x2
),
PPrimitive
(
prim
::
kPrimEnvGetItem
,
false_br
,
x
,
x2
)),
IsNode
(
x
.
GetNode
(
node
))
&&
IsNode
(
x2
.
GetNode
(
node
)));
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimEnvGetItem
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
x
,
x2
),
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
PPrimitive
(
prim
::
kPrimEnvGetItem
,
true_br
,
x
,
x2
),
PPrimitive
(
prim
::
kPrimEnvGetItem
,
false_br
,
x
,
x2
)));
return
nullptr
;
}
...
...
@@ -142,9 +141,9 @@ class ConvertSwitchReplacement : public OptimizerCaller {
return
nnode
;
};
MATCH_REPLACE_LAMBDA_IF
(
node_
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
ConvertSwitchLambda
,
IsNode
(
cond
.
GetNode
(
node_
))
&&
IsValueNode
<
FuncGraph
>
(
true_br
.
GetNode
(
node_
))
&&
IsValueNode
<
FuncGraph
>
(
false_br
.
GetNode
(
node_
)
));
MATCH_REPLACE_LAMBDA_IF
(
node_
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
ConvertSwitchLambda
,
true_br
.
CheckFunc
(
IsValueNode
<
FuncGraph
>
,
node_
)
&&
false_br
.
CheckFunc
(
IsValueNode
<
FuncGraph
>
,
node_
));
return
nullptr
;
}
...
...
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
浏览文件 @
69d1b4c0
...
...
@@ -21,109 +21,70 @@
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "utils/graph_utils.h"
#include "operator/composite/composite.h"
#include "ir/pattern_matcher.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {prim::kPrimMakeRef, X, Y, Z} -> Y
class
MakeRefEliminater
:
public
AnfVisito
r
{
class
MakeRefEliminater
:
public
OptimizerCalle
r
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
y_
=
nullptr
;
auto
gety
=
[
this
](
const
AnfNodePtr
&
node
)
->
bool
{
this
->
y_
=
node
;
return
true
;
};
AnfVisitor
::
Match
(
prim
::
kPrimMakeRef
,
{
IsNode
,
gety
,
IsNode
})(
node
);
return
y_
;
PatternNode
<
AnfNodePtr
>
x
,
y
,
z
;
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
),
y
);
return
nullptr
;
}
void
Visit
(
const
AnfNodePtr
&
)
override
{}
private:
AnfNodePtr
y_
{
nullptr
};
};
// {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class
GetRefParamEliminater
:
public
AnfVisito
r
{
class
GetRefParamEliminater
:
public
OptimizerCalle
r
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
x_
=
nullptr
;
AnfVisitor
::
Match
(
prim
::
kPrimGetRefOrigin
,
{
IsParam
})(
node
);
if
(
x_
!=
nullptr
)
{
return
x_
;
}
AnfVisitor
::
Match
(
prim
::
kPrimGetRefValue
,
{
IsParam
})(
node
);
return
x_
;
PatternNode
<
AnfNodePtr
>
x
;
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimGetRefValue
,
x
),
x
,
x
.
CheckFunc
(
IsParam
,
node
));
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimGetRefOrigin
,
x
),
x
,
x
.
CheckFunc
(
IsParam
,
node
));
return
nullptr
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
x_
=
node
;
}
private:
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class
GetMakeRefEliminater
:
public
AnfVisito
r
{
class
GetMakeRefEliminater
:
public
OptimizerCalle
r
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
!=
2
)
{
return
nullptr
;
}
// {prim::kPrimGetRefKey/Value, {...}}
auto
ref
=
cnode
->
input
(
1
)
->
cast
<
CNodePtr
>
();
if
(
ref
==
nullptr
||
!
ref
->
IsApply
(
prim
::
kPrimMakeRef
)
||
ref
->
size
()
!=
4
)
{
return
nullptr
;
}
// {prim::kPrimMakeRef, X, Y, Z}
if
(
cnode
->
IsApply
(
prim
::
kPrimGetRefKey
))
{
return
ref
->
input
(
1
);
}
if
(
cnode
->
IsApply
(
prim
::
kPrimGetRefValue
))
{
return
ref
->
input
(
2
);
}
if
(
cnode
->
IsApply
(
prim
::
kPrimGetRefOrigin
))
{
return
ref
->
input
(
3
);
}
PatternNode
<
AnfNodePtr
>
x
,
y
,
z
;
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefKey
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
)),
x
);
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefValue
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
)),
y
);
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimGetRefOrigin
,
PPrimitive
(
prim
::
kPrimMakeRef
,
x
,
y
,
z
)),
z
);
return
nullptr
;
}
};
// IsValueNode<RefKey>
class
ReplaceRefkeyByParam
:
public
AnfVisito
r
{
class
ReplaceRefkeyByParam
:
public
OptimizerCalle
r
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
IsValueNode
<
RefKey
>
(
node
))
{
return
nullptr
;
}
auto
refkey
=
GetValueNode
<
RefKeyPtr
>
(
node
);
auto
resource
=
std
::
dynamic_pointer_cast
<
pipeline
::
Resource
>
(
optimizer
->
resource
());
MS_EXCEPTION_IF_NULL
(
resource
);
auto
top_graph
=
resource
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
top_graph
);
for
(
const
auto
&
tnode
:
top_graph
->
parameters
())
{
auto
para
=
tnode
->
cast
<
ParameterPtr
>
();
if
(
para
!=
nullptr
&&
para
->
name
()
==
refkey
->
tag
())
{
return
para
;
auto
RefKeyLambda
=
[
&
node
,
&
optimizer
]()
->
AnfNodePtr
{
auto
refkey
=
GetValueNode
<
RefKeyPtr
>
(
node
);
auto
resource
=
std
::
dynamic_pointer_cast
<
pipeline
::
Resource
>
(
optimizer
->
resource
());
MS_EXCEPTION_IF_NULL
(
resource
);
auto
top_graph
=
resource
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
top_graph
);
for
(
const
auto
&
tnode
:
top_graph
->
parameters
())
{
auto
para
=
tnode
->
cast
<
ParameterPtr
>
();
if
(
para
!=
nullptr
&&
para
->
name
()
==
refkey
->
tag
())
{
return
para
;
}
}
}
return
nullptr
;
};
PatternNode
<
AnfNodePtr
>
x
;
MATCH_REPLACE_LAMBDA_IF
(
node
,
x
,
RefKeyLambda
,
x
.
CheckFunc
(
IsValueNode
<
RefKey
>
,
node
));
return
nullptr
;
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录