Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8463731b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8463731b
编写于
6月 18, 2020
作者:
H
huanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make those AdamXX and LambXX fusion pass not work for unexpect data type
上级
ef698a93
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
41 addition
and
1 deletion
+41
-1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
...rc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
...tivate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
+3
-1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
.../ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
...activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc
...ivate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
...src/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc
...ivate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc
...c/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc
+3
-0
mindspore/ccsrc/pre_activate/common/helper.cc
mindspore/ccsrc/pre_activate/common/helper.cc
+10
-0
mindspore/ccsrc/pre_activate/common/helper.h
mindspore/ccsrc/pre_activate/common/helper.h
+4
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+3
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
浏览文件 @
8463731b
...
...
@@ -109,6 +109,9 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
auto
new_node
=
CreateAdamApplyOneNode
(
func_graph
,
equiv
);
MS_EXCEPTION_IF_NULL
(
new_node
);
new_node
->
set_scope
(
node
->
scope
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
浏览文件 @
8463731b
...
...
@@ -146,7 +146,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
if
(
graph
==
nullptr
||
node
==
nullptr
||
equiv
==
nullptr
)
{
return
nullptr
;
}
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
inputs
=
GetFusionNodeInputs
(
equiv
);
auto
fusion_node
=
graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
fusion_node
);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
浏览文件 @
8463731b
...
...
@@ -108,6 +108,9 @@ bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2
const
AnfNodePtr
LambNextMVRule
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
old_pattern_outputs
;
if
(
!
IsRuleMatched
(
func_graph
,
node
,
equiv
,
&
old_pattern_outputs
))
{
return
nullptr
;
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
浏览文件 @
8463731b
...
...
@@ -88,6 +88,9 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
AnfNodePtr
mul4
=
GetAnfNodeByVar
(
equiv
,
mul4_var_
);
MS_EXCEPTION_IF_NULL
(
mul4
);
// Get add3 and match the add3 pattern
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc
浏览文件 @
8463731b
...
...
@@ -153,6 +153,9 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra
if
(
func_graph
==
nullptr
||
node
==
nullptr
||
equiv
==
nullptr
)
{
return
nullptr
;
}
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
AnfNodePtr
mul4
=
nullptr
;
AnfNodePtr
real_div0
=
nullptr
;
AnfNodePtr
real_div1
=
nullptr
;
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
浏览文件 @
8463731b
...
...
@@ -61,6 +61,9 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
auto
new_node
=
CreateLambNextRightNode
(
func_graph
,
equiv
);
MS_EXCEPTION_IF_NULL
(
new_node
);
// Set abstract of new node
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc
浏览文件 @
8463731b
...
...
@@ -50,6 +50,9 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
auto
input0
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input0_
]);
auto
input1
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input1_
]);
auto
input2
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input2_
]);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc
浏览文件 @
8463731b
...
...
@@ -42,6 +42,9 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
equiv
);
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kLambUpdateWithLrV2OpName
);
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
)};
(
void
)
std
::
transform
(
input_varptr_
.
begin
(),
input_varptr_
.
end
(),
std
::
back_inserter
(
inputs
),
...
...
mindspore/ccsrc/pre_activate/common/helper.cc
浏览文件 @
8463731b
...
...
@@ -765,5 +765,15 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
MS_EXCEPTION_IF_NULL
(
cnode
);
return
AnfAlgo
::
HasNodeAttr
(
attr_name
,
cnode
)
&&
AnfAlgo
::
GetNodeAttr
<
bool
>
(
node
,
attr_name
);
}
bool
CheckSupportDataType
(
const
AnfNodePtr
&
node
,
const
std
::
set
<
TypeId
>
&
supported_data_type_set
)
{
MS_EXCEPTION_IF_NULL
(
node
);
TypeId
data_type
=
AnfAlgo
::
GetOutputInferDataType
(
node
,
0
);
if
(
supported_data_type_set
.
find
(
data_type
)
!=
supported_data_type_set
.
end
())
{
return
true
;
}
MS_LOG
(
DEBUG
)
<<
"Not supported data type. Node:"
<<
node
->
DebugString
();
return
false
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/common/helper.h
浏览文件 @
8463731b
...
...
@@ -20,6 +20,7 @@
#include <memory>
#include <utility>
#include <string>
#include <set>
#include <unordered_set>
#include "ir/func_graph.h"
#include "session/kernel_graph.h"
...
...
@@ -189,6 +190,9 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);
// Get attr which is bool from cnode
bool
GetBoolAttr
(
const
AnfNodePtr
&
node
,
const
std
::
string
&
attr_name
);
// Check node's data type is in supported data type set
bool
CheckSupportDataType
(
const
AnfNodePtr
&
node
,
const
std
::
set
<
TypeId
>
&
supported_data_type_set
);
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
mindspore/ccsrc/utils/utils.h
浏览文件 @
8463731b
...
...
@@ -25,6 +25,7 @@
#include <set>
#include "utils/log_adapter.h"
#include "ir/dtype/type.h"
namespace
mindspore
{
// op name. Op which not exists in operator/ops.h, so define it's name here
...
...
@@ -270,6 +271,8 @@ const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFo
kOpFormat_FRAC_NZ
,
kOpFormat_C1HWNCoC0
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
};
const
std
::
set
<
TypeId
>
kFloatDataTypeSet
=
{
kNumberTypeFloat16
,
kNumberTypeFloat32
};
static
inline
void
ChangeFileMode
(
const
std
::
string
&
file_name
,
mode_t
mode
)
{
try
{
if
(
chmod
(
file_name
.
c_str
(),
mode
)
!=
0
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录