Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a87face
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a87face
编写于
8月 05, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 05, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3720 New optimization pass to remove redundant Min/Max ops
Merge pull request !3720 from thlinh/dev_July29_remove_redundant_minmax
上级
25da0c3f
eae5f282
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
86 addition
and
4 deletion
+86
-4
mindspore/ccsrc/frontend/optimizer/irpass.cc
mindspore/ccsrc/frontend/optimizer/irpass.cc
+2
-2
mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc
.../ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc
+80
-0
mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h
...e/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h
+2
-0
mindspore/core/ir/pattern_matcher.h
mindspore/core/ir/pattern_matcher.h
+2
-2
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass.cc
浏览文件 @
0a87face
...
@@ -168,8 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
...
@@ -168,8 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{
prim
::
kPrimSparseTensorGetIndices
,
prim
::
kPrimSparseTensorGetValues
,
prim
::
kPrimSparseTensorGetDenseShape
});
{
prim
::
kPrimSparseTensorGetIndices
,
prim
::
kPrimSparseTensorGetValues
,
prim
::
kPrimSparseTensorGetDenseShape
});
// Value_Based Eliminate
// Value_Based Eliminate
value_based_eliminate_
=
value_based_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
ValueBasedEliminate
>
(),
"value_based_eliminate"
,
MakeSubstitution
(
std
::
make_shared
<
ValueBasedEliminate
>
(),
"value_based_eliminate"
,
{
prim
::
kPrimSelect
});
{
prim
::
kPrimSelect
,
prim
::
kPrimMinimum
,
prim
::
kPrimMaximum
});
}
}
ResolveIRPassLib
::
ResolveIRPassLib
()
{
ResolveIRPassLib
::
ResolveIRPassLib
()
{
...
...
mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc
浏览文件 @
0a87face
...
@@ -19,6 +19,9 @@
...
@@ -19,6 +19,9 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
namespace
irpass
{
namespace
irpass
{
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0)
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0)
bool
IsCNodePositive
(
const
AnfNodePtr
&
node
)
{
bool
IsCNodePositive
(
const
AnfNodePtr
&
node
)
{
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimReduceSum
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimSqueeze
))
{
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimReduceSum
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimSqueeze
))
{
return
IsCNodePositive
(
node
->
cast
<
CNodePtr
>
()
->
input
(
1
));
return
IsCNodePositive
(
node
->
cast
<
CNodePtr
>
()
->
input
(
1
));
...
@@ -29,17 +32,94 @@ bool IsCNodePositive(const AnfNodePtr &node) {
...
@@ -29,17 +32,94 @@ bool IsCNodePositive(const AnfNodePtr &node) {
return
false
;
return
false
;
}
}
// check if a value is bigger than UPPER_FLT_LIMIT
bool
IsNodeScalarMaxFLT
(
const
AnfNodePtr
&
node
)
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
false
;
}
auto
value
=
value_node
->
value
();
if
(
value
==
nullptr
)
{
return
false
;
}
auto
scalar
=
value
->
cast
<
ScalarPtr
>
();
if
(
scalar
!=
nullptr
)
{
if
(
scalar
->
isa
<
FloatImm
>
())
{
return
GetValue
<
float
>
(
scalar
)
>
UPPER_FLT_LIMIT
;
}
}
// Check for Tensor [] or Tensor [1]
auto
tensor_ptr
=
value
->
cast
<
tensor
::
TensorPtr
>
();
if
(
tensor_ptr
==
nullptr
)
{
return
false
;
}
if
(
tensor_ptr
->
DataSize
()
>
1
)
{
return
false
;
}
TypeId
tensor_type
=
tensor_ptr
->
Dtype
()
->
type_id
();
if
((
tensor_type
==
TypeId
::
kNumberTypeFloat32
)
||
(
tensor_type
==
TypeId
::
kNumberTypeFloat
))
{
float
*
data
=
reinterpret_cast
<
float
*>
(
tensor_ptr
->
data_c
());
return
data
[
0
]
>
UPPER_FLT_LIMIT
;
}
return
false
;
}
// check if a value is smaller than LOWER_FLT_LIMIT
bool
IsNodeScalarMinFLT
(
const
AnfNodePtr
&
node
)
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
if
(
value_node
==
nullptr
)
{
return
false
;
}
auto
value
=
value_node
->
value
();
if
(
value
==
nullptr
)
{
return
false
;
}
auto
scalar
=
value
->
cast
<
ScalarPtr
>
();
if
(
scalar
!=
nullptr
)
{
if
(
scalar
->
isa
<
FloatImm
>
())
{
return
GetValue
<
float
>
(
scalar
)
<
LOWER_FLT_LIMIT
;
}
}
// Check for Tensor [] or Tensor [1]
auto
tensor_ptr
=
value
->
cast
<
tensor
::
TensorPtr
>
();
if
(
tensor_ptr
==
nullptr
)
{
return
false
;
}
if
(
tensor_ptr
->
DataSize
()
>
1
)
{
return
false
;
}
TypeId
tensor_type
=
tensor_ptr
->
Dtype
()
->
type_id
();
if
((
tensor_type
==
TypeId
::
kNumberTypeFloat32
)
||
(
tensor_type
==
TypeId
::
kNumberTypeFloat
))
{
float
*
data
=
reinterpret_cast
<
float
*>
(
tensor_ptr
->
data_c
());
return
data
[
0
]
<
LOWER_FLT_LIMIT
;
}
return
false
;
}
AnfNodePtr
ValueBasedEliminate
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
ValueBasedEliminate
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
PatternNode
x
,
y
,
z
;
PatternNode
x
,
y
,
z
;
PConstant
zero_
(
node
,
false
,
0
);
PConstant
zero_
(
node
,
false
,
0
);
PConstant
zero_scalar_
(
node
,
false
,
0
,
true
);
PConstant
zero_scalar_
(
node
,
false
,
0
,
true
);
// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimSelect
,
PPrimitive
(
prim
::
kPrimGreater
,
x
,
zero_
),
y
,
z
),
y
,
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimSelect
,
PPrimitive
(
prim
::
kPrimGreater
,
x
,
zero_
),
y
,
z
),
y
,
IsCNodePositive
(
x
.
GetNode
(
node
)));
IsCNodePositive
(
x
.
GetNode
(
node
)));
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimSelect
,
PPrimitive
(
prim
::
kPrimGreater
,
x
,
zero_scalar_
),
y
,
z
),
y
,
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimSelect
,
PPrimitive
(
prim
::
kPrimGreater
,
x
,
zero_scalar_
),
y
,
z
),
y
,
IsCNodePositive
(
x
.
GetNode
(
node
)));
IsCNodePositive
(
x
.
GetNode
(
node
)));
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimMaximum
,
x
,
y
),
x
,
IsNodeScalarMinFLT
(
y
.
GetNode
(
node
)));
MATCH_REPLACE_IF
(
node
,
PPrimitive
(
prim
::
kPrimMinimum
,
x
,
y
),
x
,
IsNodeScalarMaxFLT
(
y
.
GetNode
(
node
)));
return
nullptr
;
return
nullptr
;
}
}
...
...
mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h
浏览文件 @
0a87face
...
@@ -32,6 +32,8 @@ namespace opt {
...
@@ -32,6 +32,8 @@ namespace opt {
namespace
irpass
{
namespace
irpass
{
// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
// {prim::kPrimMaximum, X, Y} -> X when Y is smaller than LOWER_FLT_LIMIT
// {prim::kPrimMinimum, X, Y} -> X when Y is greater than UPPER_FLT_LIMIT
class
ValueBasedEliminate
:
public
OptimizerCaller
{
class
ValueBasedEliminate
:
public
OptimizerCaller
{
public:
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
...
...
mindspore/core/ir/pattern_matcher.h
浏览文件 @
0a87face
...
@@ -487,7 +487,7 @@ class PConstant : public PBase<PConstant<T> > {
...
@@ -487,7 +487,7 @@ class PConstant : public PBase<PConstant<T> > {
TypeId
tensor_type
=
tensor_ptr
->
Dtype
()
->
type_id
();
TypeId
tensor_type
=
tensor_ptr
->
Dtype
()
->
type_id
();
if
((
tensor_type
==
TypeId
::
kNumberTypeFloat32
)
||
(
tensor_type
==
TypeId
::
kNumberTypeFloat
))
{
if
((
tensor_type
==
TypeId
::
kNumberTypeFloat32
)
||
(
tensor_type
==
TypeId
::
kNumberTypeFloat
))
{
float
*
data2
=
reinterpret_cast
<
float
*>
(
tensor_ptr
->
data_c
());
float
*
data2
=
reinterpret_cast
<
float
*>
(
tensor_ptr
->
data_c
());
auto
threshold
=
FLT_
EPSILON
*
FLT_EPSILO
N
;
auto
threshold
=
FLT_
MI
N
;
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
if
(
fabs
(
data2
[
i
]
-
check_value_
)
>
threshold
)
{
if
(
fabs
(
data2
[
i
]
-
check_value_
)
>
threshold
)
{
return
false
;
return
false
;
...
@@ -496,7 +496,7 @@ class PConstant : public PBase<PConstant<T> > {
...
@@ -496,7 +496,7 @@ class PConstant : public PBase<PConstant<T> > {
return
true
;
return
true
;
}
else
if
(
tensor_type
==
TypeId
::
kNumberTypeFloat64
)
{
}
else
if
(
tensor_type
==
TypeId
::
kNumberTypeFloat64
)
{
double
*
data2
=
reinterpret_cast
<
double
*>
(
tensor_ptr
->
data_c
());
double
*
data2
=
reinterpret_cast
<
double
*>
(
tensor_ptr
->
data_c
());
auto
threshold
=
DBL_
EPSILON
*
DBL_EPSILO
N
;
auto
threshold
=
DBL_
MI
N
;
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
for
(
int
i
=
0
;
i
<
tensor_ptr
->
DataSize
();
i
++
)
{
if
(
fabs
(
data2
[
i
]
-
check_value_
)
>
threshold
)
{
if
(
fabs
(
data2
[
i
]
-
check_value_
)
>
threshold
)
{
return
false
;
return
false
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录