Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
097b77c3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
097b77c3
编写于
7月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3273 Optimized checkpoint save slice tensor
Merge pull request !3273 from changzherui/save_slice_tensor
上级
c84d4bbd
d37398cd
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
53 addition
and
29 deletion
+53
-29
mindspore/train/serialization.py
mindspore/train/serialization.py
+53
-29
未找到文件。
mindspore/train/serialization.py
浏览文件 @
097b77c3
...
...
@@ -15,6 +15,7 @@
"""Model and parameters serialization."""
import
os
import
stat
import
math
from
threading
import
Thread
,
Lock
import
numpy
as
np
...
...
@@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
_ckpt_mutex
=
Lock
()
SLICE_SIZE
=
512
*
1024
*
1024
def
_special_process_par
(
par
,
new_par
):
"""
...
...
@@ -105,26 +108,38 @@ def _update_param(param, new_param):
def
_exec_save
(
ckpt_file_name
,
data_list
):
"""Execute save checkpoint into file process."""
checkpoint_list
=
Checkpoint
()
try
:
with
_ckpt_mutex
:
for
name
,
value
in
data_list
.
items
():
param_value
=
checkpoint_list
.
value
.
add
()
param_value
.
tag
=
name
param_tensor
=
param_value
.
tensor
param_tensor
.
dims
.
extend
(
value
[
0
])
param_tensor
.
tensor_type
=
value
[
1
]
param_tensor
.
tensor_content
=
value
[
2
].
tostring
()
with
open
(
ckpt_file_name
,
"wb"
)
as
f
:
f
.
write
(
checkpoint_list
.
SerializeToString
())
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IRUSR
)
if
os
.
path
.
exists
(
ckpt_file_name
):
os
.
remove
(
ckpt_file_name
)
with
open
(
ckpt_file_name
,
"ab"
)
as
f
:
for
name
,
value
in
data_list
.
items
():
data_size
=
value
[
2
].
nbytes
if
data_size
>
SLICE_SIZE
:
slice_count
=
math
.
ceil
(
data_size
/
SLICE_SIZE
)
param_slice_list
=
np
.
array_split
(
value
[
2
],
slice_count
)
else
:
param_slice_list
=
[
value
[
2
]]
for
param_slice
in
param_slice_list
:
checkpoint_list
=
Checkpoint
()
param_value
=
checkpoint_list
.
value
.
add
()
param_value
.
tag
=
name
param_tensor
=
param_value
.
tensor
param_tensor
.
dims
.
extend
(
value
[
0
])
param_tensor
.
tensor_type
=
value
[
1
]
param_tensor
.
tensor_content
=
param_slice
.
tostring
()
f
.
write
(
checkpoint_list
.
SerializeToString
())
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IRUSR
)
except
BaseException
as
e
:
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckpt_file_name
)
raise
RuntimeError
(
e
.
__str__
())
def
save_checkpoint
(
parameter_list
,
ckpt_file_name
,
async_save
=
False
):
"""
Saves checkpoint info to a specified file.
...
...
@@ -206,28 +221,37 @@ def load_checkpoint(ckpt_file_name, net=None):
parameter_dict
=
{}
try
:
element_id
=
0
param_data_list
=
[]
for
element
in
checkpoint_list
.
value
:
data
=
element
.
tensor
.
tensor_content
data_type
=
element
.
tensor
.
tensor_type
np_type
=
tensor_to_np_type
[
data_type
]
ms_type
=
tensor_to_ms_type
[
data_type
]
param_data
=
np
.
fromstring
(
data
,
np_type
)
dims
=
element
.
tensor
.
dims
if
dims
==
[
0
]:
if
'Float'
in
data_type
:
param_data
=
float
(
param_data
[
0
])
elif
'Int'
in
data_type
:
param_data
=
int
(
param_data
[
0
])
parameter_dict
[
element
.
tag
]
=
Parameter
(
Tensor
(
param_data
,
ms_type
),
name
=
element
.
tag
)
elif
dims
==
[
1
]:
parameter_dict
[
element
.
tag
]
=
Parameter
(
Tensor
(
param_data
,
ms_type
),
name
=
element
.
tag
)
else
:
param_dim
=
[]
for
dim
in
dims
:
param_dim
.
append
(
dim
)
param_value
=
param_data
.
reshape
(
param_dim
)
parameter_dict
[
element
.
tag
]
=
Parameter
(
Tensor
(
param_value
,
ms_type
),
name
=
element
.
tag
)
element_data
=
np
.
frombuffer
(
data
,
np_type
)
param_data_list
.
append
(
element_data
)
if
(
element_id
==
len
(
checkpoint_list
.
value
)
-
1
)
or
\
(
element
.
tag
!=
checkpoint_list
.
value
[
element_id
+
1
].
tag
):
param_data
=
np
.
concatenate
((
param_data_list
),
axis
=
0
)
param_data_list
.
clear
()
dims
=
element
.
tensor
.
dims
if
dims
==
[
0
]:
if
'Float'
in
data_type
:
param_data
=
float
(
param_data
[
0
])
elif
'Int'
in
data_type
:
param_data
=
int
(
param_data
[
0
])
parameter_dict
[
element
.
tag
]
=
Parameter
(
Tensor
(
param_data
,
ms_type
),
name
=
element
.
tag
)
elif
dims
==
[
1
]:
parameter_dict
[
element
.
tag
]
=
Parameter
(
Tensor
(
param_data
,
ms_type
),
name
=
element
.
tag
)
else
:
param_dim
=
[]
for
dim
in
dims
:
param_dim
.
append
(
dim
)
param_value
=
param_data
.
reshape
(
param_dim
)
parameter_dict
[
element
.
tag
]
=
Parameter
(
Tensor
(
param_value
,
ms_type
),
name
=
element
.
tag
)
element_id
+=
1
logger
.
info
(
"Load checkpoint process finish."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录