Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b2ec296f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b2ec296f
编写于
6月 22, 2020
作者:
Z
zhousiyi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add opt pass for tuple_getitem with constant input
上级
5b14292f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
214 addition
and
76 deletion
+214
-76
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+1
-0
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+1
-0
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+145
-75
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
+46
-1
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+1
-0
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+20
-0
未找到文件。
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
b2ec296f
...
...
@@ -51,6 +51,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_
=
MakeSubstitution
(
ArithmeticSimplify
(),
"arithmetic_simplify"
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
,
prim
::
kPrimPow
});
arithmetic_simplify2_
=
MakeSubstitution
(
ArithmeticSimplify2
(),
"arithmetic_simplify2"
,
{
prim
::
kPrimMul
});
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimStopGradient
,
prim
::
kPrimHookBackward
,
...
...
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
b2ec296f
...
...
@@ -33,6 +33,7 @@ class OptimizeIRPassLib {
~
OptimizeIRPassLib
()
=
default
;
SubstitutionPtr
arithmetic_simplify_
;
SubstitutionPtr
arithmetic_simplify2_
;
SubstitutionPtr
special_op_eliminate_
;
SubstitutionPtr
zero_like_fill_zero_
;
SubstitutionPtr
adjust_all_reduce_mul_add_
;
...
...
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
b2ec296f
...
...
@@ -139,76 +139,8 @@ class CheckTensorConstant {
int
check_value_
;
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class
TensorMultiplyByZeroOrOne
:
public
AnfVisitor
{
public:
TensorMultiplyByZeroOrOne
()
:
zero_
(
MakeValue
(
0
))
{}
~
TensorMultiplyByZeroOrOne
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
if
(
is_zero_
)
{
if
(
x_
->
func_graph
()
!=
node
->
func_graph
())
{
return
nullptr
;
}
return
NewTensorFilledWithData
(
node
);
}
if
(
is_one_
)
{
return
NewTensorFilledWithData
(
node
,
x_
);
}
return
nullptr
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
is_zero_
||
is_one_
)
{
x_
=
node
;
return
;
}
if
(
IsParam
(
node
))
{
x_
=
node
;
return
;
}
if
(
IsCNode
(
node
))
{
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
IsPrimitive
(
cnode
->
input
(
0
),
prim
::
kPrimZerosLike
))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
return
;
}
auto
value
=
node
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
else
if
(
CheckTensorConstant
(
1
).
IsTensorConstant
(
value
))
{
is_one_
=
true
;
return
;
}
x_
=
node
;
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
auto
value
=
vnode
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
else
if
(
CheckTensorConstant
(
1
).
IsTensorConstant
(
value
))
{
is_one_
=
true
;
return
;
}
x_
=
vnode
;
}
void
Reset
()
{
x_
=
nullptr
;
is_one_
=
false
;
is_zero_
=
false
;
}
class
TensorMultiplyBase
:
public
AnfVisitor
{
protected:
void
*
GetPointerToTensorData
(
const
AnfNodePtr
&
node
,
bool
writable
=
false
)
{
if
(
!
node
->
isa
<
ValueNode
>
())
{
return
nullptr
;
...
...
@@ -287,10 +219,122 @@ class TensorMultiplyByZeroOrOne : public AnfVisitor {
return
new_vnode
;
}
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class
TensorMultiplyByZero
:
public
TensorMultiplyBase
{
public:
TensorMultiplyByZero
()
:
zero_
(
MakeValue
(
0
))
{}
~
TensorMultiplyByZero
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
if
(
is_zero_
)
{
if
(
x_
->
func_graph
()
!=
node
->
func_graph
())
{
return
nullptr
;
}
return
NewTensorFilledWithData
(
node
);
}
return
nullptr
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
is_zero_
)
{
x_
=
node
;
return
;
}
if
(
IsParam
(
node
))
{
x_
=
node
;
return
;
}
if
(
IsCNode
(
node
))
{
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
IsPrimitive
(
cnode
->
input
(
0
),
prim
::
kPrimZerosLike
))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
return
;
}
auto
value
=
node
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
auto
value
=
vnode
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
x_
=
vnode
;
}
void
Reset
()
{
x_
=
nullptr
;
is_zero_
=
false
;
}
private:
bool
is_zero_
{
false
}
,
is_one_
{
false
}
;
bool
is_zero_
{
false
};
ValuePtr
zero_
;
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class
TensorMultiplyByOne
:
public
TensorMultiplyBase
{
public:
TensorMultiplyByOne
()
{}
~
TensorMultiplyByOne
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
if
(
is_one_
)
{
return
NewTensorFilledWithData
(
node
,
x_
);
}
return
nullptr
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
is_one_
)
{
x_
=
node
;
return
;
}
if
(
IsParam
(
node
)
||
IsCNode
(
node
))
{
x_
=
node
;
return
;
}
auto
value
=
node
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
CheckTensorConstant
(
1
).
IsTensorConstant
(
value
))
{
is_one_
=
true
;
return
;
}
x_
=
node
;
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
auto
value
=
vnode
->
value
();
if
(
CheckTensorConstant
(
1
).
IsTensorConstant
(
value
))
{
is_one_
=
true
;
return
;
}
x_
=
vnode
;
}
void
Reset
()
{
x_
=
nullptr
;
is_one_
=
false
;
}
private:
bool
is_one_
{
false
};
};
// {prim::kPrimScalarAdd, X, 0}
...
...
@@ -699,7 +743,7 @@ class ArithmeticSimplify {
public:
ArithmeticSimplify
()
:
multiply_by_zero_or_one_
(),
tensor_multiply_by_
zero_or_
one_
(),
tensor_multiply_by_one_
(),
add_by_zero_
(),
tensor_add_by_zero_
(),
identity_
(
prim
::
kPrimIdentity
),
...
...
@@ -707,7 +751,7 @@ class ArithmeticSimplify {
constant_duplicate_mul_
(),
power_one_
()
{
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_
zero_or_
one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_one_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
eliminaters_
.
emplace_back
(
tensor_add_by_zero_
);
eliminaters_
.
emplace_back
(
identity_
);
...
...
@@ -730,7 +774,7 @@ class ArithmeticSimplify {
private:
MultiplyByZeroOrOne
multiply_by_zero_or_one_
;
TensorMultiplyBy
ZeroOrOne
tensor_multiply_by_zero_or
_one_
;
TensorMultiplyBy
One
tensor_multiply_by
_one_
;
AddByZero
add_by_zero_
;
TensorAddByZero
tensor_add_by_zero_
;
PrimEliminater
identity_
;
...
...
@@ -739,6 +783,32 @@ class ArithmeticSimplify {
PowerOneEliminate
power_one_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
// Arithmetic Simplifications should be done after step_parallel.
// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class
ArithmeticSimplify2
{
public:
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
()
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
~
ArithmeticSimplify2
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
private:
TensorMultiplyByZero
tensor_multiply_by_zero_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
浏览文件 @
b2ec296f
...
...
@@ -70,6 +70,45 @@ class GetitemEliminater : public AnfVisitor {
CNodePtr
tuple_
{
nullptr
};
};
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
class
GetitemConstEliminater
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimTupleGetItem
,
{
IsVNode
,
IsVNode
})(
node
);
if
(
is_match_
)
{
return
NewValueNode
((
*
tuple_
)[
id_
]);
}
return
nullptr
;
}
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
if
(
IsValueNode
<
ValueTuple
>
(
vnode
))
{
tuple_
=
GetValueNode
<
ValueTuplePtr
>
(
vnode
);
}
if
(
tuple_
!=
nullptr
&&
IsValueNode
<
Int32Imm
>
(
vnode
))
{
id_
=
IntToSize
(
GetValue
<
int
>
(
vnode
->
value
()));
if
(
tuple_
->
size
()
>
id_
)
{
is_match_
=
true
;
}
}
}
void
Reset
()
{
id_
=
0
;
tuple_
=
nullptr
;
is_match_
=
false
;
}
private:
bool
is_match_
{
false
};
size_t
id_
{
0
};
ValueTuplePtr
tuple_
{
nullptr
};
};
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
...
...
@@ -225,8 +264,13 @@ class GetitemDependReorder : public AnfVisitor {
class
ItemTupleEliminater
{
public:
ItemTupleEliminater
()
:
get_item_eliminater_
(),
set_item_eliminater_
(),
get_set_item_eliminater_
(),
get_item_depend_reorder_
()
{
:
get_item_eliminater_
(),
get_item_const_eliminater_
(),
set_item_eliminater_
(),
get_set_item_eliminater_
(),
get_item_depend_reorder_
()
{
eliminaters_
.
emplace_back
(
get_item_eliminater_
);
eliminaters_
.
emplace_back
(
get_item_const_eliminater_
);
eliminaters_
.
emplace_back
(
set_item_eliminater_
);
eliminaters_
.
emplace_back
(
get_set_item_eliminater_
);
eliminaters_
.
emplace_back
(
get_item_depend_reorder_
);
...
...
@@ -246,6 +290,7 @@ class ItemTupleEliminater {
private:
GetitemEliminater
get_item_eliminater_
;
GetitemConstEliminater
get_item_const_eliminater_
;
SetitemEliminater
set_item_eliminater_
;
GetSetitemEliminater
get_set_item_eliminater_
;
GetitemDependReorder
get_item_depend_reorder_
;
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
b2ec296f
...
...
@@ -114,6 +114,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
depend_value_elim_
,
});
opt
::
OptPassConfig
a_3
=
opt
::
OptPassConfig
({
irpass
.
arithmetic_simplify2_
,
irpass
.
same_eliminate_
,
irpass
.
check_bprop_eliminate_
,
irpass
.
replace_applicator_
,
...
...
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
b2ec296f
...
...
@@ -20,9 +20,12 @@
#include "common/py_func_graph_fetcher.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "pipeline/resource.h"
#include "debug/draw.h"
...
...
@@ -343,9 +346,26 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr
after_0
=
getPyFun
.
CallAndParseRet
(
"test_tuple_getitem"
,
"after_0"
);
FuncGraphPtr
after_1
=
getPyFun
.
CallAndParseRet
(
"test_tuple_getitem"
,
"after_1"
);
FuncGraphPtr
make_get_const
=
std
::
make_shared
<
FuncGraph
>
();
auto
value_node_1
=
NewValueNode
(
1
);
auto
value_node_2
=
NewValueNode
(
2
);
std
::
vector
<
int
>
vec
{
1
,
2
};
auto
value_node_tuple
=
NewValueNode
(
MakeValue
(
vec
));
std
::
vector
<
AnfNodePtr
>
node_list
{
NewValueNode
(
prim
::
kPrimTupleGetItem
),
value_node_tuple
,
value_node_1
};
auto
get_item
=
make_get_const
->
NewCNode
(
node_list
);
make_get_const
->
set_output
(
get_item
);
FuncGraphPtr
after_2
=
std
::
make_shared
<
FuncGraph
>
();
after_2
->
set_output
(
value_node_2
);
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
item_tuple_eliminate_
});
ASSERT_TRUE
(
CheckOpt
(
make_get_0
,
after_0
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
make_get_1
,
after_1
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
make_get_const
,
after_2
,
patterns
));
}
TEST_F
(
TestOptLib
,
test_tuple_setitem
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录