Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
db80643e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db80643e
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!3177 Update Arithmetic Simplify to use Pattern Matcher (2)
Merge pull request !3177 from Giancarlo/pm_arithmetic_simplify
上级
f30df6e3
cfbfaddf
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
510 addition
and
757 deletion
+510
-757
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
...re/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
+57
-547
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h
...ore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h
+6
-190
mindspore/core/ir/pattern_matcher.h
mindspore/core/ir/pattern_matcher.h
+446
-19
tests/ut/cpp/optimizer/opt_test.cc
tests/ut/cpp/optimizer/opt_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
浏览文件 @
db80643e
此差异已折叠。
点击以展开。
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
db80643e
...
...
@@ -21,159 +21,15 @@
#include <memory>
#include <vector>
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
class
MultiplyByZeroOrOne
:
public
AnfVisitor
{
public:
MultiplyByZeroOrOne
()
:
zero_
(
MakeValue
(
0
)),
one_
(
MakeValue
(
1
))
{}
~
MultiplyByZeroOrOne
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
},
is_one_
{
false
};
ValuePtr
zero_
,
one_
;
AnfNodePtr
x_
{
nullptr
};
};
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class
CheckTensorConstant
{
public:
explicit
CheckTensorConstant
(
int
_check_value
=
0
)
:
check_value_
(
_check_value
)
{}
~
CheckTensorConstant
()
=
default
;
bool
IsTensorConstant
(
const
ValuePtr
&
value
);
bool
IsTensorScalarConstant
(
const
ValuePtr
&
value
);
private:
int
check_value_
;
};
class
TensorMultiplyBase
:
public
AnfVisitor
{
protected:
void
*
GetPointerToTensorData
(
const
AnfNodePtr
&
node
,
bool
writable
=
false
);
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr
NewTensorFilledWithData
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
x
=
nullptr
);
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class
TensorMultiplyByZero
:
public
TensorMultiplyBase
{
public:
TensorMultiplyByZero
()
:
zero_
(
MakeValue
(
0
))
{}
~
TensorMultiplyByZero
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
};
ValuePtr
zero_
;
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class
TensorMultiplyByOne
:
public
TensorMultiplyBase
{
public:
TensorMultiplyByOne
()
{}
~
TensorMultiplyByOne
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_one_
{
false
};
};
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
class
AddByZero
:
public
AnfVisitor
{
public:
AddByZero
()
:
zero_
(
MakeValue
(
0
))
{}
~
AddByZero
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
};
ValuePtr
zero_
;
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
class
TensorAddByZero
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
};
AnfNodePtr
x_
{
nullptr
};
};
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
class
OptUpdateZeroTensor
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class
ConstantDuplicateMul
:
public
AnfVisitor
{
public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template
<
typename
T
>
void
Multiply
(
void
*
in_data_1
,
int
in_data_1_size
,
void
*
in_data_2
,
int
in_data_2_size
,
void
**
out_data
,
int
out_data_size
);
AnfNodePtr
MulConstantTensors
(
const
AnfNodePtr
&
vnode_1
,
const
AnfNodePtr
&
vnode_2
,
const
AnfNodePtr
&
node_3
);
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Reset
();
private:
AnfNodePtr
vnode_
;
AnfNodePtr
c_p_node_
;
};
class
PowerOneEliminate
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
...
...
@@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
class
ArithmeticSimplify
:
public
OptimizerCaller
{
public:
ArithmeticSimplify
()
:
multiply_by_zero_or_one_
(
std
::
make_shared
<
MultiplyByZeroOrOne
>
()),
tensor_multiply_by_one_
(
std
::
make_shared
<
TensorMultiplyByOne
>
()),
add_by_zero_
(
std
::
make_shared
<
AddByZero
>
()),
tensor_add_by_zero_
(
std
::
make_shared
<
TensorAddByZero
>
()),
identity_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimIdentity
)),
opt_update_zero_tensor_
(
std
::
make_shared
<
OptUpdateZeroTensor
>
()),
constant_duplicate_mul_
(
std
::
make_shared
<
ConstantDuplicateMul
>
()),
power_one_
(
std
::
make_shared
<
PowerOneEliminate
>
())
{
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_one_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
eliminaters_
.
emplace_back
(
tensor_add_by_zero_
);
eliminaters_
.
emplace_back
(
identity_
);
eliminaters_
.
emplace_back
(
opt_update_zero_tensor_
);
eliminaters_
.
emplace_back
(
constant_duplicate_mul_
);
eliminaters_
.
emplace_back
(
power_one_
);
}
~
ArithmeticSimplify
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
;
private:
OptimizerCallerPtr
multiply_by_zero_or_one_
;
OptimizerCallerPtr
tensor_multiply_by_one_
;
OptimizerCallerPtr
add_by_zero_
;
OptimizerCallerPtr
tensor_add_by_zero_
;
OptimizerCallerPtr
identity_
;
OptimizerCallerPtr
opt_update_zero_tensor_
;
OptimizerCallerPtr
constant_duplicate_mul_
;
OptimizerCallerPtr
power_one_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
// Arithmetic Simplifications should be done after step_parallel.
...
...
@@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller {
// ArithmeticSimplify and deferred until step_parallel.
class
ArithmeticSimplify2
:
public
OptimizerCaller
{
public:
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
(
std
::
make_shared
<
TensorMultiplyByZero
>
())
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
~
ArithmeticSimplify2
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
;
private:
OptimizerCallerPtr
tensor_multiply_by_zero_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
This diff is collapsed.
Click to expand it.
mindspore/core/ir/pattern_matcher.h
浏览文件 @
db80643e
此差异已折叠。
点击以展开。
tests/ut/cpp/optimizer/opt_test.cc
浏览文件 @
db80643e
...
...
@@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common {
};
void
SetUp
()
{
elim_Z
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
A
ddByZero
>
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_Z
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
A
rithmeticSimplify
>
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_R
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
PrimEliminater
>
(
R
),
"elim_R"
,
R
);
idempotent_P
=
MakeSubstitution
(
std
::
make_shared
<
IdempotentEliminater
>
(),
"idempotent_P"
,
P
);
Qct_to_P
=
MakeSubstitution
(
std
::
make_shared
<
QctToP
>
(),
"Qct_to_P"
,
Q
);
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部