Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
067616d0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
067616d0
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1833 Pattern Matcher class for optimizations
Merge pull request !1833 from Giancarlo/pattern_matcher
上级
1cc43008
8eaea744
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
377 addition
and
166 deletion
+377
-166
mindspore/ccsrc/ir/pattern_matcher.h
mindspore/ccsrc/ir/pattern_matcher.h
+306
-0
mindspore/ccsrc/optimizer/irpass/branch_culling.h
mindspore/ccsrc/optimizer/irpass/branch_culling.h
+71
-166
未找到文件。
mindspore/ccsrc/ir/pattern_matcher.h
0 → 100644
浏览文件 @
067616d0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#include <tuple>
#include <vector>
#include "ir/anf.h"
#include "operator/ops.h"
namespace
mindspore
{
///
/// Base class for all recognizable patterns.
/// We implement an Expression Template approach using static polymorphism based on
/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
/// to the use of virtual functions without the costs..." as described in:
/// https://en.wikipedia.org/wiki/Expression_templates and
/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
/// The TryCapture function tries to capture the pattern with the given node.
/// The GetNode function builds a new node using the captured values.
///
template
<
typename
T
>
class
PBase
{
public:
const
T
&
get_object
()
const
{
return
*
static_cast
<
const
T
*>
(
this
);
}
template
<
typename
TN
>
bool
TryCapture
(
const
TN
&
value
)
const
{
get_object
().
Reset
();
return
get_object
().
TryCapture_
(
value
);
}
using
Internal
=
T
;
};
template
<
typename
T
>
class
PIsEqual
{
public:
bool
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
const
{
return
lhs
==
rhs
;
}
};
template
<
typename
T
>
class
PatternNode
:
public
PBase
<
PatternNode
<
T
>
>
{
public:
T
GetNode
(
const
AnfNodePtr
&
node
)
const
{
if
(
!
captured_
)
{
MS_EXCEPTION
(
ValueError
)
<<
"A Pattern wasn't captured for this Token before the call to GetNode."
;
}
return
captured_node_
;
}
bool
TryCapture_
(
const
T
&
node
)
const
{
if
(
!
captured_
)
{
captured_node_
=
node
;
captured_
=
true
;
return
true
;
}
return
PIsEqual
<
T
>
()(
captured_node_
,
node
);
}
void
Reset
()
const
{
captured_
=
false
;
}
using
Internal
=
const
PatternNode
<
T
>
&
;
protected:
mutable
T
captured_node_
;
mutable
bool
captured_
{
false
};
};
template
<
typename
T
,
typename
T2
>
class
PBinOperation
:
public
PBase
<
PBinOperation
<
T
,
T2
>
>
{
public:
PBinOperation
(
const
PrimitivePtr
&
prim
,
const
T
&
x
,
const
T2
&
y
)
:
prim_
(
prim
),
x_
(
x
),
y_
(
y
)
{}
AnfNodePtr
GetNode
(
const
AnfNodePtr
&
node
)
const
{
AnfNodePtr
lhs
=
x_
.
GetNode
(
node
->
func_graph
());
AnfNodePtr
rhs
=
y_
.
GetNode
(
node
->
func_graph
());
AnfNodePtrList
list
=
{
prim_
->
cast
<
AnfNodePtr
>
(),
lhs
,
rhs
};
return
NewCNode
(
list
,
node
->
func_graph
());
}
bool
TryCapture_
(
const
AnfNodePtr
&
node
)
const
{
if
(
IsPrimitiveCNode
(
node
,
prim_
))
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
inputs
=
cnode
->
inputs
();
if
(
inputs
.
size
()
==
3
)
{
// Binary Prim assumes only two inputs
if
(
!
x_
.
TryCapture_
(
inputs
[
1
])
||
!
y_
.
TryCapture_
(
inputs
[
2
]))
{
return
false
;
}
return
true
;
}
}
return
false
;
}
void
Reset
()
const
{
x_
.
Reset
();
y_
.
Reset
();
}
private:
const
PrimitivePtr
prim_
;
typename
T
::
Internal
x_
;
typename
T2
::
Internal
y_
;
};
///
/// Helper functions to apply a pattern function on all elements of a tuple
///
namespace
tuple_utils
{
template
<
bool
stop
,
size_t
Index
,
typename
Func
>
struct
apply_func_tuple_item
{
template
<
typename
TTuple
>
static
void
apply
(
Func
*
func
,
const
TTuple
&
tuple
)
{
(
*
func
)(
Index
,
std
::
get
<
Index
>
(
tuple
));
apply_func_tuple_item
<
(
Index
+
1
)
==
std
::
tuple_size
<
TTuple
>::
value
,
(
Index
+
1
),
Func
>::
apply
(
func
,
tuple
);
}
};
template
<
size_t
Index
,
typename
Func
>
struct
apply_func_tuple_item
<
true
,
Index
,
Func
>
{
template
<
typename
TTuple
>
static
void
apply
(
Func
*
func
,
const
TTuple
&
tuple
)
{}
};
template
<
typename
Func
,
typename
TTuple
>
inline
void
apply_func_tuple
(
Func
*
func
,
const
TTuple
&
tuple
)
{
apply_func_tuple_item
<
std
::
tuple_size
<
TTuple
>::
value
==
0
,
0
,
Func
>::
apply
(
func
,
tuple
);
}
struct
PTupleResetCapture
{
template
<
typename
T
>
void
operator
()(
size_t
i
,
const
T
&
pattern
)
const
{
pattern
.
Reset
();
}
};
struct
PTupleCapture
{
explicit
PTupleCapture
(
const
AnfNodePtrList
tuple
)
:
tuple_
(
tuple
)
{}
template
<
typename
TPattern
>
void
operator
()(
size_t
i
,
const
TPattern
&
pattern
)
{
// Check if the first node is a Primitive
if
(
i
==
0
&&
tuple_
[
i
]
->
isa
<
Primitive
>
())
{
auto
prim
=
tuple_
[
i
]
->
cast
<
PrimitivePtr
>
();
if
(
tuple_
[
i
]
!=
pattern
.
GetNode
(
tuple_
[
i
]))
{
captured_
=
false
;
}
}
else
{
captured_
=
captured_
&&
pattern
.
TryCapture_
(
tuple_
[
i
]);
}
}
const
AnfNodePtrList
tuple_
;
bool
captured_
{
true
};
};
struct
PTupleGetNode
{
explicit
PTupleGetNode
(
const
AnfNodePtr
&
node
)
:
node_
(
node
)
{}
template
<
typename
TPattern
>
void
operator
()(
size_t
,
const
TPattern
&
pattern
)
{
args_
.
push_back
(
pattern
.
GetNode
(
node_
));
}
const
AnfNodePtr
&
node_
;
std
::
vector
<
AnfNodePtr
>
args_
;
};
}
// namespace tuple_utils
template
<
typename
...
TArgs
>
class
PCNode
:
public
PBase
<
PCNode
<
TArgs
...
>
>
{
public:
explicit
PCNode
(
const
TArgs
&
...
args
)
:
args_
(
args
...)
{}
AnfNodePtr
GetNode
(
const
AnfNodePtr
&
node
)
const
{
tuple_utils
::
PTupleGetNode
get_node
(
node
);
tuple_utils
::
apply_func_tuple
(
&
get_node
,
args_
);
return
NewCNode
(
get_node
.
args_
,
node
->
func_graph
());
}
bool
TryCapture_
(
const
AnfNodePtr
&
node
)
const
{
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
inputs
=
cnode
->
inputs
();
if
(
inputs
.
size
()
!=
sizeof
...(
TArgs
))
{
return
false
;
}
tuple_utils
::
PTupleCapture
capture_func
(
inputs
);
tuple_utils
::
apply_func_tuple
(
&
capture_func
,
args_
);
return
capture_func
.
captured_
;
}
return
false
;
}
void
Reset
()
const
{
tuple_utils
::
PTupleResetCapture
reset
;
tuple_utils
::
apply_func_tuple
(
&
reset
,
args_
);
}
private:
std
::
tuple
<
typename
TArgs
::
Internal
...
>
args_
;
};
template
<
typename
...
TArgs
>
class
PPrimitive
:
public
PBase
<
PPrimitive
<
TArgs
...
>
>
{
public:
explicit
PPrimitive
(
const
PrimitivePtr
&
prim
,
const
TArgs
&
...
args
)
:
prim_
(
prim
),
args_
(
args
...)
{}
AnfNodePtr
GetNode
(
const
AnfNodePtr
&
node
)
const
{
tuple_utils
::
PTupleGetNode
get_node
(
node
);
tuple_utils
::
apply_func_tuple
(
&
get_node
,
args_
);
auto
prim_cnode
=
get_node
.
args_
;
prim_cnode
.
insert
(
prim_cnode
.
begin
(),
NewValueNode
(
prim_
));
return
NewCNode
(
prim_cnode
,
node
->
func_graph
());
}
bool
TryCapture_
(
const
AnfNodePtr
&
node
)
const
{
if
(
IsPrimitiveCNode
(
node
,
prim_
))
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
inputs
=
cnode
->
inputs
();
if
((
inputs
.
size
()
-
1
)
!=
sizeof
...(
TArgs
))
{
return
false
;
}
AnfNodePtrList
rest
(
inputs
.
begin
()
+
1
,
inputs
.
end
());
tuple_utils
::
PTupleCapture
capture_func
(
rest
);
tuple_utils
::
apply_func_tuple
(
&
capture_func
,
args_
);
return
capture_func
.
captured_
;
}
return
false
;
}
void
Reset
()
const
{
tuple_utils
::
PTupleResetCapture
reset
;
tuple_utils
::
apply_func_tuple
(
&
reset
,
args_
);
}
private:
const
PrimitivePtr
prim_
;
std
::
tuple
<
typename
TArgs
::
Internal
...
>
args_
;
};
// Macro for binary operation functions
#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \
template <typename T, typename T2> \
inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \
}
// Arithmetic operations
BIN_OPERATION_PATTERN
(
operator
+
,
prim
::
kPrimTensorAdd
);
BIN_OPERATION_PATTERN
(
operator
*
,
prim
::
kPrimMul
);
// Macros for match and replace
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
if ((CaptureNode).TryCapture(OrigNode)) { \
return (ReplaceWith).GetNode(OrigNode); \
}
#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
return (ReplaceWith).GetNode(OrigNode); \
}
#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \
if ((CaptureNode).TryCapture(OrigNode)) { \
if ((Condition)) { \
return (ReplaceWith).GetNode(OrigNode); \
} \
return (ElseNode).GetNode(OrigNode); \
}
#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
if ((CaptureNode).TryCapture(OrigNode)) { \
return (Lambda)(); \
}
#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
return (Lambda)(); \
}
}
// namespace mindspore
#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
mindspore/ccsrc/optimizer/irpass/branch_culling.h
浏览文件 @
067616d0
...
@@ -26,141 +26,61 @@
...
@@ -26,141 +26,61 @@
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/func_graph_cloner.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
namespace
irpass
{
namespace
irpass
{
// {prim::kPrimSwitch, true, X, Y}
// {prim::kPrimSwitch, true, X, Y}
// {prim::kPrimSwitch, false, X, Y}
// {prim::kPrimSwitch, false, X, Y}
class
SwitchSimplify
:
public
AnfVisitor
{
class
SwitchSimplify
{
public:
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
PatternNode
<
AnfNodePtr
>
cond
,
true_br
,
false_br
;
auto
getx
=
[
this
](
const
AnfNodePtr
&
node
)
->
bool
{
auto
SwitchSimplLambda
=
[
&
node
,
&
cond
,
&
true_br
,
&
false_br
]()
->
AnfNodePtr
{
this
->
x_
=
node
;
auto
cond_value_
=
GetValue
<
bool
>
(
GetValueNode
(
cond
.
GetNode
(
node
)));
return
true
;
if
(
cond_value_
)
{
};
return
true_br
.
GetNode
(
node
);
auto
gety
=
[
this
](
const
AnfNodePtr
&
node
)
->
bool
{
}
this
->
y_
=
node
;
return
false_br
.
GetNode
(
node
);
return
true
;
};
};
AnfVisitor
::
Match
(
prim
::
kPrimSwitch
,
{
IsValueNode
<
BoolImm
>
,
getx
,
gety
})(
node
);
// simplify the switch
MATCH_REPLACE_LAMBDA_IF
(
node
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
SwitchSimplLambda
,
if
(
is_match_
)
{
IsValueNode
<
BoolImm
>
(
cond
.
GetNode
(
node
)));
if
(
cond_
)
{
return
x_
;
}
return
y_
;
}
return
nullptr
;
return
nullptr
;
}
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
!
is_match_
&&
IsValueNode
<
BoolImm
>
(
node
))
{
cond_
=
GetValue
<
bool
>
(
GetValueNode
(
node
));
is_match_
=
true
;
}
}
void
Reset
()
{
x_
=
nullptr
;
y_
=
nullptr
;
cond_
=
false
;
is_match_
=
false
;
}
private:
bool
is_match_
{
false
},
cond_
{
false
};
AnfNodePtr
x_
{
nullptr
},
y_
{
nullptr
};
};
};
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
class
FloatTupleGetItemSwitch
:
public
AnfVisitor
{
class
FloatTupleGetItemSwitch
{
public:
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
PatternNode
<
AnfNodePtr
>
cond
,
true_br
,
false_br
,
x
;
AnfVisitor
::
Match
(
prim
::
kPrimTupleGetItem
,
{
IsCNode
,
IsVNode
})(
node
);
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimTupleGetItem
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
x
),
auto
fg
=
node
->
func_graph
();
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
PPrimitive
(
prim
::
kPrimTupleGetItem
,
true_br
,
x
),
if
(
Xs_
.
empty
()
||
c_
==
nullptr
||
fg
==
nullptr
)
{
PPrimitive
(
prim
::
kPrimTupleGetItem
,
false_br
,
x
)),
return
nullptr
;
IsVNode
(
x
.
GetNode
(
node
)));
}
return
nullptr
;
auto
true_node
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
Xs_
[
1
],
c_
});
auto
false_node
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
Xs_
[
2
],
c_
});
return
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimSwitch
),
Xs_
[
0
],
true_node
,
false_node
});
}
void
Visit
(
const
CNodePtr
&
cnode
)
override
{
// {prim::kPrimSwith, X1, X2, X3}
if
(
!
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimSwitch
)
||
cnode
->
size
()
!=
4
)
{
return
;
}
// copy X1, X2, X3
auto
&
inputs
=
cnode
->
inputs
();
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
Xs_
));
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
c_
=
vnode
;
}
void
Reset
()
{
Xs_
.
clear
();
c_
=
nullptr
;
}
}
private:
AnfNodePtr
c_
{
nullptr
};
std
::
vector
<
AnfNodePtr
>
Xs_
{};
};
};
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
class
FloatEnvGetItemSwitch
:
public
AnfVisitor
{
class
FloatEnvGetItemSwitch
{
public:
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
is_match_
=
false
;
PatternNode
<
AnfNodePtr
>
cond
,
true_br
,
false_br
,
x
,
x2
;
AnfVisitor
::
Match
(
prim
::
kPrimEnvGetItem
,
{
IsCNode
,
IsNode
,
IsNode
})(
node
);
MATCH_REPLACE_IF
(
node
,
if
(
!
is_match_
)
{
PPrimitive
(
prim
::
kPrimEnvGetItem
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
x
,
x2
),
return
nullptr
;
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
PPrimitive
(
prim
::
kPrimEnvGetItem
,
true_br
,
x
,
x2
),
}
PPrimitive
(
prim
::
kPrimEnvGetItem
,
false_br
,
x
,
x2
)),
IsNode
(
x
.
GetNode
(
node
))
&&
IsNode
(
x2
.
GetNode
(
node
)));
// {prim::kPrimEnvGetItem, {...}, X4, X5}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
sw_node
=
cnode
->
input
(
1
)
->
cast
<
CNodePtr
>
();
auto
x4
=
cnode
->
input
(
2
);
auto
x5
=
cnode
->
input
(
3
);
is_match_
=
false
;
return
nullptr
;
AnfVisitor
::
Match
(
prim
::
kPrimSwitch
,
{
IsNode
,
IsNode
,
IsNode
})(
sw_node
);
if
(
!
is_match_
)
{
return
nullptr
;
}
// {prim::kPrimSwitch, X1, X2, X3}
auto
x1
=
sw_node
->
input
(
1
);
auto
x2
=
sw_node
->
input
(
2
);
auto
x3
=
sw_node
->
input
(
3
);
auto
fg
=
node
->
func_graph
();
if
(
fg
==
nullptr
)
{
return
nullptr
;
}
auto
true_node
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimEnvGetItem
),
x2
,
x4
,
x5
});
auto
false_node
=
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimEnvGetItem
),
x3
,
x4
,
x5
});
return
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimSwitch
),
x1
,
true_node
,
false_node
});
}
}
void
Visit
(
const
AnfNodePtr
&
)
override
{
is_match_
=
true
;
}
private:
bool
is_match_
{
false
};
};
};
namespace
internal
{
namespace
internal
{
...
@@ -173,79 +93,64 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
...
@@ -173,79 +93,64 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
}
// namespace internal
}
// namespace internal
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
class
ConvertSwitchReplacement
:
public
AnfVisitor
{
class
ConvertSwitchReplacement
{
public:
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
if
(
!
node
->
isa
<
CNode
>
()
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
return
nullptr
;
}
}
Reset
();
auto
cnode_
=
node
->
cast
<
CNodePtr
>
();
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode_
->
size
()
<
1
)
{
if
(
cnode
->
size
()
<
1
)
{
return
nullptr
;
return
nullptr
;
}
}
// {prim::kPrimSwitch, X, G1, G2}
auto
node_
=
cnode_
->
input
(
0
);
AnfVisitor
::
Match
(
prim
::
kPrimSwitch
,
{
IsNode
,
IsValueNode
<
FuncGraph
>
,
IsValueNode
<
FuncGraph
>
})(
cnode
->
input
(
0
));
if
(
g2_
==
nullptr
||
g1_
->
output
()
==
nullptr
||
g2_
->
output
()
==
nullptr
)
{
PatternNode
<
AnfNodePtr
>
cond
,
true_br
,
false_br
;
return
nullptr
;
}
auto
ConvertSwitchLambda
=
[
&
node_
,
&
cond
,
&
true_br
,
&
false_br
]()
->
AnfNodePtr
{
// for switch replace method, only graphs without graph inside can be replaced
auto
g1_
=
GetValueNode
<
FuncGraphPtr
>
(
true_br
.
GetNode
(
node_
));
for
(
auto
&
item
:
g1_
->
value_nodes
())
{
auto
g2_
=
GetValueNode
<
FuncGraphPtr
>
(
false_br
.
GetNode
(
node_
));
auto
value_node
=
item
.
first
;
auto
x_
=
cond
.
GetNode
(
node_
);
if
(
IsValueNode
<
FuncGraph
>
(
value_node
))
{
return
nullptr
;
// for switch replace method, only graphs without graph inside can be replaced
for
(
auto
&
item
:
g1_
->
value_nodes
())
{
auto
value_node
=
item
.
first
;
if
(
IsValueNode
<
FuncGraph
>
(
value_node
))
{
return
nullptr
;
}
}
}
}
for
(
auto
&
item
:
g2_
->
value_nodes
())
{
for
(
auto
&
item
:
g2_
->
value_nodes
())
{
auto
value_node
=
item
.
first
;
auto
value_node
=
item
.
first
;
if
(
IsValueNode
<
FuncGraph
>
(
value_node
))
{
if
(
IsValueNode
<
FuncGraph
>
(
value_node
))
{
return
nullptr
;
return
nullptr
;
}
}
}
}
auto
true_output
=
g1_
->
output
()
->
abstract
();
auto
true_output
=
g1_
->
output
()
->
abstract
();
auto
false_output
=
g2_
->
output
()
->
abstract
();
auto
false_output
=
g2_
->
output
()
->
abstract
();
auto
trans_g1
=
internal
::
TransformGraphCondTrueBranchNodes
(
g1_
,
x_
);
auto
trans_g1
=
internal
::
TransformGraphCondTrueBranchNodes
(
g1_
,
x_
);
auto
trans_g2
=
internal
::
TransformGraphCondFalseBranchNodes
(
g2_
,
x_
);
auto
trans_g2
=
internal
::
TransformGraphCondFalseBranchNodes
(
g2_
,
x_
);
std
::
vector
<
AnfNodePtr
>
params
;
auto
fg
=
node
->
func_graph
();
auto
cloned_g1
=
InlineClone
(
trans_g1
,
fg
,
params
);
auto
cloned_g2
=
InlineClone
(
trans_g2
,
fg
,
params
);
auto
nnode
=
internal
::
TransformMergeBranches
(
cloned_g1
,
cloned_g2
,
true_output
,
false_output
,
x_
,
fg
);
return
nnode
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
std
::
vector
<
AnfNodePtr
>
params
;
if
(
x_
==
nullptr
)
{
auto
fg
=
node_
->
func_graph
();
x_
=
node
;
auto
cloned_g1
=
InlineClone
(
trans_g1
,
fg
,
params
);
return
;
auto
cloned_g2
=
InlineClone
(
trans_g2
,
fg
,
params
);
}
auto
nnode
=
internal
::
TransformMergeBranches
(
cloned_g1
,
cloned_g2
,
true_output
,
false_output
,
x_
,
fg
);
AnfVisitor
::
Visit
(
node
);
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
return
nnode
;
auto
g
=
GetValueNode
<
FuncGraphPtr
>
(
vnode
);
};
if
(
g1_
==
nullptr
)
{
g1_
=
g
;
}
else
{
g2_
=
g
;
}
}
void
Reset
()
{
MATCH_REPLACE_LAMBDA_IF
(
node_
,
PPrimitive
(
prim
::
kPrimSwitch
,
cond
,
true_br
,
false_br
),
ConvertSwitchLambda
,
x_
=
nullptr
;
IsNode
(
cond
.
GetNode
(
node_
))
&&
IsValueNode
<
FuncGraph
>
(
true_br
.
GetNode
(
node_
))
&&
g1_
=
nullptr
;
IsValueNode
<
FuncGraph
>
(
false_br
.
GetNode
(
node_
)));
g2_
=
nullptr
;
}
private:
return
nullptr
;
AnfNodePtr
x_
{
nullptr
};
}
FuncGraphPtr
g1_
{
nullptr
},
g2_
{
nullptr
};
};
};
}
// namespace irpass
}
// namespace irpass
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录