Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
60a9fb00
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
60a9fb00
编写于
8月 05, 2020
作者:
Y
yao_yf
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add_tensor_layout_in_stra_ckpt
上级
57fd31b2
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
115 addition
and
13 deletion
+115
-13
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h
...spore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h
+2
-0
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
+5
-1
mindspore/ccsrc/frontend/parallel/step_parallel.cc
mindspore/ccsrc/frontend/parallel/step_parallel.cc
+40
-8
mindspore/ccsrc/frontend/parallel/step_parallel.h
mindspore/ccsrc/frontend/parallel/step_parallel.h
+1
-1
mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc
...allel/strategy_checkpoint/parallel_strategy_checkpoint.cc
+29
-1
mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h
...rallel/strategy_checkpoint/parallel_strategy_checkpoint.h
+7
-1
mindspore/ccsrc/utils/node_strategy.proto
mindspore/ccsrc/utils/node_strategy.proto
+29
-0
tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc
..._strategy_checkpoint/parallel_strategy_checkpoint_stub.cc
+2
-1
未找到文件。
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h
浏览文件 @
60a9fb00
...
@@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo {
...
@@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo {
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
SetCostUnderStrategy
(
const
StrategyPtr
&
strategy
)
override
;
ReplaceGraphPtr
replace_graph
(
const
CNodePtr
&
cnode
)
override
;
ReplaceGraphPtr
replace_graph
(
const
CNodePtr
&
cnode
)
override
;
std
::
shared_ptr
<
Strategys
>
GenerateBatchStrategies
()
override
;
std
::
shared_ptr
<
Strategys
>
GenerateBatchStrategies
()
override
;
const
std
::
vector
<
int64_t
>
&
param_split_shapes
()
const
{
return
param_split_shapes_
;
}
const
std
::
vector
<
int64_t
>
&
index_offsets
()
const
{
return
index_offsets_
;
}
protected:
protected:
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
Status
CheckStrategy
(
const
StrategyPtr
&
strategy
)
override
;
...
...
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
浏览文件 @
60a9fb00
...
@@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
...
@@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info
->
set_outputs_dtype
(
cnode
->
Type
());
operator_info
->
set_outputs_dtype
(
cnode
->
Type
());
operator_info
->
set_cnode
(
cnode
);
operator_info
->
set_cnode
(
cnode
);
// key of strategy map
// key of strategy map
std
::
string
strategy_key_name
=
NodeParameterName
(
cnode
);
std
::
string
strategy_key_name
=
""
;
auto
param_names
=
NodeParameterName
(
cnode
);
if
(
!
param_names
.
empty
())
{
strategy_key_name
=
param_names
[
0
].
first
;
}
bool
load_strategy_from_ckpt
=
bool
load_strategy_from_ckpt
=
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
->
find
(
strategy_key_name
)
!=
stra_map
->
end
();
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
->
find
(
strategy_key_name
)
!=
stra_map
->
end
();
// If no strategy has been configured for this operator, then candidate strategies are generated for
// If no strategy has been configured for this operator, then candidate strategies are generated for
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.cc
浏览文件 @
60a9fb00
...
@@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
...
@@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
}
}
// load strategy checkpoint
// load strategy checkpoint
// key of strategy map
// key of strategy map
std
::
string
strategy_key_name
=
NodeParameterName
(
cnode
);
std
::
string
strategy_key_name
=
""
;
auto
param_names
=
NodeParameterName
(
cnode
);
if
(
!
param_names
.
empty
())
{
strategy_key_name
=
param_names
[
0
].
first
;
}
bool
load_strategy_from_ckpt
=
bool
load_strategy_from_ckpt
=
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
.
find
(
strategy_key_name
)
!=
stra_map
.
end
();
StrategyCheckpoint
::
GetInstance
().
LoadCheckPointOn
()
&&
stra_map
.
find
(
strategy_key_name
)
!=
stra_map
.
end
();
if
(
!
StrategyFound
(
attrs
)
&&
!
load_strategy_from_ckpt
)
{
if
(
!
StrategyFound
(
attrs
)
&&
!
load_strategy_from_ckpt
)
{
...
@@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
...
@@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
}
}
}
}
std
::
string
NodeParameterName
(
const
CNodePtr
&
node
)
{
std
::
vector
<
std
::
pair
<
std
::
string
,
int
>>
NodeParameterName
(
const
CNodePtr
&
node
)
{
std
::
vector
<
AnfNodePtr
>
node_inputs
{
node
->
inputs
()};
std
::
vector
<
AnfNodePtr
>
node_inputs
{
node
->
inputs
()};
for
(
auto
input
:
node_inputs
)
{
std
::
vector
<
std
::
pair
<
std
::
string
,
int
>>
param_names
;
for
(
int
i
=
0
;
i
<
UintToInt
(
node_inputs
.
size
());
++
i
)
{
auto
input
=
node_inputs
[
i
];
if
(
input
->
isa
<
Parameter
>
())
{
if
(
input
->
isa
<
Parameter
>
())
{
auto
input_parameter
=
input
->
cast
<
ParameterPtr
>
();
auto
input_parameter
=
input
->
cast
<
ParameterPtr
>
();
if
(
input_parameter
->
has_default
())
{
if
(
input_parameter
->
has_default
())
{
input_parameter
->
name
();
if
(
ParameterRequireGrad
(
input_parameter
))
{
param_names
.
push_back
({
input_parameter
->
name
(),
i
});
}
}
}
}
}
}
}
return
""
;
return
param_names
;
}
}
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
)
{
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_LOG
(
DEBUG
)
<<
"Save strategy to checkpoint begin"
;
MS_LOG
(
DEBUG
)
<<
"Save strategy to checkpoint begin"
;
StrategyMap
stra_map
;
StrategyMap
stra_map
;
TensorInfoMap
tensor_info_map
;
ManualShapeMap
manual_shape_map
;
auto
ret
=
func_graph
->
get_return
();
auto
ret
=
func_graph
->
get_return
();
auto
all_nodes
=
DeepScopedGraphSearch
(
ret
);
auto
all_nodes
=
DeepScopedGraphSearch
(
ret
);
for
(
auto
&
node
:
all_nodes
)
{
for
(
auto
&
node
:
all_nodes
)
{
...
@@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
...
@@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
if
((
cnode
==
nullptr
)
||
!
IsValueNode
<
Primitive
>
(
cnode
->
input
(
0
)))
{
continue
;
continue
;
}
}
std
::
string
param_name
=
NodeParameterName
(
cnode
);
auto
param_names
=
NodeParameterName
(
cnode
);
if
(
param_name
.
empty
())
{
if
(
param_name
s
.
empty
())
{
continue
;
continue
;
}
}
string
param_name
=
param_names
[
0
].
first
;
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
PrimitivePtr
prim
=
GetValueNode
<
PrimitivePtr
>
(
cnode
->
input
(
0
));
MS_EXCEPTION_IF_NULL
(
prim
);
MS_EXCEPTION_IF_NULL
(
prim
);
OperatorInfoPtr
operator_info
=
cnode
->
user_data
<
OperatorInfo
>
();
OperatorInfoPtr
operator_info
=
cnode
->
user_data
<
OperatorInfo
>
();
...
@@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
...
@@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
if
(
operator_info
->
name
().
find
(
RESHAPEINFO
)
!=
std
::
string
::
npos
)
{
if
(
operator_info
->
name
().
find
(
RESHAPEINFO
)
!=
std
::
string
::
npos
)
{
continue
;
continue
;
}
}
std
::
vector
<
TensorInfo
>
input_tensor_info
=
operator_info
->
inputs_tensor_info
();
StrategyPtr
strategyPtr
=
operator_info
->
strategy
();
StrategyPtr
strategyPtr
=
operator_info
->
strategy
();
MS_EXCEPTION_IF_NULL
(
node
->
scope
());
MS_EXCEPTION_IF_NULL
(
node
->
scope
());
stra_map
[
param_name
]
=
strategyPtr
;
stra_map
[
param_name
]
=
strategyPtr
;
for
(
auto
param_name_pair
:
param_names
)
{
if
(
param_name_pair
.
second
-
1
>=
UintToInt
(
input_tensor_info
.
size
()))
{
continue
;
}
tensor_info_map
[
param_name_pair
.
first
]
=
input_tensor_info
[
param_name_pair
.
second
-
1
];
}
if
(
operator_info
->
name
().
find
(
EMBEDDING_LOOKUP
)
!=
std
::
string
::
npos
||
operator_info
->
name
().
find
(
GATHERV2
)
!=
std
::
string
::
npos
)
{
auto
gatherv2_info
=
std
::
dynamic_pointer_cast
<
GatherV2PInfo
>
(
operator_info
);
auto
param_split_shapes
=
gatherv2_info
->
param_split_shapes
();
auto
index_offsets
=
gatherv2_info
->
index_offsets
();
if
(
param_split_shapes
.
size
()
!=
index_offsets
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"In manual split, the param_split_shapes and index_offsets lenght should be same."
;
}
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
manual_shape
;
for
(
int
i
=
0
;
i
<
UintToInt
(
param_split_shapes
.
size
());
++
i
)
{
manual_shape
.
push_back
({
param_split_shapes
[
i
],
index_offsets
[
i
]});
}
manual_shape_map
[
param_name
]
=
manual_shape
;
}
}
}
}
}
if
(
StrategyCheckpoint
::
GetInstance
().
Save
(
stra_map
)
!=
SUCCESS
)
{
if
(
StrategyCheckpoint
::
GetInstance
().
Save
(
stra_map
,
tensor_info_map
,
&
manual_shape_map
)
!=
SUCCESS
)
{
MS_LOG
(
EXCEPTION
)
<<
"Save strategy checkpoint failed"
;
MS_LOG
(
EXCEPTION
)
<<
"Save strategy checkpoint failed"
;
}
}
}
}
...
...
mindspore/ccsrc/frontend/parallel/step_parallel.h
浏览文件 @
60a9fb00
...
@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
...
@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void
ParallelCommunication
(
const
FuncGraphPtr
&
root
,
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
void
ParallelCommunication
(
const
FuncGraphPtr
&
root
,
const
std
::
vector
<
AnfNodePtr
>
&
all_nodes
,
const
FuncGraphManagerPtr
&
manager
);
const
FuncGraphManagerPtr
&
manager
);
std
::
string
NodeParameterName
(
const
CNodePtr
&
node
);
std
::
vector
<
std
::
pair
<
std
::
string
,
int
>>
NodeParameterName
(
const
CNodePtr
&
node
);
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
);
void
CheckpointStrategy
(
const
FuncGraphPtr
&
func_graph
);
...
...
mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc
浏览文件 @
60a9fb00
...
@@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
...
@@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
return
SUCCESS
;
return
SUCCESS
;
}
}
Status
StrategyCheckpoint
::
Save
(
const
StrategyMap
&
strategy_map
)
{
Status
StrategyCheckpoint
::
Save
(
const
StrategyMap
&
strategy_map
,
const
TensorInfoMap
&
tensor_info_map
,
ManualShapeMap
*
manual_shape_map
)
{
straspb
::
ParallelStrategyMap
parallel_strategy_map
;
straspb
::
ParallelStrategyMap
parallel_strategy_map
;
parallel_strategy_map
.
set_current_stage
(
IntToUint
(
++
current_stage_
));
parallel_strategy_map
.
set_current_stage
(
IntToUint
(
++
current_stage_
));
for
(
auto
&
node_stra
:
strategy_map
)
{
for
(
auto
&
node_stra
:
strategy_map
)
{
...
@@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
...
@@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
}
}
}
}
}
}
for
(
auto
&
node_tensor_info
:
tensor_info_map
)
{
TensorInfo
tensor_info
=
node_tensor_info
.
second
;
TensorLayout
tensor_layout
=
tensor_info
.
tensor_layout
();
straspb
::
ParallelLayoutItem
*
parallel_layout_item
=
parallel_strategy_map
.
add_parallel_layout_item
();
MS_EXCEPTION_IF_NULL
(
parallel_layout_item
);
parallel_layout_item
->
set_param_name
(
node_tensor_info
.
first
);
straspb
::
ParallelLayouts
*
parallel_layouts
=
parallel_layout_item
->
mutable_parallel_layouts
();
straspb
::
DevMatrix
*
dev_matrix
=
parallel_layouts
->
add_dev_matrix
();
MS_EXCEPTION_IF_NULL
(
dev_matrix
);
for
(
auto
dim
:
tensor_layout
.
device_arrangement
().
array
())
{
dev_matrix
->
add_dim
(
IntToUint
(
dim
));
}
straspb
::
TensorMap
*
tensor_map
=
parallel_layouts
->
add_tensor_map
();
MS_EXCEPTION_IF_NULL
(
tensor_map
);
for
(
auto
dim
:
tensor_layout
.
tensor_map
().
array
())
{
tensor_map
->
add_dim
(
dim
);
}
straspb
::
ParamSplitShape
*
param_split_shape
=
parallel_layouts
->
add_param_split_shape
();
straspb
::
IndicesOffset
*
indices_offset
=
parallel_layouts
->
add_indices_offset
();
MS_EXCEPTION_IF_NULL
(
manual_shape_map
);
auto
manual_shape
=
(
*
manual_shape_map
)[
node_tensor_info
.
first
];
for
(
auto
dim_pair
:
manual_shape
)
{
param_split_shape
->
add_dim
(
dim_pair
.
first
);
indices_offset
->
add_dim
(
dim_pair
.
second
);
}
}
std
::
fstream
output
(
save_file_
,
std
::
ios
::
out
|
std
::
ios
::
trunc
|
std
::
ios
::
binary
);
std
::
fstream
output
(
save_file_
,
std
::
ios
::
out
|
std
::
ios
::
trunc
|
std
::
ios
::
binary
);
if
(
!
parallel_strategy_map
.
SerializeToOstream
(
&
output
))
{
if
(
!
parallel_strategy_map
.
SerializeToOstream
(
&
output
))
{
MS_LOG
(
ERROR
)
<<
"Save strategy file failed"
;
MS_LOG
(
ERROR
)
<<
"Save strategy file failed"
;
...
...
mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h
浏览文件 @
60a9fb00
...
@@ -19,13 +19,19 @@
...
@@ -19,13 +19,19 @@
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <vector>
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#include "frontend/parallel/tensor_layout/tensor_info.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
using
StrategyMap
=
std
::
unordered_map
<
std
::
string
,
StrategyPtr
>
;
using
StrategyMap
=
std
::
unordered_map
<
std
::
string
,
StrategyPtr
>
;
using
TensorInfoMap
=
std
::
unordered_map
<
std
::
string
,
TensorInfo
>
;
using
ManualShapeMap
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>>
;
class
StrategyCheckpoint
{
class
StrategyCheckpoint
{
public:
public:
StrategyCheckpoint
()
{
StrategyCheckpoint
()
{
...
@@ -38,7 +44,7 @@ class StrategyCheckpoint {
...
@@ -38,7 +44,7 @@ class StrategyCheckpoint {
~
StrategyCheckpoint
()
=
default
;
~
StrategyCheckpoint
()
=
default
;
Status
Load
(
StrategyMap
*
strategy_map
);
Status
Load
(
StrategyMap
*
strategy_map
);
Status
Save
(
const
StrategyMap
&
strategy_map
);
Status
Save
(
const
StrategyMap
&
strategy_map
,
const
TensorInfoMap
&
tensor_info_map
,
ManualShapeMap
*
manual_shape_map
);
static
StrategyCheckpoint
&
GetInstance
();
static
StrategyCheckpoint
&
GetInstance
();
bool
LoadCheckPointOn
()
const
{
return
load_checkpoint_on_
;
}
bool
LoadCheckPointOn
()
const
{
return
load_checkpoint_on_
;
}
...
...
mindspore/ccsrc/utils/node_strategy.proto
浏览文件 @
60a9fb00
...
@@ -32,7 +32,36 @@ message ParallelStrategyItem {
...
@@ -32,7 +32,36 @@ message ParallelStrategyItem {
required
ParallelStrategys
parallel_strategys
=
2
;
required
ParallelStrategys
parallel_strategys
=
2
;
}
}
message
DevMatrix
{
repeated
uint32
dim
=
1
;
}
message
TensorMap
{
repeated
int32
dim
=
1
;
}
message
ParamSplitShape
{
repeated
int64
dim
=
1
;
}
message
IndicesOffset
{
repeated
int64
dim
=
1
;
}
message
ParallelLayouts
{
repeated
DevMatrix
dev_matrix
=
1
;
repeated
TensorMap
tensor_map
=
2
;
repeated
ParamSplitShape
param_split_shape
=
3
;
repeated
IndicesOffset
indices_offset
=
4
;
}
message
ParallelLayoutItem
{
required
string
param_name
=
1
;
required
ParallelLayouts
parallel_layouts
=
2
;
}
message
ParallelStrategyMap
{
message
ParallelStrategyMap
{
required
uint32
current_stage
=
1
;
required
uint32
current_stage
=
1
;
repeated
ParallelStrategyItem
parallel_strategy_item
=
2
;
repeated
ParallelStrategyItem
parallel_strategy_item
=
2
;
repeated
ParallelLayoutItem
parallel_layout_item
=
3
;
}
}
\ No newline at end of file
tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc
浏览文件 @
60a9fb00
...
@@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f
...
@@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f
Status
StrategyCheckpoint
::
Load
(
StrategyMap
*
strategy_map
)
{
return
SUCCESS
;
}
Status
StrategyCheckpoint
::
Load
(
StrategyMap
*
strategy_map
)
{
return
SUCCESS
;
}
Status
StrategyCheckpoint
::
Save
(
const
StrategyMap
&
strategy_map
)
{
return
SUCCESS
;
}
Status
StrategyCheckpoint
::
Save
(
const
StrategyMap
&
strategy_map
,
const
TensorInfoMap
&
tensor_info_map
,
ManualShapeMap
*
manual_shape_map
)
{
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace parallel
}
// namespace mindspore
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录