Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
483ba282
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
483ba282
编写于
9月 15, 2022
作者:
H
Hui Zhang
提交者:
GitHub
9月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[jit] skip forward save (#45901)
* skip forward save * fix bug * more ci for jit skip forward
上级
b042a3b1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
5 deletion
+49
-5
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+15
-5
python/paddle/fluid/tests/unittests/test_jit_save_load.py
python/paddle/fluid/tests/unittests/test_jit_save_load.py
+34
-0
未找到文件。
python/paddle/fluid/dygraph/jit.py
浏览文件 @
483ba282
...
...
@@ -381,7 +381,8 @@ class _SaveLoadConfig(object):
def
_parse_save_configs
(
configs
):
supported_configs
=
[
'output_spec'
,
"with_hook"
,
"combine_params"
,
"clip_extra"
'output_spec'
,
"with_hook"
,
"combine_params"
,
"clip_extra"
,
"skip_forward"
]
# input check
...
...
@@ -397,6 +398,7 @@ def _parse_save_configs(configs):
inner_config
.
with_hook
=
configs
.
get
(
'with_hook'
,
False
)
inner_config
.
combine_params
=
configs
.
get
(
"combine_params"
,
False
)
inner_config
.
clip_extra
=
configs
.
get
(
"clip_extra"
,
False
)
inner_config
.
skip_forward
=
configs
.
get
(
"skip_forward"
,
False
)
return
inner_config
...
...
@@ -522,7 +524,10 @@ def _build_load_path_and_config(path, config):
"don't know which one to load, please make sure that the specified target "
"of ``path`` is unique."
%
(
path
,
path
))
elif
not
prefix_format_exist
and
not
directory_format_exist
:
raise
ValueError
(
"The ``path`` (%s) to load model not exists."
%
path
)
raise
ValueError
(
"The ``path`` (%s) to load model not exists. "
"Please make sure that *.pdmodel exists or "
"don't using ``skip_forward=True`` to jit.save."
%
path
)
else
:
if
prefix_format_exist
:
file_prefix
=
os
.
path
.
basename
(
path
)
...
...
@@ -906,6 +911,7 @@ def save(layer, path, input_spec=None, **configs):
combine_vars
=
{}
property_vals
=
[]
# (value, key)
concrete_program
=
None
for
attr_func
in
functions
:
if
isinstance
(
layer
,
Layer
):
static_func
=
getattr
(
inner_layer
,
attr_func
,
None
)
...
...
@@ -921,6 +927,10 @@ def save(layer, path, input_spec=None, **configs):
concrete_program
=
static_func
.
concrete_program_specify_input_spec
(
inner_input_spec
,
with_hook
=
with_hook
)
elif
'forward'
==
attr_func
:
if
configs
.
skip_forward
:
# do not jit.save forward function
continue
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same structure
# as original input_spec here.
...
...
@@ -1100,10 +1110,10 @@ def save(layer, path, input_spec=None, **configs):
# file `***.pdiparams.info`
# "layer" can only be Layer or function or StaticFunction.
contain_parameter
=
False
for
var
in
concrete_program
.
main_program
.
list_vars
():
contain_parameter
|=
isinstance
(
var
,
Parameter
)
if
concrete_program
is
not
None
:
for
var
in
concrete_program
.
main_program
.
list_vars
():
contain_parameter
|=
isinstance
(
var
,
Parameter
)
if
(
isinstance
(
layer
,
Layer
)
or
contain_parameter
)
and
extra_var_info
:
with
scope_guard
(
scope
):
...
...
python/paddle/fluid/tests/unittests/test_jit_save_load.py
浏览文件 @
483ba282
...
...
@@ -1740,6 +1740,40 @@ class TestInputSpecCompatibility(unittest.TestCase):
shutil
.
rmtree
(
save_dir
)
class
NotJitForward
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
NotJitForward
,
self
).
__init__
()
def
forward
(
self
,
x
,
y
):
return
x
+
y
class
TestNotJitForward
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
def
tearDown
(
self
):
self
.
temp_dir
.
cleanup
()
def
test_jit_not_save_forward
(
self
):
layer
=
NotJitForward
()
save_dir
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
"jit_not_save_forward"
)
path
=
save_dir
+
"/model"
paddle
.
jit
.
save
(
layer
=
layer
,
path
=
path
,
skip_forward
=
True
)
self
.
assertTrue
(
not
os
.
path
.
exists
(
path
+
".pdmodel"
))
self
.
assertTrue
(
not
os
.
path
.
exists
(
path
+
".pdparam"
))
with
self
.
assertRaises
(
ValueError
):
paddle
.
jit
.
load
(
path
=
path
)
shutil
.
rmtree
(
save_dir
)
if
__name__
==
'__main__'
:
with
fluid
.
framework
.
_test_eager_guard
():
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录