Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a05aa21c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a05aa21c
编写于
5月 08, 2020
作者:
X
Xiaoda Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
calculating PEAK memory cost in the inference phase
上级
552fc5c9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
324 addition
and
33 deletion
+324
-33
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
+1
-1
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
+15
-0
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h
+10
-2
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
+194
-2
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
+15
-3
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
+16
-0
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
+6
-0
mindspore/ccsrc/parallel/ops_info/operator_info.cc
mindspore/ccsrc/parallel/ops_info/operator_info.cc
+34
-0
mindspore/ccsrc/parallel/ops_info/operator_info.h
mindspore/ccsrc/parallel/ops_info/operator_info.h
+20
-3
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+12
-22
tests/ut/python/parallel/test_auto_parallel_inference.py
tests/ut/python/parallel/test_auto_parallel_inference.py
+1
-0
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
浏览文件 @
a05aa21c
...
@@ -53,7 +53,7 @@ struct Cost {
...
@@ -53,7 +53,7 @@ struct Cost {
communication_redis_backward_
=
0.0
;
communication_redis_backward_
=
0.0
;
communication_forward_
=
0.0
;
communication_forward_
=
0.0
;
}
}
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
// 'memory_with_reuse_' calculates the peak memory usage in a training
(or inference)
phase
double
memory_with_reuse_
;
double
memory_with_reuse_
;
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// by ONLY forward phase
// by ONLY forward phase
...
...
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
浏览文件 @
a05aa21c
...
@@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() {
...
@@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
Edge
::
CalculateMemoryCostForInference
()
{
// Currently, memory cost is NOT calculated for redistribution
if
((
is_output_critical_
!=
0
)
&&
(
is_output_critical_
!=
1
))
{
MS_LOG
(
ERROR
)
<<
"Failure: unexpected output critical flag value: "
<<
is_output_critical_
;
return
FAILED
;
}
for
(
auto
&
cost_kv
:
cost_map_
)
{
auto
&
cost_v
=
cost_kv
.
second
;
if
(
!
cost_v
.
empty
())
{
cost_v
[
0
]
->
memory_with_reuse_
=
0
;
}
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h
浏览文件 @
a05aa21c
...
@@ -131,9 +131,13 @@ class Edge {
...
@@ -131,9 +131,13 @@ class Edge {
void
set_selected_cost
(
const
CostPtr
&
cost
)
{
selected_cost_
=
cost
;
}
void
set_selected_cost
(
const
CostPtr
&
cost
)
{
selected_cost_
=
cost
;
}
const
CostPtr
&
selected_cost
()
const
{
return
selected_cost_
;
}
const
CostPtr
&
selected_cost
()
const
{
return
selected_cost_
;
}
void
set_parameter_involve
(
int
para_invol
)
{
is_output_parameter_involve_
=
para_invol
;
}
void
set_parameter_involve
(
int
para_invol
)
{
is_output_parameter_involve_
=
para_invol
;
}
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
// at the end of forward phase.
Status
CalculateMemoryCost
();
Status
CalculateMemoryCost
();
// In the inference phase,
Status
CalculateMemoryCostForInference
();
void
mark_output_critical
()
{
is_output_critical_
=
1
;
}
private:
private:
std
::
string
edge_name_
;
std
::
string
edge_name_
;
...
@@ -156,7 +160,11 @@ class Edge {
...
@@ -156,7 +160,11 @@ class Edge {
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
bool
is_identity_edge
;
bool
is_identity_edge
;
CostPtr
selected_cost_
;
CostPtr
selected_cost_
;
// In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator
// is parameter-involved
int
is_output_parameter_involve_
=
-
1
;
// -1: unset; 0: not parameter_involved; 1: parameter_involved
int
is_output_parameter_involve_
=
-
1
;
// -1: unset; 0: not parameter_involved; 1: parameter_involved
// In the inference phase, this is used to mark whether the output of the previous operator is critical.
int
is_output_critical_
=
0
;
};
};
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
浏览文件 @
a05aa21c
...
@@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
...
@@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
<<
", communication_with_partial_para_: "
<<
ret
->
communication_with_partial_para_
<<
", communication_with_partial_para_: "
<<
ret
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
ret
->
communication_cost_
<<
", communication_cost_: "
<<
ret
->
communication_cost_
<<
", communication_without_parameter_: "
<<
ret
->
communication_without_parameter_
<<
"."
;
<<
", communication_without_parameter_: "
<<
ret
->
communication_without_parameter_
<<
"."
;
MS_LOG
(
INFO
)
<<
"Cost 0: tot
o
al_cost: "
<<
minimum
;
MS_LOG
(
INFO
)
<<
"Cost 0: total_cost: "
<<
minimum
;
for
(
size_t
i
=
1
;
i
<
after_mem_filter
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
after_mem_filter
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
after_mem_filter
[
i
]);
MS_EXCEPTION_IF_NULL
(
after_mem_filter
[
i
]);
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": memory_cost: "
<<
after_mem_filter
[
i
]
->
memory_with_reuse_
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": memory_cost: "
<<
after_mem_filter
[
i
]
->
memory_with_reuse_
...
@@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
...
@@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
<<
", communication_with_partial_para_: "
<<
ret
->
communication_with_partial_para_
<<
", communication_with_partial_para_: "
<<
ret
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
ret
->
communication_cost_
<<
", communication_cost_: "
<<
ret
->
communication_cost_
<<
", communication_without_parameter_: "
<<
ret
->
communication_without_parameter_
<<
"."
;
<<
", communication_without_parameter_: "
<<
ret
->
communication_without_parameter_
<<
"."
;
MS_LOG
(
INFO
)
<<
"Cost 0: tot
o
al_cost: "
<<
minimum
;
MS_LOG
(
INFO
)
<<
"Cost 0: total_cost: "
<<
minimum
;
for
(
size_t
i
=
1
;
i
<
after_mem_filter
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
after_mem_filter
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
after_mem_filter
[
i
]);
MS_EXCEPTION_IF_NULL
(
after_mem_filter
[
i
]);
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": memory_cost: "
<<
after_mem_filter
[
i
]
->
memory_with_reuse_
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": memory_cost: "
<<
after_mem_filter
[
i
]
->
memory_with_reuse_
...
@@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
...
@@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
return
succ_edges
;
return
succ_edges
;
}
}
size_t
CostGraph
::
GetNumEdges
()
const
{
size_t
sum
=
0
;
for
(
const
auto
&
kv
:
edges_
)
{
auto
&
edges
=
kv
.
second
;
sum
+=
edges
.
size
();
}
return
sum
;
}
Status
CostGraph
::
InitSelectedStrategy
()
{
Status
CostGraph
::
InitSelectedStrategy
()
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
MS_EXCEPTION_IF_NULL
(
op
);
MS_EXCEPTION_IF_NULL
(
op
);
...
@@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
...
@@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
return
SUCCESS
;
return
SUCCESS
;
}
}
void
CostGraph
::
DFSForTopoOrder
(
const
OperatorInfoPtr
&
current_op
,
std
::
map
<
OperatorInfoPtr
,
bool
>
*
visited
,
std
::
vector
<
OperatorInfoPtr
>
*
topo_order
)
{
MS_EXCEPTION_IF_NULL
(
current_op
);
MS_EXCEPTION_IF_NULL
(
visited
);
MS_EXCEPTION_IF_NULL
(
topo_order
);
visited
->
at
(
current_op
)
=
true
;
for
(
const
auto
&
s_edge
:
current_op
->
succ_edges
())
{
if
(
!
visited
->
at
(
s_edge
->
next_operator
()))
{
DFSForTopoOrder
(
s_edge
->
next_operator
(),
visited
,
topo_order
);
}
}
topo_order
->
push_back
(
current_op
);
}
// Compute a topological order of the costgraph
void
CostGraph
::
TopologyOrder
(
std
::
vector
<
OperatorInfoPtr
>
*
topo_order
)
{
std
::
map
<
OperatorInfoPtr
,
bool
>
visited
;
for
(
auto
&
op
:
ops_
)
{
visited
[
op
]
=
false
;
}
for
(
auto
&
op
:
ops_
)
{
if
(
!
visited
[
op
])
{
DFSForTopoOrder
(
op
,
&
visited
,
topo_order
);
}
}
}
void
CostGraph
::
MarkCriticalOpsAndEdges
(
const
std
::
map
<
OperatorInfoPtr
,
int
>
&
candidate_ops
)
{
for
(
auto
&
op
:
ops_
)
{
auto
search
=
candidate_ops
.
find
(
op
);
if
(
search
!=
candidate_ops
.
end
())
{
// Mark the critical operators
op
->
mark_output_critical
();
// Mark the successive edges
for
(
auto
&
s_edge
:
op
->
succ_edges
())
{
s_edge
->
mark_output_critical
();
}
}
else
{
op
->
mark_output_not_critical
();
}
}
}
Status
CostGraph
::
DetermineCriticalOps
(
const
std
::
vector
<
OperatorInfoPtr
>
&
topo_order
)
{
if
(
topo_order
.
size
()
==
0
)
{
MS_LOG
(
ERROR
)
<<
"0 operator in costgraph."
;
return
FAILED
;
}
auto
&
first_op
=
topo_order
[
0
];
if
(
first_op
->
prev_edges
().
size
()
>
0
)
{
MS_LOG
(
ERROR
)
<<
"The first operator in the first of topological order of "
"costgraph should have 0 incoming edge, but has "
<<
first_op
->
prev_edges
()
<<
"edges."
;
return
FAILED
;
}
// The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
// of the output of OperatorInfo that currently has not been used
std
::
map
<
OperatorInfoPtr
,
int
>
curr_memory_state
;
(
void
)
curr_memory_state
.
emplace
(
std
::
make_pair
(
first_op
,
SizeToInt
(
first_op
->
succ_edges
().
size
())));
std
::
map
<
OperatorInfoPtr
,
int
>
max_memory_state
=
curr_memory_state
;
// The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
// not been used
double
curr_memory_size
=
first_op
->
GetOutputsTotalSize
();
double
max_memory_size
=
curr_memory_size
;
for
(
size_t
finished
=
1
;
finished
<
topo_order
.
size
();
++
finished
)
{
// Produce
(
void
)
curr_memory_state
.
emplace
(
std
::
make_pair
(
topo_order
[
finished
],
SizeToInt
(
topo_order
[
finished
]
->
succ_edges
().
size
())));
curr_memory_size
+=
topo_order
[
finished
]
->
GetOutputsTotalSize
();
// Consume
for
(
const
auto
&
prev_edge
:
topo_order
[
finished
]
->
prev_edges
())
{
const
auto
&
prev_op
=
prev_edge
->
prev_operator
();
curr_memory_state
[
prev_op
]
--
;
}
for
(
const
auto
&
prev_edge
:
topo_order
[
finished
]
->
prev_edges
())
{
const
auto
&
prev_op
=
prev_edge
->
prev_operator
();
if
(
curr_memory_state
[
prev_op
]
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Failure: "
<<
prev_op
->
name
()
<<
"'s current output count: "
<<
curr_memory_state
[
prev_op
];
return
FAILED
;
}
else
if
(
curr_memory_state
[
prev_op
]
==
0
)
{
curr_memory_state
.
erase
(
prev_op
);
curr_memory_size
-=
prev_op
->
GetOutputsTotalSize
();
}
}
if
(
curr_memory_size
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Memory size calculation failed: "
<<
curr_memory_size
;
}
// Modify the max
if
(
curr_memory_size
>
max_memory_size
)
{
max_memory_size
=
curr_memory_size
;
max_memory_state
=
curr_memory_state
;
}
}
// Mark those critical operators
MarkCriticalOpsAndEdges
(
max_memory_state
);
return
SUCCESS
;
}
Status
CostGraph
::
ComputeOpsAndEdgesOutputCritical
()
{
// Two steps to do:
// 1. Compute a topological order of the costgraph
// 2. Determine and mark the operators (and necessary edges) that are critical
std
::
vector
<
OperatorInfoPtr
>
topo_order
;
TopologyOrder
(
&
topo_order
);
std
::
reverse
(
std
::
begin
(
topo_order
),
std
::
end
(
topo_order
));
if
(
DetermineCriticalOps
(
topo_order
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Determining critical operators failed."
;
return
FAILED
;
}
return
SUCCESS
;
}
Status
CostGraph
::
CalculateOpsMemoryCost
()
{
Status
CostGraph
::
CalculateOpsMemoryCost
()
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
MS_EXCEPTION_IF_NULL
(
op
);
MS_EXCEPTION_IF_NULL
(
op
);
...
@@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() {
...
@@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
CostGraph
::
CalculateOpsMemoryCostForInference
()
{
for
(
auto
&
op
:
ops_
)
{
MS_EXCEPTION_IF_NULL
(
op
);
if
(
op
->
CalculateMemoryCostForInference
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Calculate Operator: "
<<
op
->
name
()
<<
" cost for memory usage failed."
;
return
FAILED
;
}
}
return
SUCCESS
;
}
Status
CostGraph
::
CalculateEdgesMemoryCost
()
{
Status
CostGraph
::
CalculateEdgesMemoryCost
()
{
for
(
auto
&
edge_pair
:
edges_
)
{
for
(
auto
&
edge_pair
:
edges_
)
{
const
auto
&
edges
=
edge_pair
.
second
;
const
auto
&
edges
=
edge_pair
.
second
;
...
@@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() {
...
@@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
CostGraph
::
CalculateEdgesMemoryCostForInference
()
{
for
(
auto
&
edge_pair
:
edges_
)
{
const
auto
&
edges
=
edge_pair
.
second
;
for
(
auto
&
one_edge
:
edges
)
{
if
(
one_edge
->
CalculateMemoryCostForInference
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Calculate Edge: "
<<
one_edge
->
edge_name
()
<<
" cost for memory usage failed."
;
return
FAILED
;
}
}
}
return
SUCCESS
;
}
OperatorInfoPtr
CostGraph
::
FindTmpIdentityByParameterName
(
std
::
string
&
p_name
)
const
{
OperatorInfoPtr
CostGraph
::
FindTmpIdentityByParameterName
(
std
::
string
&
p_name
)
const
{
for
(
auto
one_op
:
ops_
)
{
for
(
auto
one_op
:
ops_
)
{
if
(
one_op
->
name
().
find
(
IDENTITY_INFO
)
!=
std
::
string
::
npos
)
{
if
(
one_op
->
name
().
find
(
IDENTITY_INFO
)
!=
std
::
string
::
npos
)
{
...
@@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() {
...
@@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() {
}
}
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
CostGraph
::
CalculateMemoryCost
()
{
if
(
RUN_PHASE
==
TRAINING_PHASE
)
{
// training phase
if
(
ComputeOpsAndEdgesParameterInvolved
()
==
SUCCESS
)
{
// Calculate operators' memory usage
if
(
CalculateOpsMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Calculating operators' cost for memory cost failed."
;
return
FAILED
;
}
// Calculate edges' memory usage
if
(
CalculateEdgesMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Calculating edges' cost for memory cost failed."
;
return
FAILED
;
}
// Correct memory usage caused by TmpIdentity
if
(
CorrectOpsMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Correcting operators' cost for memory cost failed."
;
return
FAILED
;
}
}
else
{
MS_LOG
(
ERROR
)
<<
"Computing operators' parameter_involved failed."
;
return
FAILED
;
}
}
else
{
// inference phase
if
(
ComputeOpsAndEdgesOutputCritical
()
==
SUCCESS
)
{
// Calculate operators' memory usage
if
(
CalculateOpsMemoryCostForInference
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Calculating operators' memory cost for inference failed."
;
return
FAILED
;
}
// Calculate edges's memory usage
if
(
CalculateEdgesMemoryCostForInference
()
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Calculating operators' memory cost for inference failed."
;
return
FAILED
;
}
}
else
{
MS_LOG
(
ERROR
)
<<
"Computing operators' critical flag failed."
;
return
FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
浏览文件 @
a05aa21c
...
@@ -179,16 +179,24 @@ class CostGraph {
...
@@ -179,16 +179,24 @@ class CostGraph {
void
CreateStarEliminationSubCostList
(
const
StrategyPtr
&
,
const
CostPtrList
&
,
const
CostPtrList
&
,
void
CreateStarEliminationSubCostList
(
const
StrategyPtr
&
,
const
CostPtrList
&
,
const
CostPtrList
&
,
const
StrategyPtr
&
,
const
CostPtrList
&
,
std
::
vector
<
StrategyPtr
>
,
const
StrategyPtr
&
,
const
CostPtrList
&
,
std
::
vector
<
StrategyPtr
>
,
CostPtrList
&
,
CostPtrList
&
,
CostPtrList
*
);
CostPtrList
&
,
CostPtrList
&
,
CostPtrList
*
);
// Calculate memory cost for training phase or inference phase.
Status
CalculateMemoryCost
();
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// the memory cost can be resused.
// the memory cost can be resused.
This is used to calculate memory in the training phase.
Status
CalculateOpsMemoryCost
();
Status
CalculateOpsMemoryCost
();
// When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// the memory cost can be re
sused
.
// the memory cost can be re
used. This is used to calculate memory in the training phase
.
Status
CalculateEdgesMemoryCost
();
Status
CalculateEdgesMemoryCost
();
// Calculate memory cost of operators in the inference phase.
Status
CalculateOpsMemoryCostForInference
();
// Calculate memory cost of edges in the inference phase.
Status
CalculateEdgesMemoryCostForInference
();
Status
ComputeOpsAndEdgesParameterInvolved
();
Status
ComputeOpsAndEdgesParameterInvolved
();
// Compute for each operator whether the output is critical.
Status
ComputeOpsAndEdgesOutputCritical
();
std
::
vector
<
OperatorInfoPtr
>
GetOperators
()
const
{
return
ops_
;
}
std
::
vector
<
OperatorInfoPtr
>
GetOperators
()
const
{
return
ops_
;
}
size_t
GetNum
Pairs
()
const
{
return
edges_
.
size
();
}
size_t
GetNum
Edges
()
const
;
Status
InitSelectedStrategy
();
Status
InitSelectedStrategy
();
OperatorInfoPtr
FindTmpIdentityByParameterName
(
std
::
string
&
)
const
;
OperatorInfoPtr
FindTmpIdentityByParameterName
(
std
::
string
&
)
const
;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
...
@@ -208,6 +216,10 @@ class CostGraph {
...
@@ -208,6 +216,10 @@ class CostGraph {
const
std
::
map
<
std
::
string
,
std
::
string
>
get_tuple_getitem_list
()
const
{
return
tuple_getitem_list_
;
}
const
std
::
map
<
std
::
string
,
std
::
string
>
get_tuple_getitem_list
()
const
{
return
tuple_getitem_list_
;
}
private:
private:
void
TopologyOrder
(
std
::
vector
<
OperatorInfoPtr
>
*
);
void
DFSForTopoOrder
(
const
OperatorInfoPtr
&
,
std
::
map
<
OperatorInfoPtr
,
bool
>
*
,
std
::
vector
<
OperatorInfoPtr
>
*
);
Status
DetermineCriticalOps
(
const
std
::
vector
<
OperatorInfoPtr
>
&
);
void
MarkCriticalOpsAndEdges
(
const
std
::
map
<
OperatorInfoPtr
,
int
>
&
);
// Needed by rec_parser
// Needed by rec_parser
std
::
vector
<
std
::
vector
<
std
::
string
>>
inputs_tensor_name_list_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
inputs_tensor_name_list_
;
std
::
map
<
std
::
string
,
std
::
string
>
tuple_getitem_list_
;
std
::
map
<
std
::
string
,
std
::
string
>
tuple_getitem_list_
;
...
...
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
浏览文件 @
a05aa21c
...
@@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_
...
@@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_
outputs_type_lengths_
=
output_lengths
;
outputs_type_lengths_
=
output_lengths
;
}
}
void
OperatorCost
::
set_output_critical
(
int
critical
)
{
is_outputs_critical_
=
critical
;
}
double
OperatorCost
::
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>
&
inputs
,
double
OperatorCost
::
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>
&
inputs
,
const
std
::
vector
<
TensorInfo
>
&
outputs
)
const
{
const
std
::
vector
<
TensorInfo
>
&
outputs
)
const
{
double
result
=
0.0
;
double
result
=
0.0
;
...
@@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
...
@@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
return
result
;
return
result
;
}
}
double
OperatorCost
::
GetMemoryCostForInference
(
const
std
::
vector
<
TensorInfo
>
&
,
const
std
::
vector
<
TensorInfo
>
&
outputs
)
const
{
double
result
=
0.0
;
if
(
is_outputs_critical_
==
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"The critical flag is not set."
;
}
if
(
is_outputs_critical_
==
1
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
result
+=
ListProduct
(
outputs
[
i
].
slice_shape
())
*
static_cast
<
double
>
(
outputs_type_lengths_
[
i
]);
}
}
return
result
;
}
// return the per device communication cost in the forward phase.
// return the per device communication cost in the forward phase.
double
MatMulCost
::
GetForwardCommCost
(
const
std
::
vector
<
TensorInfo
>
&
inputs
,
const
std
::
vector
<
TensorInfo
>
&
outputs
,
double
MatMulCost
::
GetForwardCommCost
(
const
std
::
vector
<
TensorInfo
>
&
inputs
,
const
std
::
vector
<
TensorInfo
>
&
outputs
,
int32_t
)
const
{
int32_t
)
const
{
...
...
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
浏览文件 @
a05aa21c
...
@@ -70,6 +70,7 @@ class OperatorCost {
...
@@ -70,6 +70,7 @@ class OperatorCost {
void
set_is_parameter
(
const
std
::
vector
<
bool
>
&
is_parameter
);
void
set_is_parameter
(
const
std
::
vector
<
bool
>
&
is_parameter
);
void
set_is_parameter_involve
(
const
std
::
vector
<
bool
>
&
);
void
set_is_parameter_involve
(
const
std
::
vector
<
bool
>
&
);
void
set_output_parameter_involve
(
int
);
void
set_output_parameter_involve
(
int
);
void
set_output_critical
(
int
);
void
SetInputAndOutputTypeLength
(
const
std
::
vector
<
size_t
>
&
input_lengths
,
const
std
::
vector
<
size_t
>
&
output_lengths
);
void
SetInputAndOutputTypeLength
(
const
std
::
vector
<
size_t
>
&
input_lengths
,
const
std
::
vector
<
size_t
>
&
output_lengths
);
std
::
vector
<
size_t
>
inputs_type_lengths
()
const
{
return
inputs_type_lengths_
;
}
std
::
vector
<
size_t
>
inputs_type_lengths
()
const
{
return
inputs_type_lengths_
;
}
std
::
vector
<
size_t
>
outputs_type_lengths
()
const
{
return
outputs_type_lengths_
;
}
std
::
vector
<
size_t
>
outputs_type_lengths
()
const
{
return
outputs_type_lengths_
;
}
...
@@ -92,6 +93,8 @@ class OperatorCost {
...
@@ -92,6 +93,8 @@ class OperatorCost {
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// plus necessary inputs.
// plus necessary inputs.
virtual
double
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>
&
inputs
,
const
std
::
vector
<
TensorInfo
>
&
outputs
)
const
;
virtual
double
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>
&
inputs
,
const
std
::
vector
<
TensorInfo
>
&
outputs
)
const
;
// per device memory cost in a inference phase
double
GetMemoryCostForInference
(
const
std
::
vector
<
TensorInfo
>
&
,
const
std
::
vector
<
TensorInfo
>
&
)
const
;
protected:
protected:
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
...
@@ -106,6 +109,9 @@ class OperatorCost {
...
@@ -106,6 +109,9 @@ class OperatorCost {
// for each input and output, the followings record the number of bytes of each element
// for each input and output, the followings record the number of bytes of each element
std
::
vector
<
size_t
>
inputs_type_lengths_
;
std
::
vector
<
size_t
>
inputs_type_lengths_
;
std
::
vector
<
size_t
>
outputs_type_lengths_
;
std
::
vector
<
size_t
>
outputs_type_lengths_
;
// Whether the output is critical, which means that this output is included in calculating peak memory cost
// in the inference phase.
int
is_outputs_critical_
=
-
1
;
};
};
using
OperatorCostPtr
=
std
::
shared_ptr
<
OperatorCost
>
;
using
OperatorCostPtr
=
std
::
shared_ptr
<
OperatorCost
>
;
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.cc
浏览文件 @
a05aa21c
...
@@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() {
...
@@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
OperatorInfo
::
CalculateMemoryCostForInference
()
{
// First, set the 'is_outputs_critical_' flag into OperatorCost.
if
(
is_output_critical_
==
-
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"The critical flag is not set."
;
return
FAILED
;
}
operator_cost
()
->
set_output_critical
(
is_output_critical_
);
// Set the memory cost in the 'strategy_cost_'
for
(
auto
&
swc
:
strategy_cost_
)
{
auto
mem_cost
=
operator_cost
()
->
GetMemoryCostForInference
(
swc
->
inputs_ptr
,
swc
->
outputs_ptr
);
swc
->
cost_list
[
0
]
->
memory_with_reuse_
=
mem_cost
;
}
return
SUCCESS
;
}
Status
OperatorInfo
::
CorrectMemoryCost
(
size_t
input_index
)
{
Status
OperatorInfo
::
CorrectMemoryCost
(
size_t
input_index
)
{
for
(
auto
&
swc
:
strategy_cost_
)
{
for
(
auto
&
swc
:
strategy_cost_
)
{
double
parameter_mem_cost
=
ListProduct
(
swc
->
inputs_ptr
[
input_index
].
slice_shape
())
*
double
parameter_mem_cost
=
ListProduct
(
swc
->
inputs_ptr
[
input_index
].
slice_shape
())
*
...
@@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &inpu
...
@@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &inpu
return
SUCCESS
;
return
SUCCESS
;
}
}
double
OperatorInfo
::
GetOutputsTotalSize
()
{
if
(
is_calculated_outputs_size_
)
{
return
outputs_total_size_
;
}
if
(
outputs_type_lengths_
.
size
()
!=
outputs_shape_
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Output_lengths: "
<<
outputs_type_lengths_
.
size
()
<<
" do not have the same number of outputs shape: "
<<
outputs_shape_
.
size
();
}
double
sum
=
0.0
;
for
(
size_t
i
=
0
;
i
<
outputs_type_lengths_
.
size
();
++
i
)
{
auto
size
=
std
::
accumulate
(
outputs_shape_
[
i
].
begin
(),
outputs_shape_
[
i
].
end
(),
static_cast
<
double
>
(
1.0
),
std
::
multiplies
<
double
>
());
sum
+=
size
*
static_cast
<
double
>
(
outputs_type_lengths_
[
i
]);
}
is_calculated_outputs_size_
=
true
;
outputs_total_size_
=
sum
;
return
outputs_total_size_
;
}
Status
OperatorInfo
::
set_outputs_type
(
const
std
::
vector
<
TypePtr
>
&
outputs_type
)
{
Status
OperatorInfo
::
set_outputs_type
(
const
std
::
vector
<
TypePtr
>
&
outputs_type
)
{
if
(
outputs_type
.
size
()
!=
outputs_shape_
.
size
())
{
if
(
outputs_type
.
size
()
!=
outputs_shape_
.
size
())
{
MS_LOG
(
ERROR
)
<<
"Outputs type: "
<<
outputs_type
.
size
()
MS_LOG
(
ERROR
)
<<
"Outputs type: "
<<
outputs_type
.
size
()
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.h
浏览文件 @
a05aa21c
...
@@ -72,6 +72,7 @@ class OperatorInfo {
...
@@ -72,6 +72,7 @@ class OperatorInfo {
Status
set_is_parameter
(
const
std
::
vector
<
bool
>
&
is_parameter
);
Status
set_is_parameter
(
const
std
::
vector
<
bool
>
&
is_parameter
);
Status
SetInputAndOutputTypeLength
(
const
std
::
vector
<
size_t
>
&
input_lengths
,
Status
SetInputAndOutputTypeLength
(
const
std
::
vector
<
size_t
>
&
input_lengths
,
const
std
::
vector
<
size_t
>
&
output_lengths
);
const
std
::
vector
<
size_t
>
&
output_lengths
);
double
GetOutputsTotalSize
();
// Set outputs dtype.
// Set outputs dtype.
// If only one output, outputs_type.size() is 1.
// If only one output, outputs_type.size() is 1.
// If output is tuple, outputs_type.size() is greater than 1.
// If output is tuple, outputs_type.size() is greater than 1.
...
@@ -96,9 +97,13 @@ class OperatorInfo {
...
@@ -96,9 +97,13 @@ class OperatorInfo {
// is checked
// is checked
Status
SetCostUnderStrategyBase
(
const
StrategyPtr
&
strategy
);
Status
SetCostUnderStrategyBase
(
const
StrategyPtr
&
strategy
);
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
GetStrategyCost
()
{
return
strategy_cost_
;
}
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
GetStrategyCost
()
{
return
strategy_cost_
;
}
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
// at the end of forward phase.
Status
CalculateMemoryCost
();
Status
CalculateMemoryCost
();
// In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
// by the output
Status
CalculateMemoryCostForInference
();
int
ComputeOpAndPrevEdgeParameterInvolved
();
int
ComputeOpAndPrevEdgeParameterInvolved
();
ForwardOp
forward_op
()
const
{
return
forward_op_
;
}
ForwardOp
forward_op
()
const
{
return
forward_op_
;
}
...
@@ -147,6 +152,9 @@ class OperatorInfo {
...
@@ -147,6 +152,9 @@ class OperatorInfo {
// multiple times. This method is to correct this, and makes the cost is calulated only once.
// multiple times. This method is to correct this, and makes the cost is calulated only once.
Status
CorrectMemoryCost
(
size_t
input_index
);
Status
CorrectMemoryCost
(
size_t
input_index
);
int
is_output_parameter_involve
()
const
{
return
is_output_parameter_involve_
;
}
int
is_output_parameter_involve
()
const
{
return
is_output_parameter_involve_
;
}
int
is_output_critical
()
const
{
return
is_output_critical_
;
}
void
mark_output_critical
()
{
is_output_critical_
=
1
;
}
void
mark_output_not_critical
()
{
is_output_critical_
=
0
;
}
int
used_devices
()
const
{
return
used_devices_
;
}
int
used_devices
()
const
{
return
used_devices_
;
}
// needed by rec_parser
// needed by rec_parser
void
set_type
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
void
set_type
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
...
@@ -220,7 +228,16 @@ class OperatorInfo {
...
@@ -220,7 +228,16 @@ class OperatorInfo {
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// pre-operator that has parameters as input.
// pre-operator that has parameters as input.
std
::
vector
<
bool
>
is_parameter_involve_
;
std
::
vector
<
bool
>
is_parameter_involve_
;
int
is_output_parameter_involve_
=
-
1
;
// -1: unset; 0: not parameter_involved; 1: parameter_involved
// If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating
// peak memory cost in the training phase.
// -1: unset; 0: not parameter_involved; 1: parameter_involved
int
is_output_parameter_involve_
=
-
1
;
// Whether this output is critical, which means that this output is included in calculating peak memory cost
// in the inference phase.
// -1 : unset; 0: not critical; 1: critical
int
is_output_critical_
=
-
1
;
double
outputs_total_size_
=
0.0
;
bool
is_calculated_outputs_size_
=
false
;
// for each input and output, the followings record the number of bytes of each element
// for each input and output, the followings record the number of bytes of each element
std
::
vector
<
size_t
>
inputs_type_lengths_
;
std
::
vector
<
size_t
>
inputs_type_lengths_
;
std
::
vector
<
size_t
>
outputs_type_lengths_
;
std
::
vector
<
size_t
>
outputs_type_lengths_
;
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
a05aa21c
...
@@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
...
@@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
// create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
// create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
// for each OperatorInfo;
// for each OperatorInfo;
// Step 1.1: Deal with 'Reshape':
// For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
// layout as its output layout.
// Step 2: Traverse the ANF graph, and create EDGES for costgraph:
// Step 2: Traverse the ANF graph, and create EDGES for costgraph:
// create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
// create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
// for each edge, based on the strategies of two OperatorInfos;
// for each edge, based on the strategies of two OperatorInfos;
...
@@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
...
@@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
// taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
// operator for this Parameter, and add an edge for the use of this Parameter by each
// operator for this Parameter, and add an edge for the use of this Parameter by each
// subsequent operator;
// subsequent operator;
// Step 3.1: Calculate memory usage
// Step 3.1: Calculate memory usage:
// note the memory usage calculation is different in training phase and inference phase.
// Step 4: Run the Dynamic Programming algorithm:
// Step 4: Run the Dynamic Programming algorithm:
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
...
@@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
...
@@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG
(
EXCEPTION
)
<<
"Constructing nodes for cost graph failed."
;
MS_LOG
(
EXCEPTION
)
<<
"Constructing nodes for cost graph failed."
;
}
}
}
}
// reshape operator needs the next node's input_layout as its output_layout.
// Step 1.1
// and needs the previous node's output_layout as its input_layout.
ReshapeCostCompute
(
all_nodes
);
ReshapeCostCompute
(
all_nodes
);
// Step 2
// Step 2
ConstructCostGraphEdges
(
all_nodes
);
ConstructCostGraphEdges
(
all_nodes
);
MS_LOG
(
INFO
)
<<
"Constructing edges for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
MS_LOG
(
INFO
)
<<
"Constructing edges for cost graph succeeded. There are "
<<
entire_costgraph
->
GetOperators
().
size
()
<<
" operators, and "
<<
entire_costgraph
->
GetNum
Pairs
()
<<
" edges."
,
<<
" operators, and "
<<
entire_costgraph
->
GetNum
Edges
()
<<
" edges."
;
// Step 3: Augment the costgraph.
// Step 3: Augment the costgraph.
AugmentCostGraph
(
all_nodes
);
AugmentCostGraph
(
all_nodes
);
MS_LOG
(
INFO
)
<<
"After the augmenting procedure, there are "
<<
entire_costgraph
->
GetOperators
().
size
()
MS_LOG
(
INFO
)
<<
"After the augmenting procedure, there are "
<<
entire_costgraph
->
GetOperators
().
size
()
<<
" operators, and "
<<
entire_costgraph
->
GetNum
Pair
s
()
<<
" edges."
;
<<
" operators, and "
<<
entire_costgraph
->
GetNum
Edge
s
()
<<
" edges."
;
// Step 3.1: Calculate the memory usage
// Step 3.1: Calculate the memory usage
if
(
entire_costgraph
->
ComputeOpsAndEdgesParameterInvolved
()
==
SUCCESS
)
{
if
(
entire_costgraph
->
CalculateMemoryCost
()
!=
SUCCESS
)
{
// Calculate operators' memory usage
MS_LOG
(
EXCEPTION
)
<<
"Calculating memory cost failed."
;
if
(
entire_costgraph
->
CalculateOpsMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Calculating operators' cost for memory cost failed."
;
}
// Calculate edges' memory usage
if
(
entire_costgraph
->
CalculateEdgesMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Calculating edges' cost for memory cost failed."
;
}
// Correct memory usage caused by TmpIdentity
if
(
entire_costgraph
->
CorrectOpsMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Correcting operators' cost for memory cost failed."
;
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Computing operators' parameter_involved failed."
;
}
}
// Step 4: run DP algorithm on the costgraph.
// Step 4: run DP algorithm on the costgraph.
...
...
tests/ut/python/parallel/test_auto_parallel_inference.py
浏览文件 @
a05aa21c
...
@@ -32,5 +32,6 @@ def test_inference_phase():
...
@@ -32,5 +32,6 @@ def test_inference_phase():
net_with_loss
=
WithLossCell
(
net
,
loss
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
train_network
.
set_train
()
train_network
.
set_train
()
train_network
.
set_auto_parallel
()
output
=
train_network
(
predict
,
label
)
output
=
train_network
(
predict
,
label
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录