Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
28f873e9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
28f873e9
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!3203 GPU fix cast fusion bug
Merge pull request !3203 from VectorSL/fix-cast-fusion
上级
fd9619bb
80ed8e0e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
58 addition
and
48 deletion
+58
-48
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
...ore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
+26
-25
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc
...csrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc
+32
-23
未找到文件。
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
浏览文件 @
28f873e9
...
...
@@ -30,8 +30,7 @@ const BaseRef ReplaceBNCastFusion::DefinePattern() const {
VectorRef
in_cast
=
VectorRef
({
prim
::
kPrimCast
,
x_
});
VectorRef
fbn2
=
VectorRef
({
prim
::
kPrimFusedBatchNorm
,
in_cast
,
scale_
,
bias_
,
mean_
,
var_
});
VectorRef
tupleget
=
VectorRef
({
prim
::
kPrimTupleGetItem
,
fbn2
,
index_
});
VectorRef
out_cast
=
VectorRef
({
prim
::
kPrimCast
,
tupleget
});
return
out_cast
;
return
tupleget
;
}
const
AnfNodePtr
ReplaceBNCastFusion
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
...
...
@@ -40,19 +39,9 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
tuple
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
node
),
0
);
auto
index_node
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
tuple
),
1
);
MS_EXCEPTION_IF_NULL
(
index_node
);
auto
value_node
=
index_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
int
item_idx
=
GetValue
<
int
>
(
value_node
->
value
());
auto
fbn2
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
tuple
),
0
);
auto
fbn2
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
node
),
0
);
auto
x_after
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2
),
0
);
auto
x_before
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
x_after
),
0
);
if
(
item_idx
!=
0
)
{
return
nullptr
;
}
auto
scale
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2
),
1
);
auto
bias
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2
),
2
);
auto
mean
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2
),
3
);
...
...
@@ -65,14 +54,32 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL
(
bias
);
MS_EXCEPTION_IF_NULL
(
mean
);
MS_EXCEPTION_IF_NULL
(
var
);
std
::
vector
<
TypeId
>
outputs_type
;
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
x_after
),
utils
::
cast
<
CNodePtr
>
(
x_before
));
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
node
),
utils
::
cast
<
CNodePtr
>
(
tuple
));
std
::
vector
<
TypeId
>
outputs_type
;
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
auto
outlist
=
GetRealNodeUsedList
(
graph
,
fbn2
);
for
(
size_t
i
=
0
;
i
<
outlist
->
size
();
i
++
)
{
auto
index_node
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
outlist
->
at
(
i
).
first
),
1
);
auto
value_node
=
index_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
int
item_idx
=
GetValue
<
int
>
(
value_node
->
value
());
if
(
item_idx
==
0
)
{
auto
cast
=
GetRealNodeUsedList
(
graph
,
outlist
->
at
(
i
).
first
);
if
(
AnfAlgo
::
GetCNodeName
(
cast
->
at
(
0
).
first
)
!=
"Cast"
)
{
return
nullptr
;
}
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
cast
->
at
(
0
).
first
),
utils
::
cast
<
CNodePtr
>
(
outlist
->
at
(
i
).
first
));
outputs_type
.
push_back
(
kNumberTypeFloat16
);
outputs_shape
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
outlist
->
at
(
i
).
first
,
0
));
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
outlist
->
at
(
i
).
first
.
get
());
}
}
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
x_after
),
utils
::
cast
<
CNodePtr
>
(
x_before
));
outputs_type
.
clear
();
outputs_shape
.
clear
();
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
fbn2
);
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
outputs_type
.
push_back
(
AnfAlgo
::
GetOutputInferDataType
(
fbn2
,
i
));
...
...
@@ -80,13 +87,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
}
outputs_type
[
0
]
=
kNumberTypeFloat16
;
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
fbn2
.
get
());
outputs_type
.
clear
();
outputs_shape
.
clear
();
outputs_type
.
push_back
(
kNumberTypeFloat16
);
outputs_shape
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
tuple
,
0
));
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
tuple
.
get
());
return
tuple
;
return
node
;
}
}
// namespace opt
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc
浏览文件 @
28f873e9
...
...
@@ -30,8 +30,7 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const {
VectorRef
dy_cast
=
VectorRef
({
prim
::
kPrimCast
,
dy_
});
VectorRef
fbn2g
=
VectorRef
({
prim
::
kPrimFusedBatchNormGrad
,
dy_cast
,
x_
,
scale_
,
mean_
,
var_
});
VectorRef
tupleget
=
VectorRef
({
prim
::
kPrimTupleGetItem
,
fbn2g
,
index_
});
VectorRef
out_cast
=
VectorRef
({
prim
::
kPrimCast
,
tupleget
});
return
out_cast
;
return
tupleget
;
}
const
AnfNodePtr
ReplaceBNGradCastFusion
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
...
...
@@ -40,21 +39,16 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
tuple
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
node
),
0
);
auto
index_node
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
tuple
),
1
);
MS_EXCEPTION_IF_NULL
(
index_node
);
auto
value_node
=
index_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
int
item_idx
=
GetValue
<
int
>
(
value_node
->
value
());
if
(
item_idx
!=
0
)
{
return
nullptr
;
}
auto
fbn2g
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
tuple
),
0
);
auto
fbn2g
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
node
),
0
);
auto
dy_after
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2g
),
0
);
auto
dy_before
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
dy_after
),
0
);
auto
x_
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2g
),
1
);
auto
x_type
=
AnfAlgo
::
GetOutputInferDataType
(
x_
,
0
);
// if x_type is fp32, the cast is nessery.
if
(
x_type
==
kNumberTypeFloat32
)
{
return
nullptr
;
}
auto
scale
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2g
),
2
);
auto
mean
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2g
),
3
);
auto
var
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
fbn2g
),
4
);
...
...
@@ -66,13 +60,32 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL
(
x_
);
MS_EXCEPTION_IF_NULL
(
mean
);
MS_EXCEPTION_IF_NULL
(
var
);
std
::
vector
<
TypeId
>
outputs_type
;
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
outlist
=
GetRealNodeUsedList
(
graph
,
fbn2g
);
for
(
size_t
i
=
0
;
i
<
outlist
->
size
();
i
++
)
{
auto
index_node
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
outlist
->
at
(
i
).
first
),
1
);
auto
value_node
=
index_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
int
item_idx
=
GetValue
<
int
>
(
value_node
->
value
());
if
(
item_idx
==
0
)
{
auto
cast
=
GetRealNodeUsedList
(
graph
,
outlist
->
at
(
i
).
first
);
if
(
AnfAlgo
::
GetCNodeName
(
cast
->
at
(
0
).
first
)
!=
"Cast"
)
{
return
nullptr
;
}
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
cast
->
at
(
0
).
first
),
utils
::
cast
<
CNodePtr
>
(
outlist
->
at
(
i
).
first
));
outputs_type
.
push_back
(
kNumberTypeFloat16
);
outputs_shape
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
outlist
->
at
(
i
).
first
,
0
));
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
outlist
->
at
(
i
).
first
.
get
());
}
}
outputs_type
.
clear
();
outputs_shape
.
clear
();
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
dy_after
),
utils
::
cast
<
CNodePtr
>
(
dy_before
));
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
node
),
utils
::
cast
<
CNodePtr
>
(
tuple
));
std
::
vector
<
TypeId
>
outputs_type
;
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
fbn2g
);
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
outputs_type
.
push_back
(
AnfAlgo
::
GetOutputInferDataType
(
fbn2g
,
i
));
...
...
@@ -80,12 +93,8 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
}
outputs_type
[
0
]
=
kNumberTypeFloat16
;
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
fbn2g
.
get
());
outputs_type
.
clear
();
outputs_shape
.
clear
();
outputs_type
.
push_back
(
kNumberTypeFloat16
);
outputs_shape
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
tuple
,
0
));
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
tuple
.
get
());
return
tuple
;
return
node
;
}
}
// namespace opt
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录