Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
ee79023e
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ee79023e
编写于
6月 19, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean pclint warning
上级
e9670f3c
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
53 addition
and
38 deletion
+53
-38
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc
...uffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc
+0
-3
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
...ctivate/ascend/format_type/rectify_do_mask_kernel_info.cc
+21
-12
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h
...activate/ascend/format_type/rectify_do_mask_kernel_info.h
+1
-0
mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc
mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc
+30
-23
mindspore/ccsrc/pre_activate/pass/optimize_dependence.h
mindspore/ccsrc/pre_activate/pass/optimize_dependence.h
+1
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc
浏览文件 @
ee79023e
...
@@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
...
@@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
manager
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
};
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
};
auto
write_input
=
cnode
->
input
(
1
);
auto
write_input
=
cnode
->
input
(
1
);
if
(
CheckEltWiseNode
(
manager
.
get
(),
write_input
))
{
if
(
CheckEltWiseNode
(
manager
.
get
(),
write_input
))
{
(
void
)
record
.
insert
(
write_input
);
(
void
)
record
.
insert
(
write_input
);
auto
input_cnode
=
write_input
->
cast
<
CNodePtr
>
();
auto
input_cnode
=
write_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
input_cnode
);
MS_EXCEPTION_IF_NULL
(
input_cnode
);
write_input
=
input_cnode
->
input
(
1
);
write_input
=
input_cnode
->
input
(
1
);
}
}
MS_EXCEPTION_IF_NULL
(
write_input
);
MS_EXCEPTION_IF_NULL
(
write_input
);
if
(
!
write_input
->
isa
<
CNode
>
()
||
!
AnfAlgo
::
IsRealCNodeKernel
(
write_input
)
||
if
(
!
write_input
->
isa
<
CNode
>
()
||
!
AnfAlgo
::
IsRealCNodeKernel
(
write_input
)
||
fusion_id_allocator
->
HasFusionIdAttr
(
write_input
))
{
fusion_id_allocator
->
HasFusionIdAttr
(
write_input
))
{
...
@@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
...
@@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
fusion_id_allocator
->
HasFusionIdAttr
(
conv_input
))
{
fusion_id_allocator
->
HasFusionIdAttr
(
conv_input
))
{
return
;
return
;
}
}
if
(
AnfAlgo
::
GetCNodeName
(
conv_input
)
==
kStridedReadOpName
)
{
if
(
AnfAlgo
::
GetCNodeName
(
conv_input
)
==
kStridedReadOpName
)
{
(
void
)
record
.
insert
(
conv_input
);
(
void
)
record
.
insert
(
conv_input
);
candidate_fusion
->
push_back
(
record
);
candidate_fusion
->
push_back
(
record
);
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
浏览文件 @
ee79023e
...
@@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
...
@@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto
ms_context
=
MsContext
::
GetInstance
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
)
{
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
)
{
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutDoMask
->
name
())
{
return
RectifyKernelInfoInPynativeProcess
(
node
);
return
nullptr
;
}
auto
do_mask_input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
if
(
do_mask_input_format
!=
kOpFormat_DEFAULT
)
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
node
));
builder
->
SetInputFormat
(
kOpFormat_DEFAULT
,
0
);
builder
->
SetOutputFormat
(
kOpFormat_DEFAULT
,
0
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
node
.
get
());
}
return
nullptr
;
}
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutGenMask
->
name
())
{
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutGenMask
->
name
())
{
return
nullptr
;
return
nullptr
;
...
@@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
...
@@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
}
}
return
convert_format
;
return
convert_format
;
}
}
void
RectifyDoMaskKernelInfo
::
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
void
RectifyDoMaskKernelInfo
::
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
{
const
std
::
string
&
format
)
const
{
for
(
const
auto
&
do_mask
:
do_mask_node_list
)
{
for
(
const
auto
&
do_mask
:
do_mask_node_list
)
{
...
@@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C
...
@@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C
}
}
}
}
AnfNodePtr
RectifyDoMaskKernelInfo
::
RectifyKernelInfoInPynativeProcess
(
const
AnfNodePtr
&
node
)
const
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
)
{
return
nullptr
;
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutDoMask
->
name
())
{
return
nullptr
;
}
auto
do_mask_input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
if
(
do_mask_input_format
!=
kOpFormat_DEFAULT
)
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
AnfAlgo
::
GetSelectKernelBuildInfo
(
node
));
builder
->
SetInputFormat
(
kOpFormat_DEFAULT
,
0
);
builder
->
SetOutputFormat
(
kOpFormat_DEFAULT
,
0
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
node
.
get
());
}
return
nullptr
;
}
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h
浏览文件 @
ee79023e
...
@@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
...
@@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass {
private:
private:
void
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
)
const
;
void
RectifyKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
)
const
;
AnfNodePtr
RectifyKernelInfoInPynativeProcess
(
const
AnfNodePtr
&
node
)
const
;
std
::
string
GetConvertFormat
(
const
std
::
map
<
std
::
string
,
size_t
>
&
format_counter
)
const
;
std
::
string
GetConvertFormat
(
const
std
::
map
<
std
::
string
,
size_t
>
&
format_counter
)
const
;
void
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
;
void
RectifyDropOutDoMaskKernelInfo
(
const
std
::
vector
<
CNodePtr
>
&
do_mask_node_list
,
const
std
::
string
&
format
)
const
;
};
};
...
...
mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc
浏览文件 @
ee79023e
...
@@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
...
@@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
}
}
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
depend_cnode
);
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
depend_cnode
);
while
(
index
<
input_num
)
{
while
(
index
<
input_num
)
{
auto
replacing_node
=
AnfAlgo
::
GetInputNode
(
depend_cnode
,
index
);
auto
replace_node
=
GetConvertNode
(
func_graph
,
node
,
index
);
++
index
;
MS_EXCEPTION_IF_NULL
(
replace_node
);
MS_EXCEPTION_IF_NULL
(
replacing_node
);
if
(
!
replacing_node
->
isa
<
CNode
>
())
{
new_depend_inputs
.
push_back
(
replacing_node
);
continue
;
}
auto
replacing_cnode
=
replacing_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
replacing_cnode
);
// Deal with the make_tuple with TransData or Cast inputs.
auto
make_tuple_replace_node
=
ReplaceMakeTuple
(
func_graph
,
replacing_cnode
);
if
(
make_tuple_replace_node
!=
nullptr
)
{
new_depend_inputs
.
push_back
(
make_tuple_replace_node
);
continue
;
}
AnfNodePtr
replace_node
=
GetReplaceNode
(
replacing_cnode
);
if
(
replace_node
==
nullptr
)
{
new_depend_inputs
.
push_back
(
replacing_node
);
MS_LOG
(
DEBUG
)
<<
"Can not find the TransData or Cast with single output node. Depend node: "
<<
node
->
DebugString
();
continue
;
}
new_depend_inputs
.
push_back
(
replace_node
);
new_depend_inputs
.
push_back
(
replace_node
);
++
index
;
}
}
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
auto
kernel_graph
=
func_graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
CNodePtr
new_depend
;
CNodePtr
new_depend
=
nullptr
;
if
(
kernel_graph
==
nullptr
)
{
if
(
kernel_graph
==
nullptr
)
{
new_depend
=
func_graph
->
NewCNode
(
new_depend_inputs
);
new_depend
=
func_graph
->
NewCNode
(
new_depend_inputs
);
MS_EXCEPTION_IF_NULL
(
new_depend
);
MS_EXCEPTION_IF_NULL
(
new_depend
);
...
@@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
...
@@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
}
}
return
new_depend
;
return
new_depend
;
}
}
const
AnfNodePtr
OptimizeDependence
::
GetConvertNode
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
const
size_t
index
)
const
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
depend_cnode
=
node
->
cast
<
CNodePtr
>
();
auto
replacing_node
=
AnfAlgo
::
GetInputNode
(
depend_cnode
,
index
);
MS_EXCEPTION_IF_NULL
(
replacing_node
);
if
(
!
replacing_node
->
isa
<
CNode
>
())
{
return
replacing_node
;
}
auto
replacing_cnode
=
replacing_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
replacing_cnode
);
// Deal with the make_tuple with TransData or Cast inputs.
auto
make_tuple_replace_node
=
ReplaceMakeTuple
(
graph
,
replacing_cnode
);
if
(
make_tuple_replace_node
!=
nullptr
)
{
return
make_tuple_replace_node
;
}
AnfNodePtr
replace_node
=
GetReplaceNode
(
replacing_cnode
);
if
(
replace_node
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Can not find the TransData or Cast with single output node. Depend node: "
<<
node
->
DebugString
();
return
replacing_node
;
}
return
replace_node
;
}
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/pre_activate/pass/optimize_dependence.h
浏览文件 @
ee79023e
...
@@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass {
...
@@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass {
~
OptimizeDependence
()
override
=
default
;
~
OptimizeDependence
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
const
AnfNodePtr
GetConvertNode
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
const
size_t
index
)
const
;
};
};
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录