Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7f759c2a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7f759c2a
编写于
7月 06, 2020
作者:
G
Giancarlo Colmenares
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Updated arithmetic simplify to use Pattern Matcher
上级
faa1084b
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
467 addition
and
755 deletion
+467
-755
mindspore/ccsrc/ir/pattern_matcher.h
mindspore/ccsrc/ir/pattern_matcher.h
+412
-18
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc
+50
-546
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+4
-188
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
+0
-2
tests/ut/cpp/optimizer/opt_test.cc
tests/ut/cpp/optimizer/opt_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/ir/pattern_matcher.h
浏览文件 @
7f759c2a
此差异已折叠。
点击以展开。
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc
浏览文件 @
7f759c2a
...
...
@@ -14,542 +14,67 @@
* limitations under the License.
*/
#include <algorithm>
#include <memory>
#include <vector>
#include <functional>
#include "optimizer/irpass/arithmetic_simplify.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
AnfNodePtr
MultiplyByZeroOrOne
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimScalarMul
)(
node
);
if
(
is_zero_
)
{
return
NewValueNode
(
zero_
);
}
if
(
is_one_
)
{
return
x_
;
}
return
nullptr
;
}
void
MultiplyByZeroOrOne
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
is_one_
||
node
->
isa
<
CNode
>
())
{
x_
=
node
;
return
;
}
AnfVisitor
::
Visit
(
node
);
if
(
!
is_one_
)
{
x_
=
node
;
}
}
void
MultiplyByZeroOrOne
::
Visit
(
const
ValueNodePtr
&
vnode
)
{
auto
value
=
vnode
->
value
();
if
(
*
value
==
*
zero_
)
{
is_zero_
=
true
;
}
else
if
(
*
value
==
*
one_
)
{
is_one_
=
true
;
}
}
void
MultiplyByZeroOrOne
::
Reset
()
{
x_
=
nullptr
;
is_one_
=
false
;
is_zero_
=
false
;
}
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
bool
CheckTensorConstant
::
IsTensorConstant
(
const
ValuePtr
&
value
)
{
if
(
!
value
->
isa
<
tensor
::
Tensor
>
())
{
return
false
;
}
auto
tensor_ptr
=
dyn_cast
<
tensor
::
Tensor
>
(
value
);
TypeId
tensor_type
=
tensor_ptr
->
Dtype
()
->
type_id
();
if
((
tensor_type
==
TypeId
::
kNumberTypeFloat32
)
||
(
tensor_type
==
TypeId
::
kNumberTypeFloat
))
{
float
*
data2
=
reinterpret_cast
<
float
*>
(
tensor_ptr
->
data_c
());
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
if
(
fabs
(
data2
[
i
]
-
check_value_
)
>
FLT_EPSILON
)
{
return
false
;
}
}
return
true
;
}
else
if
(
tensor_type
==
TypeId
::
kNumberTypeFloat64
)
{
double
*
data2
=
reinterpret_cast
<
double
*>
(
tensor_ptr
->
data_c
());
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
if
(
fabs
(
data2
[
i
]
-
check_value_
)
>
DBL_EPSILON
)
{
return
false
;
}
}
return
true
;
}
else
if
((
tensor_type
==
TypeId
::
kNumberTypeInt32
)
||
(
tensor_type
==
TypeId
::
kNumberTypeInt
))
{
int
*
data2
=
reinterpret_cast
<
int
*>
(
tensor_ptr
->
data_c
());
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
if
(
data2
[
i
]
!=
check_value_
)
{
return
false
;
}
}
return
true
;
}
// input Data Types is not supported
return
false
;
}
bool
CheckTensorConstant
::
IsTensorScalarConstant
(
const
ValuePtr
&
value
)
{
if
(
!
value
->
isa
<
tensor
::
Tensor
>
())
{
return
false
;
}
auto
tensor_ptr
=
dyn_cast
<
tensor
::
Tensor
>
(
value
);
if
((
tensor_ptr
->
DataSize
()
>
1
)
||
(
tensor_ptr
->
DataDim
()
>
0
))
{
return
false
;
}
return
IsTensorConstant
(
value
);
}
void
*
TensorMultiplyBase
::
GetPointerToTensorData
(
const
AnfNodePtr
&
node
,
bool
writable
)
{
if
(
!
node
->
isa
<
ValueNode
>
())
{
return
nullptr
;
}
auto
value
=
node
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
!
value
->
isa
<
tensor
::
Tensor
>
())
{
return
nullptr
;
}
tensor
::
TensorPtr
tensor_ptr
=
dyn_cast
<
tensor
::
Tensor
>
(
value
);
return
tensor_ptr
->
data_c
();
}
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr
TensorMultiplyBase
::
NewTensorFilledWithData
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
x
)
{
if
((
node
->
abstract
()
==
nullptr
)
||
!
node
->
abstract
()
->
isa
<
abstract
::
AbstractTensor
>
())
{
return
nullptr
;
}
auto
tensor_abstract
=
node
->
abstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
TypePtr
tensor_type_ptr
=
tensor_abstract
->
element
()
->
BuildType
();
std
::
vector
<
int
>
tensor_shape
=
tensor_abstract
->
shape
()
->
shape
();
auto
new_tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
tensor_type_ptr
->
type_id
(),
tensor_shape
);
size_t
mem_size
=
GetTypeByte
(
tensor_type_ptr
)
*
IntToSize
(
new_tensor_ptr
->
ElementsNum
());
char
*
data
=
reinterpret_cast
<
char
*>
(
new_tensor_ptr
->
data_c
());
if
(
x
==
nullptr
)
{
std
::
memset
(
data
,
0
,
mem_size
);
auto
new_vnode
=
NewValueNode
(
new_tensor_ptr
);
new_vnode
->
set_abstract
(
new_tensor_ptr
->
ToAbstract
());
return
new_vnode
;
}
// x is not nullptr
if
(
x
->
isa
<
CNode
>
())
{
if
((
x
->
abstract
()
==
nullptr
)
||
!
x
->
abstract
()
->
isa
<
abstract
::
AbstractTensor
>
())
{
return
nullptr
;
}
auto
x_abstract
=
x
->
abstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
std
::
vector
<
int
>
x_shape
=
x_abstract
->
shape
()
->
shape
();
if
(
x_shape
!=
tensor_shape
)
{
return
nullptr
;
}
return
x
;
}
if
(
!
x
->
isa
<
ValueNode
>
())
{
return
nullptr
;
}
auto
x_value
=
x
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
!
x_value
->
isa
<
tensor
::
Tensor
>
())
{
return
nullptr
;
}
auto
x_tensor_ptr
=
dyn_cast
<
tensor
::
Tensor
>
(
x_value
);
if
((
x_tensor_ptr
->
DataSize
()
>
1
)
&&
(
x_tensor_ptr
->
DataSize
()
!=
new_tensor_ptr
->
DataSize
()))
{
return
nullptr
;
}
char
*
source_data
=
reinterpret_cast
<
char
*>
(
GetPointerToTensorData
(
x
));
if
(
x_tensor_ptr
->
DataSize
()
==
1
)
{
for
(
int
i
=
0
;
i
<
new_tensor_ptr
->
ElementsNum
();
i
++
)
{
memcpy
(
data
+
i
*
GetTypeByte
(
tensor_type_ptr
),
source_data
,
GetTypeByte
(
tensor_type_ptr
));
}
}
else
{
memcpy
(
data
,
source_data
,
mem_size
);
}
auto
new_vnode
=
NewValueNode
(
new_tensor_ptr
);
new_vnode
->
set_abstract
(
new_tensor_ptr
->
ToAbstract
());
return
new_vnode
;
}
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
AnfNodePtr
TensorMultiplyByZero
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
if
(
is_zero_
)
{
if
(
x_
->
func_graph
()
!=
node
->
func_graph
())
{
return
nullptr
;
}
return
NewTensorFilledWithData
(
node
);
}
return
nullptr
;
}
void
TensorMultiplyByZero
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
is_zero_
)
{
x_
=
node
;
return
;
}
if
(
IsParam
(
node
))
{
x_
=
node
;
return
;
}
if
(
IsCNode
(
node
))
{
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
IsPrimitive
(
cnode
->
input
(
0
),
prim
::
kPrimZerosLike
))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
return
;
}
auto
value
=
node
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
}
void
TensorMultiplyByZero
::
Visit
(
const
ValueNodePtr
&
vnode
)
{
auto
value
=
vnode
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
x_
=
vnode
;
}
void
TensorMultiplyByZero
::
Reset
()
{
x_
=
nullptr
;
is_zero_
=
false
;
}
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
AnfNodePtr
TensorMultiplyByOne
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
if
(
is_one_
)
{
return
NewTensorFilledWithData
(
node
,
x_
);
}
return
nullptr
;
}
void
TensorMultiplyByOne
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
is_one_
)
{
x_
=
node
;
return
;
}
if
(
IsParam
(
node
)
||
IsCNode
(
node
))
{
x_
=
node
;
return
;
}
auto
value
=
node
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
CheckTensorConstant
(
1
).
IsTensorConstant
(
value
))
{
is_one_
=
true
;
return
;
}
x_
=
node
;
}
void
TensorMultiplyByOne
::
Visit
(
const
ValueNodePtr
&
vnode
)
{
auto
value
=
vnode
->
value
();
if
(
CheckTensorConstant
(
1
).
IsTensorConstant
(
value
))
{
is_one_
=
true
;
return
;
}
x_
=
vnode
;
}
void
TensorMultiplyByOne
::
Reset
()
{
x_
=
nullptr
;
is_one_
=
false
;
}
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
AnfNodePtr
AddByZero
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimScalarAdd
)(
node
);
if
(
is_zero_
)
{
return
x_
;
}
return
nullptr
;
}
void
AddByZero
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
node
->
isa
<
ValueNode
>
()
&&
((
*
GetValueNode
(
node
)
==
*
zero_
)
||
CheckTensorConstant
(
0
).
IsTensorScalarConstant
(
GetValueNode
(
node
))))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
}
void
AddByZero
::
Reset
()
{
x_
=
nullptr
;
is_zero_
=
false
;
}
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
AnfNodePtr
TensorAddByZero
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
AnfVisitor
::
Match
(
prim
::
kPrimTensorAdd
)(
node
);
AnfNodePtr
ArithmeticSimplify
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
PatternNode
x
,
y
,
z
,
xs
;
PConstant
one_
(
node
,
false
,
1
);
PConstant
one_scalar_
(
node
,
false
,
1
,
true
);
PConstant
zero_
(
node
,
false
,
0
);
PConstant
zero_scalar_
(
node
,
false
,
0
,
true
);
PConstant
const_
(
node
);
PConstant
const_2
(
node
);
PConstant
any_const
(
node
);
MATCH_REPLACE
(
node
,
x
+
zero_
,
x
);
// Add by zero
MATCH_REPLACE
(
node
,
x
+
zero_scalar_
,
x
);
// Add by zero
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimScalarAdd
,
zero_scalar_
,
x
),
x
);
// Scalar Add by zero
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimScalarAdd
,
x
,
zero_scalar_
),
x
);
// Scalar Add by zero
MATCH_REPLACE_IF
(
node
,
x
*
one_
,
any_const
.
WithValueOf
(
x
),
x
.
CheckFunc
(
IsVNode
,
node
));
// Multiply by one
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimScalarMul
,
one_scalar_
,
x
),
x
);
// Scalar Mul by one
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimScalarMul
,
x
,
one_scalar_
),
x
);
// Scalar Mul by one
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimScalarMul
,
zero_scalar_
,
x
),
zero_
.
NewValue
());
// Scalar Mul by zero
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimScalarMul
,
x
,
zero_scalar_
),
zero_
.
NewValue
());
// Scalar Mul by zero
// Prim Eliminate (identity)
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimIdentity
,
x
),
x
);
// ConstantDuplicateMul
auto
const_dup_lambda
=
[
&
node
,
&
x
,
&
const_
,
&
const_2
]()
->
AnfNodePtr
{
auto
new_mul_tensor
=
const_
.
MulByPatternConst
(
const_2
,
x
.
GetNode
(
node
));
auto
mul_node
=
node
->
cast
<
CNodePtr
>
()
->
inputs
()[
0
];
if
(
new_mul_tensor
==
nullptr
)
{
auto
ttmul
=
NewCNode
({
mul_node
,
const_
.
GetNode
(
node
),
const_2
.
GetNode
(
node
)},
node
->
func_graph
());
return
NewCNode
({
mul_node
,
x
.
GetNode
(
node
),
ttmul
},
node
->
func_graph
());
}
return
NewCNode
({
mul_node
,
x
.
GetNode
(
node
),
new_mul_tensor
},
node
->
func_graph
());
};
MATCH_REPLACE_LAMBDA
(
node
,
const_
*
(
const_2
*
x
),
const_dup_lambda
);
if
(
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
// OptUpdateZeroTensor
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimMomentum
,
PPrimitive
(
prim
::
kPrimZerosLike
,
x
),
y
,
z
,
xs
),
PPrimitive
(
prim
::
kPrimMakeTuple
,
z
,
y
));
// PowerOneEliminate
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimPow
,
x
,
one_scalar_
),
x
);
if
(
is_zero_
)
{
return
x_
;
}
return
nullptr
;
}
void
TensorAddByZero
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
node
->
isa
<
ValueNode
>
()
&&
CheckTensorConstant
(
0
).
IsTensorScalarConstant
(
GetValueNode
(
node
)))
{
is_zero_
=
true
;
return
;
}
x_
=
node
;
}
void
TensorAddByZero
::
Visit
(
const
ValueNodePtr
&
vnode
)
{
auto
value
=
vnode
->
value
();
if
(
CheckTensorConstant
(
0
).
IsTensorConstant
(
value
))
{
is_zero_
=
true
;
return
;
}
}
void
TensorAddByZero
::
Reset
()
{
x_
=
nullptr
;
is_zero_
=
false
;
}
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
AnfNodePtr
OptUpdateZeroTensor
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimMomentum
)
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
// {PrimMomentum, {...}, Y, Z, Xs}
auto
&
inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
inputs
.
size
()
<
4
||
!
IsPrimitiveCNode
(
inputs
[
1
],
prim
::
kPrimZerosLike
))
{
return
nullptr
;
}
auto
y
=
inputs
[
2
];
auto
z
=
inputs
[
3
];
// {kPrimZerosLike, X}
if
(
inputs
[
1
]
->
cast
<
CNodePtr
>
()
->
size
()
!=
2
)
{
return
nullptr
;
}
// {prim::kPrimMakeTuple, Z, Y}
return
node
->
func_graph
()
->
NewCNode
({
NewValueNode
(
prim
::
kPrimMakeTuple
),
z
,
y
});
}
// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
// Support function to multiply two constant tensors: partially support broadcasting shapes
template
<
typename
T
>
void
ConstantDuplicateMul
::
Multiply
(
void
*
in_data_1
,
int
in_data_1_size
,
void
*
in_data_2
,
int
in_data_2_size
,
void
**
out_data
,
int
out_data_size
)
{
T
*
data_1
=
reinterpret_cast
<
T
*>
(
in_data_1
);
T
*
data_2
=
reinterpret_cast
<
T
*>
(
in_data_2
);
T
*
data_out
=
new
T
[
out_data_size
];
if
(
in_data_1_size
==
1
)
{
for
(
int
i
=
0
;
i
<
out_data_size
;
i
++
)
{
data_out
[
i
]
=
data_1
[
0
];
}
}
else
{
for
(
int
i
=
0
;
i
<
out_data_size
;
i
++
)
{
data_out
[
i
]
=
data_1
[
i
];
}
}
if
(
in_data_2_size
==
1
)
{
for
(
int
i
=
0
;
i
<
out_data_size
;
i
++
)
{
data_out
[
i
]
*=
data_2
[
0
];
}
}
else
{
for
(
int
i
=
0
;
i
<
out_data_size
;
i
++
)
{
data_out
[
i
]
*=
data_2
[
i
];
}
}
*
out_data
=
reinterpret_cast
<
void
*>
(
data_out
);
return
;
}
AnfNodePtr
ConstantDuplicateMul
::
MulConstantTensors
(
const
AnfNodePtr
&
vnode_1
,
const
AnfNodePtr
&
vnode_2
,
const
AnfNodePtr
&
node_3
)
{
if
(
!
vnode_1
->
isa
<
ValueNode
>
()
||
!
vnode_2
->
isa
<
ValueNode
>
()
||
(
vnode_1
->
abstract
()
==
nullptr
)
||
(
vnode_2
->
abstract
()
==
nullptr
)
||
(
node_3
->
abstract
()
==
nullptr
))
{
return
nullptr
;
}
auto
value_1
=
GetValueNode
(
vnode_1
);
auto
value_2
=
GetValueNode
(
vnode_2
);
if
(
!
value_1
->
isa
<
tensor
::
Tensor
>
()
||
!
value_2
->
isa
<
tensor
::
Tensor
>
())
{
return
nullptr
;
}
auto
tensor_ptr_1
=
dyn_cast
<
tensor
::
Tensor
>
(
value_1
);
auto
tensor_ptr_2
=
dyn_cast
<
tensor
::
Tensor
>
(
value_2
);
auto
tensor_1_abstract
=
vnode_1
->
abstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_2_abstract
=
vnode_1
->
abstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
auto
tensor_3_abstract
=
node_3
->
abstract
()
->
cast
<
abstract
::
AbstractTensorPtr
>
();
TypePtr
tensor_1_type_ptr
=
tensor_1_abstract
->
element
()
->
BuildType
();
TypePtr
tensor_2_type_ptr
=
tensor_2_abstract
->
element
()
->
BuildType
();
TypePtr
tensor_3_type_ptr
=
tensor_3_abstract
->
element
()
->
BuildType
();
if
((
tensor_1_type_ptr
->
type_id
()
!=
tensor_3_type_ptr
->
type_id
())
||
(
tensor_2_type_ptr
->
type_id
()
!=
tensor_3_type_ptr
->
type_id
()))
{
return
nullptr
;
}
AnfNodePtr
ArithmeticSimplify2
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
PatternNode
x
,
y
;
PConstant
zero_
(
node
,
false
,
0
);
std
::
vector
<
int
>
tensor_out_shape
=
tensor_3_abstract
->
shape
()
->
shape
();
MATCH_REPLACE
(
node
,
x
*
zero_
,
zero_
);
// Multiply by zero
MATCH_REPLACE
(
node
,
x
*
PPrimitive
(
prim
::
kPrimZerosLike
,
y
),
zero_
);
// Multiply by zero
int
data_out_size
=
std
::
accumulate
(
tensor_out_shape
.
begin
(),
tensor_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
((
tensor_ptr_1
->
DataSize
()
>
1
)
&&
(
tensor_ptr_1
->
DataSize
()
!=
data_out_size
))
{
return
nullptr
;
}
if
((
tensor_ptr_2
->
DataSize
()
>
1
)
&&
(
tensor_ptr_2
->
DataSize
()
!=
data_out_size
))
{
return
nullptr
;
}
void
*
data_out
;
if
((
tensor_3_type_ptr
->
type_id
()
==
TypeId
::
kNumberTypeFloat32
)
||
(
tensor_3_type_ptr
->
type_id
()
==
TypeId
::
kNumberTypeFloat
))
{
Multiply
<
float
>
(
tensor_ptr_1
->
data_c
(),
tensor_ptr_1
->
DataSize
(),
tensor_ptr_2
->
data_c
(),
tensor_ptr_2
->
DataSize
(),
&
data_out
,
data_out_size
);
}
else
{
if
(
tensor_3_type_ptr
->
type_id
()
==
TypeId
::
kNumberTypeFloat64
)
{
Multiply
<
double
>
(
tensor_ptr_1
->
data_c
(),
tensor_ptr_1
->
DataSize
(),
tensor_ptr_2
->
data_c
(),
tensor_ptr_2
->
DataSize
(),
&
data_out
,
data_out_size
);
}
else
{
if
((
tensor_3_type_ptr
->
type_id
()
==
TypeId
::
kNumberTypeInt32
)
||
(
tensor_3_type_ptr
->
type_id
()
==
TypeId
::
kNumberTypeInt
))
{
Multiply
<
int
>
(
tensor_ptr_1
->
data_c
(),
tensor_ptr_1
->
DataSize
(),
tensor_ptr_2
->
data_c
(),
tensor_ptr_2
->
DataSize
(),
&
data_out
,
data_out_size
);
}
else
{
// Un-support data types
return
nullptr
;
}
}
}
auto
new_tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
tensor_3_type_ptr
->
type_id
(),
tensor_out_shape
);
size_t
mem_size
=
GetTypeByte
(
tensor_3_type_ptr
)
*
IntToSize
(
new_tensor_ptr
->
ElementsNum
());
char
*
data
=
reinterpret_cast
<
char
*>
(
new_tensor_ptr
->
data_c
());
memcpy
(
data
,
data_out
,
mem_size
);
auto
new_vnode
=
NewValueNode
(
new_tensor_ptr
);
new_vnode
->
set_abstract
(
new_tensor_ptr
->
ToAbstract
());
return
new_vnode
;
}
AnfNodePtr
ConstantDuplicateMul
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
// {prim::kPrimMul, Tensor1, {...}}
AnfVisitor
::
Match
(
prim
::
kPrimMul
,
{
IsNode
,
IsNode
})(
node
);
if
(
vnode_
==
nullptr
||
c_p_node_
==
nullptr
)
{
return
nullptr
;
}
if
(
!
IsCNode
(
c_p_node_
))
{
return
nullptr
;
}
auto
tensor1
=
vnode_
;
auto
mul
=
c_p_node_
->
cast
<
CNodePtr
>
();
Reset
();
// {prim::kPrimMul, Tensor2, {...}}
AnfVisitor
::
Match
(
prim
::
kPrimMul
,
{
IsNode
,
IsNode
})(
mul
);
if
(
vnode_
==
nullptr
||
c_p_node_
==
nullptr
)
{
return
nullptr
;
}
auto
tensor2
=
vnode_
;
auto
c_p_node
=
c_p_node_
;
auto
PrimMul
=
GetValueNode
<
PrimitivePtr
>
(
mul
->
input
(
0
));
auto
fg
=
node
->
func_graph
();
auto
new_mul_tensor
=
MulConstantTensors
(
tensor1
,
tensor2
,
c_p_node
);
if
(
new_mul_tensor
==
nullptr
)
{
auto
ttmul
=
NewCNode
({
NewValueNode
(
PrimMul
),
tensor1
,
tensor2
},
fg
);
return
NewCNode
({
NewValueNode
(
PrimMul
),
c_p_node
,
ttmul
},
fg
);
}
return
NewCNode
({
NewValueNode
(
PrimMul
),
c_p_node
,
new_mul_tensor
},
fg
);
}
void
ConstantDuplicateMul
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
IsValueNode
<
tensor
::
Tensor
>
(
node
))
{
vnode_
=
node
;
}
if
(
IsCNode
(
node
)
||
IsParam
(
node
))
{
c_p_node_
=
node
;
}
}
void
ConstantDuplicateMul
::
Reset
()
{
vnode_
=
nullptr
;
c_p_node_
=
nullptr
;
}
AnfNodePtr
PowerOneEliminate
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimPow
)
||
node
->
func_graph
()
==
nullptr
)
{
return
nullptr
;
}
auto
&
inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
!
IsValueNode
<
Scalar
>
(
inputs
[
2
]))
{
return
nullptr
;
}
auto
scalar
=
GetValueNode
<
ScalarPtr
>
(
inputs
[
2
]);
if
(
scalar
->
isa
<
FloatImm
>
()
&&
GetValue
<
float
>
(
scalar
)
==
1.0
)
{
return
inputs
[
1
];
}
else
if
(
scalar
->
isa
<
IntergerImm
>
()
&&
GetValue
<
int
>
(
scalar
)
==
1
)
{
return
inputs
[
1
];
}
return
nullptr
;
}
...
...
@@ -654,27 +179,6 @@ void AdjustAllReduceMulAdd::Reset() {
all_reduce_fg_
=
nullptr
;
}
AnfNodePtr
ArithmeticSimplify
::
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
(
*
eliminater
)(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
AnfNodePtr
ArithmeticSimplify2
::
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
(
*
eliminater
)(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
7f759c2a
...
...
@@ -22,158 +22,14 @@
#include <vector>
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
class
MultiplyByZeroOrOne
:
public
AnfVisitor
{
public:
MultiplyByZeroOrOne
()
:
zero_
(
MakeValue
(
0
)),
one_
(
MakeValue
(
1
))
{}
~
MultiplyByZeroOrOne
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
},
is_one_
{
false
};
ValuePtr
zero_
,
one_
;
AnfNodePtr
x_
{
nullptr
};
};
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class
CheckTensorConstant
{
public:
explicit
CheckTensorConstant
(
int
_check_value
=
0
)
:
check_value_
(
_check_value
)
{}
~
CheckTensorConstant
()
=
default
;
bool
IsTensorConstant
(
const
ValuePtr
&
value
);
bool
IsTensorScalarConstant
(
const
ValuePtr
&
value
);
private:
int
check_value_
;
};
class
TensorMultiplyBase
:
public
AnfVisitor
{
protected:
void
*
GetPointerToTensorData
(
const
AnfNodePtr
&
node
,
bool
writable
=
false
);
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr
NewTensorFilledWithData
(
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
x
=
nullptr
);
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class
TensorMultiplyByZero
:
public
TensorMultiplyBase
{
public:
TensorMultiplyByZero
()
:
zero_
(
MakeValue
(
0
))
{}
~
TensorMultiplyByZero
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
};
ValuePtr
zero_
;
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class
TensorMultiplyByOne
:
public
TensorMultiplyBase
{
public:
TensorMultiplyByOne
()
{}
~
TensorMultiplyByOne
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_one_
{
false
};
};
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
class
AddByZero
:
public
AnfVisitor
{
public:
AddByZero
()
:
zero_
(
MakeValue
(
0
))
{}
~
AddByZero
()
override
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
};
ValuePtr
zero_
;
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
class
TensorAddByZero
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
;
void
Reset
();
private:
bool
is_zero_
{
false
};
AnfNodePtr
x_
{
nullptr
};
};
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
class
OptUpdateZeroTensor
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class
ConstantDuplicateMul
:
public
AnfVisitor
{
public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template
<
typename
T
>
void
Multiply
(
void
*
in_data_1
,
int
in_data_1_size
,
void
*
in_data_2
,
int
in_data_2_size
,
void
**
out_data
,
int
out_data_size
);
AnfNodePtr
MulConstantTensors
(
const
AnfNodePtr
&
vnode_1
,
const
AnfNodePtr
&
vnode_2
,
const
AnfNodePtr
&
node_3
);
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Reset
();
private:
AnfNodePtr
vnode_
;
AnfNodePtr
c_p_node_
;
};
class
PowerOneEliminate
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
...
...
@@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
class
ArithmeticSimplify
:
public
OptimizerCaller
{
public:
ArithmeticSimplify
()
:
multiply_by_zero_or_one_
(
std
::
make_shared
<
MultiplyByZeroOrOne
>
()),
tensor_multiply_by_one_
(
std
::
make_shared
<
TensorMultiplyByOne
>
()),
add_by_zero_
(
std
::
make_shared
<
AddByZero
>
()),
tensor_add_by_zero_
(
std
::
make_shared
<
TensorAddByZero
>
()),
identity_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimIdentity
)),
opt_update_zero_tensor_
(
std
::
make_shared
<
OptUpdateZeroTensor
>
()),
constant_duplicate_mul_
(
std
::
make_shared
<
ConstantDuplicateMul
>
()),
power_one_
(
std
::
make_shared
<
PowerOneEliminate
>
())
{
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_one_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
eliminaters_
.
emplace_back
(
tensor_add_by_zero_
);
eliminaters_
.
emplace_back
(
identity_
);
eliminaters_
.
emplace_back
(
opt_update_zero_tensor_
);
eliminaters_
.
emplace_back
(
constant_duplicate_mul_
);
eliminaters_
.
emplace_back
(
power_one_
);
}
~
ArithmeticSimplify
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
;
private:
OptimizerCallerPtr
multiply_by_zero_or_one_
;
OptimizerCallerPtr
tensor_multiply_by_one_
;
OptimizerCallerPtr
add_by_zero_
;
OptimizerCallerPtr
tensor_add_by_zero_
;
OptimizerCallerPtr
identity_
;
OptimizerCallerPtr
opt_update_zero_tensor_
;
OptimizerCallerPtr
constant_duplicate_mul_
;
OptimizerCallerPtr
power_one_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
// Arithmetic Simplifications should be done after step_parallel.
...
...
@@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller {
// ArithmeticSimplify and deferred until step_parallel.
class
ArithmeticSimplify2
:
public
OptimizerCaller
{
public:
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
(
std
::
make_shared
<
TensorMultiplyByZero
>
())
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
~
ArithmeticSimplify2
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
;
private:
OptimizerCallerPtr
tensor_multiply_by_zero_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
浏览文件 @
7f759c2a
...
...
@@ -25,10 +25,8 @@
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
...
...
tests/ut/cpp/optimizer/opt_test.cc
浏览文件 @
7f759c2a
...
...
@@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common {
};
void
SetUp
()
{
elim_Z
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
A
ddByZero
>
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_Z
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
A
rithmeticSimplify
>
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_R
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
PrimEliminater
>
(
R
),
"elim_R"
,
R
);
idempotent_P
=
MakeSubstitution
(
std
::
make_shared
<
IdempotentEliminater
>
(),
"idempotent_P"
,
P
);
Qct_to_P
=
MakeSubstitution
(
std
::
make_shared
<
QctToP
>
(),
"Qct_to_P"
,
Q
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录