Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1ab43007
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1ab43007
编写于
4月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!232 [Auto parallel] Model the memory_cost in cost model
Merge pull request !232 from Xiaoda/model-memory-cost-in-auto-parallel
上级
8674e0ad
0ac50a19
变更
36
隐藏空白更改
内联
并排
Showing
36 changed file
with
401 addition
and
177 deletion
+401
-177
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
+2
-6
mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc
+7
-8
mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h
+2
-5
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
+30
-4
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h
+1
-1
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
+109
-44
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
+3
-0
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
+63
-27
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
+64
-21
mindspore/ccsrc/parallel/ops_info/activation_info.h
mindspore/ccsrc/parallel/ops_info/activation_info.h
+2
-2
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
+11
-11
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
+6
-2
mindspore/ccsrc/parallel/ops_info/bias_add_info.h
mindspore/ccsrc/parallel/ops_info/bias_add_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/comparison_function_info.h
mindspore/ccsrc/parallel/ops_info/comparison_function_info.h
+5
-4
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/get_next_info.h
mindspore/ccsrc/parallel/ops_info/get_next_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/loss_info.h
mindspore/ccsrc/parallel/ops_info/loss_info.h
+2
-1
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
+3
-3
mindspore/ccsrc/parallel/ops_info/matmul_info.h
mindspore/ccsrc/parallel/ops_info/matmul_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/onehot_info.h
mindspore/ccsrc/parallel/ops_info/onehot_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/operator_info.cc
mindspore/ccsrc/parallel/ops_info/operator_info.cc
+38
-6
mindspore/ccsrc/parallel/ops_info/operator_info.h
mindspore/ccsrc/parallel/ops_info/operator_info.h
+10
-6
mindspore/ccsrc/parallel/ops_info/prelu_info.h
mindspore/ccsrc/parallel/ops_info/prelu_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc
mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc
+1
-1
mindspore/ccsrc/parallel/ops_info/reduce_method_info.h
mindspore/ccsrc/parallel/ops_info/reduce_method_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/reshape_info.h
mindspore/ccsrc/parallel/ops_info/reshape_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h
mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/transpose_info.h
mindspore/ccsrc/parallel/ops_info/transpose_info.h
+1
-1
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h
+1
-1
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+6
-2
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc
...ore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc
+6
-0
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
...pore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
+7
-0
tests/ut/cpp/parallel/ops_info/activation_test.cc
tests/ut/cpp/parallel/ops_info/activation_test.cc
+4
-4
tests/ut/cpp/parallel/ops_info/matmul_info_test.cc
tests/ut/cpp/parallel/ops_info/matmul_info_test.cc
+2
-2
tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc
tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc
+4
-4
tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc
tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc
+2
-2
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
浏览文件 @
1ab43007
...
...
@@ -207,15 +207,13 @@ struct ContractEliminationDecision : public Decision {
*/
struct
TriangleEliminationDecision
:
public
Decision
{
TriangleEliminationDecision
(
StrategyPtr
elimi_stra
,
CostPtr
elimi_op_cost
,
CostPtr
l_edge_cost
,
CostPtr
r_edge_cost
,
StrategyPtr
left_stra
,
CostPtr
l_node_cost
,
StrategyPtr
right_stra
,
CostPtr
r_node_cost
)
StrategyPtr
left_stra
,
CostPtr
l_node_cost
)
:
eliminated_op_strategy_
(
std
::
move
(
elimi_stra
)),
eliminated_op_cost_
(
std
::
move
(
elimi_op_cost
)),
left_edge_cost_
(
std
::
move
(
l_edge_cost
)),
right_edge_cost_
(
std
::
move
(
r_edge_cost
)),
left_node_strategy_
(
std
::
move
(
left_stra
)),
left_node_cost_
(
std
::
move
(
l_node_cost
)),
right_node_strategy_
(
std
::
move
(
right_stra
)),
right_node_cost_
(
std
::
move
(
r_node_cost
))
{
left_node_cost_
(
std
::
move
(
l_node_cost
))
{
type_
=
DecisionType
::
TRIANGLE_ELIMINATION
;
}
...
...
@@ -225,8 +223,6 @@ struct TriangleEliminationDecision : public Decision {
CostPtr
right_edge_cost_
;
StrategyPtr
left_node_strategy_
;
CostPtr
left_node_cost_
;
StrategyPtr
right_node_strategy_
;
CostPtr
right_node_cost_
;
MS_DECLARE_PARENT
(
TriangleEliminationDecision
,
Decision
);
};
...
...
mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc
浏览文件 @
1ab43007
...
...
@@ -76,7 +76,6 @@ Status GetStrategy(const CostGraphPtr& graph) {
auto
l_r_edge
=
triangle_pair
.
second
;
auto
left_node
=
l_r_edge
->
prev_operator
();
auto
right_node
=
l_r_edge
->
next_operator
();
auto
left_edge
=
eliminated_node
->
GetAliveSuccEdges
()[
0
];
auto
right_edge
=
eliminated_node
->
GetAliveSuccEdges
()[
1
];
MS_EXCEPTION_IF_NULL
(
left_edge
);
...
...
@@ -86,8 +85,7 @@ Status GetStrategy(const CostGraphPtr& graph) {
right_edge
=
tmp
;
}
auto
left_node_cpy
=
graph
->
EliminationTriangle
(
eliminated_node
,
l_r_edge
);
auto
elimi
=
std
::
make_shared
<
TriangleElimination
>
(
eliminated_node
,
left_edge
,
left_node_cpy
,
right_edge
,
right_node
);
auto
elimi
=
std
::
make_shared
<
TriangleElimination
>
(
eliminated_node
,
left_edge
,
left_node_cpy
,
right_edge
);
eliminations
.
emplace_back
(
std
::
move
(
elimi
));
}
auto
star_center
=
graph
->
CheckStarElimination
();
...
...
@@ -183,14 +181,13 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
auto
left_edge
=
elimination
->
left_edge_
;
auto
eliminated_node
=
elimination
->
eliminated_node_
;
auto
right_edge
=
elimination
->
right_edge_
;
auto
right_node
=
elimination
->
right_node_
;
auto
decision
=
left_node
->
selected_cost
()
->
decision_ptr_
->
cast
<
TriangleEliminationDecisionPtr
>
();
eliminated_node
->
SetSelectedStrategyAndCost
(
decision
->
eliminated_op_strategy_
,
decision
->
eliminated_op_cost_
);
left_edge
->
set_selected_cost
(
decision
->
left_edge_cost_
);
right_edge
->
set_selected_cost
(
decision
->
right_edge_cost_
);
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
left_node
->
SetSelectedStrategyAndCost
(
decision
->
left_node_strategy_
,
decision
->
left_node_cost_
);
right_node
->
SetSelectedStrategyAndCost
(
decision
->
right_node_strategy_
,
decision
->
right_node_cost_
);
MS_LOG
(
INFO
)
<<
"Recover triangleElimination succeeded."
;
}
else
if
((
*
rit
)
->
isa
<
StarElimination
>
())
{
auto
elimination
=
(
*
rit
)
->
cast
<
StarEliminationPtr
>
();
...
...
@@ -204,9 +201,11 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
for
(
size_t
i
=
0
;
i
<
succ_edges
.
size
();
++
i
)
{
succ_edges
[
i
]
->
set_selected_cost
(
decision
->
succ_edges_cost_list_
[
i
]);
}
for
(
size_t
j
=
0
;
j
<
succ_nodes
.
size
();
++
j
)
{
succ_nodes
[
j
]
->
SetSelectedStrategyAndCost
(
decision
->
succ_ops_stra_list_
[
j
],
decision
->
succ_ops_cost_list_
[
j
]);
}
MS_EXCEPTION_IF_NULL
(
succ_nodes
[
0
]);
MS_EXCEPTION_IF_NULL
(
decision
->
succ_ops_stra_list_
[
0
]);
MS_EXCEPTION_IF_NULL
(
decision
->
succ_ops_cost_list_
[
0
]);
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
succ_nodes
[
0
]
->
SetSelectedStrategyAndCost
(
decision
->
succ_ops_stra_list_
[
0
],
decision
->
succ_ops_cost_list_
[
0
]);
MS_LOG
(
INFO
)
<<
"Recover starElimination succeeded."
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Unknown Elimination type."
;
...
...
mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h
浏览文件 @
1ab43007
...
...
@@ -102,20 +102,17 @@ struct ContractElimination : public Elimination {
// Triangle Elimination
struct
TriangleElimination
:
public
Elimination
{
TriangleElimination
(
OperatorInfoPtr
elim_node
,
EdgePtr
l_edge
,
OperatorInfoPtr
l_node
,
EdgePtr
r_edge
,
OperatorInfoPtr
r_node
)
TriangleElimination
(
OperatorInfoPtr
elim_node
,
EdgePtr
l_edge
,
OperatorInfoPtr
l_node
,
EdgePtr
r_edge
)
:
Elimination
(
nullptr
,
Elimination
::
EliminationType
::
TRIANGLE
),
eliminated_node_
(
std
::
move
(
elim_node
)),
left_edge_
(
std
::
move
(
l_edge
)),
left_node_
(
std
::
move
(
l_node
)),
right_edge_
(
std
::
move
(
r_edge
)),
right_node_
(
std
::
move
(
r_node
))
{}
right_edge_
(
std
::
move
(
r_edge
))
{}
OperatorInfoPtr
eliminated_node_
;
EdgePtr
left_edge_
;
OperatorInfoPtr
left_node_
;
EdgePtr
right_edge_
;
OperatorInfoPtr
right_node_
;
MS_DECLARE_PARENT
(
TriangleElimination
,
Elimination
);
};
...
...
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
浏览文件 @
1ab43007
...
...
@@ -119,6 +119,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
double
forward_comm_cost
=
tensor_redistribution
.
forward_comm_cost
();
double
backward_comm_cost
=
tensor_redistribution
.
backward_comm_cost
();
double
computation_cost
=
tensor_redistribution
.
computation_cost
();
double
mem_cost
=
tensor_redistribution
.
memory_cost
();
// Now AllGather, ReduceScatter, AlltoAll don't support bool type
MS_EXCEPTION_IF_NULL
(
type
);
...
...
@@ -134,6 +135,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co
COST_MODEL_GAMMA
*
((
*
cost
)
->
communication_cost_
-
(
*
cost
)
->
communication_without_parameter_
);
(
*
cost
)
->
communication_redis_forward_
=
type_length
*
forward_comm_cost
;
(
*
cost
)
->
communication_redis_backward_
=
type_length
*
backward_comm_cost
;
(
*
cost
)
->
memory_with_reuse_
=
mem_cost
;
return
Status
::
SUCCESS
;
}
...
...
@@ -158,8 +160,8 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
(
void
)
std
::
transform
(
edges
.
begin
(),
edges
.
end
(),
all_cost_list
.
begin
(),
LocalGetCostList
);
CostPtrList
selected_cost_list
(
all_cost_list
.
size
(),
nullptr
);
std
::
function
<
void
(
size_t
,
double
,
double
,
double
)
>
recursive
=
[
&
](
size_t
k
,
double
computation
,
double
communication
,
double
communication_without_para
)
{
std
::
function
<
void
(
size_t
,
double
,
double
,
double
,
double
)
>
recursive
=
[
&
](
size_t
k
,
double
computation
,
double
memory
,
double
communication
,
double
communication_without_para
)
{
if
(
k
==
edges
.
size
())
{
auto
decision
=
std
::
make_shared
<
EdgeEliminationDecision
>
(
selected_cost_list
);
CostPtr
new_cost
=
std
::
make_shared
<
Cost
>
(
computation
,
communication
);
...
...
@@ -167,6 +169,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
new_cost
->
communication_without_parameter_
=
communication_without_para
;
new_cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
new_cost
->
memory_with_reuse_
=
memory
;
new_cost
->
decision_ptr_
=
decision
;
result
.
push_back
(
new_cost
);
return
;
...
...
@@ -174,11 +177,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr
for
(
auto
&
c
:
all_cost_list
[
k
])
{
MS_EXCEPTION_IF_NULL
(
c
);
selected_cost_list
[
k
]
=
c
;
recursive
(
k
+
1
,
computation
+
c
->
computation_cost_
,
communication
+
c
->
communication_cost_
,
recursive
(
k
+
1
,
computation
+
c
->
computation_cost_
,
memory
+
c
->
memory_with_reuse_
,
communication
+
c
->
communication_cost_
,
communication_without_para
+
c
->
communication_without_parameter_
);
}
};
recursive
(
0
,
0
,
0
,
0
);
recursive
(
0
,
0
.0
,
0.0
,
0.0
,
0.
0
);
SimplifyForDreasingCommunicationWithPartialPara
(
&
result
);
return
result
;
}
...
...
@@ -218,6 +222,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
double
communication_without_para
=
left_cost
->
communication_without_parameter_
+
middle_cost
->
communication_without_parameter_
+
right_cost
->
communication_without_parameter_
;
double
memory_cost
=
left_cost
->
memory_with_reuse_
+
middle_cost
->
memory_with_reuse_
+
right_cost
->
memory_with_reuse_
;
auto
decision
=
std
::
make_shared
<
OpEliminationDecision
>
(
op_strategy
,
left_cost
,
middle_cost
,
right_cost
);
auto
cost
=
std
::
make_shared
<
Cost
>
(
computation
,
communication
,
decision
);
...
...
@@ -225,6 +231,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost
->
communication_without_parameter_
=
communication_without_para
;
cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
cost
->
memory_with_reuse_
=
memory_cost
;
ret_cost_list
->
emplace_back
(
std
::
move
(
cost
));
}
}
...
...
@@ -267,5 +274,24 @@ void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op,
MS_LOG
(
EXCEPTION
)
<<
"Creating edge: "
<<
edge_name_
<<
" failed."
;
}
}
Status
Edge
::
CalculateMemoryCost
()
{
if
(
is_output_parameter_involve_
==
-
1
)
{
MS_LOG
(
ERROR
)
<<
"is_output_parameter_involve_ is unset."
;
return
FAILED
;
}
if
(
is_output_parameter_involve_
==
0
)
{
// In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
// unnecessary to keep them in memory.
for
(
auto
&
cost_kv
:
cost_map_
)
{
auto
&
cost_v
=
cost_kv
.
second
;
if
(
!
cost_v
.
empty
())
{
cost_v
[
0
]
->
memory_with_reuse_
=
0
;
}
}
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h
浏览文件 @
1ab43007
...
...
@@ -133,7 +133,7 @@ class Edge {
void
set_parameter_involve
(
int
para_invol
)
{
is_output_parameter_involve_
=
para_invol
;
}
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status
CalculateMemoryCost
()
const
{
return
SUCCESS
;
}
Status
CalculateMemoryCost
()
;
private:
std
::
string
edge_name_
;
...
...
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
浏览文件 @
1ab43007
...
...
@@ -248,6 +248,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::
MS_EXCEPTION_IF_NULL
(
cost2
);
MS_EXCEPTION_IF_NULL
(
cost3
);
double
computation
=
cost1
->
computation_cost_
+
cost2
->
computation_cost_
+
cost3
->
computation_cost_
;
double
memory
=
cost1
->
memory_with_reuse_
+
cost2
->
memory_with_reuse_
+
cost3
->
memory_with_reuse_
;
double
commmunication
=
cost1
->
communication_cost_
+
cost2
->
communication_cost_
+
cost3
->
communication_cost_
;
double
communication_without_para
=
cost1
->
communication_without_parameter_
+
...
...
@@ -260,6 +261,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::
cost
->
communication_without_parameter_
=
communication_without_para
;
cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
commmunication
-
communication_without_para
);
cost
->
memory_with_reuse_
=
memory
;
ret
.
push_back
(
cost
);
}
}
...
...
@@ -288,6 +290,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) {
new_cost
->
communication_with_partial_para_
=
cost1
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
cost1
->
communication_cost_
-
cost1
->
communication_without_parameter_
);
new_cost
->
memory_with_reuse_
=
cost1
->
memory_with_reuse_
;
ret
.
push_back
(
new_cost
);
}
}
...
...
@@ -297,9 +300,14 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) {
}
CostPtr
CostGraph
::
SelectCostWithMemoryConstraint
(
const
CostPtrList
&
cost_list
,
double
memory
)
{
if
(
cost_list
.
empty
()
||
cost_list
[
0
]
->
computation_cost_
>=
memory
)
{
return
nullptr
;
CostPtrList
after_mem_filter
;
// Filter out the valid costs
for
(
auto
&
a_cost
:
cost_list
)
{
if
(
a_cost
->
memory_with_reuse_
<=
memory
)
{
after_mem_filter
.
emplace_back
(
std
::
move
(
a_cost
));
}
}
std
::
function
<
CostPtr
(
CostPtr
,
const
CostPtr
&
)
>
LocalCompare
=
[
&
](
CostPtr
init
,
const
CostPtr
&
cost_x
)
{
MS_EXCEPTION_IF_NULL
(
cost_x
);
if
(
init
==
nullptr
||
cost_x
->
computation_cost_
<
memory
)
{
...
...
@@ -308,7 +316,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list,
return
init
;
};
CostPtr
ret
=
nullptr
;
return
std
::
accumulate
(
cost_list
.
begin
(),
cost_list
.
end
(),
ret
,
LocalCompare
);
return
std
::
accumulate
(
after_mem_filter
.
begin
(),
after_mem_filter
.
end
(),
ret
,
LocalCompare
);
}
CostPtr
CostGraph
::
SelectCostWithMinTrainingTime
(
const
CostPtrList
&
cost_list
,
double
memory
)
{
...
...
@@ -318,36 +326,46 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d
MS_LOG
(
ERROR
)
<<
"Final cost list is null."
;
return
nullptr
;
}
CostPtr
ret
=
cost_list
[
0
];
MS_EXCEPTION_IF_NULL
(
ret
);
if
(
ret
->
computation_cost_
>=
memory
)
{
MS_LOG
(
ERROR
)
<<
"No available cost; the minimum cost is "
<<
ret
->
computation_cost_
CostPtrList
after_mem_filter
;
double
minimum_memory
=
DBL_MAX
;
// Filter out the valid costs.
for
(
auto
&
a_cost
:
cost_list
)
{
if
(
a_cost
->
memory_with_reuse_
<=
memory
)
{
after_mem_filter
.
emplace_back
(
std
::
move
(
a_cost
));
}
else
if
(
a_cost
->
memory_with_reuse_
<
minimum_memory
)
{
minimum_memory
=
a_cost
->
memory_with_reuse_
;
}
}
if
(
after_mem_filter
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No available cost. The minimum memory cost is: "
<<
minimum_memory
<<
", the memory capacity is: "
<<
memory
<<
"."
;
return
nullptr
;
}
// Init the returned value with first cost.
CostPtr
ret
=
after_mem_filter
[
0
];
double
minimum
=
costmodel_alpha_
*
ret
->
computation_cost_
+
costmodel_beta_
*
ret
->
communication_with_partial_para_
;
MS_LOG
(
INFO
)
<<
"minimum: "
<<
minimum
<<
", computation_cost_: "
<<
ret
->
computation_cost_
MS_LOG
(
INFO
)
<<
"Cost 0: "
<<
"memory_cost: "
<<
ret
->
memory_with_reuse_
<<
", computation_cost_: "
<<
ret
->
computation_cost_
<<
", communication_with_partial_para_: "
<<
ret
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
ret
->
communication_cost_
<<
", communication_without_parameter_: "
<<
ret
->
communication_without_parameter_
<<
"."
;
for
(
size_t
i
=
1
;
i
<
cost_list
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
cost_list
[
i
]);
if
(
cost_list
[
i
]
->
computation_cost_
>=
memory
)
{
MS_LOG
(
INFO
)
<<
"cost_list "
<<
i
<<
" computation_cost_: "
<<
cost_list
[
i
]
->
computation_cost_
<<
", is larger than the memory capacity: "
<<
memory
<<
"."
;
break
;
}
MS_LOG
(
INFO
)
<<
"cost_list "
<<
i
<<
" computation_cost_: "
<<
cost_list
[
i
]
->
computation_cost_
<<
", communication_with_partial_para_: "
<<
cost_list
[
i
]
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
cost_list
[
i
]
->
communication_cost_
<<
", communication_without_parameter_: "
<<
cost_list
[
i
]
->
communication_without_parameter_
<<
"."
;
auto
tmp
=
costmodel_alpha_
*
cost_list
[
i
]
->
computation_cost_
+
costmodel_beta_
*
cost_list
[
i
]
->
communication_with_partial_para_
;
MS_LOG
(
INFO
)
<<
"tmp: "
<<
tmp
;
MS_LOG
(
INFO
)
<<
"Cost 0: totoal_cost: "
<<
minimum
;
for
(
size_t
i
=
1
;
i
<
after_mem_filter
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
after_mem_filter
[
i
]);
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": memory_cost: "
<<
after_mem_filter
[
i
]
->
memory_with_reuse_
<<
", computation_cost_: "
<<
after_mem_filter
[
i
]
->
computation_cost_
<<
", communication_with_partial_para_: "
<<
after_mem_filter
[
i
]
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
after_mem_filter
[
i
]
->
communication_cost_
<<
", communication_without_parameter_: "
<<
after_mem_filter
[
i
]
->
communication_without_parameter_
<<
"."
;
auto
tmp
=
costmodel_alpha_
*
after_mem_filter
[
i
]
->
computation_cost_
+
costmodel_beta_
*
after_mem_filter
[
i
]
->
communication_with_partial_para_
;
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": total_cost: "
<<
tmp
;
if
(
minimum
>
tmp
)
{
minimum
=
tmp
;
ret
=
cost_list
[
i
];
MS_LOG
(
INFO
)
<<
"
s
elected: "
<<
i
;
ret
=
after_mem_filter
[
i
];
MS_LOG
(
INFO
)
<<
"
S
elected: "
<<
i
;
}
}
return
ret
;
...
...
@@ -356,17 +374,21 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d
CostPtrList
CostGraph
::
SelectCostListWithMinTrainingTimeMultiple
(
const
std
::
vector
<
CostPtrList
>&
all_cost_list
,
double
available_memory
)
{
CostPtrList
selected_cost_list
(
all_cost_list
.
size
(),
nullptr
);
double
minimum
=
0.0
,
total_memory
=
0.0
;
double
minimum
=
DBL_MAX
,
total_memory
=
0.0
;
CostPtrList
ret
(
all_cost_list
.
size
(),
nullptr
);
// Check whether valid costs exist.
for
(
size_t
i
=
0
;
i
<
all_cost_list
.
size
();
++
i
)
{
if
(
all_cost_list
[
i
][
0
]
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"The cost list "
<<
i
<<
" is empty."
;
return
ret
;
}
else
{
total_memory
+=
all_cost_list
[
i
][
0
]
->
computation_cost_
;
minimum
+=
costmodel_alpha_
*
all_cost_list
[
i
][
0
]
->
computation_cost_
+
costmodel_beta_
*
all_cost_list
[
i
][
0
]
->
communication_with_partial_para_
;
ret
[
i
]
=
all_cost_list
[
i
][
0
];
double
memory_i_cost
=
DBL_MAX
;
for
(
size_t
j
=
0
;
j
<
all_cost_list
[
i
].
size
();
++
j
)
{
if
(
all_cost_list
[
i
][
j
]
->
memory_with_reuse_
<
memory_i_cost
)
{
memory_i_cost
=
all_cost_list
[
i
][
j
]
->
memory_with_reuse_
;
}
}
total_memory
+=
memory_i_cost
;
}
}
if
(
total_memory
>=
available_memory
)
{
...
...
@@ -381,7 +403,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect
double
tmp_memory
=
0.0
,
tmp_minimum
=
0.0
;
for
(
size_t
i
=
0
;
i
<
selected_cost_list
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
selected_cost_list
[
i
]);
tmp_memory
+=
selected_cost_list
[
i
]
->
computation_cost
_
;
tmp_memory
+=
selected_cost_list
[
i
]
->
memory_with_reuse
_
;
tmp_minimum
+=
costmodel_alpha_
*
selected_cost_list
[
i
]
->
computation_cost_
+
costmodel_beta_
*
selected_cost_list
[
i
]
->
communication_with_partial_para_
;
}
...
...
@@ -816,6 +838,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
auto
&
tar_cost
=
tar_cost_list
[
k
];
MS_EXCEPTION_IF_NULL
(
tar_cost
);
double
computation
=
op_cost
->
computation_cost_
+
edge_cost
->
computation_cost_
+
tar_cost
->
computation_cost_
;
double
memory
=
op_cost
->
memory_with_reuse_
+
edge_cost
->
memory_with_reuse_
+
tar_cost
->
memory_with_reuse_
;
double
communication
=
op_cost
->
communication_cost_
+
edge_cost
->
communication_cost_
+
tar_cost
->
communication_cost_
;
double
communication_without_para
=
op_cost
->
communication_without_parameter_
+
...
...
@@ -829,6 +852,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
new_cost
->
communication_without_parameter_
=
communication_without_para
;
new_cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
new_cost
->
memory_with_reuse_
=
memory
;
MS_EXCEPTION_IF_NULL
(
tar_cost_list_new
);
tar_cost_list_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
...
...
@@ -894,6 +918,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
MS_EXCEPTION_IF_NULL
(
tar_cost
);
double
computation
=
contract_op_cost
->
computation_cost_
+
edge_cost
->
computation_cost_
+
tar_cost
->
computation_cost_
;
double
memory
=
contract_op_cost
->
memory_with_reuse_
+
edge_cost
->
memory_with_reuse_
+
tar_cost
->
memory_with_reuse_
;
double
communication
=
contract_op_cost
->
communication_cost_
+
edge_cost
->
communication_cost_
+
tar_cost
->
communication_cost_
;
double
communication_without_para
=
contract_op_cost
->
communication_without_parameter_
+
...
...
@@ -906,6 +932,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
new_cost
->
communication_without_parameter_
=
communication_without_para
;
new_cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
new_cost
->
memory_with_reuse_
=
memory
;
tar_cost_list_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
}
...
...
@@ -966,23 +993,22 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
for
(
auto
&
left_node_cost
:
left_node_clist_origin
)
{
MS_EXCEPTION_IF_NULL
(
left_node_cost
);
double
new_computation
=
elimi_op_cost
->
computation_cost_
+
left_edge_cost
->
computation_cost_
+
left_node_cost
->
computation_cost_
+
right_edge_cost
->
computation_cost_
+
right_op_cost
->
computation_cost_
;
left_node_cost
->
computation_cost_
+
right_edge_cost
->
computation_cost_
;
double
new_memory
=
elimi_op_cost
->
memory_with_reuse_
+
left_edge_cost
->
memory_with_reuse_
+
left_node_cost
->
memory_with_reuse_
+
right_edge_cost
->
memory_with_reuse_
;
double
new_commu_cost
=
elimi_op_cost
->
communication_cost_
+
left_edge_cost
->
communication_cost_
+
left_node_cost
->
communication_cost_
+
right_edge_cost
->
communication_cost_
+
right_op_cost
->
communication_cost_
;
left_node_cost
->
communication_cost_
+
right_edge_cost
->
communication_cost_
;
double
new_commu_without
=
elimi_op_cost
->
communication_without_parameter_
+
left_edge_cost
->
communication_without_parameter_
+
left_node_cost
->
communication_without_parameter_
+
right_edge_cost
->
communication_without_parameter_
+
right_op_cost
->
communication_without_parameter_
;
left_node_cost
->
communication_without_parameter_
+
right_edge_cost
->
communication_without_parameter_
;
auto
decision
=
std
::
make_shared
<
TriangleEliminationDecision
>
(
elimi_op_stra
,
elimi_op_cost
,
left_edge_cost
,
right_edge_cost
,
left_op_stra
,
left_node_cost
,
right_op_stra
,
right_op_cost
);
auto
decision
=
std
::
make_shared
<
TriangleEliminationDecision
>
(
elimi_op_stra
,
elimi_op_cost
,
left_edge_cost
,
right_edge_cost
,
left_op_stra
,
left_node_cost
);
auto
new_cost
=
std
::
make_shared
<
Cost
>
(
new_computation
,
new_commu_cost
,
decision
);
new_cost
->
communication_without_parameter_
=
new_commu_without
;
new_cost
->
communication_with_partial_para_
=
new_commu_without
+
COST_MODEL_GAMMA
*
(
new_commu_cost
-
new_commu_without
);
new_cost
->
memory_with_reuse_
=
new_memory
;
left_node_clist_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
}
...
...
@@ -1085,14 +1111,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n
succ_nodes_costs
[
0
]
=
first_succ_node_cost
;
double
computation_cost
=
merged_node_cost
->
computation_cost_
,
commu_cost
=
merged_node_cost
->
communication_cost_
,
memory_cost
=
merged_node_cost
->
memory_with_reuse_
,
commu_cost
=
merged_node_cost
->
communication_cost_
,
commu_without
=
merged_node_cost
->
communication_without_parameter_
;
for
(
size_t
i
=
0
;
i
<
succ_nodes_stras
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
succ_edges_costs
[
i
]);
computation_cost
+=
succ_edges_costs
[
i
]
->
computation_cost_
+
succ_nodes_costs
[
i
]
->
computation_cost_
;
commu_cost
+=
succ_edges_costs
[
i
]
->
communication_cost_
+
succ_nodes_costs
[
i
]
->
communication_cost_
;
commu_without
+=
succ_edges_costs
[
i
]
->
communication_without_parameter_
+
succ_nodes_costs
[
i
]
->
communication_without_parameter_
;
if
(
i
==
0
)
{
computation_cost
+=
succ_edges_costs
[
i
]
->
computation_cost_
+
succ_nodes_costs
[
i
]
->
computation_cost_
;
memory_cost
+=
succ_edges_costs
[
i
]
->
memory_with_reuse_
+
succ_nodes_costs
[
i
]
->
memory_with_reuse_
;
commu_cost
+=
succ_edges_costs
[
i
]
->
communication_cost_
+
succ_nodes_costs
[
i
]
->
communication_cost_
;
commu_without
+=
succ_edges_costs
[
i
]
->
communication_without_parameter_
+
succ_nodes_costs
[
i
]
->
communication_without_parameter_
;
}
else
{
computation_cost
+=
succ_edges_costs
[
i
]
->
computation_cost_
;
memory_cost
+=
succ_edges_costs
[
i
]
->
memory_with_reuse_
;
commu_cost
+=
succ_edges_costs
[
i
]
->
communication_cost_
;
commu_without
+=
succ_edges_costs
[
i
]
->
communication_without_parameter_
;
}
}
auto
decision
=
std
::
make_shared
<
StarEliminationDecision
>
(
merged_op_stra
,
merged_node_cost
,
succ_edges_costs
,
...
...
@@ -1100,6 +1134,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n
auto
new_cost
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
commu_cost
,
decision
);
new_cost
->
communication_without_parameter_
=
commu_without
;
new_cost
->
communication_with_partial_para_
=
commu_without
+
COST_MODEL_GAMMA
*
(
commu_cost
-
commu_without
);
new_cost
->
memory_with_reuse_
=
memory_cost
;
first_succ_node_clist_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
}
...
...
@@ -1259,5 +1294,35 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c
}
return
nullptr
;
}
Status
CostGraph
::
CorrectOpsMemoryCost
()
{
for
(
auto
&
one_op
:
ops_
)
{
if
((
one_op
->
name
().
find
(
IDENTITY_INFO
)
!=
std
::
string
::
npos
)
&&
(
one_op
->
is_output_parameter_involve
()
==
1
))
{
if
(
one_op
->
GetAliveSuccEdges
().
size
()
>
1
)
{
// Filter out the case when the TmpIdentity being used by multiple operators
std
::
map
<
size_t
,
int
>
output_count
;
for
(
size_t
i
=
0
;
i
<
one_op
->
GetAliveSuccEdges
().
size
();
++
i
)
{
auto
output_index
=
one_op
->
GetAliveSuccEdges
()[
i
]
->
prev_op_output_index
();
output_count
[
output_index
]
++
;
}
for
(
size_t
i
=
0
;
i
<
one_op
->
GetAliveSuccEdges
().
size
();
++
i
)
{
auto
output_index
=
one_op
->
GetAliveSuccEdges
()[
i
]
->
prev_op_output_index
();
if
(
output_count
[
output_index
]
<=
1
)
{
continue
;
}
auto
next_op
=
one_op
->
GetAliveSuccEdges
()[
i
]
->
next_operator
();
MS_EXCEPTION_IF_NULL
(
next_op
);
auto
input_index
=
one_op
->
GetAliveSuccEdges
()[
i
]
->
next_op_input_index
();
if
(
next_op
->
CorrectMemoryCost
(
input_index
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"The operator name: "
<<
one_op
->
name
()
<<
", the next operator name: "
<<
next_op
->
name
()
<<
", the output_index: "
<<
output_index
<<
", the input_index: "
<<
input_index
<<
"."
;
return
FAILED
;
}
output_count
[
output_index
]
--
;
}
}
}
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
浏览文件 @
1ab43007
...
...
@@ -187,6 +187,9 @@ class CostGraph {
size_t
GetNumPairs
()
const
{
return
edges_
.
size
();
}
Status
InitSelectedStrategy
();
OperatorInfoPtr
FindTmpIdentityByParameterName
(
std
::
string
&
)
const
;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
// once (instead of multiple times), this method is used to correct this.
Status
CorrectOpsMemoryCost
();
// Needed by rec_parser
void
add_inputs_tensor_name
(
const
std
::
vector
<
std
::
string
>&
inputs_tensor_name
)
{
inputs_tensor_name_list_
.
push_back
(
inputs_tensor_name
);
...
...
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
浏览文件 @
1ab43007
...
...
@@ -17,6 +17,7 @@
#include "parallel/auto_parallel/operator_costmodel.h"
#include <random>
#include <algorithm>
#include "parallel/device_matrix.h"
#include "parallel/tensor_layout/tensor_redistribution.h"
...
...
@@ -24,12 +25,44 @@ namespace mindspore {
namespace
parallel
{
void
OperatorCost
::
set_is_parameter
(
const
std
::
vector
<
bool
>&
is_parameter
)
{
is_parameter_
=
is_parameter
;
}
void
OperatorCost
::
set_is_parameter_involve
(
const
std
::
vector
<
bool
>&
is_parameter_inv
)
{
is_parameter_involve_
=
is_parameter_inv
;
}
void
OperatorCost
::
set_output_parameter_involve
(
int
output_para
)
{
output_parameter_involve_
=
output_para
;
}
void
OperatorCost
::
SetInputAndOutputTypeLength
(
const
std
::
vector
<
size_t
>&
input_lengths
,
const
std
::
vector
<
size_t
>&
output_lengths
)
{
inputs_type_lengths_
=
input_lengths
;
outputs_type_lengths_
=
output_lengths
;
}
double
OperatorCost
::
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
)
const
{
double
result
=
0.0
;
if
(
output_parameter_involve_
==
1
)
{
// When this operator has multiple outputs, they all contributes to the memory.
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
result
+=
ListProduct
(
outputs
[
i
].
slice_shape
())
*
static_cast
<
double
>
(
outputs_type_lengths_
[
i
]);
}
bool
is_any_para_inv
=
std
::
any_of
(
is_parameter_involve_
.
begin
(),
is_parameter_involve_
.
end
(),
[](
bool
value
)
{
return
value
;
});
if
(
is_any_para_inv
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
is_parameter_
[
i
])
{
result
+=
ListProduct
(
inputs
[
i
].
slice_shape
())
*
static_cast
<
double
>
(
inputs_type_lengths_
[
i
]);
}
else
if
(
inputs_related_
&&
(
!
is_parameter_involve_
[
i
]))
{
// When the inputs of this operator are related, and they are not parameter-involved, then they are included
// in the memory cost.
result
+=
ListProduct
(
inputs
[
i
].
slice_shape
())
*
static_cast
<
double
>
(
inputs_type_lengths_
[
i
]);
}
}
}
}
return
result
;
}
// return the per device communication cost in the forward phase.
double
MatMulCost
::
GetForwardCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
)
const
{
...
...
@@ -72,11 +105,11 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, co
return
result
;
}
// Return the per device
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per device
computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
MatMulCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
)
const
{
// In forward phase, the
memory
cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
// In forward phase, the
compuatation
cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
double
result
=
0.0
;
TensorInfo
output0
=
outputs
[
0
];
Shape
input0_slice_shape
=
inputs
[
0
].
slice_shape
();
...
...
@@ -91,11 +124,11 @@ double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inpu
return
result
;
}
// Return the per device
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per device
computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
MatMulCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
stage_id
)
const
{
// In backward phase, the
memory
cost = (0 or 1) allreduce(slice(B))
// In backward phase, the
computation
cost = (0 or 1) allreduce(slice(B))
double
result
=
0.0
;
if
(
is_parameter_
[
1
])
{
TensorInfo
input1
=
inputs
[
1
];
// tensor B
...
...
@@ -145,7 +178,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs
return
result
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
ActivationCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
...
...
@@ -154,7 +187,7 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>&
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
ActivationCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
...
...
@@ -189,17 +222,17 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, c
return
result
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
SoftmaxCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// In the forward phase, the
memory
cost = slice(A)
// In the forward phase, the
computation
cost = slice(A)
TensorInfo
input0
=
inputs
[
0
];
Shape
input0_slice_shape
=
input0
.
slice_shape
();
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
SoftmaxCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
...
...
@@ -221,17 +254,15 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::paralle
return
0.0
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
TmpIdentityCost
::
GetForwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
inputs
,
double
TmpIdentityCost
::
GetForwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
int32_t
&
)
const
{
TensorInfo
input0_info
=
inputs
[
0
];
Shape
input0_slice_shape
=
input0_info
.
slice_shape
();
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
return
0.0
;
}
// Return the per
memory
cost in the backward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double
TmpIdentityCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
...
...
@@ -239,6 +270,11 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::
return
0.0
;
}
// Return the per device PEAK memory cost contributed by this operator in a training iteration.
double
TmpIdentityCost
::
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
)
const
{
return
0.0
;
}
double
BatchParallelCost
::
GetForwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
inputs
,
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
int32_t
&
)
const
{
...
...
@@ -284,11 +320,11 @@ double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, con
return
result
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
PReLUCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// In forward phase, the
memory
cost = slice(A) + slice(B)
// In forward phase, the
computation
cost = slice(A) + slice(B)
Shape
input0_slice_shape
=
inputs
[
0
].
slice_shape
();
Shape
input1_slice_shape
=
inputs
[
1
].
slice_shape
();
double
result
=
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
])
+
...
...
@@ -296,12 +332,12 @@ double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& input
return
result
;
}
// Return the per
memory
cost in the backward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double
PReLUCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
inputs
,
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
int32_t
&
stage_id
)
const
{
// In backward phase, the
memory
cost = (0 or 1) allreduce(slice(B))
// In backward phase, the
computation
cost = (0 or 1) allreduce(slice(B))
double
result
=
0.0
;
if
(
is_parameter_
[
1
])
{
TensorInfo
input1
=
inputs
[
1
];
// tensor B
...
...
@@ -337,16 +373,16 @@ double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std
return
0.0
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
OneHotCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// In onehot's forward phase, the
memory
cost = slice(A)
// In onehot's forward phase, the
computation
cost = slice(A)
Shape
input0_slice_shape
=
inputs
[
0
].
slice_shape
();
return
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
]);
}
// Return the per
memory
cost in the backward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double
OneHotCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
...
...
@@ -367,12 +403,12 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector<
return
0.0
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
SoftmaxCrossEntropyWithLogitsCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
,
const
int32_t
&
)
const
{
// In forward phase, the
memory
cost = slice(A) + slice(B)
// In forward phase, the
computation
cost = slice(A) + slice(B)
Shape
input0_slice_shape
=
inputs
[
0
].
slice_shape
();
Shape
input1_slice_shape
=
inputs
[
1
].
slice_shape
();
double
result
=
ListProduct
(
input0_slice_shape
)
*
static_cast
<
double
>
(
inputs_type_lengths_
[
0
])
+
...
...
@@ -380,7 +416,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v
return
result
;
}
// Return the per
memory
cost in the backward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double
SoftmaxCrossEntropyWithLogitsCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
,
const
std
::
vector
<
TensorInfo
>&
,
...
...
@@ -410,7 +446,7 @@ double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const st
return
0.0
;
}
// Return the per
memory
cost in the forward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the forward phase. The cost is calculated according to the bytes
// this operator uses
double
ReshapeCost
::
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
{
...
...
@@ -427,7 +463,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
return
(
inputs_type_lengths_
[
0
]
*
tensor_redistribution
.
computation_cost
());
}
// Return the per
memory
cost in the backward phase. The cost is calculated according to the bytes
// Return the per
device computation
cost in the backward phase. The cost is calculated according to the bytes
// this operator uses
double
ReshapeCost
::
GetBackwardComputationCost
(
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
const
std
::
vector
<
mindspore
::
parallel
::
TensorInfo
>&
,
...
...
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h
浏览文件 @
1ab43007
...
...
@@ -43,10 +43,20 @@ double ListProduct(std::vector<T> vec) {
// entries timing the length of each entry's data type
class
OperatorCost
{
public:
OperatorCost
(
)
{
explicit
OperatorCost
(
bool
is_inputs_related
)
:
inputs_related_
(
is_inputs_related
)
{
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for
(
size_t
i
=
0
;
i
<
MAXIMUM_INPUT_NUMBER
;
++
i
)
{
is_parameter_
.
push_back
(
false
);
is_parameter_involve_
.
push_back
(
false
);
inputs_type_lengths_
.
push_back
(
DEFAULT_DATA_TYPE_LENGTH
);
outputs_type_lengths_
.
push_back
(
DEFAULT_DATA_TYPE_LENGTH
);
}
}
OperatorCost
()
:
inputs_related_
(
false
)
{
// this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked
for
(
size_t
i
=
0
;
i
<
MAXIMUM_INPUT_NUMBER
;
++
i
)
{
is_parameter_
.
push_back
(
false
);
is_parameter_involve_
.
push_back
(
false
);
inputs_type_lengths_
.
push_back
(
DEFAULT_DATA_TYPE_LENGTH
);
outputs_type_lengths_
.
push_back
(
DEFAULT_DATA_TYPE_LENGTH
);
}
...
...
@@ -54,6 +64,8 @@ class OperatorCost {
virtual
~
OperatorCost
()
=
default
;
void
set_is_parameter
(
const
std
::
vector
<
bool
>&
is_parameter
);
void
set_is_parameter_involve
(
const
std
::
vector
<
bool
>&
);
void
set_output_parameter_involve
(
int
);
void
SetInputAndOutputTypeLength
(
const
std
::
vector
<
size_t
>&
input_lengths
,
const
std
::
vector
<
size_t
>&
output_lengths
);
std
::
vector
<
size_t
>
inputs_type_lengths
()
const
{
return
inputs_type_lengths_
;
}
std
::
vector
<
size_t
>
outputs_type_lengths
()
const
{
return
outputs_type_lengths_
;
}
...
...
@@ -72,8 +84,19 @@ class OperatorCost {
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
=
0
;
virtual
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
=
0
;
// per device PEAK memory cost in a training iteration
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// plus necessary inputs.
virtual
double
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
)
const
;
protected:
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// pre-operator that has parameters as input.
std
::
vector
<
bool
>
is_parameter_involve_
;
int
output_parameter_involve_
=
-
1
;
// -1: unset; 0: not parameter_involved; 1: parameter_involved
// Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while
// Mul's two inputs are dependent (related).
bool
inputs_related_
;
// for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
std
::
vector
<
bool
>
is_parameter_
;
// for each input and output, the followings record the number of bytes of each element
...
...
@@ -85,7 +108,8 @@ using OperatorCostPtr = std::shared_ptr<OperatorCost>;
class
MatMulCost
:
public
OperatorCost
{
public:
MatMulCost
()
=
default
;
explicit
MatMulCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
MatMulCost
()
:
OperatorCost
(
true
)
{}
~
MatMulCost
()
override
=
default
;
// per device communication cost
...
...
@@ -108,12 +132,12 @@ class MatMulCost : public OperatorCost {
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
};
using
MatMulCostPtr
=
std
::
shared_ptr
<
MatMulCost
>
;
class
ActivationCost
:
public
OperatorCost
{
public:
ActivationCost
()
=
default
;
explicit
ActivationCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
ActivationCost
()
:
OperatorCost
(
false
)
{}
~
ActivationCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -133,14 +157,14 @@ class ActivationCost : public OperatorCost {
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
};
using
ActivationCostPtr
=
std
::
shared_ptr
<
ActivationCost
>
;
using
TransposeCost
=
ActivationCost
;
using
TransposeCostPtr
=
std
::
shared_ptr
<
TransposeCost
>
;
class
SoftmaxCost
:
public
OperatorCost
{
public:
SoftmaxCost
()
=
default
;
explicit
SoftmaxCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
SoftmaxCost
()
:
OperatorCost
(
false
)
{}
~
SoftmaxCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -160,12 +184,12 @@ class SoftmaxCost : public OperatorCost {
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
)
const
override
;
};
using
SoftmaxCostPtr
=
std
::
shared_ptr
<
SoftmaxCost
>
;
class
TmpIdentityCost
:
public
OperatorCost
{
public:
TmpIdentityCost
()
=
default
;
explicit
TmpIdentityCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
TmpIdentityCost
()
:
OperatorCost
(
false
)
{}
~
TmpIdentityCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -184,12 +208,15 @@ class TmpIdentityCost : public OperatorCost {
const
int32_t
&
stage_id
)
const
override
;
double
GetBackwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
const
int32_t
&
stage_id
)
const
override
;
// per device PEAK memory cost in a training iteration
double
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
)
const
override
;
};
using
TmpIdentityCostPtr
=
std
::
shared_ptr
<
TmpIdentityCost
>
;
class
BatchParallelCost
:
public
OperatorCost
{
public:
BatchParallelCost
()
=
default
;
explicit
BatchParallelCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
BatchParallelCost
()
:
OperatorCost
(
false
)
{}
~
BatchParallelCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -217,7 +244,8 @@ using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>;
class
VirtualDatasetCost
:
public
OperatorCost
{
public:
VirtualDatasetCost
()
=
default
;
explicit
VirtualDatasetCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
VirtualDatasetCost
()
:
OperatorCost
(
false
)
{}
~
VirtualDatasetCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -244,12 +272,17 @@ class VirtualDatasetCost : public OperatorCost {
const
int32_t
&
)
const
override
{
return
0.0
;
}
// per device PEAK memory cost in a training iteration
double
GetMemoryCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
)
const
override
{
return
0.0
;
}
};
using
VirtualDatasetCostPtr
=
std
::
shared_ptr
<
VirtualDatasetCost
>
;
class
GeneratorBaseCost
:
public
OperatorCost
{
public:
GeneratorBaseCost
()
=
default
;
explicit
GeneratorBaseCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
GeneratorBaseCost
()
:
OperatorCost
(
false
)
{}
~
GeneratorBaseCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -283,7 +316,8 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>;
class
PReLUCost
:
public
OperatorCost
{
public:
PReLUCost
()
=
default
;
explicit
PReLUCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
PReLUCost
()
:
OperatorCost
(
true
)
{}
~
PReLUCost
()
override
=
default
;
// per device communication cost
...
...
@@ -310,7 +344,8 @@ using PReLUCostPtr = std::shared_ptr<PReLUCost>;
class
OneHotCost
:
public
OperatorCost
{
public:
OneHotCost
()
=
default
;
explicit
OneHotCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
OneHotCost
()
:
OperatorCost
(
true
)
{}
~
OneHotCost
()
override
=
default
;
// per device communication cost
...
...
@@ -337,7 +372,8 @@ using OneHotCostPtr = std::shared_ptr<OneHotCost>;
class
SoftmaxCrossEntropyWithLogitsCost
:
public
OperatorCost
{
public:
SoftmaxCrossEntropyWithLogitsCost
()
=
default
;
explicit
SoftmaxCrossEntropyWithLogitsCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
SoftmaxCrossEntropyWithLogitsCost
()
:
OperatorCost
(
false
)
{}
~
SoftmaxCrossEntropyWithLogitsCost
()
override
=
default
;
// per device communication cost
...
...
@@ -364,7 +400,8 @@ using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropy
class
ReshapeCost
:
public
OperatorCost
{
public:
ReshapeCost
()
=
default
;
explicit
ReshapeCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
ReshapeCost
()
:
OperatorCost
(
true
)
{}
~
ReshapeCost
()
override
=
default
;
...
...
@@ -396,7 +433,8 @@ using ReshapeCostPtr = std::shared_ptr<ReshapeCost>;
class
ArithmeticCost
:
public
OperatorCost
{
public:
ArithmeticCost
()
=
default
;
explicit
ArithmeticCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
ArithmeticCost
()
:
OperatorCost
(
false
)
{}
~
ArithmeticCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -425,7 +463,8 @@ using BiasAddCostPtr = std::shared_ptr<BiasAddCost>;
class
ReduceMethodCost
:
public
OperatorCost
{
public:
ReduceMethodCost
()
=
default
;
explicit
ReduceMethodCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
ReduceMethodCost
()
:
OperatorCost
(
true
)
{}
~
ReduceMethodCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -455,7 +494,8 @@ using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>;
class
ReduceMeanCost
:
public
ReduceMethodCost
{
public:
ReduceMeanCost
()
=
default
;
explicit
ReduceMeanCost
(
bool
is_inputs_related
)
:
ReduceMethodCost
(
is_inputs_related
)
{}
ReduceMeanCost
()
:
ReduceMethodCost
(
true
)
{}
~
ReduceMeanCost
()
override
=
default
;
double
GetForwardComputationCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -465,7 +505,8 @@ using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>;
class
GetNextCost
:
public
OperatorCost
{
public:
GetNextCost
()
=
default
;
explicit
GetNextCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
GetNextCost
()
:
OperatorCost
(
false
)
{}
~
GetNextCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -499,7 +540,8 @@ using GetNextCostPtr = std::shared_ptr<GetNextCost>;
class
DropOutCost
:
public
OperatorCost
{
public:
DropOutCost
()
=
default
;
explicit
DropOutCost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
DropOutCost
()
:
OperatorCost
(
true
)
{}
~
DropOutCost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
@@ -530,7 +572,8 @@ using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class
GatherV2Cost
:
public
OperatorCost
{
public:
GatherV2Cost
()
=
default
;
explicit
GatherV2Cost
(
bool
is_inputs_related
)
:
OperatorCost
(
is_inputs_related
)
{}
GatherV2Cost
()
:
OperatorCost
(
true
)
{}
~
GatherV2Cost
()
override
=
default
;
double
GetCommCost
(
const
std
::
vector
<
TensorInfo
>&
inputs
,
const
std
::
vector
<
TensorInfo
>&
outputs
,
...
...
mindspore/ccsrc/parallel/ops_info/activation_info.h
浏览文件 @
1ab43007
...
...
@@ -51,7 +51,7 @@ class Activation : public ActivationBase {
public:
Activation
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ActivationCost
>
())
{}
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ActivationCost
>
(
false
))
{}
~
Activation
()
override
=
default
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
...
@@ -102,7 +102,7 @@ class Softmax : public ActivationBase {
public:
explicit
Softmax
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
SoftmaxCost
>
())
{}
:
ActivationBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
SoftmaxCost
>
(
false
))
{}
~
Softmax
()
override
=
default
;
Status
GenerateStrategies
(
int32_t
stage_id
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
浏览文件 @
1ab43007
...
...
@@ -32,8 +32,8 @@ namespace parallel {
class
ArithmeticBase
:
public
OperatorInfo
{
public:
ArithmeticBase
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
()
)
{}
const
PrimitiveAttrs
&
attrs
,
OperatorCostPtr
cost
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
cost
)
{}
~
ArithmeticBase
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
...
...
@@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo {
class
SubInfo
:
public
ArithmeticBase
{
public:
SubInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
)
)
{}
~
SubInfo
()
override
=
default
;
};
...
...
@@ -64,21 +64,21 @@ class TensorAddInfo : public ArithmeticBase {
public:
TensorAddInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
)
)
{}
~
TensorAddInfo
()
override
=
default
;
};
class
MulInfo
:
public
ArithmeticBase
{
public:
MulInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
MulInfo
()
override
=
default
;
};
class
DivInfo
:
public
ArithmeticBase
{
public:
DivInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
DivInfo
()
override
=
default
;
};
...
...
@@ -86,7 +86,7 @@ class RealDivInfo : public ArithmeticBase {
public:
RealDivInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
RealDivInfo
()
override
=
default
;
};
...
...
@@ -94,14 +94,14 @@ class FloorDivInfo : public ArithmeticBase {
public:
FloorDivInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
FloorDivInfo
()
override
=
default
;
};
class
PowInfo
:
public
ArithmeticBase
{
public:
PowInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
PowInfo
()
override
=
default
;
};
...
...
@@ -109,7 +109,7 @@ class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
)
)
{}
~
GreaterInfo
()
override
=
default
;
};
...
...
@@ -117,7 +117,7 @@ class AssignSubInfo : public ArithmeticBase {
public:
AssignSubInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
)
)
{}
~
AssignSubInfo
()
override
=
default
;
};
}
// namespace parallel
...
...
mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h
浏览文件 @
1ab43007
...
...
@@ -29,9 +29,13 @@ namespace mindspore {
namespace
parallel
{
class
BatchParallelInfo
:
public
OperatorInfo
{
public:
BatchParallelInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
,
OperatorCostPtr
cost
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
cost
),
dev_num_
(
1
)
{}
BatchParallelInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BatchParallelCost
>
()),
dev_num_
(
1
)
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BatchParallelCost
>
(
false
)),
dev_num_
(
1
)
{}
~
BatchParallelInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
@@ -58,7 +62,7 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
public:
SparseSoftmaxCrossEntropyWithLogitsInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
BatchParallelInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
BatchParallelInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BatchParallelCost
>
(
true
)
)
{}
~
SparseSoftmaxCrossEntropyWithLogitsInfo
()
override
=
default
;
void
ReComputeBatchSplitFlagList
()
override
;
};
...
...
mindspore/ccsrc/parallel/ops_info/bias_add_info.h
浏览文件 @
1ab43007
...
...
@@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo {
public:
BiasAddInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BiasAddCost
>
())
{}
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
BiasAddCost
>
(
false
))
{}
~
BiasAddInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/comparison_function_info.h
浏览文件 @
1ab43007
...
...
@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
...
...
@@ -31,7 +32,7 @@ class EqualInfo : public ArithmeticBase {
public:
EqualInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
)
)
{}
~
EqualInfo
()
override
=
default
;
};
...
...
@@ -39,7 +40,7 @@ class NotEqualInfo : public ArithmeticBase {
public:
NotEqualInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
false
)
)
{}
~
NotEqualInfo
()
override
=
default
;
};
...
...
@@ -47,7 +48,7 @@ class MaximumInfo : public ArithmeticBase {
public:
MaximumInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
MaximumInfo
()
override
=
default
;
};
...
...
@@ -55,7 +56,7 @@ class MinimumInfo : public ArithmeticBase {
public:
MinimumInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ArithmeticCost
>
(
true
)
)
{}
~
MinimumInfo
()
override
=
default
;
};
}
// namespace parallel
...
...
mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h
浏览文件 @
1ab43007
...
...
@@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
public:
DropoutDoMaskInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
DropOutCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
DropOutCost
>
(
true
))
{}
~
DropoutDoMaskInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/get_next_info.h
浏览文件 @
1ab43007
...
...
@@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo {
public:
GetNextInfo
(
const
std
::
string
&
operator_name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
GetNextCost
>
())
{}
:
OperatorInfo
(
operator_name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
GetNextCost
>
(
false
))
{}
~
GetNextInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/loss_info.h
浏览文件 @
1ab43007
...
...
@@ -36,7 +36,8 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
public:
SoftmaxCrossEntropyWithLogitsInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
SoftmaxCrossEntropyWithLogitsCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
SoftmaxCrossEntropyWithLogitsCost
>
(
false
))
{}
~
SoftmaxCrossEntropyWithLogitsInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
浏览文件 @
1ab43007
...
...
@@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr&
// Here, we use the origin outputs_, because we only use the slice size of the output tensor.
// It does not matter whether the output tensor is transposed or not.
double
computation_cost
=
cost
()
->
GetForwardComputationCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
cost
()
->
GetCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
operator_
cost
()
->
GetForwardComputationCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
operator_
cost
()
->
GetCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
result
->
communication_without_parameter_
=
cost
()
->
GetForwardCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
operator_
cost
()
->
GetForwardCommCost
(
relica_inputs_tensor_vector
,
outputs_tensor_info_
,
stage_id
);
result
->
communication_with_partial_para_
=
result
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
...
...
mindspore/ccsrc/parallel/ops_info/matmul_info.h
浏览文件 @
1ab43007
...
...
@@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo {
public:
MatMulBase
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
MatMulCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
MatMulCost
>
(
true
))
{}
~
MatMulBase
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/onehot_info.h
浏览文件 @
1ab43007
...
...
@@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo {
public:
OneHotInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
OneHotCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
OneHotCost
>
(
false
))
{}
~
OneHotInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.cc
浏览文件 @
1ab43007
...
...
@@ -1035,11 +1035,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) {
return
FAILED
;
}
int32_t
stage_id
=
strategy
->
GetInputStage
();
double
computation_cost
=
cost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
cost
()
->
GetCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
double
computation_cost
=
operator_cost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
double
communication_cost
=
operator_cost
()
->
GetCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
std
::
shared_ptr
<
Cost
>
result
=
std
::
make_shared
<
Cost
>
(
computation_cost
,
communication_cost
);
result
->
communication_without_parameter_
=
cost
()
->
GetForwardCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
operator_
cost
()
->
GetForwardCommCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
stage_id
);
result
->
communication_with_partial_para_
=
result
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
communication_cost
-
result
->
communication_without_parameter_
);
...
...
@@ -1096,7 +1097,38 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) {
return
FAILED
;
}
is_parameter_
=
is_parameter
;
cost
()
->
set_is_parameter
(
is_parameter
);
operator_cost
()
->
set_is_parameter
(
is_parameter
);
return
SUCCESS
;
}
Status
OperatorInfo
::
CalculateMemoryCost
()
{
// First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to
// calculate memory cost.
if
(
is_parameter_involve_
.
size
()
!=
is_parameter_
.
size
())
{
MS_LOG
(
ERROR
)
<<
"'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."
;
return
FAILED
;
}
operator_cost
()
->
set_is_parameter_involve
(
is_parameter_involve_
);
operator_cost
()
->
set_output_parameter_involve
(
is_output_parameter_involve_
);
// Set the memory cost in the 'strategy_cost_'
for
(
auto
&
swc
:
strategy_cost_
)
{
auto
mem_cost
=
operator_cost
()
->
GetMemoryCost
(
swc
->
inputs_ptr
,
swc
->
outputs_ptr
);
swc
->
cost_list
[
0
]
->
memory_with_reuse_
=
mem_cost
;
}
return
SUCCESS
;
}
Status
OperatorInfo
::
CorrectMemoryCost
(
size_t
input_index
)
{
for
(
auto
&
swc
:
strategy_cost_
)
{
double
parameter_mem_cost
=
ListProduct
(
swc
->
inputs_ptr
[
input_index
].
slice_shape
())
*
static_cast
<
double
>
(
operator_cost
()
->
inputs_type_lengths
()[
input_index
]);
swc
->
cost_list
[
0
]
->
memory_with_reuse_
-=
parameter_mem_cost
;
if
(
swc
->
cost_list
[
0
]
->
memory_with_reuse_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"The memory cost after correction is: "
<<
swc
->
cost_list
[
0
]
->
memory_with_reuse_
<<
", the parameter memory cost is: "
<<
parameter_mem_cost
;
return
FAILED
;
}
}
return
SUCCESS
;
}
...
...
@@ -1193,7 +1225,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu
}
inputs_type_lengths_
=
input_lengths
;
outputs_type_lengths_
=
output_lengths
;
cost
()
->
SetInputAndOutputTypeLength
(
input_lengths
,
output_lengths
);
operator_
cost
()
->
SetInputAndOutputTypeLength
(
input_lengths
,
output_lengths
);
return
SUCCESS
;
}
...
...
@@ -1221,7 +1253,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra
}
double
OperatorInfo
::
GetForwardMemoryCostFromCNode
()
{
return
cost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
0
);
return
operator_
cost
()
->
GetForwardComputationCost
(
inputs_tensor_info_
,
outputs_tensor_info_
,
0
);
}
}
// namespace parallel
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.h
浏览文件 @
1ab43007
...
...
@@ -60,7 +60,7 @@ class OperatorInfo {
outputs_shape_
(
std
::
move
(
outputs_shape
)),
attrs_
(
std
::
move
(
attrs
)),
is_alive_
(
true
),
cost_
(
cost
),
operator_
cost_
(
cost
),
outputs_type_
()
{
std
::
vector
<
bool
>
not_parameteter
(
inputs_shape_
.
size
(),
false
);
is_parameter_
=
not_parameteter
;
...
...
@@ -83,8 +83,8 @@ class OperatorInfo {
// Given the stage_id (which indicates the number of devices),
// generate all strategies for this operator
virtual
Status
GenerateStrategies
(
int32_t
stage_id
)
=
0
;
const
OperatorCostPtr
&
cost
()
const
{
return
cost_
;
}
void
set_cost
(
const
OperatorCostPtr
&
cost
)
{
cost_
=
cost
;
}
const
OperatorCostPtr
&
operator_cost
()
const
{
return
operator_
cost_
;
}
void
set_cost
(
const
OperatorCostPtr
&
cost
)
{
operator_
cost_
=
cost
;
}
virtual
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
=
0
;
virtual
std
::
shared_ptr
<
std
::
vector
<
std
::
vector
<
int32_t
>>>
GenerateBatchStrategies
();
...
...
@@ -98,7 +98,7 @@ class OperatorInfo {
std
::
vector
<
std
::
shared_ptr
<
StrategyWithCost
>>
GetStrategyCost
()
{
return
strategy_cost_
;
}
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
Status
CalculateMemoryCost
()
const
{
return
SUCCESS
;
}
Status
CalculateMemoryCost
()
;
int
ComputeOpAndPrevEdgeParameterInvolved
();
ForwardOp
forward_op
()
const
{
return
forward_op_
;
}
...
...
@@ -125,7 +125,7 @@ class OperatorInfo {
void
ReplaceSuccEdge
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplacePreEdges
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
void
ReplaceSuccEdges
(
const
std
::
shared_ptr
<
OperatorInfo
>&
op
,
const
std
::
shared_ptr
<
Edge
>&
new_edge
);
std
::
vector
<
size_t
>
GetOutputTypeLengths
()
const
{
return
cost
()
->
outputs_type_lengths
();
}
std
::
vector
<
size_t
>
GetOutputTypeLengths
()
const
{
return
operator_
cost
()
->
outputs_type_lengths
();
}
void
SetSelectedStrategyAndCost
(
const
StrategyPtr
&
s_strategy
,
const
CostPtr
&
cost
)
{
selected_strategy_
=
s_strategy
;
selected_cost_
=
cost
;
...
...
@@ -142,6 +142,10 @@ class OperatorInfo {
void
set_strategy
(
const
StrategyPtr
&
strategy
)
{
strategy_
=
strategy
;
}
void
set_refkey_parameter_name
(
std
::
string
p_name
)
{
refkey_parameter_name_
=
std
::
move
(
p_name
);
}
const
std
::
string
&
refkey_parameter_name
()
const
{
return
refkey_parameter_name_
;
}
// When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
// multiple times. This method is to correct this, and makes the cost is calulated only once.
Status
CorrectMemoryCost
(
size_t
input_index
);
int
is_output_parameter_involve
()
const
{
return
is_output_parameter_involve_
;
}
int
used_devices
()
const
{
return
used_devices_
;
}
// needed by rec_parser
void
set_type
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
...
...
@@ -234,7 +238,7 @@ class OperatorInfo {
int32_t
used_devices_
=
-
1
;
private:
OperatorCostPtr
cost_
;
OperatorCostPtr
operator_
cost_
;
std
::
vector
<
TypePtr
>
outputs_type_
;
};
...
...
mindspore/ccsrc/parallel/ops_info/prelu_info.h
浏览文件 @
1ab43007
...
...
@@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo {
public:
PReLUInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
PReLUCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
PReLUCost
>
(
true
))
{}
~
PReLUInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc
浏览文件 @
1ab43007
...
...
@@ -109,7 +109,7 @@ Status ReduceMethod::GetAttrs() {
}
cross_batch_
=
cross_batch_iter
->
second
->
cast
<
BoolImmPtr
>
()
->
value
();
}
auto
reducemethodcost
=
std
::
dynamic_pointer_cast
<
ReduceMethodCost
>
(
cost
());
auto
reducemethodcost
=
std
::
dynamic_pointer_cast
<
ReduceMethodCost
>
(
operator_
cost
());
if
(
reducemethodcost
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Cost cast to ReduceMethodCostPtr failed!"
;
return
FAILED
;
...
...
mindspore/ccsrc/parallel/ops_info/reduce_method_info.h
浏览文件 @
1ab43007
...
...
@@ -34,7 +34,7 @@ class ReduceMethod : public OperatorInfo {
public:
ReduceMethod
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ReduceMethodCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ReduceMethodCost
>
(
true
))
{}
~
ReduceMethod
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/reshape_info.h
浏览文件 @
1ab43007
...
...
@@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo {
public:
ReshapeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ReshapeCost
>
()),
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
ReshapeCost
>
(
false
)),
dev_num_
(
0
),
input_layout_set_flag_
(
false
),
output_layout_set_flag_
(
false
)
{}
...
...
mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h
浏览文件 @
1ab43007
...
...
@@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo {
public:
TmpIdentityInfo
(
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
,
const
std
::
string
&
name
=
IDENTITY_INFO
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
TmpIdentityCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
TmpIdentityCost
>
(
false
))
{}
~
TmpIdentityInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/transpose_info.h
浏览文件 @
1ab43007
...
...
@@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo {
public:
TransposeInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
TransposeCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
TransposeCost
>
(
false
))
{}
~
TransposeInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h
浏览文件 @
1ab43007
...
...
@@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo {
public:
VirtualDatasetInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
VirtualDatasetCost
>
())
{}
:
OperatorInfo
(
name
,
inputs_shape
,
outputs_shape
,
attrs
,
std
::
make_shared
<
VirtualDatasetCost
>
(
false
))
{}
~
VirtualDatasetInfo
()
override
=
default
;
Status
Init
(
const
StrategyPtr
&
strategy
)
override
;
Status
InitForCostModel
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
1ab43007
...
...
@@ -874,11 +874,15 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
if
(
entire_costgraph
->
ComputeOpsAndEdgesParameterInvolved
()
==
SUCCESS
)
{
// Calculate operators' memory usage
if
(
entire_costgraph
->
CalculateOpsMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"C
orrecting operators' cost for memory reuse
failed."
;
MS_LOG
(
EXCEPTION
)
<<
"C
alculating operators' cost for memory cost
failed."
;
}
// Calculate edges' memory usage
if
(
entire_costgraph
->
CalculateEdgesMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Correcting edges' cost for memory reuse failed."
;
MS_LOG
(
EXCEPTION
)
<<
"Calculating edges' cost for memory cost failed."
;
}
// Correct memory usage caused by TmpIdentity
if
(
entire_costgraph
->
CorrectOpsMemoryCost
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Correcting operators' cost for memory cost failed."
;
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Computing operators' parameter_involved failed."
;
...
...
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc
浏览文件 @
1ab43007
...
...
@@ -159,6 +159,7 @@ Status TensorRedistribution::ComputeCost() {
backward_comm_cost_
+=
prod
;
comm_cost_
+=
2.0
*
prod
;
computation_cost_
+=
prod
;
memory_cost_
+=
prod
;
}
else
if
(
str
==
CONCAT_BY_AXIS
)
{
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// computation cost = before_slice_shape
...
...
@@ -175,20 +176,25 @@ Status TensorRedistribution::ComputeCost() {
if
(
concat_dim
==
0
)
{
// computation cost = all_gather
computation_cost_
+=
prod
;
memory_cost_
+=
prod
*
dev_num
;
}
else
{
// computation cost = all_gather + split + concat
computation_cost_
+=
(
prod
+
prod
*
dev_num
+
prod
*
dev_num
);
memory_cost_
+=
(
prod
*
dev_num
+
prod
*
dev_num
+
prod
);
}
}
else
{
// There is only computation cost in SplitByAxis.
// computation cost = before_slice_shape
computation_cost_
+=
prod
;
// This addtion may be erroneous
memory_cost_
+=
prod
;
}
}
if
(
reshape_flag
())
{
Shape
prev_slice_shape
=
from_
.
slice_shape
().
array
();
double
prev_prod
=
std
::
accumulate
(
prev_slice_shape
.
begin
(),
prev_slice_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
computation_cost_
+=
2.0
*
prev_prod
;
memory_cost_
+=
2.0
*
prev_prod
;
}
return
Status
::
SUCCESS
;
}
...
...
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
浏览文件 @
1ab43007
...
...
@@ -42,6 +42,7 @@ class TensorRedistribution {
forward_comm_cost_
(
0.0
),
backward_comm_cost_
(
0.0
),
computation_cost_
(
0.0
),
memory_cost_
(
0.0
),
construct_op_flag_
(
construct_op_flag
),
keep_reshape_
(
keep_reshape
)
{}
Status
Init
(
const
TensorLayout
&
from
,
const
TensorLayout
&
to
,
const
RankList
&
dev_list
);
...
...
@@ -54,6 +55,7 @@ class TensorRedistribution {
double
computation_cost
()
const
{
return
computation_cost_
;
}
double
forward_comm_cost
()
const
{
return
forward_comm_cost_
;
}
double
backward_comm_cost
()
const
{
return
backward_comm_cost_
;
}
double
memory_cost
()
const
{
return
memory_cost_
;
}
private:
Status
InferReshape
(
const
TensorLayout
&
from_layout
,
const
TensorLayout
&
to_layout
,
...
...
@@ -72,7 +74,12 @@ class TensorRedistribution {
double
forward_comm_cost_
;
// backward communication cost
double
backward_comm_cost_
;
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
// inputs.
double
computation_cost_
;
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
// calculated by the outputs.
double
memory_cost_
;
bool
construct_op_flag_
;
bool
keep_reshape_
;
};
...
...
tests/ut/cpp/parallel/ops_info/activation_test.cc
浏览文件 @
1ab43007
...
...
@@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) {
act_ptr_
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
act_ptr_
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
act_ptr_
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
act_ptr_
->
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
act_ptr_
->
operator_
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
ASSERT_DOUBLE_EQ
(
act_ptr_
->
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
act_ptr_
->
operator_
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
communication_cost_
);
}
}
...
...
@@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) {
soft_ptr_
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
soft_ptr_
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
soft_ptr_
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
operator_
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
soft_ptr_
->
operator_
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
communication_cost_
);
}
}
...
...
tests/ut/cpp/parallel/ops_info/matmul_info_test.cc
浏览文件 @
1ab43007
...
...
@@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) {
matmul1
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
matmul1
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
matmul1
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
matmul1
->
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
matmul1
->
operator_
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
break
;
}
...
...
@@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) {
TensorInfo
replica_input1_info
(
tly
,
input1_shape
,
input1_slice_shape
);
replica_inputs_info
.
push_back
(
replica_input1_info
);
ASSERT_DOUBLE_EQ
(
matmul3
->
cost
()
->
GetComputationCost
(
replica_inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
matmul3
->
operator_
cost
()
->
GetComputationCost
(
replica_inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
break
;
}
...
...
tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc
浏览文件 @
1ab43007
...
...
@@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) {
tensor_add
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
tensor_add
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
tensor_add
->
outputs_tensor_info
();
double
memory_cost0
=
tensor_add
->
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost0
=
tensor_add
->
operator_
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost1
=
cost
.
computation_cost_
;
bool
memory
=
memory_cost0
-
memory_cost1
<=
1.0
;
double
comm_cost0
=
tensor_add
->
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost0
=
tensor_add
->
operator_
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost1
=
cost
.
communication_cost_
;
bool
comm
=
comm_cost0
-
comm_cost1
<=
1.0
;
...
...
@@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
tensor_add1
->
InitForCostModel
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
tensor_add1
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
tensor_add1
->
outputs_tensor_info
();
double
memory_cost0
=
tensor_add1
->
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost0
=
tensor_add1
->
operator_
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
memory_cost1
=
cost
.
computation_cost_
;
bool
memory
=
memory_cost0
-
memory_cost1
<=
1.0
;
double
comm_cost0
=
tensor_add1
->
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost0
=
tensor_add1
->
operator_
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
());
double
comm_cost1
=
cost
.
communication_cost_
;
bool
comm
=
comm_cost0
-
comm_cost1
<=
1.0
;
...
...
tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc
浏览文件 @
1ab43007
...
...
@@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) {
identity_ptr
->
Init
(
sp
);
std
::
vector
<
TensorInfo
>
inputs_info
=
identity_ptr
->
inputs_tensor_info
();
std
::
vector
<
TensorInfo
>
outputs_info
=
identity_ptr
->
outputs_tensor_info
();
ASSERT_DOUBLE_EQ
(
identity_ptr
->
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
identity_ptr
->
operator_
cost
()
->
GetComputationCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
computation_cost_
);
ASSERT_DOUBLE_EQ
(
identity_ptr
->
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
ASSERT_DOUBLE_EQ
(
identity_ptr
->
operator_
cost
()
->
GetCommCost
(
inputs_info
,
outputs_info
,
sp
->
GetInputStage
()),
cost
.
communication_cost_
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录