Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
21d936e6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
21d936e6
编写于
4月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!728 auto parallel strategy checkpoint full
Merge pull request !728 from yao_yf/strategy_checkpoint_extend
上级
c553a70a
6cde5f6d
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
303 addition
and
248 deletion
+303
-248
mindspore/ccsrc/ir/primitive.h
mindspore/ccsrc/ir/primitive.h
+5
-1
mindspore/ccsrc/parallel/context.cc
mindspore/ccsrc/parallel/context.cc
+10
-0
mindspore/ccsrc/parallel/context.h
mindspore/ccsrc/parallel/context.h
+7
-0
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+32
-8
mindspore/ccsrc/parallel/step_parallel.cc
mindspore/ccsrc/parallel/step_parallel.cc
+36
-49
mindspore/ccsrc/parallel/step_parallel.h
mindspore/ccsrc/parallel/step_parallel.h
+1
-1
mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc
...allel/strategy_checkpoint/parallel_strategy_checkpoint.cc
+15
-13
mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h
...rallel/strategy_checkpoint/parallel_strategy_checkpoint.h
+16
-22
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+6
-0
mindspore/ccsrc/utils/node_strategy.proto
mindspore/ccsrc/utils/node_strategy.proto
+1
-1
mindspore/context.py
mindspore/context.py
+7
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+2
-0
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+42
-3
tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc
..._strategy_checkpoint/parallel_strategy_checkpoint_stub.cc
+1
-3
tests/ut/python/parallel/test_strategy_checkpoint.py
tests/ut/python/parallel/test_strategy_checkpoint.py
+122
-146
未找到文件。
mindspore/ccsrc/ir/primitive.h
浏览文件 @
21d936e6
...
...
@@ -52,7 +52,11 @@ class Primitive : public Named {
:
Named
(
name
),
signatures_
(),
prim_type_
(
prim_type
)
{}
Primitive
(
const
Primitive
&
prim
)
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
signatures_
(
prim
.
signatures_
),
prim_type_
(
prim
.
prim_type_
)
{}
:
Named
(
prim
),
attrs_
(
prim
.
attrs_
),
signatures_
(
prim
.
signatures_
),
instance_name_
(
prim
.
instance_name_
),
prim_type_
(
prim
.
prim_type_
)
{}
MS_DECLARE_PARENT
(
Primitive
,
Named
);
...
...
mindspore/ccsrc/parallel/context.cc
浏览文件 @
21d936e6
...
...
@@ -56,6 +56,8 @@ void ParallelContext::Reset() {
parameter_broadcast_
=
false
;
parameter_broadcast_is_set_
=
false
;
enable_all_reduce_fusion_
=
false
;
strategy_ckpt_load_file_
=
""
;
strategy_ckpt_save_file_
=
""
;
}
void
ParallelContext
::
set_device_num
(
int32_t
device_num
)
{
...
...
@@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) {
parameter_broadcast_is_set_
=
true
;
}
void
ParallelContext
::
set_strategy_ckpt_load_file
(
const
std
::
string
&
strategy_ckpt_load_file
)
{
strategy_ckpt_load_file_
=
strategy_ckpt_load_file
;
}
void
ParallelContext
::
set_strategy_ckpt_save_file
(
const
std
::
string
&
strategy_ckpt_save_file
)
{
strategy_ckpt_save_file_
=
strategy_ckpt_save_file
;
}
void
ParallelContext
::
set_all_reduce_fusion_split_indices
(
const
std
::
vector
<
uint32_t
>
indices
)
{
all_reduce_fusion_split_indices_
=
indices
;
}
...
...
mindspore/ccsrc/parallel/context.h
浏览文件 @
21d936e6
...
...
@@ -85,6 +85,11 @@ class ParallelContext {
}
bool
enable_all_reduce_fusion
()
const
{
return
enable_all_reduce_fusion_
;
}
void
set_strategy_ckpt_load_file
(
const
std
::
string
&
strategy_ckpt_load_file
);
std
::
string
strategy_ckpt_load_file
()
const
{
return
strategy_ckpt_load_file_
;
}
void
set_strategy_ckpt_save_file
(
const
std
::
string
&
strategy_ckpt_save_file
);
std
::
string
strategy_ckpt_save_file
()
const
{
return
strategy_ckpt_save_file_
;
}
void
Reset
();
private:
...
...
@@ -105,6 +110,8 @@ class ParallelContext {
bool
enable_all_reduce_fusion_
;
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_indices_
;
std
::
vector
<
uint32_t
>
all_reduce_fusion_split_sizes_
;
std
::
string
strategy_ckpt_load_file_
;
std
::
string
strategy_ckpt_save_file_
;
};
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
21d936e6
...
...
@@ -40,6 +40,7 @@
#include "parallel/context.h"
#include "parallel/ops_info/tmp_identity_info.h"
#include "parallel/step_parallel.h"
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/pipeline.h"
...
...
@@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
return
IsParallelCareNode
(
cnode
)
&&
IsSplittableOperator
(
prim
->
name
());
}
OperatorInfoPtr
CreateTheOperatorInfo
(
const
PrimitivePtr
&
prim
,
const
CNodePtr
&
cnode
)
{
OperatorInfoPtr
CreateTheOperatorInfo
(
const
PrimitivePtr
&
prim
,
const
CNodePtr
&
cnode
,
StrategyMap
*
stra_map
)
{
MS_EXCEPTION_IF_NULL
(
prim
);
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
attrs
=
prim
->
attrs
();
...
...
@@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info
->
set_input_value
(
input_value
);
operator_info
->
set_outputs_dtype
(
cnode
->
Type
());
operator_info
->
set_cnode
(
cnode
);
// key of strategy map
std
::
string
instance_name
=
prim
->
instance_name
();
std
::
string
strategy_key_name
=
cnode
->
scope
()
->
name
()
+
std
::
string
(
CONNSYMBOL
)
+
instance_name
;
bool
load_strategy_from_ckpt
=
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
->
find
(
strategy_key_name
)
!=
stra_map
->
end
();
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy
if
(
!
StrategyFound
(
attrs
)
||
prim
->
name
()
==
CAST
)
{
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if
((
!
StrategyFound
(
attrs
)
||
prim
->
name
()
==
CAST
)
&&
!
load_strategy_from_ckpt
)
{
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info
->
ComputeBatchSplitFlagList
();
...
...
@@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
}
}
else
{
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr
strategyPtr
=
parallel
::
ExtractStrategy
(
attrs
);
StrategyPtr
strategyPtr
;
if
(
load_strategy_from_ckpt
)
{
strategyPtr
=
(
*
stra_map
)[
strategy_key_name
];
}
else
{
strategyPtr
=
parallel
::
ExtractStrategy
(
attrs
);
}
if
(
strategyPtr
!=
nullptr
)
{
if
(
prim
->
name
()
==
RESHAPE
)
{
MS_LOG
(
EXCEPTION
)
<<
"Setting strategy for Reshape goes for nothing!"
;
...
...
@@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
entire_costgraph
->
SetDeviceMemoryAndCostParameter
();
// The map from CNode's UniqueId to its operatorInfo
std
::
map
<
std
::
string
,
OperatorInfoPtr
>
from_cnode_to_info
;
// extract strategy from checkpoint for multi-train
StrategyMap
stra_map
;
if
(
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
())
{
if
(
StrategyCheckpoint
::
GetInstance
().
Load
(
&
stra_map
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Load strategy checkpoint failed"
;
}
}
// Step 1
for
(
auto
&
node
:
all_nodes
)
{
// NOTE: we only care about splittable Primitive operators
...
...
@@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
auto
search_cnode
=
from_cnode_to_info
.
find
(
cnode
->
UniqueId
());
if
(
search_cnode
==
from_cnode_to_info
.
end
())
{
auto
operator_info
=
CreateTheOperatorInfo
(
prim
,
cnode
);
auto
operator_info
=
CreateTheOperatorInfo
(
prim
,
cnode
,
&
stra_map
);
if
(
operator_info
==
nullptr
)
{
return
FAILED
;
}
...
...
@@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
entire_costgraph
->
SetDeviceMemoryAndCostParameter
();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std
::
map
<
std
::
string
,
OperatorInfoPtr
>
from_cnode_to_info
;
// extract strategy from checkpoint for multi-train
StrategyMap
stra_map
;
if
(
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
())
{
if
(
StrategyCheckpoint
::
GetInstance
().
Load
(
&
stra_map
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Load strategy checkpoint failed"
;
}
}
for
(
auto
&
node
:
all_nodes
)
{
// NOTE: we only care about splittable Primitive operators
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
...
...
@@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
auto
search_cnode
=
from_cnode_to_info
.
find
(
cnode
->
UniqueIdThroughCopy
());
if
(
search_cnode
==
from_cnode_to_info
.
end
())
{
// In this case, the corresponding OperatorInfo is not created, create the new one.
auto
operator_info
=
CreateTheOperatorInfo
(
prim
,
cnode
);
auto
operator_info
=
CreateTheOperatorInfo
(
prim
,
cnode
,
&
stra_map
);
if
(
operator_info
==
nullptr
)
{
return
FAILED
;
}
...
...
mindspore/ccsrc/parallel/step_parallel.cc
浏览文件 @
21d936e6
...
...
@@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
}
void
ExtractInformation
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
// load strategy map from checkpoint
StrategyMap
stra_map
;
if
(
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
())
{
if
(
StrategyCheckpoint
::
GetInstance
().
Load
(
&
stra_map
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Load strategy checkpoint failed"
;
}
}
for
(
auto
&
node
:
all_nodes
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
...
...
@@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(
void
)
cnode
->
set_operator_info
(
operator_
);
continue
;
}
if
(
!
StrategyFound
(
attrs
))
{
// load strategy checkpoint
// key of strategy map
std
::
string
instance_name
=
prim
->
instance_name
();
std
::
string
strategy_key_name
=
cnode
->
scope
()
->
name
()
+
std
::
string
(
CONNSYMBOL
)
+
instance_name
;
bool
load_strategy_from_ckpt
=
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
.
find
(
strategy_key_name
)
!=
stra_map
.
end
();
if
(
!
StrategyFound
(
attrs
)
&&
!
load_strategy_from_ckpt
)
{
MS_LOG
(
INFO
)
<<
"ExtractInformation: the strategy of node "
<<
node
->
ToString
()
<<
" prim "
<<
prim
->
name
()
<<
" is empty, using batch parallel"
;
std
::
shared_ptr
<
std
::
vector
<
Dimensions
>>
strategy_v_ptr
=
operator_
->
GenerateBatchStrategies
();
...
...
@@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG
(
INFO
)
<<
"node "
<<
node
->
ToString
()
<<
" prim "
<<
prim
->
name
()
<<
" batch parallel strategy is "
<<
attrs
[
GEN_STRATEGY
]
->
ToString
();
strategyPtr
=
NewStrategy
(
0
,
*
strategy_v_ptr
);
}
else
if
(
load_strategy_from_ckpt
)
{
strategyPtr
=
stra_map
[
strategy_key_name
];
}
else
{
strategyPtr
=
ExtractStrategy
(
attrs
);
}
...
...
@@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
}
}
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_LOG
(
INFO
)
<<
"Save strategy to checkpoint begin"
;
StrategyMap
straMap
;
auto
ret
=
func_graph
->
get_return
();
auto
all_nodes
=
DeepScopedGraphSearch
(
ret
);
for
(
auto
&
node
:
all_nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
}
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
MS_EXCEPTION_IF_NULL
(
prim
);
OperatorInfoPtr
operator_info
=
cnode
->
operator_info
();
if
(
operator_info
)
{
if
(
prim
->
instance_name
().
empty
())
{
continue
;
bool
NodeWithParameter
(
const
CNodePtr
&
node
)
{
std
::
vector
<
AnfNodePtr
>
node_inputs
{
node
->
inputs
()};
for
(
auto
input
:
node_inputs
)
{
if
(
input
->
isa
<
Parameter
>
())
{
auto
input_parameter
=
input
->
cast
<
ParameterPtr
>
();
if
(
input_parameter
->
has_default
())
{
return
py
::
cast
<
bool
>
(
parse
::
python_adapter
::
GetPyObjAttr
(
input_parameter
->
default_param
(),
"requires_grad"
));
}
std
::
string
instance_name
=
prim
->
instance_name
();
StrategyPtr
strategyPtr
=
operator_info
->
strategy
();
MS_EXCEPTION_IF_NULL
(
node
->
scope
());
std
::
string
node_name
=
node
->
scope
()
->
name
()
+
std
::
string
(
CONNSYMBOL
)
+
instance_name
;
straMap
[
node_name
]
=
strategyPtr
;
}
}
if
(
StrategyCheckpoint
::
GetInstance
().
Save
(
straMap
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Save strategy checkpoint failed"
;
}
return
false
;
}
void
Restore
Strategy
(
const
FuncGraphPtr
&
func_graph
)
{
void
Checkpoint
Strategy
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_LOG
(
INFO
)
<<
"Extract strategy from checkpoint begin"
;
StrategyMap
straMap
;
if
(
StrategyCheckpoint
::
GetInstance
().
Load
(
&
straMap
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Load strategy checkpoint failed"
;
}
if
(
StrategyCheckpoint
::
GetInstance
().
RemoveCheckPoint
()
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Remove strategy checkpoint failed"
;
}
MS_LOG
(
DEBUG
)
<<
"Save strategy to checkpoint begin"
;
StrategyMap
stra_map
;
auto
ret
=
func_graph
->
get_return
();
auto
all_nodes
=
DeepScopedGraphSearch
(
ret
);
for
(
auto
&
node
:
all_nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
))
||
!
NodeWithParameter
(
cnode
)
)
{
continue
;
}
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
...
...
@@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) {
OperatorInfoPtr
operator_info
=
cnode
->
operator_info
();
if
(
operator_info
)
{
if
(
prim
->
instance_name
().
empty
())
{
continue
;
MS_LOG
(
EXCEPTION
)
<<
"Node with parameter to checkpoint strategy needs instance name"
;
}
std
::
string
instance_name
=
prim
->
instance_name
();
StrategyPtr
strategyPtr
=
operator_info
->
strategy
();
MS_EXCEPTION_IF_NULL
(
node
->
scope
());
std
::
string
node_name
=
node
->
scope
()
->
name
()
+
std
::
string
(
CONNSYMBOL
)
+
instance_name
;
MS_LOG
(
INFO
)
<<
"Node name is "
<<
node_name
;
if
(
straMap
.
find
(
node_name
)
!=
straMap
.
end
())
{
StrategyPtr
strategyPtr
=
straMap
[
node_name
];
operator_info
->
set_strategy
(
strategyPtr
);
}
stra_map
[
node_name
]
=
strategyPtr
;
}
}
if
(
StrategyCheckpoint
::
GetInstance
().
Save
(
stra_map
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Save strategy checkpoint failed"
;
}
}
void
SetForwardFlag
(
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
)
{
...
...
@@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// extract shape and strategy, set operator_info
ExtractInformation
(
all_nodes
);
ReshapeInit
(
all_nodes
);
// extract strategy from checkpoint for multi-train
if
(
StrategyCheckpoint
::
GetInstance
().
CheckPointOn
()
&&
StrategyCheckpoint
::
GetInstance
().
CheckPointExit
())
{
RestoreStrategy
(
root
);
}
}
// save strategy as checkpoint for multi-train
if
(
StrategyCheckpoint
::
GetInstance
().
CheckPointOn
()
&&
StrategyCheckpoint
::
GetInstance
().
GetCurrentTrainTime
()
<
StrategyCheckpoint
::
GetInstance
().
GetTrainTimes
())
{
if
(
StrategyCheckpoint
::
GetInstance
().
SaveCheckPointOn
())
{
CheckpointStrategy
(
root
);
}
...
...
mindspore/ccsrc/parallel/step_parallel.h
浏览文件 @
21d936e6
...
...
@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void
ParallelCommunication
(
const
FuncGraphPtr
&
root
,
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphManagerPtr
&
manager
);
void
RestoreStrategy
(
const
FuncGraphPtr
&
func_graph
);
bool
NodeWithParameter
(
const
CNodePtr
&
node
);
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
);
...
...
mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc
浏览文件 @
21d936e6
...
...
@@ -29,30 +29,32 @@ namespace mindspore {
namespace
parallel
{
StrategyCheckpoint
&
StrategyCheckpoint
::
GetInstance
()
{
static
StrategyCheckpoint
instance
=
StrategyCheckpoint
();
if
(
ParallelContext
::
GetInstance
()
!=
nullptr
)
{
instance
.
load_file_
=
ParallelContext
::
GetInstance
()
->
strategy_ckpt_load_file
();
instance
.
load_checkpoint_on_
=
!
ParallelContext
::
GetInstance
()
->
strategy_ckpt_load_file
().
empty
();
instance
.
save_file_
=
ParallelContext
::
GetInstance
()
->
strategy_ckpt_save_file
();
instance
.
save_checkpoint_on_
=
!
ParallelContext
::
GetInstance
()
->
strategy_ckpt_save_file
().
empty
();
}
return
instance
;
}
bool
StrategyCheckpoint
::
CheckPointExit
()
const
{
std
::
ifstream
fin
(
path
_
);
bool
StrategyCheckpoint
::
CheckPointExit
(
const
std
::
string
path
)
const
{
std
::
ifstream
fin
(
path
);
if
(
fin
)
{
return
true
;
}
return
false
;
}
Status
StrategyCheckpoint
::
RemoveCheckPoint
()
const
{
if
(
std
::
remove
(
common
::
SafeCStr
(
path_
))
==
0
)
{
return
SUCCESS
;
}
return
FAILED
;
}
Status
StrategyCheckpoint
::
Load
(
StrategyMap
*
strategy_map
)
{
if
(
strategy_map
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failure:strategy_map is nullptr"
;
}
if
(
!
CheckPointExit
(
load_file_
))
{
MS_LOG
(
EXCEPTION
)
<<
"CheckPoint file is not found"
;
}
straspb
::
ParallelStrategyMap
parallel_strategy_map
;
std
::
fstream
input
(
path
_
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
fstream
input
(
load_file
_
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
parallel_strategy_map
.
ParseFromIstream
(
&
input
))
{
MS_LOG
(
ERROR
)
<<
"Load strategy file failed"
;
return
FAILED
;
...
...
@@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
StrategyPtr
strategy
=
NewStrategy
(
stage
,
strategy_inputs
);
(
*
strategy_map
)[
node_name
]
=
strategy
;
current_
train_time_
=
(
int32_t
)
parallel_strategy_map
.
train_tim
e
();
current_
stage_
=
(
int32_t
)
parallel_strategy_map
.
current_stag
e
();
}
return
SUCCESS
;
}
Status
StrategyCheckpoint
::
Save
(
const
StrategyMap
&
strategy_map
)
{
straspb
::
ParallelStrategyMap
parallel_strategy_map
;
parallel_strategy_map
.
set_
train_time
(
IntToUint
(
++
current_train_tim
e_
));
parallel_strategy_map
.
set_
current_stage
(
IntToUint
(
++
current_stag
e_
));
for
(
auto
&
node_stra
:
strategy_map
)
{
straspb
::
ParallelStrategyItem
*
parallel_strategy_item
=
parallel_strategy_map
.
add_parallel_strategy_item
();
MS_EXCEPTION_IF_NULL
(
parallel_strategy_item
);
...
...
@@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
}
}
}
std
::
fstream
output
(
path
_
,
std
::
ios
::
out
|
std
::
ios
::
trunc
|
std
::
ios
::
binary
);
std
::
fstream
output
(
save_file
_
,
std
::
ios
::
out
|
std
::
ios
::
trunc
|
std
::
ios
::
binary
);
if
(
!
parallel_strategy_map
.
SerializeToOstream
(
&
output
))
{
MS_LOG
(
ERROR
)
<<
"Save strategy file failed"
;
return
FAILED
;
...
...
mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h
浏览文件 @
21d936e6
...
...
@@ -21,43 +21,37 @@
#include <unordered_map>
#include "parallel/ops_info/ops_utils.h"
#include "parallel/strategy.h"
#include "parallel/context.h"
namespace
mindspore
{
namespace
parallel
{
constexpr
char
DEFAULT_CHECKPOINT_PATH
[]
=
"./strategys.ckpt"
;
using
StrategyMap
=
std
::
unordered_map
<
std
::
string
,
StrategyPtr
>
;
class
StrategyCheckpoint
{
public:
StrategyCheckpoint
()
:
path_
(
DEFAULT_CHECKPOINT_PATH
),
current_train_time_
(
1
)
{
train_times_
=
1
;
checkpoint_on_
=
false
;
const
char
*
train_times_str
=
std
::
getenv
(
"PARALLEL_TRAIN_TIMES"
);
if
(
train_times_str
!=
nullptr
&&
std
::
stoi
(
train_times_str
)
>
0
)
{
train_times_
=
std
::
stoi
(
train_times_str
);
}
const
char
*
checkpoint_on_str
=
std
::
getenv
(
"PARALLEL_CHECKPOINT_ON"
);
if
(
checkpoint_on_str
!=
nullptr
)
{
checkpoint_on_
=
(
std
::
string
(
checkpoint_on_str
)
==
"on"
);
}
StrategyCheckpoint
()
{
current_stage_
=
0
;
load_file_
=
""
;
load_checkpoint_on_
=
false
;
save_file_
=
""
;
save_checkpoint_on_
=
false
;
}
~
StrategyCheckpoint
()
=
default
;
bool
CheckPointExit
()
const
;
Status
RemoveCheckPoint
()
const
;
Status
Load
(
StrategyMap
*
strategy_map
);
Status
Save
(
const
StrategyMap
&
strategy_map
);
static
StrategyCheckpoint
&
GetInstance
();
int32_t
GetTrainTimes
()
const
{
return
train_times_
;
}
int32_t
GetCurrentTrainTime
()
const
{
return
current_train_time_
;
}
bool
CheckPointOn
()
const
{
return
checkpoint_on_
;
}
bool
LoadCheckPointOn
()
const
{
return
load_checkpoint_on_
;
}
bool
SaveCheckPointOn
()
const
{
return
save_checkpoint_on_
;
}
private:
std
::
string
path_
;
bool
checkpoint_on_
;
// total train times for a train, get from Environmental variable:TRAIN_TIME, please export it
int32_t
train_times_
;
int32_t
current_train_time_
;
std
::
string
load_file_
;
std
::
string
save_file_
;
bool
load_checkpoint_on_
;
bool
save_checkpoint_on_
;
bool
CheckPointExit
(
const
std
::
string
path
)
const
;
int32_t
current_stage_
;
};
}
// namespace parallel
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
21d936e6
...
...
@@ -189,6 +189,12 @@ PYBIND11_MODULE(_c_expression, m) {
.
def
(
"get_parameter_broadcast_is_set"
,
&
ParallelContext
::
parameter_broadcast_is_set
,
"Get parameter broadcast is set."
)
.
def
(
"set_parameter_broadcast"
,
&
ParallelContext
::
set_parameter_broadcast
,
"Set parameter broadcast."
)
.
def
(
"set_strategy_ckpt_load_file"
,
&
ParallelContext
::
set_strategy_ckpt_load_file
,
"Set strategy checkpoint load file."
)
.
def
(
"set_strategy_ckpt_save_file"
,
&
ParallelContext
::
set_strategy_ckpt_save_file
,
"Set strategy checkpoint save file."
)
.
def
(
"get_strategy_ckpt_load_file"
,
&
ParallelContext
::
strategy_ckpt_load_file
,
"Get strategy checkpoint load file."
)
.
def
(
"get_strategy_ckpt_save_file"
,
&
ParallelContext
::
strategy_ckpt_save_file
,
"Get strategy checkpoint save file."
)
.
def
(
"reset"
,
&
ParallelContext
::
Reset
,
"Reset auto parallel context."
);
(
void
)
py
::
class_
<
CostModelContext
,
std
::
shared_ptr
<
CostModelContext
>>
(
m
,
"CostModelContext"
)
...
...
mindspore/ccsrc/utils/node_strategy.proto
浏览文件 @
21d936e6
...
...
@@ -33,6 +33,6 @@ message ParallelStrategyItem {
}
message
ParallelStrategyMap
{
required
uint32
train_tim
e
=
1
;
required
uint32
current_stag
e
=
1
;
repeated
ParallelStrategyItem
parallel_strategy_item
=
2
;
}
\ No newline at end of file
mindspore/context.py
浏览文件 @
21d936e6
...
...
@@ -396,7 +396,7 @@ def _context():
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
parallel_mode
=
str
,
parameter_broadcast
=
bool
)
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
)
def
set_auto_parallel_context
(
**
kwargs
):
"""
Set auto parallel context.
...
...
@@ -428,6 +428,8 @@ def set_auto_parallel_context(**kwargs):
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
Raises:
ValueError: If input key is not attribute in auto parallel context.
...
...
@@ -439,6 +441,8 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(cast_before_mirror=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
"""
_set_auto_parallel_context
(
**
kwargs
)
...
...
@@ -469,6 +473,8 @@ def reset_auto_parallel_context():
- cast_before_mirror: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: "".
- strategy_ckpt_save_file: "".
"""
_reset_auto_parallel_context
()
...
...
mindspore/ops/primitive.py
浏览文件 @
21d936e6
...
...
@@ -88,6 +88,8 @@ class Primitive(Primitive_):
for
name
in
self
.
attrs
:
value
=
self
.
attrs
[
name
]
cloned
.
add_prim_attr
(
name
,
value
)
if
hasattr
(
self
,
'instance_name'
):
cloned
.
set_prim_instance_name
(
self
.
instance_name
)
return
cloned
def
add_prim_attr
(
self
,
name
,
value
):
...
...
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
21d936e6
...
...
@@ -208,6 +208,36 @@ class _AutoParallelContext:
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_parameter_broadcast
()
def
set_strategy_ckpt_load_file
(
self
,
strategy_ckpt_load_file
):
"""
Set strategy checkpoint load path.
Args:
strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
"""
self
.
check_context_handle
()
self
.
_context_handle
.
set_strategy_ckpt_load_file
(
strategy_ckpt_load_file
)
def
get_strategy_ckpt_load_file
(
self
):
"""Get strategy checkpoint load path."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_strategy_ckpt_load_file
()
def
set_strategy_ckpt_save_file
(
self
,
strategy_ckpt_save_file
):
"""
Set strategy checkpoint save path.
Args:
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
"""
self
.
check_context_handle
()
self
.
_context_handle
.
set_strategy_ckpt_save_file
(
strategy_ckpt_save_file
)
def
get_strategy_ckpt_save_file
(
self
):
"""Get strategy checkpoint save path."""
self
.
check_context_handle
()
return
self
.
_context_handle
.
get_strategy_ckpt_save_file
()
def
get_parameter_broadcast_is_set
(
self
):
"""Get parameter broadcast is set or not."""
self
.
check_context_handle
()
...
...
@@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = {
"cast_before_mirror"
:
auto_parallel_context
().
set_cast_before_mirror
,
"loss_repeated_mean"
:
auto_parallel_context
().
set_loss_repeated_mean
,
"parallel_mode"
:
auto_parallel_context
().
set_parallel_mode
,
"parameter_broadcast"
:
auto_parallel_context
().
set_parameter_broadcast
}
"parameter_broadcast"
:
auto_parallel_context
().
set_parameter_broadcast
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
set_strategy_ckpt_load_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
set_strategy_ckpt_save_file
}
_get_auto_parallel_context_func_map
=
{
...
...
@@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = {
"cast_before_mirror"
:
auto_parallel_context
().
get_cast_before_mirror
,
"loss_repeated_mean"
:
auto_parallel_context
().
get_loss_repeated_mean
,
"parallel_mode"
:
auto_parallel_context
().
get_parallel_mode
,
"parameter_broadcast"
:
auto_parallel_context
().
get_parameter_broadcast
}
"parameter_broadcast"
:
auto_parallel_context
().
get_parameter_broadcast
,
"strategy_ckpt_load_file"
:
auto_parallel_context
().
get_strategy_ckpt_load_file
,
"strategy_ckpt_save_file"
:
auto_parallel_context
().
get_strategy_ckpt_save_file
}
@
args_type_check
(
device_num
=
int
,
global_rank
=
int
,
mirror_mean
=
bool
,
cast_before_mirror
=
bool
,
loss_repeated_mean
=
bool
,
parallel_mode
=
str
,
parameter_broadcast
=
bool
)
loss_repeated_mean
=
bool
,
parallel_mode
=
str
,
parameter_broadcast
=
bool
,
strategy_ckpt_load_file
=
str
,
strategy_ckpt_save_file
=
str
)
def
_set_auto_parallel_context
(
**
kwargs
):
"""
Set auto parallel context.
...
...
@@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs):
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
Raises:
ValueError: If input key is not attribute in auto parallel context.
...
...
@@ -400,5 +437,7 @@ def _reset_auto_parallel_context():
- cast_before_mirror: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
"""
auto_parallel_context
().
reset
()
tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc
浏览文件 @
21d936e6
...
...
@@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() {
return
instance
;
}
bool
StrategyCheckpoint
::
CheckPointExit
()
const
{
return
false
;
}
Status
StrategyCheckpoint
::
RemoveCheckPoint
()
const
{
return
SUCCESS
;
}
bool
StrategyCheckpoint
::
CheckPointExit
(
const
std
::
string
path
)
const
{
return
false
;
}
Status
StrategyCheckpoint
::
Load
(
StrategyMap
*
strategy_map
)
{
return
SUCCESS
;
}
...
...
tests/ut/python/parallel/test_strategy_checkpoint.py
浏览文件 @
21d936e6
...
...
@@ -14,10 +14,10 @@
import
numpy
as
np
from
mindspore
import
context
from
mindspore.context
import
set_auto_parallel_context
from
mindspore.context
import
set_auto_parallel_context
,
reset_auto_parallel_context
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
from
mindspore
import
Tensor
from
mindspore
import
Tensor
,
Parameter
from
tests.ut.python.ops.test_math_ops
import
VirtualLoss
import
mindspore
as
ms
from
mindspore.common.api
import
_executor
...
...
@@ -25,17 +25,15 @@ from mindspore.ops import composite as C
# model_parallel test
# export PARALLEL_CHECKPOINT_ON=on
# export PARALLEL_TRAIN_TIMES=4
def
test_six_matmul
():
def
test_six_matmul_save
():
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x3
,
x4
,
x5
,
x6
,
x7
):
predict
=
self
.
network
(
x1
,
x
2
,
x3
,
x4
,
x5
,
x6
,
x7
)
def
construct
(
self
,
x1
,
x
6
):
predict
=
self
.
network
(
x1
,
x
6
)
return
self
.
loss
(
predict
)
...
...
@@ -44,8 +42,8 @@ def test_six_matmul():
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x3
,
x4
,
x5
,
x6
,
x7
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
2
,
x3
,
x4
,
x5
,
x6
,
x7
)
def
construct
(
self
,
x1
,
x
6
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
6
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
,
strategy2
,
strategy3
,
strategy4
,
strategy5
,
strategy6
):
...
...
@@ -56,45 +54,46 @@ def test_six_matmul():
self
.
matmul4
=
P
.
MatMul
().
set_strategy
(
strategy4
)
self
.
matmul5
=
P
.
MatMul
().
set_strategy
(
strategy5
)
self
.
matmul6
=
P
.
MatMul
().
set_strategy
(
strategy6
)
def
construct
(
self
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
):
out
=
self
.
matmul1
(
x1
,
x2
)
out
=
self
.
matmul2
(
out
,
x3
)
out
=
self
.
matmul3
(
out
,
x4
)
out
=
self
.
matmul4
(
out
,
x5
)
out
=
self
.
matmul5
(
out
,
x6
)
out
=
self
.
matmul6
(
out
,
x7
)
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight1"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight2"
)
self
.
weight3
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight3"
)
self
.
weight4
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight4"
)
self
.
weight5
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight5"
)
def
construct
(
self
,
x1
,
x6
):
out
=
self
.
matmul1
(
x1
,
self
.
weight1
)
out
=
self
.
matmul2
(
out
,
self
.
weight2
)
out
=
self
.
matmul3
(
out
,
self
.
weight3
)
out
=
self
.
matmul4
(
out
,
self
.
weight4
)
out
=
self
.
matmul5
(
out
,
self
.
weight5
)
out
=
self
.
matmul6
(
out
,
x6
)
return
out
set_auto_parallel_context
(
device_num
=
512
,
global_rank
=
0
)
reset_auto_parallel_context
()
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
strategy_ckpt_save_file
=
"./strategy_stage1.ckpt"
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy2
=
((
1
,
8
),
(
8
,
1
))
strategy3
=
((
2
,
2
),
(
2
,
2
))
strategy4
=
((
4
,
2
),
(
2
,
4
))
strategy5
=
((
2
,
4
),
(
4
,
2
))
strategy6
=
((
4
,
4
),
(
4
,
4
))
strategy4
=
((
1
,
1
),
(
1
,
8
))
strategy5
=
((
4
,
2
),
(
2
,
1
))
strategy6
=
((
4
,
1
),
(
1
,
2
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy2
,
strategy3
,
strategy4
,
strategy5
,
strategy6
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x1
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
x2
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
)
x3
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
x4
=
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
)
x5
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
x7
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
)
x1
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x1
,
x6
)
# remove matmul2
def
test_six_matmul_
repeated1
():
# remove matmul2
, add matmul7
def
test_six_matmul_
load
():
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x4
,
x5
,
x
6
,
x7
):
predict
=
self
.
network
(
x1
,
x
2
,
x4
,
x5
,
x
6
,
x7
)
def
construct
(
self
,
x1
,
x6
,
x7
):
predict
=
self
.
network
(
x1
,
x6
,
x7
)
return
self
.
loss
(
predict
)
...
...
@@ -103,53 +102,58 @@ def test_six_matmul_repeated1():
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x4
,
x5
,
x
6
,
x7
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
2
,
x4
,
x5
,
x
6
,
x7
)
def
construct
(
self
,
x1
,
x6
,
x7
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x6
,
x7
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
):
def
__init__
(
self
,
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
,
strategy7
):
super
().
__init__
()
self
.
matmul1
=
P
.
MatMul
().
set_strategy
(
strategy1
)
self
.
matmul3
=
P
.
MatMul
().
set_strategy
(
strategy3
)
self
.
matmul4
=
P
.
MatMul
().
set_strategy
(
strategy4
)
self
.
matmul5
=
P
.
MatMul
().
set_strategy
(
strategy5
)
self
.
matmul6
=
P
.
MatMul
().
set_strategy
(
strategy6
)
def
construct
(
self
,
x1
,
x2
,
x4
,
x5
,
x6
,
x7
):
out
=
self
.
matmul1
(
x1
,
x2
)
out
=
self
.
matmul3
(
out
,
x4
)
out
=
self
.
matmul4
(
out
,
x5
)
out
=
self
.
matmul5
(
out
,
x6
)
out
=
self
.
matmul6
(
out
,
x7
)
self
.
matmul7
=
P
.
MatMul
().
set_strategy
(
strategy7
)
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight1"
)
self
.
weight3
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight3"
)
self
.
weight4
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight4"
)
self
.
weight5
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight5"
)
def
construct
(
self
,
x1
,
x6
,
x7
):
out
=
self
.
matmul1
(
x1
,
self
.
weight1
)
out
=
self
.
matmul3
(
out
,
self
.
weight3
)
out
=
self
.
matmul4
(
out
,
self
.
weight4
)
out
=
self
.
matmul5
(
out
,
self
.
weight5
)
out
=
self
.
matmul6
(
out
,
x6
)
out
=
self
.
matmul7
(
out
,
x7
)
return
out
set_auto_parallel_context
(
device_num
=
512
,
global_rank
=
0
)
reset_auto_parallel_context
()
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
strategy_ckpt_load_file
=
"./strategy_stage1.ckpt"
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy3
=
((
8
,
1
),
(
1
,
1
))
strategy4
=
((
8
,
1
),
(
1
,
1
))
strategy5
=
((
8
,
1
),
(
1
,
1
))
strategy6
=
((
8
,
1
),
(
1
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
)))
strategy7
=
((
8
,
1
),
(
1
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
,
strategy7
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x1
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
x2
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
)
x4
=
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
)
x5
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
x1
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
x7
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x1
,
x
2
,
x4
,
x5
,
x
6
,
x7
)
_executor
.
compile
(
net
,
x1
,
x6
,
x7
)
#
add matmul7
def
test_six_matmul_
repeated2
():
#
model_parallel test
def
test_six_matmul_
save_auto
():
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x4
,
x5
,
x6
,
x7
,
x8
):
predict
=
self
.
network
(
x1
,
x
2
,
x4
,
x5
,
x6
,
x7
,
x8
)
def
construct
(
self
,
x1
,
x
6
):
predict
=
self
.
network
(
x1
,
x
6
)
return
self
.
loss
(
predict
)
...
...
@@ -158,60 +162,52 @@ def test_six_matmul_repeated2():
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x4
,
x5
,
x6
,
x7
,
x8
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
2
,
x4
,
x5
,
x6
,
x7
,
x8
)
def
construct
(
self
,
x1
,
x
6
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
6
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
,
strategy7
):
def
__init__
(
self
):
super
().
__init__
()
self
.
matmul1
=
P
.
MatMul
().
set_strategy
(
strategy1
)
self
.
matmul3
=
P
.
MatMul
().
set_strategy
(
strategy3
)
self
.
matmul4
=
P
.
MatMul
().
set_strategy
(
strategy4
)
self
.
matmul5
=
P
.
MatMul
().
set_strategy
(
strategy5
)
self
.
matmul6
=
P
.
MatMul
().
set_strategy
(
strategy6
)
self
.
matmul7
=
P
.
MatMul
().
set_strategy
(
strategy7
)
def
construct
(
self
,
x1
,
x2
,
x4
,
x5
,
x6
,
x7
,
x8
):
out
=
self
.
matmul1
(
x1
,
x2
)
out
=
self
.
matmul3
(
out
,
x4
)
out
=
self
.
matmul4
(
out
,
x5
)
out
=
self
.
matmul5
(
out
,
x6
)
out
=
self
.
matmul6
(
out
,
x7
)
out
=
self
.
matmul7
(
out
,
x8
)
self
.
matmul1
=
P
.
MatMul
()
self
.
matmul2
=
P
.
MatMul
()
self
.
matmul3
=
P
.
MatMul
()
self
.
matmul4
=
P
.
MatMul
()
self
.
matmul5
=
P
.
MatMul
()
self
.
matmul6
=
P
.
MatMul
()
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight1"
)
self
.
weight2
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight2"
)
self
.
weight3
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight3"
)
self
.
weight4
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight4"
)
self
.
weight5
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight5"
)
def
construct
(
self
,
x1
,
x6
):
out
=
self
.
matmul1
(
x1
,
self
.
weight1
)
out
=
self
.
matmul2
(
out
,
self
.
weight2
)
out
=
self
.
matmul3
(
out
,
self
.
weight3
)
out
=
self
.
matmul4
(
out
,
self
.
weight4
)
out
=
self
.
matmul5
(
out
,
self
.
weight5
)
out
=
self
.
matmul6
(
out
,
x6
)
return
out
set_auto_parallel_context
(
device_num
=
512
,
global_rank
=
0
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy3
=
((
8
,
1
),
(
1
,
1
))
strategy4
=
((
8
,
1
),
(
1
,
1
))
strategy5
=
((
8
,
1
),
(
1
,
1
))
strategy6
=
((
8
,
1
),
(
1
,
1
))
strategy7
=
((
8
,
1
),
(
1
,
1
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
,
strategy7
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x1
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
x2
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
)
x4
=
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
)
x5
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
x7
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
x8
=
Tensor
(
np
.
ones
([
32
,
128
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x1
,
x2
,
x4
,
x5
,
x6
,
x7
,
x8
)
reset_auto_parallel_context
()
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
strategy_ckpt_save_file
=
"./strategy_stage1_auto.ckpt"
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
x1
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x1
,
x6
)
#
add scope2
def
test_six_matmul_
repeated3
():
#
remove matmul2, add matmul7
def
test_six_matmul_
load_auto
():
class
NetWithLoss
(
nn
.
Cell
):
def
__init__
(
self
,
network
1
,
network2
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network1
self
.
network2
=
network2
self
.
network
=
network
def
construct
(
self
,
x1
,
x2
,
x4
,
x5
,
x6
,
x7
,
x8
,
x9
,
x10
):
predict
=
self
.
network
(
x1
,
x2
,
x4
,
x5
,
x6
,
x7
,
x8
)
predict
=
self
.
network2
(
predict
,
x9
,
x10
)
def
construct
(
self
,
x1
,
x6
,
x7
):
predict
=
self
.
network
(
x1
,
x6
,
x7
)
return
self
.
loss
(
predict
)
...
...
@@ -220,62 +216,42 @@ def test_six_matmul_repeated3():
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x1
,
x
2
,
x4
,
x5
,
x6
,
x7
,
x8
,
x9
,
x10
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
2
,
x4
,
x5
,
x6
,
x7
,
x8
,
x9
,
x10
)
def
construct
(
self
,
x1
,
x
6
,
x7
):
return
C
.
grad_all
(
self
.
network
)(
x1
,
x
6
,
x7
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
,
strategy7
):
def
__init__
(
self
,
strategy1
,
strategy3
,
strategy4
,
strategy5
):
super
().
__init__
()
self
.
matmul1
=
P
.
MatMul
().
set_strategy
(
strategy1
)
self
.
matmul3
=
P
.
MatMul
().
set_strategy
(
strategy3
)
self
.
matmul4
=
P
.
MatMul
().
set_strategy
(
strategy4
)
self
.
matmul5
=
P
.
MatMul
().
set_strategy
(
strategy5
)
self
.
matmul6
=
P
.
MatMul
().
set_strategy
(
strategy6
)
self
.
matmul7
=
P
.
MatMul
().
set_strategy
(
strategy7
)
def
construct
(
self
,
x1
,
x2
,
x4
,
x5
,
x6
,
x7
,
x8
):
out
=
self
.
matmul1
(
x1
,
x2
)
out
=
self
.
matmul3
(
out
,
x4
)
out
=
self
.
matmul4
(
out
,
x5
)
out
=
self
.
matmul5
(
out
,
x6
)
out
=
self
.
matmul6
(
out
,
x7
)
out
=
self
.
matmul7
(
out
,
x8
)
return
out
class
Net1
(
nn
.
Cell
):
def
__init__
(
self
,
strategy1
,
strategy2
):
super
().
__init__
()
self
.
matmul1
=
P
.
MatMul
().
set_strategy
(
strategy1
)
self
.
matmul2
=
P
.
MatMul
().
set_strategy
(
strategy2
)
def
construct
(
self
,
x1
,
x2
,
x3
):
out
=
self
.
matmul1
(
x1
,
x2
)
out
=
self
.
matmul2
(
out
,
x3
)
self
.
matmul6
=
P
.
MatMul
()
self
.
matmul7
=
P
.
MatMul
()
self
.
weight1
=
Parameter
(
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight1"
)
self
.
weight3
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight3"
)
self
.
weight4
=
Parameter
(
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
),
name
=
"weight4"
)
self
.
weight5
=
Parameter
(
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
),
name
=
"weight5"
)
def
construct
(
self
,
x1
,
x6
,
x7
):
out
=
self
.
matmul1
(
x1
,
self
.
weight1
)
out
=
self
.
matmul3
(
out
,
self
.
weight3
)
out
=
self
.
matmul4
(
out
,
self
.
weight4
)
out
=
self
.
matmul5
(
out
,
self
.
weight5
)
out
=
self
.
matmul6
(
out
,
x6
)
out
=
self
.
matmul7
(
out
,
x7
)
return
out
reset_auto_parallel_context
()
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
,
strategy_ckpt_load_file
=
"./strategy_stage1_auto.ckpt"
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
strategy3
=
((
2
,
2
),
(
2
,
2
))
strategy4
=
((
2
,
2
),
(
2
,
2
))
strategy5
=
((
2
,
2
),
(
2
,
2
))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy3
,
strategy4
,
strategy5
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
set_auto_parallel_context
(
device_num
=
512
,
global_rank
=
0
)
strategy1
=
((
8
,
1
),
(
1
,
1
))
strategy3
=
((
8
,
1
),
(
1
,
1
))
strategy4
=
((
8
,
1
),
(
1
,
1
))
strategy5
=
((
8
,
1
),
(
1
,
1
))
strategy6
=
((
8
,
1
),
(
1
,
1
))
strategy7
=
((
8
,
1
),
(
1
,
1
))
strategy8
=
((
8
,
1
),
(
1
,
1
))
strategy9
=
((
8
,
1
),
(
1
,
1
))
net1
=
Net
(
strategy1
,
strategy3
,
strategy4
,
strategy5
,
strategy6
,
strategy7
)
net2
=
Net1
(
strategy8
,
strategy9
)
net
=
GradWrap
(
NetWithLoss
(
net1
,
net2
))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
x1
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
x2
=
Tensor
(
np
.
ones
([
32
,
64
]),
dtype
=
ms
.
float32
)
x4
=
Tensor
(
np
.
ones
([
64
,
128
]),
dtype
=
ms
.
float32
)
x5
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
64
,
32
]),
dtype
=
ms
.
float32
)
x1
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
x6
=
Tensor
(
np
.
ones
([
128
,
32
]),
dtype
=
ms
.
float32
)
x7
=
Tensor
(
np
.
ones
([
32
,
32
]),
dtype
=
ms
.
float32
)
x8
=
Tensor
(
np
.
ones
([
32
,
128
]),
dtype
=
ms
.
float32
)
x9
=
Tensor
(
np
.
ones
([
128
,
64
]),
dtype
=
ms
.
float32
)
x10
=
Tensor
(
np
.
ones
([
64
,
64
]),
dtype
=
ms
.
float32
)
_executor
.
compile
(
net
,
x1
,
x2
,
x4
,
x5
,
x6
,
x7
,
x8
,
x9
,
x10
)
_executor
.
compile
(
net
,
x1
,
x6
,
x7
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录