Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
60e38491
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
60e38491
编写于
7月 03, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reselect the domask's child node after rectify the node domask
上级
17319d8d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
66 addition
and
36 deletion
+66
-36
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+2
-0
mindspore/ccsrc/kernel/kernel_build_info.h
mindspore/ccsrc/kernel/kernel_build_info.h
+2
-0
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
...ctivate/ascend/format_type/rectify_do_mask_kernel_info.cc
+54
-33
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h
...activate/ascend/format_type/rectify_do_mask_kernel_info.h
+8
-3
未找到文件。
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
60e38491
...
...
@@ -119,6 +119,8 @@ bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_
bool
KernelBuildInfo
::
IsOutputDefaultPadding
()
const
{
return
output_reshape_type_
.
empty
();
}
bool
KernelBuildInfo
::
operator
!=
(
const
KernelBuildInfo
&
other
)
const
{
return
!
((
*
this
)
==
other
);
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetKernelType
(
const
KernelType
&
kernel_type
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
kernel_type_
=
kernel_type
;
...
...
mindspore/ccsrc/kernel/kernel_build_info.h
浏览文件 @
60e38491
...
...
@@ -85,6 +85,8 @@ class KernelBuildInfo {
bool
operator
==
(
const
KernelBuildInfo
&
other
)
const
;
bool
operator
!=
(
const
KernelBuildInfo
&
other
)
const
;
public:
static
auto
constexpr
kInvalidFormat
=
"InvalidFormat"
;
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
浏览文件 @
60e38491
...
...
@@ -26,6 +26,7 @@
#include "utils/utils.h"
#include "kernel/common_utils.h"
#include "utils/context/ms_context.h"
#include "pre_activate/common/helper.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -50,16 +51,11 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
return
nullptr
;
}
std
::
vector
<
CNodePtr
>
do_mask_node_list
;
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
node_map
=
manager
->
node_users
();
auto
iter
=
node_map
.
find
(
node
);
if
(
iter
==
node_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot find the node "
<<
node
->
DebugString
()
<<
" in the graph manager!"
;
}
auto
gen_mask_output_nodes
=
iter
->
second
;
for
(
const
auto
&
output_node
:
gen_mask_output_nodes
)
{
auto
gen_mask_output_nodes
=
GetRealNodeUsedList
(
graph
,
cnode
);
MS_EXCEPTION_IF_NULL
(
gen_mask_output_nodes
);
for
(
const
auto
&
output_node
:
*
gen_mask_output_nodes
)
{
if
(
AnfAlgo
::
GetCNodeName
(
output_node
.
first
)
==
prim
::
kPrimDropoutDoMask
->
name
())
{
MS_EXCEPTION_IF_NULL
(
output_node
.
first
);
auto
output_cnode
=
output_node
.
first
->
cast
<
CNodePtr
>
();
do_mask_node_list
.
push_back
(
output_cnode
);
}
...
...
@@ -76,11 +72,12 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
<<
" GenMask "
<<
node
->
DebugString
();
}
}
RectifyKernelInfo
(
do_mask_node_list
);
RectifyKernelInfo
(
do_mask_node_list
,
graph
);
return
nullptr
;
}
void
RectifyDoMaskKernelInfo
::
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
)
const
{
void
RectifyDoMaskKernelInfo
::
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
FuncGraphPtr
&
graph
)
const
{
std
::
map
<
std
::
string
,
size_t
>
format_counter
;
std
::
string
special_format
;
std
::
string
convert_format
;
...
...
@@ -94,17 +91,6 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
}
else
{
format_counter
[
do_mask_data_format
]
=
format_counter
[
do_mask_data_format
]
+
1
;
}
// if has two or more special format we need change all domask's format to default that can avoid insert more
// transdata
if
(
format_counter
.
size
()
>
2
)
{
convert_format
=
kOpFormat_DEFAULT
;
break
;
}
if
(
kHWSpecialFormatSet
.
find
(
do_mask_data_format
)
!=
kHWSpecialFormatSet
.
end
()
&&
special_format
!=
do_mask_data_format
)
{
convert_format
=
kOpFormat_DEFAULT
;
break
;
}
}
if
(
format_counter
.
size
()
==
1
)
{
return
;
...
...
@@ -112,17 +98,23 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
if
(
convert_format
.
empty
())
{
convert_format
=
GetConvertFormat
(
format_counter
);
}
RectifyDropOutDoMaskKernelInfo
(
do_mask_node_list
,
convert_format
);
RectifyDropOutDoMaskKernelInfo
(
do_mask_node_list
,
convert_format
,
graph
);
}
std
::
string
RectifyDoMaskKernelInfo
::
GetConvertFormat
(
const
std
::
map
<
std
::
string
,
size_t
>
&
format_counter
)
const
{
std
::
string
convert_format
;
const
size_t
counter
=
0
;
std
::
string
convert_format
=
kOpFormat_DEFAULT
;
size_t
counter
=
0
;
if
(
format_counter
.
size
()
>
2
)
{
return
kOpFormat_DEFAULT
;
}
if
(
format_counter
.
size
()
==
2
&&
format_counter
.
find
(
kOpFormat_DEFAULT
)
==
format_counter
.
end
())
{
return
kOpFormat_DEFAULT
;
}
for
(
const
auto
&
iter
:
format_counter
)
{
if
(
counter
<
iter
.
second
)
{
convert_format
=
iter
.
first
;
}
if
(
counter
==
iter
.
second
&&
kHWSpecialFormatSet
.
find
(
convert_format
)
=
=
kHWSpecialFormatSet
.
end
())
{
counter
=
iter
.
second
;
}
else
if
(
counter
==
iter
.
second
&&
kHWSpecialFormatSet
.
find
(
iter
.
first
)
!
=
kHWSpecialFormatSet
.
end
())
{
convert_format
=
iter
.
first
;
}
}
...
...
@@ -130,13 +122,17 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
}
void
RectifyDoMaskKernelInfo
::
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
{
const
std
::
string
&
format
,
const
FuncGraphPtr
&
graph
)
const
{
for
(
const
auto
&
do_mask
:
do_mask_node_list
)
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
do_mask
));
builder
->
SetInputFormat
(
format
,
0
);
builder
->
SetOutputFormat
(
format
,
0
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
do_mask
.
get
());
if
(
AnfAlgo
::
GetInputFormat
(
do_mask
,
0
)
!=
format
)
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
do_mask
));
builder
->
SetInputFormat
(
format
,
0
);
builder
->
SetOutputFormat
(
format
,
0
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
do_mask
.
get
());
ReSelecChildNodeKernelInfo
(
do_mask
,
graph
);
}
}
}
...
...
@@ -159,5 +155,30 @@ AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const Anf
}
return
nullptr
;
}
void
RectifyDoMaskKernelInfo
::
ReSelecChildNodeKernelInfo
(
const
CNodePtr
&
cnode
,
const
FuncGraphPtr
&
graph
)
const
{
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
output_node_list
=
GetRealNodeUsedList
(
graph
,
cnode
);
MS_EXCEPTION_IF_NULL
(
output_node_list
);
for
(
const
auto
&
out_node_info
:
*
output_node_list
)
{
MS_EXCEPTION_IF_NULL
(
out_node_info
.
first
);
auto
out_node
=
out_node_info
.
first
->
cast
<
CNodePtr
>
();
if
(
AnfAlgo
::
IsRealKernel
(
out_node_info
.
first
))
{
auto
ori_build_info
=
AnfAlgo
::
GetSelectKernelBuildInfo
(
out_node
);
kernel_selecter
->
SelectKernel
(
out_node
);
auto
new_build_info
=
AnfAlgo
::
GetSelectKernelBuildInfo
(
out_node
);
MS_EXCEPTION_IF_NULL
(
new_build_info
);
MS_EXCEPTION_IF_NULL
(
ori_build_info
);
if
((
*
new_build_info
)
!=
(
*
ori_build_info
))
{
ReSelecChildNodeKernelInfo
(
out_node
,
graph
);
}
}
else
if
(
AnfAlgo
::
GetCNodeName
(
out_node
)
==
prim
::
kPrimTupleGetItem
->
name
()
||
AnfAlgo
::
GetCNodeName
(
out_node
)
==
prim
::
kPrimDepend
->
name
())
{
ReSelecChildNodeKernelInfo
(
out_node
,
graph
);
}
else
{
MS_LOG
(
INFO
)
<<
"Reselected the node "
<<
cnode
->
DebugString
()
<<
" failed"
;
}
}
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h
浏览文件 @
60e38491
...
...
@@ -19,23 +19,28 @@
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
class
RectifyDoMaskKernelInfo
:
public
PatternProcessPass
{
public:
explicit
RectifyDoMaskKernelInfo
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"batch_norm_bert_fission"
,
multigraph
)
{}
:
PatternProcessPass
(
"batch_norm_bert_fission"
,
multigraph
)
,
kernel_selecter
(
std
::
make_shared
<
KernelSelect
>
())
{}
~
RectifyDoMaskKernelInfo
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
void
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
)
const
;
void
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
FuncGraphPtr
&
graph
)
const
;
AnfNodePtr
RectifyKernelInfoInPynativeProcess
(
const
AnfNodePtr
&
node
)
const
;
std
::
string
GetConvertFormat
(
const
std
::
map
<
std
::
string
,
size_t
>
&
format_counter
)
const
;
void
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
;
void
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
,
const
FuncGraphPtr
&
graph
)
const
;
void
ReSelecChildNodeKernelInfo
(
const
CNodePtr
&
cnode
,
const
FuncGraphPtr
&
graph
)
const
;
KernelSelectPtr
kernel_selecter
;
};
}
// namespace opt
}
// namespace mindspore
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录