未验证 提交 e25e9471 编写于 作者: H Hui Zhang 提交者: GitHub

[cherry-pick][jit] Jit skip forward (#45926)

* skip forward save

* fix bug

* more ci for jit skip forward
上级 8caaf85a
...@@ -381,7 +381,8 @@ class _SaveLoadConfig(object): ...@@ -381,7 +381,8 @@ class _SaveLoadConfig(object):
def _parse_save_configs(configs): def _parse_save_configs(configs):
supported_configs = [ supported_configs = [
'output_spec', "with_hook", "combine_params", "clip_extra" 'output_spec', "with_hook", "combine_params", "clip_extra",
"skip_forward"
] ]
# input check # input check
...@@ -397,6 +398,7 @@ def _parse_save_configs(configs): ...@@ -397,6 +398,7 @@ def _parse_save_configs(configs):
inner_config.with_hook = configs.get('with_hook', False) inner_config.with_hook = configs.get('with_hook', False)
inner_config.combine_params = configs.get("combine_params", False) inner_config.combine_params = configs.get("combine_params", False)
inner_config.clip_extra = configs.get("clip_extra", False) inner_config.clip_extra = configs.get("clip_extra", False)
inner_config.skip_forward = configs.get("skip_forward", False)
return inner_config return inner_config
...@@ -522,7 +524,10 @@ def _build_load_path_and_config(path, config): ...@@ -522,7 +524,10 @@ def _build_load_path_and_config(path, config):
"don't know which one to load, please make sure that the specified target " "don't know which one to load, please make sure that the specified target "
"of ``path`` is unique." % (path, path)) "of ``path`` is unique." % (path, path))
elif not prefix_format_exist and not directory_format_exist: elif not prefix_format_exist and not directory_format_exist:
raise ValueError("The ``path`` (%s) to load model not exists." % path) raise ValueError("The ``path`` (%s) to load model not exists. "
"Please make sure that *.pdmodel exists or "
"don't using ``skip_forward=True`` to jit.save." %
path)
else: else:
if prefix_format_exist: if prefix_format_exist:
file_prefix = os.path.basename(path) file_prefix = os.path.basename(path)
...@@ -906,6 +911,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -906,6 +911,7 @@ def save(layer, path, input_spec=None, **configs):
combine_vars = {} combine_vars = {}
property_vals = [] # (value, key) property_vals = [] # (value, key)
concrete_program = None
for attr_func in functions: for attr_func in functions:
if isinstance(layer, Layer): if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None) static_func = getattr(inner_layer, attr_func, None)
...@@ -921,6 +927,10 @@ def save(layer, path, input_spec=None, **configs): ...@@ -921,6 +927,10 @@ def save(layer, path, input_spec=None, **configs):
concrete_program = static_func.concrete_program_specify_input_spec( concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec, with_hook=with_hook) inner_input_spec, with_hook=with_hook)
elif 'forward' == attr_func: elif 'forward' == attr_func:
if configs.skip_forward:
# do not jit.save forward function
continue
# transform in jit.save, if input_spec is incomplete, declarative will throw error # transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same structure # inner_input_spec is list[InputSpec], it should be packed with same structure
# as original input_spec here. # as original input_spec here.
...@@ -1100,10 +1110,10 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1100,10 +1110,10 @@ def save(layer, path, input_spec=None, **configs):
# file `***.pdiparams.info` # file `***.pdiparams.info`
# "layer" can only be Layer or function or StaticFunction. # "layer" can only be Layer or function or StaticFunction.
contain_parameter = False contain_parameter = False
for var in concrete_program.main_program.list_vars(): if concrete_program is not None:
contain_parameter |= isinstance(var, Parameter) for var in concrete_program.main_program.list_vars():
contain_parameter |= isinstance(var, Parameter)
if (isinstance(layer, Layer) or contain_parameter) and extra_var_info: if (isinstance(layer, Layer) or contain_parameter) and extra_var_info:
with scope_guard(scope): with scope_guard(scope):
......
...@@ -1740,6 +1740,40 @@ class TestInputSpecCompatibility(unittest.TestCase): ...@@ -1740,6 +1740,40 @@ class TestInputSpecCompatibility(unittest.TestCase):
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
class NotJitForward(paddle.nn.Layer):
def __init__(self):
super(NotJitForward, self).__init__()
def forward(self, x, y):
return x + y
class TestNotJitForward(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_jit_not_save_forward(self):
layer = NotJitForward()
save_dir = os.path.join(self.temp_dir.name, "jit_not_save_forward")
path = save_dir + "/model"
paddle.jit.save(layer=layer, path=path, skip_forward=True)
self.assertTrue(not os.path.exists(path + ".pdmodel"))
self.assertTrue(not os.path.exists(path + ".pdparam"))
with self.assertRaises(ValueError):
paddle.jit.load(path=path)
shutil.rmtree(save_dir)
if __name__ == '__main__': if __name__ == '__main__':
with fluid.framework._test_eager_guard(): with fluid.framework._test_eager_guard():
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册