Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
28521e0f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
28521e0f
编写于
6月 15, 2021
作者:
W
WeiXin
提交者:
GitHub
6月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Save all the information of 'ParamBase' in 'Layer'. (#33500)
* Save all the information of 'ParamBase' in 'Layer'. * edit unittest
上级
009a163c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
53 addition
and
14 deletion
+53
-14
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+12
-0
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
+4
-8
python/paddle/framework/io.py
python/paddle/framework/io.py
+37
-6
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
28521e0f
...
...
@@ -5540,6 +5540,18 @@ class ParamBase(core.VarBase):
core
.
varbase_copy
(
self
,
new_param
,
device
,
blocking
)
return
new_param
def
__reduce__
(
self
):
value
=
self
.
numpy
()
state
=
(
self
.
name
,
self
.
persistable
,
self
.
stop_gradient
)
return
ParamBase
,
(
self
.
shape
,
self
.
dtype
),
(
self
.
__dict__
,
value
,
state
)
def
__setstate__
(
self
,
state
):
self
.
__dict__
.
update
(
state
[
0
])
t
=
self
.
value
().
get_tensor
()
t
.
set
(
state
[
1
],
_current_expected_place
())
self
.
name
,
self
.
persistable
,
self
.
stop_gradient
=
state
[
2
]
__repr__
=
__str__
...
...
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
浏览文件 @
28521e0f
...
...
@@ -935,21 +935,17 @@ class TestSaveLoadLayer(unittest.TestCase):
layer2
=
LinearNet
()
layer1
.
eval
()
layer2
.
eval
()
origin_layer
=
(
layer1
,
layer2
)
origin
=
(
layer1
(
inps
),
layer2
(
inps
))
path
=
"test_save_load_layer_/layer.pdmodel"
paddle
.
save
((
layer1
,
layer2
),
path
)
# static
paddle
.
enable_static
()
with
self
.
assertRaises
(
ValueError
):
paddle
.
load
(
path
)
# dygraph
paddle
.
disable_static
()
paddle
.
save
(
origin_layer
,
path
)
loaded_layer
=
paddle
.
load
(
path
)
loaded_result
=
[
l
(
inps
)
for
l
in
loaded_layer
]
for
i
in
range
(
len
(
origin
)):
self
.
assertTrue
((
origin
[
i
]
-
loaded_result
[
i
]).
abs
().
max
()
<
1e-10
)
for
k
,
v
in
origin_layer
[
i
].
_linear
.
weight
.
__dict__
.
items
():
self
.
assertTrue
(
v
==
loaded_layer
[
i
].
_linear
.
weight
.
__dict__
[
k
])
if
__name__
==
'__main__'
:
...
...
python/paddle/framework/io.py
浏览文件 @
28521e0f
...
...
@@ -233,9 +233,13 @@ def _pickle_save(obj, f, protocol):
raise
ValueError
(
"Expected 1<'protocol'<5, but received protocol={}"
.
format
(
protocol
))
def
reudce_varbase
(
self
):
list_params
=
set
()
def
reduce_varbase
(
self
):
data
=
self
.
numpy
()
name
=
self
.
name
if
name
in
list_params
:
return
self
.
__reduce__
()
return
(
tuple
,
((
name
,
data
),
))
...
...
@@ -244,16 +248,43 @@ def _pickle_save(obj, f, protocol):
return
(
eval
,
(
'data'
,
{
'data'
:
data
}))
def
reduce_Layer
(
self
):
is_param_or_layer
=
lambda
v
:
isinstance
(
v
,
ParamBase
)
or
isinstance
(
v
,
core
.
Layer
)
def
collect_params
(
param_or_layer
):
if
isinstance
(
param_or_layer
,
ParamBase
):
list_params
.
add
(
param_or_layer
.
name
)
else
:
# param_or_layer is layer
_parse_every_object
(
param_or_layer
.
__dict__
,
is_param_or_layer
,
collect_params
)
return
param_or_layer
_parse_every_object
(
self
.
__dict__
,
is_param_or_layer
,
collect_params
)
return
self
.
__reduce_ex__
(
protocol
)
dispatch_table_layer
=
dict
()
def
create_layer_dispatch_table
(
layer
):
dispatch_table_layer
[
layer
.
__class__
]
=
reduce_Layer
return
layer
_parse_every_object
(
obj
,
lambda
v
:
isinstance
(
v
,
core
.
Layer
),
create_layer_dispatch_table
)
def
add_dispatch_table
():
# This is not a good method, because the pickle module has been modified.
pickle
.
dispatch_table
[
core
.
VarBase
]
=
re
ud
ce_varbase
pickle
.
dispatch_table
[
ParamBase
]
=
re
ud
ce_varbase
pickle
.
dispatch_table
[
core
.
VarBase
]
=
re
du
ce_varbase
pickle
.
dispatch_table
[
ParamBase
]
=
re
du
ce_varbase
pickle
.
dispatch_table
[
core
.
LoDTensor
]
=
reduce_LoDTensor
pickle
.
dispatch_table
.
update
(
dispatch_table_layer
)
def
pop_dispatch_table
():
pickle
.
dispatch_table
.
pop
(
core
.
VarBase
)
pickle
.
dispatch_table
.
pop
(
core
.
LoDTensor
)
pickle
.
dispatch_table
.
pop
(
ParamBase
)
for
k
in
dispatch_table_layer
:
pickle
.
dispatch_table
.
pop
(
k
)
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
...
...
@@ -273,10 +304,10 @@ def _pickle_save(obj, f, protocol):
pickler
=
pickle
.
Pickler
(
f
,
protocol
)
pickler
.
dispatch_table
=
copyreg
.
dispatch_table
.
copy
()
pickler
.
dispatch_table
[
core
.
VarBase
]
=
re
ud
ce_varbase
pickler
.
dispatch_table
[
core
.
VarBase
]
=
re
du
ce_varbase
pickler
.
dispatch_table
[
core
.
LoDTensor
]
=
reduce_LoDTensor
pickler
.
dispatch_table
[
ParamBase
]
=
re
ud
ce_varbase
pickler
.
dispatch_table
[
ParamBase
]
=
re
du
ce_varbase
pickler
.
dispatch_table
.
update
(
dispatch_table_layer
)
pickler
.
dump
(
obj
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录