Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
95a0f87b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95a0f87b
编写于
11月 27, 2020
作者:
C
Chen Weihang
提交者:
GitHub
11月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support jit.save datra parallel (#29135)
上级
449903de
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
50 addition
and
6 deletion
+50
-6
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+16
-6
python/paddle/fluid/tests/unittests/test_jit_save_load.py
python/paddle/fluid/tests/unittests/test_jit_save_load.py
+34
-0
未找到文件。
python/paddle/fluid/dygraph/jit.py
浏览文件 @
95a0f87b
...
@@ -581,6 +581,16 @@ def save(layer, path, input_spec=None, **configs):
...
@@ -581,6 +581,16 @@ def save(layer, path, input_spec=None, **configs):
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
%
type
(
layer
))
%
type
(
layer
))
# NOTE(chenweihang): If the input layer be wrapped by DataParallel,
# the args and kwargs of forward method will can't be parsed by
# function_spec, so here we save DataParallel._layers instead
# DataParallel it self
# NOTE(chenweihang): using inner_layer, do not change input layer
if
isinstance
(
layer
,
paddle
.
DataParallel
):
inner_layer
=
layer
.
_layers
else
:
inner_layer
=
layer
# path check
# path check
file_prefix
=
os
.
path
.
basename
(
path
)
file_prefix
=
os
.
path
.
basename
(
path
)
if
file_prefix
==
""
:
if
file_prefix
==
""
:
...
@@ -596,8 +606,8 @@ def save(layer, path, input_spec=None, **configs):
...
@@ -596,8 +606,8 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec
# avoid change user given input_spec
inner_input_spec
=
None
inner_input_spec
=
None
if
input_spec
is
not
None
:
if
input_spec
is
not
None
:
for
attr_func
in
dir
(
layer
):
for
attr_func
in
dir
(
inner_
layer
):
static_func
=
getattr
(
layer
,
attr_func
,
None
)
static_func
=
getattr
(
inner_
layer
,
attr_func
,
None
)
if
isinstance
(
static_func
,
if
isinstance
(
static_func
,
StaticFunction
)
and
'forward'
!=
attr_func
:
StaticFunction
)
and
'forward'
!=
attr_func
:
raise
ValueError
(
raise
ValueError
(
...
@@ -623,14 +633,14 @@ def save(layer, path, input_spec=None, **configs):
...
@@ -623,14 +633,14 @@ def save(layer, path, input_spec=None, **configs):
configs
=
_parse_save_configs
(
configs
)
configs
=
_parse_save_configs
(
configs
)
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
extra_var_info
=
dict
()
extra_var_info
=
dict
()
for
attr_func
in
dir
(
layer
):
for
attr_func
in
dir
(
inner_
layer
):
static_func
=
getattr
(
layer
,
attr_func
,
None
)
static_func
=
getattr
(
inner_
layer
,
attr_func
,
None
)
if
isinstance
(
static_func
,
StaticFunction
):
if
isinstance
(
static_func
,
StaticFunction
):
concrete_program
=
static_func
.
concrete_program
concrete_program
=
static_func
.
concrete_program
elif
'forward'
==
attr_func
:
elif
'forward'
==
attr_func
:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward
=
declarative
(
static_forward
=
declarative
(
layer
.
forward
,
input_spec
=
inner_input_spec
)
inner_
layer
.
forward
,
input_spec
=
inner_input_spec
)
concrete_program
=
static_forward
.
concrete_program
concrete_program
=
static_forward
.
concrete_program
# the input_spec has been used in declarative, which is equal to
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# @declarative with input_spec and jit.save without input_spec,
...
@@ -663,7 +673,7 @@ def save(layer, path, input_spec=None, **configs):
...
@@ -663,7 +673,7 @@ def save(layer, path, input_spec=None, **configs):
# saved to inference program may not need by dygraph Layer,
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
# we only record the state_dict variable's structured name
state_names_dict
=
dict
()
state_names_dict
=
dict
()
for
structured_name
,
var
in
six
.
iteritems
(
layer
.
state_dict
()):
for
structured_name
,
var
in
six
.
iteritems
(
inner_
layer
.
state_dict
()):
state_names_dict
[
var
.
name
]
=
structured_name
state_names_dict
[
var
.
name
]
=
structured_name
# 4. share parameters from Layer to scope & record var info
# 4. share parameters from Layer to scope & record var info
...
...
python/paddle/fluid/tests/unittests/test_jit_save_load.py
浏览文件 @
95a0f87b
...
@@ -863,5 +863,39 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
...
@@ -863,5 +863,39 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
layer
,
model_path
,
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
])])
layer
,
model_path
,
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
])])
class
TestJitSaveLoadDataParallel
(
unittest
.
TestCase
):
def
verify_inference_correctness
(
self
,
layer
,
path
):
layer
.
eval
()
loaded_layer
=
paddle
.
jit
.
load
(
path
)
loaded_layer
.
eval
()
# inference & compare
x
=
paddle
.
to_tensor
(
np
.
random
.
random
((
1
,
784
)).
astype
(
'float32'
))
pred
=
layer
(
x
).
numpy
()
loaded_pred
=
loaded_layer
(
x
).
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
pred
,
loaded_pred
),
msg
=
"Result diff when load and inference:
\n
layer result:
\n
{}
\n
"
\
"loaded layer result:
\n
{}"
.
format
(
pred
,
loaded_pred
))
def
test_jit_save_data_parallel_with_inputspec
(
self
):
layer
=
LinearNetNotDeclarative
(
784
,
1
)
layer
=
paddle
.
DataParallel
(
layer
)
path
=
"jit_save_data_parallel_with_inputspec/model"
paddle
.
jit
.
save
(
layer
=
layer
,
path
=
path
,
input_spec
=
[
InputSpec
(
shape
=
[
None
,
784
])])
self
.
verify_inference_correctness
(
layer
,
path
)
def
test_jit_save_data_parallel_with_to_static
(
self
):
layer
=
LinearNetWithInputSpec
(
784
,
1
)
layer
=
paddle
.
DataParallel
(
layer
)
path
=
"jit_save_data_parallel_with_to_static/model"
paddle
.
jit
.
save
(
layer
,
path
)
self
.
verify_inference_correctness
(
layer
,
path
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录