Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
def85732
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“bbeb9787d9a258bc55dfd40bb37e2a87962bbe8b”上不存在“Documentation/devicetree/bindings/trivial-devices.txt”
提交
def85732
编写于
4月 29, 2020
作者:
X
Xiaoda Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implementing-searching-strategy-for-inference
上级
5a03bd80
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
255 addition
and
45 deletion
+255
-45
mindspore/ccsrc/parallel/auto_parallel/costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/costmodel.cc
+14
-4
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
+8
-3
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
+13
-6
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
+123
-25
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
+6
-1
mindspore/ccsrc/parallel/costmodel_context.cc
mindspore/ccsrc/parallel/costmodel_context.cc
+3
-0
mindspore/ccsrc/parallel/costmodel_context.h
mindspore/ccsrc/parallel/costmodel_context.h
+6
-0
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
+1
-0
mindspore/ccsrc/parallel/ops_info/operator_info.cc
mindspore/ccsrc/parallel/ops_info/operator_info.cc
+1
-0
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
...pore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
+3
-3
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+2
-0
mindspore/parallel/_cost_model_context.py
mindspore/parallel/_cost_model_context.py
+32
-2
tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc
tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc
+1
-1
tests/ut/python/parallel/__init__.py
tests/ut/python/parallel/__init__.py
+6
-0
tests/ut/python/parallel/test_auto_parallel_inference.py
tests/ut/python/parallel/test_auto_parallel_inference.py
+36
-0
未找到文件。
mindspore/ccsrc/parallel/auto_parallel/costmodel.cc
浏览文件 @
def85732
...
...
@@ -23,8 +23,17 @@
namespace
mindspore
{
namespace
parallel
{
void
Simplify
(
CostPtrList
*
clist_ptrs
)
{
// Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_cost.
if
(
RUN_PHASE
==
TRAINING_PHASE
)
{
// training phase
SimplifyForDecreasingCommunicationWithPartialPara
(
clist_ptrs
);
}
else
{
// inference phase
SimplifyForDecreasingCommunicationForward
(
clist_ptrs
);
}
}
void
SimplifyForDecreasingCommunicationForward
(
CostPtrList
*
clist_ptrs
)
{
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_forward.
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
if
(
!
COST_MODEL_SIMPLIFY_CALCULATION
)
{
return
;
...
...
@@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) {
});
CostPtrList
ret
;
for
(
size_t
i
=
0
;
i
<
clist_ptrs
->
size
();
++
i
)
{
if
((
ret
.
size
()
==
size_t
(
0
))
||
(
clist_ptrs
->
at
(
id
[
i
])
->
communication_cost_
<
ret
.
back
()
->
communication_cost_
))
{
if
((
ret
.
size
()
==
size_t
(
0
))
||
(
clist_ptrs
->
at
(
id
[
i
])
->
communication_forward_
<
ret
.
back
()
->
communication_forward_
))
{
ret
.
emplace_back
(
std
::
move
(
clist_ptrs
->
at
(
id
[
i
])));
}
}
*
clist_ptrs
=
std
::
move
(
ret
);
}
void
SimplifyForDreasingCommunicationWithPartialPara
(
CostPtrList
*
clist_ptrs
)
{
void
SimplifyForD
ec
reasingCommunicationWithPartialPara
(
CostPtrList
*
clist_ptrs
)
{
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
if
(
!
COST_MODEL_SIMPLIFY_CALCULATION
)
{
...
...
mindspore/ccsrc/parallel/auto_parallel/costmodel.h
浏览文件 @
def85732
...
...
@@ -51,18 +51,22 @@ struct Cost {
communication_with_partial_para_
=
0.0
;
communication_redis_forward_
=
0.0
;
communication_redis_backward_
=
0.0
;
communication_forward_
=
0.0
;
}
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
double
memory_with_reuse_
;
// 'computation_cost_' models the training time of an iteration in a training phase
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// by ONLY forward phase
double
computation_cost_
;
// 'communication_cost_' includes communications from operators (forward and backward) and edges
// 'communication_cost_' includes communications from operators (forward and backward) and edges
(redistribution)
double
communication_cost_
;
// communication_without_parameter_ = communication_cost_ - (backward communication from operators)
double
communication_without_parameter_
;
// communication_with_partial_para_ =
// communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ )
double
communication_with_partial_para_
;
// communication_forward_ = communication cost from operators (only forward phase) and forward redistribution.
double
communication_forward_
;
double
communication_redis_forward_
;
double
communication_redis_backward_
;
std
::
shared_ptr
<
Decision
>
decision_ptr_
;
...
...
@@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
using
FinalSingleDecisionPtr
=
std
::
shared_ptr
<
FinalSingleDecision
>
;
void
Simplify
(
CostPtrList
*
clist
);
void
SimplifyForDreasingCommunicationWithPartialPara
(
CostPtrList
*
clist
);
void
SimplifyForDecreasingCommunicationForward
(
CostPtrList
*
clist
);
void
SimplifyForDecreasingCommunicationWithPartialPara
(
CostPtrList
*
clist
);
void
RefineForPracticalCost
(
const
CostPtr
&
,
bool
is_redistribution
);
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
浏览文件 @
def85732
...
...
@@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() {
<<
", communication_with_partial_para_: "
<<
cost
->
communication_with_partial_para_
<<
"."
;
// refine communication cost calculation for practice
RefineForPracticalCost
(
cost
,
true
);
cost
->
communication_forward_
=
cost
->
communication_redis_forward_
;
CostPtrKey
ck
=
{
target_output_str
,
target_input_str
};
CostPtrList
cl
;
cl
.
push_back
(
cost
);
...
...
@@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
(
void
)
std
::
transform
(
edges
.
begin
(),
edges
.
end
(),
all_cost_list
.
begin
(),
LocalGetCostList
);
CostPtrList
selected_cost_list
(
all_cost_list
.
size
(),
nullptr
);
std
::
function
<
void
(
size_t
,
double
,
double
,
double
,
double
)
>
recursive
=
[
&
](
size_t
k
,
double
computation
,
double
memory
,
double
communication
,
double
communication_without_para
)
{
std
::
function
<
void
(
size_t
,
double
,
double
,
double
,
double
,
double
)
>
recursive
=
[
&
](
size_t
k
,
double
computation
,
double
memory
,
double
communication
,
double
communication_without_para
,
double
communication_forward
)
{
if
(
k
==
edges
.
size
())
{
auto
decision
=
std
::
make_shared
<
EdgeEliminationDecision
>
(
selected_cost_list
);
CostPtr
new_cost
=
std
::
make_shared
<
Cost
>
(
computation
,
communication
);
...
...
@@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
new_cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
new_cost
->
memory_with_reuse_
=
memory
;
new_cost
->
communication_forward_
=
communication_forward
;
new_cost
->
decision_ptr_
=
decision
;
result
.
push_back
(
new_cost
);
return
;
...
...
@@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
selected_cost_list
[
k
]
=
c
;
recursive
(
k
+
1
,
computation
+
c
->
computation_cost_
,
memory
+
c
->
memory_with_reuse_
,
communication
+
c
->
communication_cost_
,
communication_without_para
+
c
->
communication_without_parameter_
);
communication_without_para
+
c
->
communication_without_parameter_
,
communication_forward
+
c
->
communication_forward_
);
}
};
recursive
(
0
,
0.0
,
0.0
,
0.0
,
0.0
);
Simplify
ForDreasingCommunicationWithPartialPara
(
&
result
);
recursive
(
0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
);
Simplify
(
&
result
);
return
result
;
}
...
...
@@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
left_cost
->
computation_cost_
+
middle_cost
->
computation_cost_
+
right_cost
->
computation_cost_
;
double
communication
=
left_cost
->
communication_cost_
+
middle_cost
->
communication_cost_
+
right_cost
->
communication_cost_
;
double
communication_forward
=
left_cost
->
communication_forward_
+
middle_cost
->
communication_forward_
+
right_cost
->
communication_forward_
;
double
communication_without_para
=
left_cost
->
communication_without_parameter_
+
middle_cost
->
communication_without_parameter_
+
right_cost
->
communication_without_parameter_
;
...
...
@@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
cost
->
memory_with_reuse_
=
memory_cost
;
cost
->
communication_forward_
=
communication_forward
;
ret_cost_list
->
emplace_back
(
std
::
move
(
cost
));
}
}
...
...
@@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP
CreateOpEliminationSubCostList
(
middle_strategy
,
e1
->
GetCostList
(
output_st_ptr
,
middle_strategy
),
op_strategy
->
cost_list
,
e2
->
GetCostList
(
middle_strategy
,
input_st_ptr
),
&
result
);
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
result
);
Simplify
(
&
result
);
return
result
;
}
...
...
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
浏览文件 @
def85732
...
...
@@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
size_t
TENSOR_SLICE_ALIGNMENT_SIZE
=
DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE
;
bool
FULLY_USE_DEVICES
=
DEFAULT_FULLY_USE_DEVICES
;
bool
ELEMENTWISE_OP_STRA_FOLLOW
=
DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW
;
bool
MULTI_SUBGRAPHS
=
DEFAULT_IS_MULTI_SUBGRAPHS
;
int32_t
RUN_PHASE
=
DEFAULT_RUN_PHASE
;
void
CostGraph
::
SetDeviceMemoryAndCostParameter
()
{
MS_EXCEPTION_IF_NULL
(
CostModelContext
::
GetInstance
());
...
...
@@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
}
else
{
MS_LOG
(
INFO
)
<<
"elementwise_op_strategy_follow: false."
;
}
// MULTI_SUBGRAPHS
auto
multi_subgraphs
=
CostModelContext
::
GetInstance
()
->
is_multi_subgraphs
();
MULTI_SUBGRAPHS
=
multi_subgraphs
;
if
(
MULTI_SUBGRAPHS
)
{
MS_LOG
(
INFO
)
<<
"multi_subgraphs: true."
;
}
else
{
MS_LOG
(
INFO
)
<<
"multi_subgraphs: false."
;
}
// RUN_PHASE
auto
phase
=
CostModelContext
::
GetInstance
()
->
run_phase
();
if
(
phase
!=
0
&&
phase
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"'run_phase' must be in {0, 1}"
;
}
RUN_PHASE
=
phase
;
MS_LOG
(
INFO
)
<<
"run_phase: "
<<
RUN_PHASE
<<
"."
;
}
void
CostGraph
::
RemoveOperator
(
const
OperatorInfoPtr
&
op
)
{
...
...
@@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
MS_EXCEPTION_IF_NULL
(
cost3
);
double
computation
=
cost1
->
computation_cost_
+
cost2
->
computation_cost_
+
cost3
->
computation_cost_
;
double
memory
=
cost1
->
memory_with_reuse_
+
cost2
->
memory_with_reuse_
+
cost3
->
memory_with_reuse_
;
double
commmunication
=
cost1
->
communication_cost_
+
cost2
->
communication_cost_
+
cost3
->
communication_cost_
;
double
communication
=
cost1
->
communication_cost_
+
cost2
->
communication_cost_
+
cost3
->
communication_cost_
;
double
communication_forward
=
cost1
->
communication_forward_
+
cost2
->
communication_forward_
+
cost3
->
communication_forward_
;
double
communication_without_para
=
cost1
->
communication_without_parameter_
+
cost2
->
communication_without_parameter_
+
cost3
->
communication_without_parameter_
;
auto
decision
=
std
::
make_shared
<
FinalDecision
>
(
u_strategy
->
strategy_ptr
,
v_strategy
->
strategy_ptr
,
cost1
,
cost2
,
cost3
);
auto
cost
=
std
::
make_shared
<
Cost
>
(
computation
,
comm
m
unication
,
decision
);
auto
cost
=
std
::
make_shared
<
Cost
>
(
computation
,
communication
,
decision
);
MS_EXCEPTION_IF_NULL
(
cost
);
cost
->
communication_without_parameter_
=
communication_without_para
;
cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
comm
m
unication
-
communication_without_para
);
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
cost
->
memory_with_reuse_
=
memory
;
cost
->
communication_forward_
=
communication_forward
;
ret
.
push_back
(
cost
);
}
}
...
...
@@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
}
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
ret
);
Simplify
(
&
ret
);
return
ret
;
}
...
...
@@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
cost1
->
communication_without_parameter_
+
COST_MODEL_GAMMA
*
(
cost1
->
communication_cost_
-
cost1
->
communication_without_parameter_
);
new_cost
->
memory_with_reuse_
=
cost1
->
memory_with_reuse_
;
new_cost
->
communication_forward_
=
cost1
->
communication_forward_
;
ret
.
push_back
(
new_cost
);
}
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
ret
);
Simplify
(
&
ret
);
return
ret
;
}
CostPtr
CostGraph
::
SelectCostWithMemoryConstraint
(
const
CostPtrList
&
cost_list
,
double
memory
)
{
CostPtr
CostGraph
::
SelectCostWithMinInferenceTime
(
const
CostPtrList
&
cost_list
,
double
memory
)
{
// Select the cost with minimum inference time. Currently, the inference time is modeled as =
// costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
if
(
cost_list
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Final cost list is null."
;
return
nullptr
;
}
CostPtrList
after_mem_filter
;
// Filter out the valid costs
double
minimum_memory
=
DBL_MAX
;
// Filter out the valid costs.
for
(
auto
&
a_cost
:
cost_list
)
{
if
(
a_cost
->
memory_with_reuse_
<=
memory
)
{
after_mem_filter
.
emplace_back
(
std
::
move
(
a_cost
));
}
else
if
(
a_cost
->
memory_with_reuse_
<
minimum_memory
)
{
minimum_memory
=
a_cost
->
memory_with_reuse_
;
}
}
if
(
after_mem_filter
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"No available cost. The minimum memory cost is: "
<<
minimum_memory
<<
", the memory capacity is: "
<<
memory
<<
"."
;
return
nullptr
;
}
// Init the returned value with first cost.
CostPtr
ret
=
after_mem_filter
[
0
];
std
::
function
<
CostPtr
(
CostPtr
,
const
CostPtr
&
)
>
LocalCompare
=
[
&
](
CostPtr
init
,
const
CostPtr
&
cost_x
)
{
MS_EXCEPTION_IF_NULL
(
cost_x
);
if
(
init
==
nullptr
||
cost_x
->
computation_cost_
<
memory
)
{
init
=
cost_x
;
double
minimum
=
costmodel_alpha_
*
ret
->
computation_cost_
+
costmodel_beta_
*
ret
->
communication_forward_
;
MS_LOG
(
INFO
)
<<
"Cost 0: "
<<
"memory_cost: "
<<
ret
->
memory_with_reuse_
<<
", computation_cost_: "
<<
ret
->
computation_cost_
<<
", communication_forward_: "
<<
ret
->
communication_forward_
<<
", communication_with_partial_para_: "
<<
ret
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
ret
->
communication_cost_
<<
", communication_without_parameter_: "
<<
ret
->
communication_without_parameter_
<<
"."
;
MS_LOG
(
INFO
)
<<
"Cost 0: totoal_cost: "
<<
minimum
;
for
(
size_t
i
=
1
;
i
<
after_mem_filter
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
after_mem_filter
[
i
]);
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": memory_cost: "
<<
after_mem_filter
[
i
]
->
memory_with_reuse_
<<
", computation_cost_: "
<<
after_mem_filter
[
i
]
->
computation_cost_
<<
", communication_forward_: "
<<
after_mem_filter
[
i
]
->
communication_forward_
<<
", communication_with_partial_para_: "
<<
after_mem_filter
[
i
]
->
communication_with_partial_para_
<<
", communication_cost_: "
<<
after_mem_filter
[
i
]
->
communication_cost_
<<
", communication_without_parameter_: "
<<
after_mem_filter
[
i
]
->
communication_without_parameter_
<<
"."
;
auto
tmp
=
costmodel_alpha_
*
after_mem_filter
[
i
]
->
computation_cost_
+
costmodel_beta_
*
after_mem_filter
[
i
]
->
communication_forward_
;
MS_LOG
(
INFO
)
<<
"Cost "
<<
i
<<
": total_cost: "
<<
tmp
;
if
(
minimum
>
tmp
)
{
minimum
=
tmp
;
ret
=
after_mem_filter
[
i
];
MS_LOG
(
INFO
)
<<
"Selected: "
<<
i
;
}
return
init
;
};
CostPtr
ret
=
nullptr
;
return
std
::
accumulate
(
after_mem_filter
.
begin
(),
after_mem_filter
.
end
(),
ret
,
LocalCompare
);
}
return
ret
;
}
CostPtr
CostGraph
::
SelectCostWithMinTrainingTime
(
const
CostPtrList
&
cost_list
,
double
memory
)
{
...
...
@@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() {
});
if
(
alive_ops
.
size
()
>
2
)
{
return
SearchStrategyForMultiNodeFinalGraph
(
alive_ops
);
if
(
RUN_PHASE
==
TRAINING_PHASE
)
{
// training phase
return
SearchStrategyForMultiNodeFinalGraph
(
alive_ops
);
}
else
{
// inference phase
MS_LOG
(
EXCEPTION
)
<<
"Currently, searching strategy for the multi-node final graph in inference phase is not supported."
;
}
}
else
if
(
alive_ops
.
size
()
==
1
)
{
MS_LOG
(
INFO
)
<<
"There are 1 single node in the final graph."
;
OperatorInfoPtr
u
=
alive_ops
[
0
];
auto
cost_list
=
CreateFinalSingleCostList
(
u
);
auto
cost
=
SelectCostWithMinTrainingTime
(
cost_list
,
dev_memory_
);
CostPtr
cost
=
nullptr
;
if
(
RUN_PHASE
==
TRAINING_PHASE
)
{
// training phase
cost
=
SelectCostWithMinTrainingTime
(
cost_list
,
dev_memory_
);
}
else
{
// inference phase
cost
=
SelectCostWithMinInferenceTime
(
cost_list
,
dev_memory_
);
}
if
(
cost
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"No vaild strategy can be found under the current device memory: "
<<
dev_memory_
<<
"."
;
return
FAILED
;
...
...
@@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() {
auto
cost_list
=
one_component
->
CreateFinalSingleCostList
(
one_component
->
GetOperators
()[
0
]);
all_list
.
push_back
(
cost_list
);
}
auto
selected_cost_list
=
SelectCostListWithMinTrainingTimeMultiple
(
all_list
,
dev_memory_
);
CostPtrList
selected_cost_list
;
if
(
RUN_PHASE
==
TRAINING_PHASE
)
{
// training phase
selected_cost_list
=
SelectCostListWithMinTrainingTimeMultiple
(
all_list
,
dev_memory_
);
}
else
{
// inference phase
MS_LOG
(
EXCEPTION
)
<<
"Currently, searching strategy for the two-separated-node final graph in the inference "
"phase is not supported."
;
}
for
(
size_t
k
=
0
;
k
<
selected_cost_list
.
size
();
++
k
)
{
auto
selected_cost
=
selected_cost_list
[
k
];
if
(
selected_cost
==
nullptr
)
{
...
...
@@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() {
auto
e
=
u
->
GetAliveSuccEdges
()[
0
];
MS_EXCEPTION_IF_NULL
(
e
);
auto
cost_list
=
CreateFinalCostList
(
u
,
e
,
v
);
auto
cost
=
SelectCostWithMinTrainingTime
(
cost_list
,
dev_memory_
);
CostPtr
cost
=
nullptr
;
if
(
RUN_PHASE
==
TRAINING_PHASE
)
{
// training phase
cost
=
SelectCostWithMinTrainingTime
(
cost_list
,
dev_memory_
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Currently, searching strategy for the two-connected-node final graph in the inference "
"phase is not supported."
;
}
if
(
cost
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"No vaild strategy can be found under the current device memory: "
<<
dev_memory_
<<
"."
;
return
FAILED
;
...
...
@@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
double
memory
=
op_cost
->
memory_with_reuse_
+
edge_cost
->
memory_with_reuse_
+
tar_cost
->
memory_with_reuse_
;
double
communication
=
op_cost
->
communication_cost_
+
edge_cost
->
communication_cost_
+
tar_cost
->
communication_cost_
;
double
communication_forward
=
op_cost
->
communication_forward_
+
edge_cost
->
communication_forward_
+
tar_cost
->
communication_forward_
;
double
communication_without_para
=
op_cost
->
communication_without_parameter_
+
edge_cost
->
communication_without_parameter_
+
tar_cost
->
communication_without_parameter_
;
...
...
@@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
new_cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
new_cost
->
memory_with_reuse_
=
memory
;
new_cost
->
communication_forward_
=
communication_forward
;
MS_EXCEPTION_IF_NULL
(
tar_cost_list_new
);
tar_cost_list_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
...
...
@@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
CreateMergeEliminationSubCostList
(
op_stra
,
op_clist
,
edge_clist
,
tar_stra
,
tar_clist_origin
,
&
tar_clist_new
);
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
tar_clist_new
);
Simplify
(
&
tar_clist_new
);
// Set the new costlist w.r.t the strategy
tar_stra_cost
->
cost_list
=
tar_clist_new
;
if
((
!
valid
)
&&
(
!
tar_clist_new
.
empty
()))
{
...
...
@@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
contract_op_cost
->
memory_with_reuse_
+
edge_cost
->
memory_with_reuse_
+
tar_cost
->
memory_with_reuse_
;
double
communication
=
contract_op_cost
->
communication_cost_
+
edge_cost
->
communication_cost_
+
tar_cost
->
communication_cost_
;
double
communication_forward
=
contract_op_cost
->
communication_forward_
+
edge_cost
->
communication_forward_
+
tar_cost
->
communication_forward_
;
double
communication_without_para
=
contract_op_cost
->
communication_without_parameter_
+
edge_cost
->
communication_without_parameter_
+
tar_cost
->
communication_without_parameter_
;
...
...
@@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
new_cost
->
communication_with_partial_para_
=
communication_without_para
+
COST_MODEL_GAMMA
*
(
communication
-
communication_without_para
);
new_cost
->
memory_with_reuse_
=
memory
;
new_cost
->
communication_forward_
=
communication_forward
;
tar_cost_list_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
}
...
...
@@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
CreateContractEliminationSubCostList
(
op_stra
,
op_clist
,
edge_clist
,
tar_stra
,
tar_clist_origin
,
&
tar_clist_new
);
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
tar_clist_new
);
Simplify
(
&
tar_clist_new
);
// Set the new costlist w.r.t the strategy
tar_stra_cost
->
cost_list
=
tar_clist_new
;
if
((
!
valid
)
&&
(
!
tar_clist_new
.
empty
()))
{
...
...
@@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
left_node_cost
->
memory_with_reuse_
+
right_edge_cost
->
memory_with_reuse_
;
double
new_commu_cost
=
elimi_op_cost
->
communication_cost_
+
left_edge_cost
->
communication_cost_
+
left_node_cost
->
communication_cost_
+
right_edge_cost
->
communication_cost_
;
double
new_commu_forward
=
elimi_op_cost
->
communication_forward_
+
left_edge_cost
->
communication_forward_
+
left_node_cost
->
communication_forward_
+
right_edge_cost
->
communication_forward_
;
double
new_commu_without
=
elimi_op_cost
->
communication_without_parameter_
+
left_edge_cost
->
communication_without_parameter_
+
left_node_cost
->
communication_without_parameter_
+
right_edge_cost
->
communication_without_parameter_
;
...
...
@@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
new_cost
->
communication_with_partial_para_
=
new_commu_without
+
COST_MODEL_GAMMA
*
(
new_commu_cost
-
new_commu_without
);
new_cost
->
memory_with_reuse_
=
new_memory
;
new_cost
->
communication_forward_
=
new_commu_forward
;
left_node_clist_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
}
...
...
@@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
&
left_node_clist_new
);
}
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
left_node_clist_new
);
Simplify
(
&
left_node_clist_new
);
// Set the new costlist w.r.t the strategy
left_node_stra_cost
->
cost_list
=
left_node_clist_new
;
if
((
!
valid
)
&&
(
!
left_node_clist_new
.
empty
()))
{
...
...
@@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
double
computation_cost
=
merged_node_cost
->
computation_cost_
,
memory_cost
=
merged_node_cost
->
memory_with_reuse_
,
commu_cost
=
merged_node_cost
->
communication_cost_
,
commu_without
=
merged_node_cost
->
communication_without_parameter_
;
commu_without
=
merged_node_cost
->
communication_without_parameter_
,
commu_forward
=
merged_node_cost
->
communication_forward_
;
for
(
size_t
i
=
0
;
i
<
succ_nodes_stras
.
size
();
++
i
)
{
MS_EXCEPTION_IF_NULL
(
succ_edges_costs
[
i
]);
if
(
i
==
0
)
{
computation_cost
+=
succ_edges_costs
[
i
]
->
computation_cost_
+
succ_nodes_costs
[
i
]
->
computation_cost_
;
memory_cost
+=
succ_edges_costs
[
i
]
->
memory_with_reuse_
+
succ_nodes_costs
[
i
]
->
memory_with_reuse_
;
commu_cost
+=
succ_edges_costs
[
i
]
->
communication_cost_
+
succ_nodes_costs
[
i
]
->
communication_cost_
;
commu_forward
+=
succ_edges_costs
[
i
]
->
communication_forward_
+
succ_nodes_costs
[
i
]
->
communication_forward_
;
commu_without
+=
succ_edges_costs
[
i
]
->
communication_without_parameter_
+
succ_nodes_costs
[
i
]
->
communication_without_parameter_
;
}
else
{
computation_cost
+=
succ_edges_costs
[
i
]
->
computation_cost_
;
memory_cost
+=
succ_edges_costs
[
i
]
->
memory_with_reuse_
;
commu_cost
+=
succ_edges_costs
[
i
]
->
communication_cost_
;
commu_forward
+=
succ_edges_costs
[
i
]
->
communication_forward_
;
commu_without
+=
succ_edges_costs
[
i
]
->
communication_without_parameter_
;
}
}
...
...
@@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
new_cost
->
communication_without_parameter_
=
commu_without
;
new_cost
->
communication_with_partial_para_
=
commu_without
+
COST_MODEL_GAMMA
*
(
commu_cost
-
commu_without
);
new_cost
->
memory_with_reuse_
=
memory_cost
;
new_cost
->
communication_forward_
=
commu_forward
;
first_succ_node_clist_new
->
emplace_back
(
std
::
move
(
new_cost
));
}
}
...
...
@@ -1220,7 +1318,7 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
CreateStarEliminationCostList
(
succ_edges
,
first_succ_node_stra
,
first_succ_node_clist
,
first_succ_edge_clist
,
merged_op_stra
,
merged_op_clist
,
&
first_succ_node_clist_new
);
}
Simplify
ForDreasingCommunicationWithPartialPara
(
&
first_succ_node_clist_new
);
Simplify
(
&
first_succ_node_clist_new
);
// Set the new costlist w.r.t the strategy
first_succ_node_stra_cost
->
cost_list
=
first_succ_node_clist_new
;
if
((
!
valid
)
&&
(
!
first_succ_node_clist_new
.
empty
()))
{
...
...
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
浏览文件 @
def85732
...
...
@@ -45,6 +45,9 @@ namespace parallel {
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false
#define DEFAULT_RUN_PHASE 0
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1
class
CostGraph
;
using
CostGraphPtr
=
std
::
shared_ptr
<
CostGraph
>
;
...
...
@@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
extern
size_t
TENSOR_SLICE_ALIGNMENT_SIZE
;
extern
bool
FULLY_USE_DEVICES
;
extern
bool
ELEMENTWISE_OP_STRA_FOLLOW
;
extern
bool
MULTI_SUBGRAPHS
;
extern
int32_t
RUN_PHASE
;
class
CostGraph
{
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
...
...
@@ -98,7 +103,7 @@ class CostGraph {
CostPtrList
CreateFinalCostList
(
const
OperatorInfoPtr
&
u
,
const
EdgePtr
&
e
,
const
OperatorInfoPtr
&
v
);
CostPtrList
CreateFinalSingleCostList
(
const
OperatorInfoPtr
&
u
);
CostPtr
SelectCostWithM
emoryConstraint
(
const
CostPtrList
&
cost_list
,
double
memory
);
CostPtr
SelectCostWithM
inInferenceTime
(
const
CostPtrList
&
cost_list
,
double
memory
);
CostPtr
SelectCostWithMinTrainingTime
(
const
CostPtrList
&
cost_list
,
double
memory
);
CostPtrList
SelectCostListWithMinTrainingTimeMultiple
(
const
std
::
vector
<
CostPtrList
>
&
all_costlist
,
double
memory
);
Status
SearchStrategyForMultiNodeFinalGraph
(
const
std
::
vector
<
OperatorInfoPtr
>
&
);
...
...
mindspore/ccsrc/parallel/costmodel_context.cc
浏览文件 @
def85732
...
...
@@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_const_
=
DEFAULT_COST_MODEL_COMMUNI_CONST
;
costmodel_communi_bias_
=
DEFAULT_COST_MODEL_COMMUNI_BIAS
;
is_multi_subgraphs_
=
DEFAULT_IS_MULTI_SUBGRAPHS
;
run_phase_
=
DEFAULT_RUN_PHASE
;
costmodel_allreduce_fusion_algorithm_
=
DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM
;
costmodel_allreduce_fusion_times_
=
DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES
;
costmodel_allreduce_fusion_tail_percent_
=
DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT
;
...
...
@@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_
void
CostModelContext
::
set_elementwise_stra_follow
(
bool
elementwise_follow
)
{
elementwise_stra_follow_
=
elementwise_follow
;
}
void
CostModelContext
::
set_run_phase
(
int32_t
phase
)
{
run_phase_
=
phase
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/costmodel_context.h
浏览文件 @
def85732
...
...
@@ -113,6 +113,9 @@ class CostModelContext {
void
set_elementwise_stra_follow
(
bool
);
bool
elementwise_stra_follow
()
const
{
return
elementwise_stra_follow_
;
}
void
set_run_phase
(
int32_t
);
int32_t
run_phase
()
const
{
return
run_phase_
;
}
private:
CostModelContext
();
static
std
::
shared_ptr
<
CostModelContext
>
cm_context_inst_
;
...
...
@@ -141,8 +144,11 @@ class CostModelContext {
// COST_MODEL_COMMUNI_BIAS
double
costmodel_communi_bias_
;
// MULTI_SUBGRAPHS
bool
is_multi_subgraphs_
;
int32_t
run_phase_
;
// 0: 'training', 1: 'inference'
int32_t
costmodel_allreduce_fusion_algorithm_
;
int32_t
costmodel_allreduce_fusion_times_
;
...
...
mindspore/ccsrc/parallel/ops_info/matmul_info.cc
浏览文件 @
def85732
...
...
@@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
<<
", communication_with_partial_para_: "
<<
result
->
communication_with_partial_para_
;
// refine communication cost calculation for practice
RefineForPracticalCost
(
result
,
false
);
result
->
communication_forward_
=
result
->
communication_without_parameter_
;
std
::
shared_ptr
<
StrategyWithCost
>
swc
=
std
::
make_shared
<
StrategyWithCost
>
(
strategy
,
inputs_tensor_info_
,
outputs_tensor_info_
);
...
...
mindspore/ccsrc/parallel/ops_info/operator_info.cc
浏览文件 @
def85732
...
...
@@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
BreakingTiesForPerferringDataParallel
(
strategy
,
result
);
// refine communication cost calculation for practice
RefineForPracticalCost
(
result
,
false
);
result
->
communication_forward_
=
result
->
communication_without_parameter_
;
std
::
shared_ptr
<
StrategyWithCost
>
swc
=
std
::
make_shared
<
StrategyWithCost
>
(
strategy
,
inputs_tensor_info_
,
outputs_tensor_info_
);
...
...
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
浏览文件 @
def85732
...
...
@@ -69,16 +69,16 @@ class TensorRedistribution {
RankList
dev_list_
;
OperatorList
operator_list_
;
bool
reshape_flag_
;
// communication cost
// communication cost
, which is the sum of forward communication cost and backward communication cost
double
comm_cost_
;
// forward communication cost
double
forward_comm_cost_
;
// backward communication cost
double
backward_comm_cost_
;
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
// inputs.
// inputs.
This is calculated ONLY for forward phase.
double
computation_cost_
;
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
// memory_cost models the PEAK memory cost in a tra
i
ning iteration contributed by this tensor redistribution, which is
// calculated by the outputs.
double
memory_cost_
;
bool
construct_op_flag_
;
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
def85732
...
...
@@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get the parameter cost_model_communi_bias of the DP algorithm."
)
.
def
(
"set_multi_subgraphs"
,
&
CostModelContext
::
set_multi_subgraphs
,
"Set the parameter is_multi_subgraphs."
)
.
def
(
"get_multi_subgraphs"
,
&
CostModelContext
::
is_multi_subgraphs
,
"Get the parameter is_multi_subgraphs."
)
.
def
(
"set_run_phase"
,
&
CostModelContext
::
set_run_phase
,
"Set the flag run_phase."
)
.
def
(
"get_run_phase"
,
&
CostModelContext
::
run_phase
,
"Get the flag run_phase."
)
.
def
(
"set_costmodel_allreduce_fusion_algorithm"
,
&
CostModelContext
::
set_costmodel_allreduce_fusion_algorithm
,
"Set the parameter gradient AllReduce fusion algorithm."
)
.
def
(
"get_costmodel_allreduce_fusion_algorithm"
,
&
CostModelContext
::
costmodel_allreduce_fusion_algorithm
,
...
...
mindspore/parallel/_cost_model_context.py
浏览文件 @
def85732
...
...
@@ -239,6 +239,33 @@ class _CostModelContext:
raise
ValueError
(
"Context handle is none in context!!!"
)
return
self
.
_context_handle
.
get_multi_subgraphs
()
def
set_run_phase
(
self
,
phase
):
"""
Set the flag of running phase: training (0) or inference (1)
Args:
phase (int): A parameter indicating which phase is running.
Raises:
ValueError: If context handle is none, or phase is not in {0, 1}.
"""
if
self
.
_context_handle
is
None
:
raise
ValueError
(
"Context handle is none in context!!!"
)
if
phase
not
in
(
0
,
1
):
raise
ValueError
(
"The argument of set_run_phase() must be '0' or '1', but got {}"
.
format
(
phase
))
self
.
_context_handle
.
set_run_phase
(
phase
)
def
get_run_phase
(
self
):
"""
Get the flag of running phase.
Raises:
ValueError: If context handle is none.
"""
if
self
.
_context_handle
is
None
:
raise
ValueError
(
"Context handle is none in context!!!"
)
return
self
.
_context_handle
.
get_run_phase
()
def
set_costmodel_allreduce_fusion_algorithm
(
self
,
algorithm
):
"""
Set costmodel allreduce fusion algorithm.
...
...
@@ -453,6 +480,7 @@ set_cost_model_context_func_map = {
"costmodel_communi_const"
:
cost_model_context
().
set_costmodel_communi_const
,
"costmodel_communi_bias"
:
cost_model_context
().
set_costmodel_communi_bias
,
"multi_subgraphs"
:
cost_model_context
().
set_multi_subgraphs
,
"run_phase"
:
cost_model_context
().
set_run_phase
,
"costmodel_allreduce_fusion_algorithm"
:
cost_model_context
().
set_costmodel_allreduce_fusion_algorithm
,
"costmodel_allreduce_fusion_times"
:
cost_model_context
().
set_costmodel_allreduce_fusion_times
,
"costmodel_allreduce_fusion_tail_percent"
:
cost_model_context
().
set_costmodel_allreduce_fusion_tail_percent
,
...
...
@@ -473,7 +501,8 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold"
:
cost_model_context
().
get_costmodel_communi_threshold
,
"costmodel_communi_const"
:
cost_model_context
().
get_costmodel_communi_const
,
"costmodel_communi_bias"
:
cost_model_context
().
get_costmodel_communi_bias
,
"multi_subgraphs"
:
cost_model_context
().
get_multi_subgraphs
(),
"multi_subgraphs"
:
cost_model_context
().
get_multi_subgraphs
,
"run_phase"
:
cost_model_context
().
get_run_phase
,
"costmodel_allreduce_fusion_algorithm"
:
cost_model_context
().
get_costmodel_allreduce_fusion_algorithm
,
"costmodel_allreduce_fusion_times"
:
cost_model_context
().
get_costmodel_allreduce_fusion_times
,
"costmodel_allreduce_fusion_tail_percent"
:
cost_model_context
().
get_costmodel_allreduce_fusion_tail_percent
,
...
...
@@ -488,7 +517,7 @@ get_cost_model_context_func_map = {
@
args_type_check
(
device_memory_capacity
=
float
,
costmodel_alpha
=
float
,
costmodel_beta
=
float
,
costmodel_gamma
=
float
,
costmodel_communi_threshold
=
float
,
costmodel_communi_const
=
float
,
costmodel_communi_bias
=
float
,
multi_subgraphs
=
bool
,
multi_subgraphs
=
bool
,
run_phase
=
int
,
costmodel_allreduce_fusion_algorithm
=
int
,
costmodel_allreduce_fusion_times
=
int
,
costmodel_allreduce_fusion_tail_percent
=
float
,
costmodel_allreduce_fusion_tail_time
=
float
,
costmodel_allreduce_fusion_allreduce_inherent_time
=
float
,
...
...
@@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs):
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
1: only use backward computation time to group allreduce;
...
...
tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc
浏览文件 @
def85732
...
...
@@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
ASSERT_EQ
(
edge_m1_m2
->
InitEdgeCost
(),
SUCCESS
);
cost_graph
.
AddEdge
(
matmul1
,
matmul2
,
edge_m1_m2
);
auto
cost_list
=
cost_graph
.
CreateFinalCostList
(
matmul1
,
edge_m1_m2
,
matmul2
);
cost_graph
.
SelectCostWithM
emoryConstraint
(
cost_list
,
cost_graph
.
GetDeviceMemory
());
cost_graph
.
SelectCostWithM
inInferenceTime
(
cost_list
,
cost_graph
.
GetDeviceMemory
());
}
TEST_F
(
TestCostGraph
,
test_EliminationOp
)
{
...
...
tests/ut/python/parallel/__init__.py
浏览文件 @
def85732
...
...
@@ -14,15 +14,21 @@
import
mindspore.context
as
context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.parallel._cost_model_context
import
reset_cost_model_context
from
mindspore.parallel.algo_parameter_config
import
reset_algo_parameters
from
mindspore.parallel._utils
import
_reset_op_id
def
setup_module
(
module
):
auto_parallel_context
().
set_enable_all_reduce_fusion
(
enable_all_reduce_fusion
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
False
)
reset_cost_model_context
()
reset_algo_parameters
()
_reset_op_id
()
def
teardown_module
():
context
.
reset_auto_parallel_context
()
reset_cost_model_context
()
reset_algo_parameters
()
_reset_op_id
()
tests/ut/python/parallel/test_auto_parallel_inference.py
0 → 100644
浏览文件 @
def85732
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
context
from
mindspore.ops
import
operations
as
P
from
mindspore.nn
import
WithLossCell
,
TrainOneStepCell
from
mindspore.nn
import
Momentum
from
mindspore.parallel._cost_model_context
import
set_cost_model_context
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
input_ch
,
out_ch
):
super
(
Net
,
self
).
__init__
()
self
.
dense
=
nn
.
Dense
(
input_ch
,
out_ch
)
self
.
relu
=
P
.
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
dense
(
x
)
x
=
self
.
relu
(
x
)
return
x
def
test_inference_phase
():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
set_cost_model_context
(
run_phase
=
1
)
net
=
Net
(
512
,
128
)
predict
=
Tensor
(
np
.
ones
([
64
,
512
]).
astype
(
np
.
float32
)
*
0.001
)
label
=
Tensor
(
np
.
ones
([
64
,
128
]).
astype
(
np
.
float32
))
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
Momentum
(
params
=
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
train_network
.
set_train
()
output
=
train_network
(
predict
,
label
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录