Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e5c67b90
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e5c67b90
编写于
4月 13, 2020
作者:
Y
YuJianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add cnode to equal map when opt matching
上级
a4cf9028
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
182 addition
and
186 deletion
+182
-186
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
...rc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
+12
-39
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h
...src/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h
+5
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
...tivate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
+14
-40
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h
...ctivate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h
+5
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
...src/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
+6
-29
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h
...csrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h
+5
-1
mindspore/ccsrc/pre_activate/common/optimizer.cc
mindspore/ccsrc/pre_activate/common/optimizer.cc
+21
-9
mindspore/ccsrc/pre_activate/common/optimizer.h
mindspore/ccsrc/pre_activate/common/optimizer.h
+2
-0
mindspore/ccsrc/pre_activate/common/pattern_engine.cc
mindspore/ccsrc/pre_activate/common/pattern_engine.cc
+55
-24
mindspore/ccsrc/pre_activate/common/pattern_engine.h
mindspore/ccsrc/pre_activate/common/pattern_engine.h
+41
-30
tests/ut/cpp/pre_activate/common/pattern_engine_test.cc
tests/ut/cpp/pre_activate/common/pattern_engine_test.cc
+16
-14
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
浏览文件 @
e5c67b90
...
...
@@ -15,43 +15,9 @@
*/
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/common/helper.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
void
GetAdd0AndAdd1
(
const
AnfNodePtr
&
sub0
,
AnfNodePtr
*
add0
,
AnfNodePtr
*
add1
)
{
MS_EXCEPTION_IF_NULL
(
sub0
);
MS_EXCEPTION_IF_NULL
(
add0
);
MS_EXCEPTION_IF_NULL
(
add1
);
auto
sub0_cnode
=
sub0
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sub0_cnode
);
CheckCNodeInputSize
(
sub0_cnode
,
kSubInputNum
);
AnfNodePtr
mul4
=
sub0_cnode
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
mul4
);
auto
mul4_cnode
=
mul4
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
mul4_cnode
);
CheckCNodeInputSize
(
mul4_cnode
,
kMulInputNum
);
AnfNodePtr
true_div0
=
mul4_cnode
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
true_div0
);
auto
true_div0_cnode
=
true_div0
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
true_div0_cnode
);
CheckCNodeInputSize
(
true_div0_cnode
,
kRealDivInputNum
);
*
add0
=
true_div0_cnode
->
input
(
1
);
AnfNodePtr
add2
=
true_div0_cnode
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
add2
);
auto
add2_cnode
=
add2
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add2_cnode
);
CheckCNodeInputSize
(
add2_cnode
,
kAddInputNum
);
AnfNodePtr
sqrt0
=
add2_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
sqrt0
);
auto
sqrt0_cnode
=
sqrt0
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sqrt0_cnode
);
CheckCNodeInputSize
(
sqrt0_cnode
,
kSqrtInputNum
);
*
add1
=
sqrt0_cnode
->
input
(
1
);
}
}
// namespace
AnfNodePtr
AdamApplyOneFusion
::
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
equiv
);
...
...
@@ -79,10 +45,10 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const {
const
auto
prim_deal_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
prim
::
kPrimTensorAdd
,
mul2
,
mul3
})});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
add1_var_
,
mul2
,
mul3
})});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
prim
::
kPrimTensorAdd
,
mul0
,
mul1
});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_deal_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
input_vars_
[
4
],
true_div0
})});
}
...
...
@@ -96,10 +62,17 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
new_node
->
set_scope
(
node
->
scope
());
// Set abstract of new node
AbstractBasePtrList
new_node_abstract_list
;
AnfNodePtr
add0
=
nullptr
;
AnfNodePtr
add1
=
nullptr
;
GetAdd0AndAdd1
(
node
,
&
add0
,
&
add1
);
auto
iter_add0
=
(
*
equiv
).
find
(
add0_var_
);
if
(
iter_add0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the add0 var after matched."
;
}
auto
iter_add1
=
(
*
equiv
).
find
(
add1_var_
);
if
(
iter_add1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the add1 var after matched."
;
}
auto
add0
=
utils
::
cast
<
AnfNodePtr
>
(
iter_add0
->
second
);
MS_EXCEPTION_IF_NULL
(
add0
);
auto
add1
=
utils
::
cast
<
AnfNodePtr
>
(
iter_add1
->
second
);
MS_EXCEPTION_IF_NULL
(
add1
);
new_node_abstract_list
.
push_back
(
add1
->
abstract
());
new_node_abstract_list
.
push_back
(
add0
->
abstract
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h
浏览文件 @
e5c67b90
...
...
@@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -35,6 +36,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
mul_x_input_vars_
.
push_back
(
std
::
make_shared
<
Var
>
());
}
add2_y_
=
std
::
make_shared
<
Var
>
();
add0_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
add1_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
}
~
AdamApplyOneFusion
()
override
=
default
;
...
...
@@ -46,6 +49,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
std
::
vector
<
VarPtr
>
input_vars_
;
std
::
vector
<
VarPtr
>
mul_x_input_vars_
;
VarPtr
add2_y_
;
VarPtr
add0_var_
;
VarPtr
add1_var_
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
浏览文件 @
e5c67b90
...
...
@@ -17,48 +17,13 @@
#include <memory>
#include <vector>
#include <tuple>
#include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "pre_activate/common/helper.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
std
::
tuple
<
AnfNodePtr
,
AnfNodePtr
>
GetAdd0Add1Node
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
sub0
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sub0
);
auto
mul5_anf
=
sub0
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
mul5_anf
);
auto
mul5
=
mul5_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
mul5
);
auto
add3_anf
=
mul5
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
add3_anf
);
auto
add3
=
add3_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add3
);
auto
real_div0_anf
=
add3
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
real_div0_anf
);
auto
real_div0
=
real_div0_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
real_div0
);
auto
add0_anf
=
real_div0
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
add0_anf
);
auto
add2_anf
=
real_div0
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
add2_anf
);
auto
add2
=
add2_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add2
);
auto
sqrt0_anf
=
add2
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
sqrt0_anf
);
auto
sqrt0
=
sqrt0_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sqrt0
);
auto
add1_anf
=
sqrt0
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
add1_anf
);
return
std
::
make_tuple
(
add0_anf
,
add1_anf
);
}
}
// namespace
std
::
vector
<
AnfNodePtr
>
AdamApplyOneWithDecayRule
::
GetFusionNodeInputs
(
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
input0
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input0_
]);
...
...
@@ -82,10 +47,10 @@ const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const {
VectorRef
mul0_pattern
({
prim
::
kPrimMul
,
mul0_x_
,
input2_
});
VectorRef
mul1_pattern
({
prim
::
kPrimMul
,
mul1_x_
,
input0_
});
VectorRef
square0_pattern
({
prim
::
kPrimSquare
,
input0_
});
VectorRef
add0_pattern
({
prim
::
kPrimTensorAdd
,
mul0_pattern
,
mul1_pattern
});
VectorRef
add0_pattern
({
add0_var_
,
mul0_pattern
,
mul1_pattern
});
VectorRef
mul2_pattern
({
prim
::
kPrimMul
,
mul2_x_
,
input1_
});
VectorRef
mul3_pattern
({
prim
::
kPrimMul
,
mul3_x_
,
square0_pattern
});
VectorRef
add1_pattern
({
prim
::
kPrimTensorAdd
,
mul2_pattern
,
mul3_pattern
});
VectorRef
add1_pattern
({
add1_var_
,
mul2_pattern
,
mul3_pattern
});
VectorRef
sqrt0_pattern
({
sqrt
,
add1_pattern
});
VectorRef
add2_pattern
({
prim
::
kPrimTensorAdd
,
sqrt0_pattern
,
add2_y_
});
VectorRef
mul4_pattern
({
prim
::
kPrimMul
,
mul4_x_
,
input3_
});
...
...
@@ -107,9 +72,18 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL
(
fusion_node
);
fusion_node
->
set_scope
(
node
->
scope
());
AnfNodePtr
add0
=
nullptr
;
AnfNodePtr
add1
=
nullptr
;
std
::
tie
(
add0
,
add1
)
=
GetAdd0Add1Node
(
node
);
auto
iter_add0
=
(
*
equiv
).
find
(
add0_var_
);
if
(
iter_add0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the add0 var after matched."
;
}
auto
iter_add1
=
(
*
equiv
).
find
(
add1_var_
);
if
(
iter_add1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the add1 var after matched."
;
}
auto
add0
=
utils
::
cast
<
AnfNodePtr
>
(
iter_add0
->
second
);
MS_EXCEPTION_IF_NULL
(
add0
);
auto
add1
=
utils
::
cast
<
AnfNodePtr
>
(
iter_add1
->
second
);
MS_EXCEPTION_IF_NULL
(
add1
);
auto
types
=
{
AnfAlgo
::
GetOutputInferDataType
(
add1
,
0
),
AnfAlgo
::
GetOutputInferDataType
(
add0
,
0
),
AnfAlgo
::
GetOutputInferDataType
(
node
,
0
)};
auto
shapes
=
{
AnfAlgo
::
GetOutputInferShape
(
add1
,
0
),
AnfAlgo
::
GetOutputInferShape
(
add0
,
0
),
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h
浏览文件 @
e5c67b90
...
...
@@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
class
AdamApplyOneWithDecayRule
:
public
PatternProcessPass
{
...
...
@@ -36,6 +37,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
mul3_x_
=
std
::
make_shared
<
Var
>
();
mul4_x_
=
std
::
make_shared
<
Var
>
();
add2_y_
=
std
::
make_shared
<
Var
>
();
add0_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
add1_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
}
~
AdamApplyOneWithDecayRule
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
...
...
@@ -54,6 +57,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
VarPtr
mul3_x_
;
VarPtr
mul4_x_
;
VarPtr
add2_y_
;
VarPtr
add0_var_
;
VarPtr
add1_var_
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
浏览文件 @
e5c67b90
...
...
@@ -16,36 +16,9 @@
#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h"
#include <vector>
#include "pre_activate/common/helper.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
AnfNodePtr
GetAdd1Node
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
add2_cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add2_cnode
);
if
(
add2_cnode
->
inputs
().
size
()
!=
kAddInputNum
)
{
MS_LOG
(
ERROR
)
<<
"The input size of Add2 is not equal to "
<<
kAddInputNum
;
}
AnfNodePtr
sqrt0
=
add2_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
sqrt0
);
auto
sqrt0_cnode
=
sqrt0
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sqrt0_cnode
);
if
(
sqrt0_cnode
->
inputs
().
size
()
!=
kSqrtInputNum
)
{
MS_LOG
(
ERROR
)
<<
"The input size of Sqrt0 is not equal to "
<<
kSqrtInputNum
;
}
AnfNodePtr
real_div1
=
sqrt0_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
real_div1
);
auto
real_div1_cnode
=
real_div1
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
real_div1_cnode
);
if
(
real_div1_cnode
->
inputs
().
size
()
!=
kMulInputNum
)
{
MS_LOG
(
ERROR
)
<<
"The input size of RealDiv1 is not equal to "
<<
kMulInputNum
;
}
return
real_div1_cnode
->
input
(
1
);
}
}
// namespace
AnfNodePtr
LambNextRightRule
::
CreateLambNextRightNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
equiv
);
...
...
@@ -79,7 +52,7 @@ const BaseRef LambNextRightRule::DefinePattern() const {
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_sqrt
);
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul3_x_
,
VectorRef
({
prim
::
kPrimSquare
,
input0_
})});
VectorRef
add1
=
VectorRef
({
prim
::
kPrimTensorAdd
,
VectorRef
({
prim
::
kPrimMul
,
mul2_x_
,
input1_
}),
mul3
});
VectorRef
add1
=
VectorRef
({
add1_var_
,
VectorRef
({
prim
::
kPrimMul
,
mul2_x_
,
input1_
}),
mul3
});
return
VectorRef
(
{
prim
::
kPrimTensorAdd
,
VectorRef
({
prim_sqrt
,
VectorRef
({
prim
::
kPrimMul
,
add1
,
true_div1_recip_
})}),
add2_y_
});
}
...
...
@@ -91,7 +64,11 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons
auto
new_node
=
CreateLambNextRightNode
(
func_graph
,
equiv
);
MS_EXCEPTION_IF_NULL
(
new_node
);
// Set abstract of new node
AnfNodePtr
add1
=
GetAdd1Node
(
node
);
auto
iter_add1
=
(
*
equiv
).
find
(
add1_var_
);
if
(
iter_add1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the add1 var after matched."
;
}
auto
add1
=
utils
::
cast
<
AnfNodePtr
>
(
iter_add1
->
second
);
MS_EXCEPTION_IF_NULL
(
add1
);
AbstractBasePtrList
new_node_abstract_list
;
new_node_abstract_list
.
push_back
(
add1
->
abstract
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h
浏览文件 @
e5c67b90
...
...
@@ -18,6 +18,8 @@
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
class
LambNextRightRule
:
public
PatternProcessPass
{
...
...
@@ -29,7 +31,8 @@ class LambNextRightRule : public PatternProcessPass {
mul2_x_
(
std
::
make_shared
<
Var
>
()),
mul3_x_
(
std
::
make_shared
<
Var
>
()),
true_div1_recip_
(
std
::
make_shared
<
Var
>
()),
add2_y_
(
std
::
make_shared
<
Var
>
())
{}
add2_y_
(
std
::
make_shared
<
Var
>
()),
add1_var_
(
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
())))
{}
~
LambNextRightRule
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
...
...
@@ -44,6 +47,7 @@ class LambNextRightRule : public PatternProcessPass {
VarPtr
mul3_x_
;
VarPtr
true_div1_recip_
;
VarPtr
add2_y_
;
VarPtr
add1_var_
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/common/optimizer.cc
浏览文件 @
e5c67b90
...
...
@@ -30,7 +30,8 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
AnfNodePtr
HandleSexpVector
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
bool
multigraph
);
AnfNodePtr
HandleSexpVector
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
);
ValueNodePtr
CreateValueNodeWithSexp
(
const
BaseRef
&
sexp
)
{
if
(
utils
::
isa
<
int
>
(
sexp
))
{
...
...
@@ -71,12 +72,20 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
return
nullptr
;
}
AnfNodePtr
SexpToNode
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
bool
multigraph
=
false
)
{
AnfNodePtr
SexpToNode
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
=
false
)
{
MS_LOG
(
DEBUG
)
<<
"SexpToNode sexp: "
+
sexp
.
ToString
()
+
", graph "
+
graph
.
ToString
();
MS_EXCEPTION_IF_NULL
(
primitive_vars
);
if
(
utils
::
isa
<
VectorRef
>
(
sexp
))
{
return
HandleSexpVector
(
sexp
,
graph
,
multigraph
);
return
HandleSexpVector
(
sexp
,
graph
,
primitive_vars
,
multigraph
);
}
if
(
utils
::
isa
<
VarPtr
>
(
sexp
))
{
auto
var_ptr
=
utils
::
cast
<
VarPtr
>
(
sexp
);
MS_EXCEPTION_IF_NULL
(
var_ptr
);
if
(
var_ptr
->
primitive
())
{
(
*
primitive_vars
)[
var_ptr
->
primitive
()]
=
var_ptr
;
return
NewValueNode
(
var_ptr
->
primitive
());
}
return
CreateVarNodeWithSexp
(
sexp
,
graph
);
}
if
(
utils
::
isa
<
AnfNodePtr
>
(
sexp
))
{
...
...
@@ -89,13 +98,14 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph
return
value_node
;
}
AnfNodePtr
HandleSexpVector
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
bool
multigraph
)
{
AnfNodePtr
HandleSexpVector
(
const
BaseRef
&
sexp
,
const
BaseRef
&
graph
,
PrimitiveVarMap
*
primitive_vars
,
bool
multigraph
)
{
MS_LOG
(
DEBUG
)
<<
"HandleSexpVector sexp: "
+
sexp
.
ToString
()
+
", graph "
+
graph
.
ToString
();
std
::
vector
<
AnfNodePtr
>
input_nodes
;
const
auto
&
tuple
=
utils
::
cast
<
VectorRef
>
(
sexp
);
if
(
multigraph
&&
utils
::
isa
<
VarPtr
>
(
graph
))
{
for
(
auto
&
x
:
tuple
)
{
AnfNodePtr
node
=
SexpToNode
(
x
,
std
::
make_shared
<
Var
>
(
"G"
),
true
);
AnfNodePtr
node
=
SexpToNode
(
x
,
std
::
make_shared
<
Var
>
(
"G"
),
primitive_vars
,
true
);
input_nodes
.
push_back
(
node
);
}
VarPtr
var_ptr
=
utils
::
cast
<
VarPtr
>
(
graph
);
...
...
@@ -103,7 +113,7 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool mult
}
for
(
auto
&
x
:
tuple
)
{
AnfNodePtr
node
=
SexpToNode
(
x
,
graph
,
multigraph
);
AnfNodePtr
node
=
SexpToNode
(
x
,
graph
,
primitive_vars
,
multigraph
);
input_nodes
.
push_back
(
node
);
}
return
CreateCNodeWithGraph
(
input_nodes
,
graph
);
...
...
@@ -166,7 +176,8 @@ PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
multigraph_
(
multigraph
),
pattern_engine_
(
PatternEngine
(
std
::
make_shared
<
DefaultVisitor
>
(),
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
AnfEqual
),
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
CNodeTypeEqual
)))
{}
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
CNodeTypeEqual
))),
primitive_vars_
(
std
::
make_shared
<
PrimitiveVarMap
>
())
{}
const
BaseRef
PatternProcessPass
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
...
...
@@ -176,7 +187,7 @@ const BaseRef PatternProcessPass::DefinePattern() const {
void
PatternProcessPass
::
Build
()
{
VarPtr
fg
=
std
::
make_shared
<
Var
>
(
"RootG"
);
BaseRef
pattern
=
std
::
move
(
DefinePattern
());
pattern_
=
SexpToNode
(
pattern
,
fg
,
multigraph_
);
pattern_
=
SexpToNode
(
pattern
,
fg
,
primitive_vars_
.
get
(),
multigraph_
);
}
AnfNodePtr
PatternProcessPass
::
Run
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
...
...
@@ -185,7 +196,8 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
}
auto
empty_equiv
=
std
::
make_shared
<
Equiv
>
();
EquivPtr
equiv
=
pattern_engine_
.
Match
(
pattern_
,
node
,
empty_equiv
);
MS_EXCEPTION_IF_NULL
(
primitive_vars_
);
EquivPtr
equiv
=
pattern_engine_
.
Match
(
pattern_
,
node
,
*
primitive_vars_
,
empty_equiv
);
if
(
equiv
!=
nullptr
&&
!
equiv
->
empty
())
{
return
Process
(
func_graph
,
node
,
equiv
);
}
...
...
mindspore/ccsrc/pre_activate/common/optimizer.h
浏览文件 @
e5c67b90
...
...
@@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "ir/anf.h"
#include "ir/func_graph.h"
...
...
@@ -46,6 +47,7 @@ class PatternProcessPass : public NodePass {
AnfNodePtr
pattern_
=
nullptr
;
bool
multigraph_
=
true
;
PatternEngine
pattern_engine_
;
PrimitiveVarMapPtr
primitive_vars_
;
};
class
GraphOptimizer
{
...
...
mindspore/ccsrc/pre_activate/common/pattern_engine.cc
浏览文件 @
e5c67b90
...
...
@@ -42,7 +42,7 @@ void Var::EnsureTag() {
}
}
bool
operator
==
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
)
{
bool
operator
==
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
)
{
if
(
lhs
->
isa
<
CondVar
>
()
&&
rhs
->
isa
<
CondVar
>
())
{
CondVarPtr
v1
=
dyn_cast
<
CondVar
>
(
lhs
);
CondVarPtr
v2
=
dyn_cast
<
CondVar
>
(
rhs
);
...
...
@@ -63,7 +63,7 @@ std::string SeqVar::ToString() const {
return
buffer
.
str
();
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
VarPtr
&
var
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
VarPtr
&
var
)
{
if
(
var
==
nullptr
)
{
os
<<
""
;
}
else
{
...
...
@@ -73,10 +73,10 @@ std::ostream& operator<<(std::ostream& os, const VarPtr& var) {
}
template
<
>
std
::
ostream
&
operator
<<<
VarPtr
,
BaseRef
>
(
std
::
ostream
&
os
,
const
Equiv
&
equiv
)
{
std
::
ostream
&
operator
<<<
VarPtr
,
BaseRef
>
(
std
::
ostream
&
os
,
const
Equiv
&
equiv
)
{
os
<<
"[Equiv]"
<<
"
\n
"
;
for
(
auto
&
equiv_item
:
equiv
)
{
for
(
auto
&
equiv_item
:
equiv
)
{
auto
k
=
equiv_item
.
first
;
os
<<
k
<<
":"
;
BaseRef
x
=
equiv_item
.
second
;
...
...
@@ -104,7 +104,7 @@ std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv)
return
os
;
}
static
BaseRef
GetVar
(
const
BaseRef
&
x
)
{
static
BaseRef
GetVar
(
const
BaseRef
&
x
)
{
MS_LOG
(
DEBUG
)
<<
"getVar start :%s"
+
x
.
ToString
();
if
(
utils
::
isa
<
AnfNodePtr
>
(
x
))
{
auto
node
=
utils
::
cast
<
AnfNodePtr
>
(
x
);
...
...
@@ -129,7 +129,7 @@ static BaseRef GetVar(const BaseRef& x) {
return
x
;
}
EquivPtr
MatchOnVar
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
EquivPtr
equiv
)
{
EquivPtr
MatchOnVar
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
EquivPtr
equiv
)
{
MS_LOG
(
DEBUG
)
<<
"MatchOnVar pattern "
+
pattern
.
ToString
()
+
" expr: "
+
expr
.
ToString
();
MS_EXCEPTION_IF_NULL
(
equiv
);
if
(
utils
::
isa
<
VarPtr
>
(
pattern
))
{
...
...
@@ -144,8 +144,8 @@ EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv)
return
nullptr
;
}
bool
PatternEngine
::
ToVector
(
const
VectorRef
&
pattern_ref
,
const
VectorRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
{
bool
PatternEngine
::
ToVector
(
const
VectorRef
&
pattern_ref
,
const
VectorRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
{
MS_EXCEPTION_IF_NULL
(
values_expr
);
if
(
utils
::
isa
<
SeqPtr
>
(
pattern_ref
))
{
*
values_pattern
=
pattern_ref
;
...
...
@@ -155,12 +155,12 @@ bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr
return
false
;
}
bool
PatternEngine
::
ToVector
(
const
BaseRef
&
pattern_ref
,
const
BaseRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
{
bool
PatternEngine
::
ToVector
(
const
BaseRef
&
pattern_ref
,
const
BaseRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
{
MS_EXCEPTION_IF_NULL
(
values_expr
);
// visitor to visite the list
auto
appender_pattern
=
[](
VectorRef
&
values
)
{
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
fn
=
[
&
](
const
BaseRef
&
u
)
{
auto
appender_pattern
=
[](
VectorRef
&
values
)
{
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
fn
=
[
&
](
const
BaseRef
&
u
)
{
values
.
push_back
(
GetVar
(
u
));
return
u
;
};
...
...
@@ -174,8 +174,8 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref
return
false
;
}
auto
appender_expr
=
[](
VectorRef
&
values
)
{
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
fn
=
[
&
](
const
BaseRef
&
u
)
{
auto
appender_expr
=
[](
VectorRef
&
values
)
{
std
::
function
<
BaseRef
(
const
BaseRef
&
)
>
fn
=
[
&
](
const
BaseRef
&
u
)
{
values
.
push_back
(
u
);
return
u
;
};
...
...
@@ -187,10 +187,10 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref
return
visitor_
->
Visit
(
expr_ref
,
nullptr
);
}
static
int
GetSVarStartIndex
(
const
VectorRef
&
values
)
{
static
int
GetSVarStartIndex
(
const
VectorRef
&
values
)
{
int
index
=
-
1
;
int
count
=
0
;
for
(
auto
&
value
:
values
)
{
for
(
auto
&
value
:
values
)
{
if
(
utils
::
isa
<
VarPtr
>
(
value
)
&&
utils
::
cast
<
VarPtr
>
(
value
)
->
isa
<
SeqVar
>
())
{
if
(
index
!=
-
1
)
{
MS_LOG
(
DEBUG
)
<<
"Multiple SVars in sequence"
;
...
...
@@ -203,7 +203,35 @@ static int GetSVarStartIndex(const VectorRef& values) {
return
index
;
}
EquivPtr
PatternEngine
::
AlignSVar
(
const
VectorRef
&
values_pattern
,
const
VectorRef
&
values_expr
,
EquivPtr
equiv
)
const
{
void
UpdateEquivMap
(
const
VectorRef
&
values_pattern
,
const
BaseRef
&
expr_ref
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
{
if
(
equiv
==
nullptr
||
values_pattern
.
empty
()
||
!
utils
::
isa
<
AnfNodePtr
>
(
values_pattern
[
0
])
||
!
utils
::
isa
<
AnfNodePtr
>
(
expr_ref
))
{
return
;
}
auto
real_node
=
utils
::
cast
<
AnfNodePtr
>
(
expr_ref
);
MS_EXCEPTION_IF_NULL
(
real_node
);
if
(
!
real_node
->
isa
<
CNode
>
())
{
return
;
}
auto
prim_node
=
utils
::
cast
<
AnfNodePtr
>
(
values_pattern
[
0
]);
MS_EXCEPTION_IF_NULL
(
prim_node
);
if
(
!
IsValueNode
<
Primitive
>
(
prim_node
))
{
return
;
}
ValuePtr
value
=
GetValueNode
(
prim_node
);
MS_EXCEPTION_IF_NULL
(
value
);
auto
prim
=
value
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
prim
);
auto
iter
=
primitive_vars
.
find
(
prim
);
if
(
iter
==
primitive_vars
.
end
())
{
return
;
}
(
*
equiv
)[
iter
->
second
]
=
real_node
;
}
EquivPtr
PatternEngine
::
AlignSVar
(
const
VectorRef
&
values_pattern
,
const
VectorRef
&
values_expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
{
int
svar_index
=
GetSVarStartIndex
(
values_pattern
);
if
(
svar_index
==
kInvalidVarIndex
)
{
return
nullptr
;
...
...
@@ -229,12 +257,12 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR
if
(
svar_index
!=
-
1
&&
i
==
IntToSize
(
svar_index
))
{
auto
seq
=
std
::
vector
<
BaseRef
>
(
values_expr
.
begin
()
+
svar_index
,
values_expr
.
begin
()
+
svar_index
+
SizeToInt
(
diff
));
equiv
=
Match
(
values_pattern
[
svar_index
],
seq
,
equiv
);
equiv
=
Match
(
values_pattern
[
svar_index
],
seq
,
primitive_vars
,
equiv
);
}
else
{
if
(
svar_index
!=
-
1
&&
i
>
IntToSize
(
svar_index
))
{
expr_i
=
i
+
diff
-
1
;
}
equiv
=
Match
(
values_pattern
[
i
],
values_expr
[
expr_i
],
equiv
);
equiv
=
Match
(
values_pattern
[
i
],
values_expr
[
expr_i
],
primitive_vars
,
equiv
);
}
if
(
equiv
==
nullptr
)
{
return
nullptr
;
...
...
@@ -243,7 +271,8 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR
return
equiv
;
}
EquivPtr
PatternEngine
::
Match
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
EquivPtr
equiv
)
const
{
EquivPtr
PatternEngine
::
Match
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
{
MS_LOG
(
DEBUG
)
<<
"-----[in Match]"
;
MS_LOG
(
DEBUG
)
<<
"GetVar w"
;
BaseRef
pattern_ref
=
GetVar
(
pattern
);
...
...
@@ -292,10 +321,12 @@ EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, Equiv
// 6. if any svar in both side, find the SeqVar index,
// try to pack the Var s in std::vector to a Seq and match elements one by one.
// check svar
return
AlignSVar
(
values_pattern
,
values_expr
,
equiv
);
equiv
=
AlignSVar
(
values_pattern
,
values_expr
,
primitive_vars
,
equiv
);
UpdateEquivMap
(
values_pattern
,
expr_ref
,
primitive_vars
,
equiv
);
return
equiv
;
}
BaseRef
PatternEngine
::
Replace
(
const
BaseRef
&
pattern
,
const
EquivPtr
&
equiv
)
const
{
BaseRef
PatternEngine
::
Replace
(
const
BaseRef
&
pattern
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
MS_LOG
(
DEBUG
)
<<
"-----[in Replace]"
;
BaseRef
ref
=
GetVar
(
pattern
);
...
...
@@ -304,7 +335,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co
// w is var
if
(
utils
::
isa
<
VarPtr
>
(
ref
))
{
const
VarPtr
&
var
=
utils
::
cast
<
VarPtr
>
(
ref
);
const
VarPtr
&
var
=
utils
::
cast
<
VarPtr
>
(
ref
);
auto
iter
=
equiv
->
find
(
var
);
if
(
iter
!=
equiv
->
end
())
{
out
=
iter
->
second
;
...
...
@@ -316,7 +347,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co
}
// visitor to visit the list
std
::
function
<
BaseRef
(
BaseRef
)
>
fn
=
[
&
,
this
,
equiv
](
const
BaseRef
&
u
)
{
return
Replace
(
u
,
equiv
);
};
std
::
function
<
BaseRef
(
BaseRef
)
>
fn
=
[
&
,
this
,
equiv
](
const
BaseRef
&
u
)
{
return
Replace
(
u
,
equiv
);
};
visitor_
->
SetFn
(
fn
);
BaseRef
visit_out
;
...
...
mindspore/ccsrc/pre_activate/common/pattern_engine.h
浏览文件 @
e5c67b90
...
...
@@ -31,6 +31,7 @@
#include <map>
#include <stdexcept>
#include <list>
#include <utility>
#include "pre_activate/common/visit.h"
#include "ir/base.h"
...
...
@@ -44,16 +45,19 @@ using CondVarPtr = std::shared_ptr<CondVar>;
using
SVarPtr
=
std
::
shared_ptr
<
SeqVar
>
;
const
int
kInvalidVarIndex
=
-
2
;
using
ConditionFunc
=
std
::
function
<
bool
(
const
BaseRef
&
)
>
;
using
ConditionFunc
=
std
::
function
<
bool
(
const
BaseRef
&
)
>
;
// Base wildcard variable which could match any anf node.
class
Var
:
public
Base
{
friend
class
VarHasher
;
public:
explicit
Var
(
const
std
::
string
&
tag
=
""
)
:
tag_
(
tag
)
{
EnsureTag
();
}
Var
(
const
Var
&
other
)
:
Base
(
other
),
tag_
(
other
.
tag_
)
{}
virtual
Var
&
operator
=
(
const
Var
&
other
)
{
explicit
Var
(
std
::
string
tag
=
""
)
:
tag_
(
std
::
move
(
tag
)),
primitive_
(
nullptr
)
{
EnsureTag
();
}
explicit
Var
(
const
PrimitivePtr
&
primitive
,
std
::
string
tag
=
""
)
:
tag_
(
std
::
move
(
tag
)),
primitive_
(
primitive
)
{
EnsureTag
();
}
Var
(
const
Var
&
other
)
:
Base
(
other
),
tag_
(
other
.
tag_
)
{}
virtual
Var
&
operator
=
(
const
Var
&
other
)
{
if
(
&
other
==
this
)
{
return
*
this
;
}
...
...
@@ -63,12 +67,13 @@ class Var : public Base {
~
Var
()
override
=
default
;
MS_DECLARE_PARENT
(
Var
,
Base
);
virtual
bool
matches
(
const
BaseRef
&
)
{
return
true
;
}
virtual
bool
matches
(
const
BaseRef
&
)
{
return
true
;
}
virtual
bool
operator
==
(
const
Var
&
other
)
const
{
return
tag_
==
other
.
tag_
;
}
bool
operator
!=
(
const
Var
&
other
)
const
{
return
!
(
&
other
==
this
);
}
virtual
bool
operator
==
(
const
Var
&
other
)
const
{
return
tag_
==
other
.
tag_
;
}
bool
operator
!=
(
const
Var
&
other
)
const
{
return
!
(
&
other
==
this
);
}
std
::
string
tag
()
const
{
return
tag_
;
}
PrimitivePtr
primitive
()
const
{
return
primitive_
;
}
std
::
string
ToString
()
const
override
{
std
::
ostringstream
buffer
;
buffer
<<
"Var("
<<
tag_
<<
")"
;
...
...
@@ -80,12 +85,13 @@ class Var : public Base {
void
EnsureTag
();
std
::
string
tag_
;
PrimitivePtr
primitive_
;
};
// VarNode means variable node, a subclass of AnfNode
class
VarNode
:
public
AnfNode
{
public:
VarNode
(
const
VarPtr
&
value
,
const
FuncGraphPtr
&
func_graph
)
:
AnfNode
(
func_graph
),
var_
(
value
)
{}
VarNode
(
const
VarPtr
&
value
,
const
FuncGraphPtr
&
func_graph
)
:
AnfNode
(
func_graph
),
var_
(
value
)
{}
~
VarNode
()
override
=
default
;
MS_DECLARE_PARENT
(
VarNode
,
AnfNode
);
...
...
@@ -95,16 +101,16 @@ using VarNodePtr = std::shared_ptr<VarNode>;
class
VarHasher
{
public:
std
::
size_t
operator
()(
const
Var
&
var
)
const
{
return
var
.
hash
();
}
std
::
size_t
operator
()(
const
Var
&
var
)
const
{
return
var
.
hash
();
}
};
// Condition Var, match an anf node when condition function return true.
class
CondVar
:
public
Var
{
public:
explicit
CondVar
(
const
ConditionFunc
&
cond
)
:
cond_fn_
(
cond
)
{}
explicit
CondVar
(
const
ConditionFunc
&
cond
)
:
cond_fn_
(
cond
)
{}
~
CondVar
()
override
=
default
;
MS_DECLARE_PARENT
(
CondVar
,
Var
);
bool
matches
(
const
BaseRef
&
value
)
override
{
bool
matches
(
const
BaseRef
&
value
)
override
{
MS_LOG
(
DEBUG
)
<<
"CondVarPtr match: "
+
value
.
ToString
();
if
(
utils
::
isa
<
Var
>
(
value
))
{
return
false
;
...
...
@@ -124,55 +130,60 @@ class SeqVar : public Var {
~
SeqVar
()
override
=
default
;
MS_DECLARE_PARENT
(
SeqVar
,
Var
);
explicit
SeqVar
(
const
VarPtr
subvar
)
:
subvar_
(
nullptr
)
{
subvar_
=
subvar
;
}
bool
matches
(
const
BaseRef
&
value
)
override
{
bool
matches
(
const
BaseRef
&
value
)
override
{
// match Seq.
if
(
utils
::
isa
<
Seq
>
(
value
))
{
const
Seq
&
seq
=
utils
::
cast
<
Seq
>
(
value
);
return
std
::
all_of
(
seq
.
begin
(),
seq
.
end
(),
[
this
](
const
BaseRef
&
v
)
{
const
Seq
&
seq
=
utils
::
cast
<
Seq
>
(
value
);
return
std
::
all_of
(
seq
.
begin
(),
seq
.
end
(),
[
this
](
const
BaseRef
&
v
)
{
auto
eq
=
subvar_
->
matches
(
v
);
return
eq
;
});
}
return
false
;
}
bool
operator
==
(
const
SeqVar
&
other
)
const
{
return
*
subvar_
==
*
other
.
subvar_
;
}
bool
operator
==
(
const
SeqVar
&
other
)
const
{
return
*
subvar_
==
*
other
.
subvar_
;
}
std
::
string
ToString
()
const
override
;
private:
VarPtr
subvar_
;
};
bool
operator
==
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
);
bool
operator
==
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
);
inline
bool
operator
!=
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
inline
bool
operator
!=
(
const
VarPtr
&
lhs
,
const
VarPtr
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
VarPtr
&
var
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
VarPtr
&
var
);
using
Equiv
=
std
::
map
<
VarPtr
,
BaseRef
>
;
using
EquivPtr
=
std
::
shared_ptr
<
Equiv
>
;
using
PrimitiveVarMap
=
std
::
unordered_map
<
PrimitivePtr
,
VarPtr
>
;
using
PrimitiveVarMapPtr
=
std
::
shared_ptr
<
PrimitiveVarMap
>
;
inline
bool
DefaultTypeEq
(
const
BaseRef
&
x
,
const
BaseRef
&
y
)
{
return
x
.
type
()
==
y
.
type
();
}
inline
bool
DefaultTypeEq
(
const
BaseRef
&
x
,
const
BaseRef
&
y
)
{
return
x
.
type
()
==
y
.
type
();
}
class
PatternEngine
{
public:
PatternEngine
(
const
std
::
shared_ptr
<
Visitor
>&
visitor
,
const
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>&
eq
,
const
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>&
type_eq
=
DefaultTypeEq
)
PatternEngine
(
const
std
::
shared_ptr
<
Visitor
>
&
visitor
,
const
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
&
eq
,
const
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
&
type_eq
=
DefaultTypeEq
)
:
visitor_
(
visitor
),
eq_
(
eq
),
type_eq_
(
type_eq
)
{}
~
PatternEngine
()
=
default
;
EquivPtr
Match
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
EquivPtr
equiv
)
const
;
EquivPtr
Match
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
;
// Replace pattern with equivalent
BaseRef
Replace
(
const
BaseRef
&
pattern
,
const
EquivPtr
&
equiv
)
const
;
BaseRef
Replace
(
const
BaseRef
&
pattern
,
const
EquivPtr
&
equiv
)
const
;
private:
EquivPtr
AlignSVar
(
const
VectorRef
&
values_pattern
,
const
VectorRef
&
values_expr
,
EquivPtr
equiv
)
const
;
bool
ToVector
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
;
bool
ToVector
(
const
VectorRef
&
pattern_ref
,
const
VectorRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
;
EquivPtr
AlignSVar
(
const
VectorRef
&
values_pattern
,
const
VectorRef
&
values_expr
,
const
PrimitiveVarMap
&
primitive_vars
,
EquivPtr
equiv
)
const
;
bool
ToVector
(
const
BaseRef
&
pattern
,
const
BaseRef
&
expr
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
;
bool
ToVector
(
const
VectorRef
&
pattern_ref
,
const
VectorRef
&
expr_ref
,
VectorRef
*
const
values_pattern
,
VectorRef
*
const
values_expr
)
const
;
std
::
shared_ptr
<
Visitor
>
visitor_
;
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
eq_
;
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
type_eq_
;
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
eq_
;
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
type_eq_
;
};
}
// namespace mindspore
namespace
std
{
...
...
tests/ut/cpp/pre_activate/common/pattern_engine_test.cc
浏览文件 @
e5c67b90
...
...
@@ -40,6 +40,7 @@ class TestMatchEngine : public UT::Common {
public:
PatternEngine
TU
;
EquivPtr
equiv_null
;
PrimitiveVarMap
primitive_vars_null
;
};
TEST_F
(
TestMatchEngine
,
Var
)
{
...
...
@@ -106,30 +107,30 @@ TEST_F(TestMatchEngine, MatchRaw_Var) {
// common
equiv_null
->
clear
();
d
=
TU
.
Match
(
v1
,
1
,
equiv_null
);
d
=
TU
.
Match
(
v1
,
1
,
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
((
*
d
)[
v1
],
1
);
equiv_null
->
clear
();
(
*
equiv_null
)[
v1
]
=
v2
;
d
=
TU
.
Match
(
v1
,
1
,
equiv_null
);
d
=
TU
.
Match
(
v1
,
1
,
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
count
(
v2
),
std
::
size_t
(
1
));
ASSERT_EQ
((
*
d
)[
v2
],
1
);
equiv_null
->
clear
();
(
*
equiv_null
)[
v1
]
=
v2
;
(
*
equiv_null
)[
v3
]
=
1
;
d
=
TU
.
Match
(
v1
,
1
,
equiv_null
);
d
=
TU
.
Match
(
v1
,
1
,
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
count
(
v2
),
std
::
size_t
(
1
));
ASSERT_EQ
((
*
d
)[
v2
],
1
);
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
v1
}),
VectorRef
({
1
}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
v1
}),
VectorRef
({
1
}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
size
(),
std
::
size_t
(
1
));
ASSERT_EQ
(
d
->
count
(
v1
),
std
::
size_t
(
1
));
ASSERT_EQ
((
*
d
)[
v1
],
1
);
equiv_null
->
clear
();
ASSERT_EQ
(
TU
.
Match
(
1
,
2
,
equiv_null
),
nullptr
);
ASSERT_EQ
(
TU
.
Match
(
1
,
2
,
primitive_vars_null
,
equiv_null
),
nullptr
);
}
TEST_F
(
TestMatchEngine
,
MatchRaw_SVar
)
{
...
...
@@ -139,22 +140,22 @@ TEST_F(TestMatchEngine, MatchRaw_SVar) {
EquivPtr
d
;
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
sv1
}),
VectorRef
({
1
,
2
}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
sv1
}),
VectorRef
({
1
,
2
}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
size
(),
std
::
size_t
(
1
));
ASSERT_EQ
(
d
->
count
(
sv1
),
std
::
size_t
(
1
));
ASSERT_EQ
(
utils
::
cast
<
Seq
>
((
*
d
)[
sv1
]),
Seq
({
1
,
2
}));
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
v1
,
sv1
}),
VectorRef
({
1
,
2
}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
v1
,
sv1
}),
VectorRef
({
1
,
2
}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
size
(),
std
::
size_t
(
2
));
ASSERT_EQ
(
utils
::
cast
<
Seq
>
((
*
d
)[
sv1
]),
Seq
({
2
}));
equiv_null
->
clear
();
ASSERT_EQ
(
TU
.
Match
(
VectorRef
({
sv1
,
sv2
}),
VectorRef
({
1
,
2
}),
equiv_null
),
nullptr
);
ASSERT_EQ
(
TU
.
Match
(
VectorRef
({
sv1
,
sv2
}),
VectorRef
({
1
,
2
}),
primitive_vars_null
,
equiv_null
),
nullptr
);
equiv_null
->
clear
();
(
*
equiv_null
)[
sv1
]
=
std
::
make_shared
<
Seq
>
(
PatternListType
{
1
,
2
});
d
=
TU
.
Match
(
VectorRef
({
v1
,
sv1
}),
VectorRef
({
1
,
1
,
2
}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
v1
,
sv1
}),
VectorRef
({
1
,
1
,
2
}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
size
(),
std
::
size_t
(
2
));
ASSERT_EQ
((
*
d
)[
v1
],
1
);
}
...
...
@@ -167,13 +168,13 @@ TEST_F(TestMatchEngine, Match) {
EquivPtr
d
;
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
v1
,
v1
,
v2
}),
VectorRef
({
1
,
1
,
2
}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
v1
,
v1
,
v2
}),
VectorRef
({
1
,
1
,
2
}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
->
size
(),
std
::
size_t
(
2
));
ASSERT_EQ
((
*
d
)[
v1
],
1
);
ASSERT_EQ
((
*
d
)[
v2
],
2
);
equiv_null
->
clear
();
d
=
TU
.
Match
(
static_cast
<
int
>
(
1
),
static_cast
<
float
>
(
1
),
equiv_null
);
d
=
TU
.
Match
(
static_cast
<
int
>
(
1
),
static_cast
<
float
>
(
1
),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
,
nullptr
);
}
...
...
@@ -197,18 +198,19 @@ TEST_F(TestMatchEngine, Match_CondVar) {
EquivPtr
d
;
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
vf
,
vn
}),
VectorRef
({
static_cast
<
float
>
(
1.0
),
-
1
}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
vf
,
vn
}),
VectorRef
({
static_cast
<
float
>
(
1.0
),
-
1
}),
primitive_vars_null
,
equiv_null
);
ASSERT_GE
(
d
->
size
(),
std
::
size_t
(
0
));
auto
vfn
=
(
*
d
)[
vf
];
ASSERT_EQ
((
*
d
)[
vf
],
static_cast
<
float
>
(
1.0
));
ASSERT_EQ
((
*
d
)[
vn
],
-
1
);
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
vf
,
vn
}),
VectorRef
({
1
,
static_cast
<
float
>
(
-
1.0
)}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
vf
,
vn
}),
VectorRef
({
1
,
static_cast
<
float
>
(
-
1.0
)}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
,
nullptr
);
equiv_null
->
clear
();
d
=
TU
.
Match
(
VectorRef
({
vf
,
vn
}),
VectorRef
({
static_cast
<
float
>
(
1.0
),
static_cast
<
int
>
(
1
)}),
equiv_null
);
d
=
TU
.
Match
(
VectorRef
({
vf
,
vn
}),
VectorRef
({
static_cast
<
float
>
(
1.0
),
static_cast
<
int
>
(
1
)}),
primitive_vars_null
,
equiv_null
);
ASSERT_EQ
(
d
,
nullptr
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录