Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4cffb0a3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4cffb0a3
编写于
6月 03, 2020
作者:
Z
zhoufeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
New control sink support dynamic loss scale
Signed-off-by:
N
zhoufeng
<
zhoufeng54@huawei.com
>
上级
71dce2f5
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
211 addition
and
57 deletion
+211
-57
mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc
...pre_activate/pass/convert_tuple_input_to_dynamic_input.cc
+1
-1
mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc
...rc/pre_activate/pass/convert_tuple_output_to_maketuple.cc
+3
-2
mindspore/ccsrc/session/ascend_control_parser.cc
mindspore/ccsrc/session/ascend_control_parser.cc
+110
-23
mindspore/ccsrc/session/ascend_control_parser.h
mindspore/ccsrc/session/ascend_control_parser.h
+0
-3
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+1
-10
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+6
-16
mindspore/ccsrc/session/kernel_graph.h
mindspore/ccsrc/session/kernel_graph.h
+1
-1
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+4
-1
mindspore/ccsrc/utils/union_find_set.h
mindspore/ccsrc/utils/union_find_set.h
+85
-0
未找到文件。
mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc
浏览文件 @
4cffb0a3
...
...
@@ -69,7 +69,7 @@ CNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNo
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
inputs
=
cnode
->
inputs
();
(
void
)
std
::
copy
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
back_inserter
(
plant_inputs
));
}
else
if
(
AnfAlgo
::
IsTupleOutput
(
input_node
))
{
}
else
if
(
input_node
->
Type
()
!=
nullptr
&&
AnfAlgo
::
IsTupleOutput
(
input_node
))
{
ConvertTupleOuputToPlantInputs
(
graph
,
input_node
,
&
plant_inputs
,
&
dyn_input_sizes
);
}
else
{
dyn_input_sizes
.
push_back
(
-
1
);
...
...
mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc
浏览文件 @
4cffb0a3
...
...
@@ -68,8 +68,9 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
if
(
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimTupleGetItem
)
||
IsPrimitiveCNode
(
cnode
,
prim
::
kPrimControlDepend
))
{
return
nullptr
;
}
if
(
std
::
any_of
(
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
(),
[](
const
AnfNodePtr
&
node
)
{
return
AnfAlgo
::
IsRealKernel
(
node
)
&&
AnfAlgo
::
IsTupleOutput
(
node
);
}))
{
if
(
std
::
any_of
(
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
(),
[](
const
AnfNodePtr
&
node
)
{
return
node
->
Type
()
!=
nullptr
&&
AnfAlgo
::
IsRealKernel
(
node
)
&&
AnfAlgo
::
IsTupleOutput
(
node
);
}))
{
return
ConvertTupleInputToMakeTuple
(
func_graph
,
cnode
);
}
return
nullptr
;
...
...
mindspore/ccsrc/session/ascend_control_parser.cc
浏览文件 @
4cffb0a3
...
...
@@ -18,6 +18,7 @@
#include <memory>
#include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/union_find_set.h"
static
constexpr
size_t
kCNodePrim
=
0
;
static
constexpr
size_t
kCNodeCallArg
=
1
;
...
...
@@ -57,6 +58,110 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
}
}
static
void
InitUnionFindSet
(
NotNull
<
KernelGraphPtr
>
kg
,
const
NotNull
<
UnionFindSet
<
AnfNodePtr
>
*>
union_find_set
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
if
(
memo
->
find
(
kg
.
get
())
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
kg
.
get
());
const
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
&
real_inputs
=
kg
->
real_inputs
();
for
(
auto
&
iter
:
real_inputs
)
{
auto
&
para
=
iter
.
first
;
if
(
para
->
isa
<
Parameter
>
())
{
union_find_set
->
Add
(
para
);
}
for
(
auto
&
arg
:
iter
.
second
)
{
if
(
!
arg
->
isa
<
Parameter
>
())
{
continue
;
}
union_find_set
->
Add
(
arg
);
}
}
for
(
auto
&
child
:
kg
->
child_graph_order
())
{
InitUnionFindSet
(
NOT_NULL
(
child
),
union_find_set
,
memo
);
}
}
static
void
UnionParentParameter
(
NotNull
<
KernelGraphPtr
>
kg
,
const
NotNull
<
UnionFindSet
<
AnfNodePtr
>
*>
union_find_set
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
if
(
memo
->
find
(
kg
.
get
())
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
kg
.
get
());
const
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
&
real_inputs
=
kg
->
real_inputs
();
for
(
auto
&
iter
:
real_inputs
)
{
auto
&
para
=
iter
.
first
;
for
(
auto
&
arg
:
iter
.
second
)
{
if
(
!
arg
->
isa
<
Parameter
>
())
{
continue
;
}
union_find_set
->
Union
(
arg
,
para
);
}
}
for
(
auto
&
child
:
kg
->
child_graph_order
())
{
UnionParentParameter
(
NOT_NULL
(
child
),
union_find_set
,
memo
);
}
}
static
UnionFindSet
<
AnfNodePtr
>
MakeUnionFindSet
(
NotNull
<
KernelGraphPtr
>
root_kg
)
{
UnionFindSet
<
AnfNodePtr
>
result
;
std
::
set
<
KernelGraphPtr
>
memo
;
InitUnionFindSet
(
root_kg
,
NOT_NULL
(
&
result
),
NOT_NULL
(
&
memo
));
memo
.
clear
();
UnionParentParameter
(
root_kg
,
NOT_NULL
(
&
result
),
NOT_NULL
(
&
memo
));
return
result
;
}
static
void
RecursiveReplaceNode
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
main_parameter
,
const
std
::
set
<
AnfNodePtr
>
&
parameter_reuse_set
,
const
NotNull
<
std
::
set
<
KernelGraphPtr
>
*>
memo
)
{
if
(
parameter_reuse_set
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"parameter_reuse_set is empty."
;
}
if
(
memo
->
find
(
kg
.
get
())
!=
memo
->
end
())
{
return
;
}
memo
->
insert
(
kg
.
get
());
for
(
auto
&
para
:
parameter_reuse_set
)
{
if
(
para
==
main_parameter
.
get
())
{
continue
;
}
MS_LOG
(
INFO
)
<<
"Replace "
<<
para
->
DebugString
()
<<
" of graph "
<<
AnfAlgo
::
GetGraphId
(
para
.
get
())
<<
" to "
<<
main_parameter
->
DebugString
()
<<
" of graph "
<<
AnfAlgo
::
GetGraphId
(
main_parameter
.
get
().
get
());
kg
->
ReplaceNode
(
NOT_NULL
(
para
),
main_parameter
);
}
for
(
auto
&
child
:
kg
->
child_graph_order
())
{
RecursiveReplaceNode
(
NOT_NULL
(
child
),
main_parameter
,
parameter_reuse_set
,
memo
);
}
}
static
void
ReuseParameter
(
NotNull
<
KernelGraphPtr
>
root_kg
,
NotNull
<
UnionFindSet
<
AnfNodePtr
>
*>
parameter_set
)
{
auto
parameter_reuse_sets
=
parameter_set
->
GetSets
();
for
(
auto
&
[
key
,
parameter_reuse_set
]
:
parameter_reuse_sets
)
{
if
(
parameter_reuse_set
.
size
()
<=
1
)
{
continue
;
}
AnfNodePtr
main_parameter
=
key
;
std
::
set
<
AnfNodePtr
>
root_inputs_set
;
const
auto
&
root_inputs_vector
=
root_kg
->
inputs
();
root_inputs_set
.
insert
(
root_inputs_vector
.
begin
(),
root_inputs_vector
.
end
());
for
(
auto
&
node
:
parameter_reuse_set
)
{
if
(
root_inputs_set
.
find
(
node
)
==
root_inputs_set
.
end
())
{
continue
;
}
main_parameter
=
node
;
}
std
::
set
<
KernelGraphPtr
>
memo
;
RecursiveReplaceNode
(
root_kg
,
NOT_NULL
(
main_parameter
),
parameter_reuse_set
,
NOT_NULL
(
&
memo
));
}
}
void
AscendControlParser
::
LinkGraph
(
NotNull
<
KernelGraphPtr
>
kg
)
{
std
::
set
<
KernelGraphPtr
>
memo
;
ProcessKernelGraph
(
kg
,
nullptr
,
nullptr
,
NOT_NULL
(
&
memo
));
...
...
@@ -68,6 +173,11 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
}
graph_id_map
[
g
->
graph_id
()]
=
g
;
}
// Make UnionFindSet
UnionFindSet
<
AnfNodePtr
>
parameter_set
=
MakeUnionFindSet
(
kg
);
// Reuse Parameter
ReuseParameter
(
kg
,
NOT_NULL
(
&
parameter_set
));
// Insert Assign
ChildGraphDataAssign
(
graph_id_map
);
}
...
...
@@ -324,29 +434,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
InsertDependToGraph
(
kg
,
NOT_NULL
(
assign_node
));
}
void
AscendControlParser
::
LinkArgsToParam
(
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
KernelGraphPtr
>
target_graph
,
NotNull
<
AnfNodePtr
>
arg
,
NotNull
<
AnfNodePtr
>
param
)
{
if
(
IsPrimitiveCNode
(
arg
,
prim
::
kPrimMakeTuple
)
&&
IsPrimitiveCNode
(
param
,
prim
::
kPrimMakeTuple
))
{
MS_LOG
(
INFO
)
<<
"Arg "
<<
arg
->
DebugString
()
<<
" Param "
<<
param
->
DebugString
()
<<
" is a tuple"
;
CNodePtr
cnode_arg
=
arg
.
get
()
->
cast
<
CNodePtr
>
();
CNodePtr
cnode_param
=
param
.
get
()
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode_arg
);
MS_EXCEPTION_IF_NULL
(
cnode_param
);
if
(
cnode_arg
->
size
()
!=
cnode_param
->
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Arg "
<<
arg
->
DebugString
()
<<
" size "
<<
cnode_arg
->
size
()
<<
" but Param "
<<
param
->
DebugString
()
<<
" size "
<<
cnode_param
->
size
();
}
for
(
size_t
i
=
1
;
i
<
cnode_param
->
size
();
++
i
)
{
LinkArgsToParam
(
to_graph
,
target_graph
,
NOT_NULL
(
cnode_arg
->
input
(
i
)),
NOT_NULL
(
cnode_param
->
input
(
i
)));
}
}
else
if
(
arg
->
isa
<
CNode
>
())
{
InsertAssignToGraph
(
target_graph
,
arg
,
param
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Arg "
<<
arg
->
DebugString
()
<<
" Param "
<<
param
->
DebugString
()
<<
" unknown type."
;
}
}
void
AscendControlParser
::
ExecutorValidate
(
NotNull
<
KernelGraphPtr
>
root_graph
)
{
std
::
set
<
KernelGraphPtr
>
memo
;
(
void
)
RecurseGraph
(
root_graph
,
NOT_NULL
(
&
memo
));
...
...
mindspore/ccsrc/session/ascend_control_parser.h
浏览文件 @
4cffb0a3
...
...
@@ -52,9 +52,6 @@ class AscendControlParser {
const
CNodePtr
&
last_label
);
static
std
::
tuple
<
CNodePtr
,
KernelGraphPtr
>
ParsePartial
(
NotNull
<
AnfNodePtr
>
node
);
static
void
LinkArgsToParam
(
NotNull
<
KernelGraphPtr
>
to_graph
,
NotNull
<
KernelGraphPtr
>
target_graph
,
NotNull
<
AnfNodePtr
>
arg
,
NotNull
<
AnfNodePtr
>
param
);
static
void
InsertAssignToGraph
(
NotNull
<
KernelGraphPtr
>
kg
,
NotNull
<
AnfNodePtr
>
from
,
NotNull
<
AnfNodePtr
>
to
);
static
CNodePtr
GetNextRealKernel
(
const
std
::
vector
<
CNodePtr
>
&
list
,
size_t
start
);
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
4cffb0a3
...
...
@@ -224,14 +224,6 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters,
MS_LOG
(
INFO
)
<<
"Parameter and arg are same"
;
continue
;
}
// if arg is a parameter ,then reuse this parameter
if
(
args
[
i
]
->
isa
<
Parameter
>
())
{
MS_LOG
(
INFO
)
<<
"Parameter:"
<<
parameters
[
i
]
->
DebugString
()
<<
" of graph:"
<<
child_graph
->
graph_id
()
<<
" reuse parameter:"
<<
args
[
i
]
->
DebugString
()
<<
" of graph:"
<<
AnfAlgo
::
GetGraphId
(
args
[
i
].
get
());
child_graph
->
ReplaceNode
(
parameters
[
i
],
args
[
i
]);
continue
;
}
child_graph
->
SetRealInput
(
parameters
[
i
],
args
[
i
]);
}
}
...
...
@@ -412,7 +404,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
VectorRef
*
const
outputs
)
{
MS_LOG
(
INFO
)
<<
"start"
;
auto
kernel_graph
=
GetGraph
(
graph_id
);
DumpIR
(
"./run_graph.ir"
,
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
// if none of child graph and no anf output exists
if
(
!
kernel_graph
->
executable
())
{
...
...
@@ -1134,7 +1125,7 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId
MS_EXCEPTION_IF_NULL
(
backend_arg
);
MS_LOG
(
INFO
)
<<
"Reuse node ["
<<
backend_arg
->
DebugString
()
<<
"], old node["
<<
backend_parameter
->
DebugString
()
<<
"] will be replaced."
;
to_graph
->
ReplaceNode
(
backend_parameter
,
backend_arg
);
to_graph
->
ReplaceNode
(
NOT_NULL
(
backend_parameter
),
NOT_NULL
(
backend_arg
)
);
return
;
}
MS_LOG
(
INFO
)
<<
"Assign of node"
<<
backend_arg
->
DebugString
()
<<
" of graph "
<<
from_graph_id
<<
" to node"
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
4cffb0a3
...
...
@@ -587,9 +587,7 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
return
false
;
}
void
KernelGraph
::
ReplaceNode
(
const
AnfNodePtr
&
old_anf_node
,
AnfNodePtr
new_anf_node
)
{
MS_EXCEPTION_IF_NULL
(
old_anf_node
);
MS_EXCEPTION_IF_NULL
(
new_anf_node
);
void
KernelGraph
::
ReplaceNode
(
NotNull
<
AnfNodePtr
>
old_anf_node
,
NotNull
<
AnfNodePtr
>
new_anf_node
)
{
MS_EXCEPTION_IF_NULL
(
inputs_
);
auto
it
=
node_output_edges_
.
find
(
old_anf_node
);
if
(
it
!=
node_output_edges_
.
end
())
{
...
...
@@ -604,16 +602,16 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
continue
;
}
for
(
size_t
i
=
1
;
i
<
output_node_inputs
.
size
();
i
++
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
)
{
if
(
output_node_inputs
[
i
]
==
old_anf_node
.
get
()
)
{
output_cnode
->
set_input
(
i
,
new_anf_node
);
}
}
// update graph inputs
for
(
size_t
i
=
0
;
i
<
inputs_
->
size
();
i
++
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
)
{
if
((
*
inputs_
)[
i
]
==
old_anf_node
.
get
()
)
{
MS_LOG
(
INFO
)
<<
"Replace input of graph:"
<<
graph_id_
<<
", old graph input: "
<<
old_anf_node
->
DebugString
()
<<
",new graph input:"
<<
new_anf_node
->
DebugString
();
(
*
inputs_
)[
i
]
=
new_anf_node
;
(
*
inputs_
)[
i
]
=
new_anf_node
.
get
()
;
break
;
}
}
...
...
@@ -621,7 +619,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
// update front to backend map
FrontBackendlMapUpdate
(
old_anf_node
,
new_anf_node
);
// update output depend relations
node_output_edges_
[
new_anf_node
]
=
it
->
second
;
node_output_edges_
[
new_anf_node
.
get
()
]
=
it
->
second
;
(
void
)
node_output_edges_
.
erase
(
old_anf_node
);
}
// update graph inputs in child graph
...
...
@@ -633,7 +631,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
MS_LOG
(
WARNING
)
<<
new_anf_node
->
DebugString
()
<<
" already exist in real inputs, will be rewrited."
;
iter
->
second
=
it_real_inputs
->
second
;
}
else
{
real_inputs_
[
new_anf_node
]
=
it_real_inputs
->
second
;
real_inputs_
[
new_anf_node
.
get
()
]
=
it_real_inputs
->
second
;
}
// erase old parameter in map
real_inputs_
.
erase
(
old_anf_node
);
...
...
@@ -697,7 +695,6 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
void
KernelGraph
::
UpdateCallRealInput
()
{
MS_LOG
(
INFO
)
<<
"Update graph id: "
<<
graph_id_
;
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
real_inputs_map
;
std
::
vector
<
std
::
pair
<
AnfNodePtr
,
AnfNodePtr
>>
replace_list
;
for
(
auto
&
it
:
real_inputs_
)
{
auto
parameter
=
it
.
first
;
MS_EXCEPTION_IF_NULL
(
parameter
);
...
...
@@ -722,16 +719,9 @@ void KernelGraph::UpdateCallRealInput() {
MS_LOG
(
INFO
)
<<
"paramter: "
<<
parameter
->
DebugString
()
<<
" insert real input:"
<<
new_real_input
->
DebugString
();
(
void
)
real_inputs
.
insert
(
new_real_input
);
if
(
new_real_input
->
isa
<
Parameter
>
())
{
replace_list
.
emplace_back
(
parameter
,
new_real_input
);
parameter
=
new_real_input
;
}
}
real_inputs_map
[
parameter
]
=
real_inputs
;
}
for
(
auto
[
parameter
,
arg
]
:
replace_list
)
{
ReplaceNode
(
parameter
,
arg
);
}
real_inputs_
=
real_inputs_map
;
}
...
...
mindspore/ccsrc/session/kernel_graph.h
浏览文件 @
4cffb0a3
...
...
@@ -99,7 +99,7 @@ class KernelGraph : public FuncGraph {
std
::
vector
<
bool
>
*
MutableValidInputs
()
{
return
&
valid_inputs_
;
}
std
::
vector
<
bool
>
valid_inputs
()
const
{
return
valid_inputs_
;
}
// replace node in graph
void
ReplaceNode
(
const
AnfNodePtr
&
old_anf_node
,
AnfNodePtr
new_anf_node
);
void
ReplaceNode
(
NotNull
<
AnfNodePtr
>
old_anf_node
,
NotNull
<
AnfNodePtr
>
new_anf_node
);
// set stream label of graph
void
set_stream_distinction_label
(
uint32_t
stream_label
)
{
stream_distinction_label_
=
stream_label
;
}
// get stream label of graph
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
4cffb0a3
...
...
@@ -459,6 +459,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
if
(
graph
->
GetBackendAnfByFrontAnf
(
anf
)
!=
nullptr
)
{
cnode_inputs
.
emplace_back
(
graph
->
GetBackendAnfByFrontAnf
(
anf
));
continue
;
}
else
if
(
IsValueNode
<
FuncGraph
>
(
anf
))
{
continue
;
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected input["
<<
anf
->
DebugString
()
<<
"]"
;
}
...
...
@@ -613,6 +615,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
if
(
ExistSummaryNode
(
graph
.
get
()))
{
graph
->
set_summary_node_exist
(
true
);
}
opt
::
BackendCommonOptimization
(
graph
);
return
graph
;
}
...
...
@@ -626,7 +629,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶
auto
backend_parameter
=
graph
->
GetBackendAnfByFrontAnf
(
parameter
);
if
(
backend_parameter
==
nullptr
)
{
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter
(
parameter
,
fals
e
,
graph
);
CreateNewParameterFromParameter
(
parameter
,
tru
e
,
graph
);
MS_LOG
(
INFO
)
<<
"Can't find parameter:"
<<
parameter
->
DebugString
();
continue
;
}
...
...
mindspore/ccsrc/utils/union_find_set.h
0 → 100644
浏览文件 @
4cffb0a3
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
#define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
#include <map>
#include <set>
namespace
mindspore
{
template
<
class
T
>
class
UnionFindSet
{
public:
UnionFindSet
()
:
union_find_set_
()
{}
void
Add
(
const
T
&
elem
)
{
if
(
union_find_set_
.
find
(
elem
)
!=
union_find_set_
.
end
())
{
return
;
}
union_find_set_
[
elem
]
=
elem
;
}
T
Find
(
const
T
&
key
)
{
T
key_parent
=
key
;
auto
iter
=
union_find_set_
.
find
(
key_parent
);
if
(
iter
==
union_find_set_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"union_find_set_ cannot find key "
<<
key_parent
;
}
while
(
key_parent
!=
iter
->
second
)
{
key_parent
=
iter
->
second
;
iter
=
union_find_set_
.
find
(
key_parent
);
if
(
iter
==
union_find_set_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"union_find_set_ cannot find key "
<<
key_parent
;
}
}
T
tmp
=
key
;
T
tmp_parent
;
while
(
tmp
!=
key_parent
)
{
iter
=
union_find_set_
.
find
(
tmp
);
if
(
iter
==
union_find_set_
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"union_find_set_ cannot find key "
<<
tmp
;
}
tmp_parent
=
iter
->
second
;
union_find_set_
[
tmp
]
=
key_parent
;
tmp
=
tmp_parent
;
}
return
key_parent
;
}
void
Union
(
const
T
&
left
,
const
T
&
right
)
{
union_find_set_
[
Find
(
left
)]
=
Find
(
right
);
}
std
::
map
<
T
,
std
::
set
<
T
>>
GetSets
()
{
std
::
map
<
T
,
std
::
set
<
T
>>
result
;
for
(
auto
&
iter
:
union_find_set_
)
{
(
void
)
Find
(
iter
.
first
);
}
for
(
auto
&
iter
:
union_find_set_
)
{
T
parent
=
Find
(
iter
.
first
);
result
[
parent
].
insert
(
iter
.
first
);
}
return
result
;
}
private:
std
::
map
<
T
,
T
>
union_find_set_
;
};
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录