diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 76b2434ef501b88cec21e9b5fe0d6fe455e0906f..885e55570ba558270d99f0cf8a328dfcaa5067ce 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -345,6 +345,7 @@ from .autograd import set_grad_enabled # noqa: F401 from .autograd import is_grad_enabled # noqa: F401 from .framework import save # noqa: F401 from .framework import load # noqa: F401 +from .framework import async_save, clear_async_save_task_queue # noqa: F401 from .distributed import DataParallel # noqa: F401 from .framework import set_default_dtype # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index f9d1e0d70a721470ac91767c16ae27ccc1058ee5..08f72e3dc02c7870d56ba49e6ae83f0e929dcfe2 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -1049,6 +1049,88 @@ class TestSaveLoad(unittest.TestCase): ) +class TestAsyncSaveLoad(unittest.TestCase): + def setUp(self): + # enable dygraph mode + paddle.disable_static() + + # config seed + paddle.seed(SEED) + paddle.framework.random._manual_program_seed(SEED) + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def build_and_train_model(self): + # create network + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + + adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) + + # create data loader + # TODO: using new DataLoader cause unknown Timeout on windows, replace it + loader = random_batch_reader() + + # train + train(layer, loader, loss_fn, adam) + + return layer, adam + + def check_load_state_dict(self, orig_dict, load_dict): + for var_name, value in orig_dict.items(): + load_value = ( + load_dict[var_name].numpy() + if hasattr(load_dict[var_name], 'numpy') + else np.array(load_dict[var_name]) + ) + np.testing.assert_array_equal(value.numpy(), load_value) + + def test_async_save_load(self): + layer, opt = self.build_and_train_model() + + # save + layer_save_path = os.path.join( + self.temp_dir.name, "test_paddle_async_save_load.linear.pdparams" + ) + opt_save_path = os.path.join( + self.temp_dir.name, "test_paddle_async_save_load.linear.pdopt" + ) + layer_state_dict = layer.state_dict() + opt_state_dict = opt.state_dict() + + paddle.async_save( + layer_state_dict, layer_save_path, sync_other_task=True + ) + paddle.async_save(opt_state_dict, opt_save_path) + paddle.clear_async_save_task_queue() + + # load + load_layer_state_dict = paddle.load(layer_save_path) + load_opt_state_dict = paddle.load(opt_save_path) + + self.check_load_state_dict(layer_state_dict, load_layer_state_dict) + self.check_load_state_dict(opt_state_dict, load_opt_state_dict) + + # test assertion on illegal object + some_tuple_obj = (1, 2, 3) + tuple_save_path = os.path.join( + self.temp_dir.name, "test_paddle_async_save_load.tuple.pdparams" + ) + with self.assertRaises(TypeError): + paddle.async_save(some_tuple_obj, tuple_save_path) + + # test assertion on static graph + paddle.enable_static() + static_save_path = os.path.join( + self.temp_dir.name, + "static_mode_test/test_paddle_async_save_load.linear.pdparams", + ) + with self.assertRaises(ValueError): + paddle.async_save(layer_state_dict, static_save_path) + + class TestSaveLoadProgram(unittest.TestCase): def test_save_load_program(self): paddle.enable_static() diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 1aca2973e344e47ea41db0eddb6687ecce72aedc..a4be42f71ccd101de6d7e1497d08ad119473dd93 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -33,6 +33,7 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from ..fluid.dygraph.base import grad # noqa: F401 from .io import save # noqa: F401 from .io import load # noqa: F401 +from .io import async_save, clear_async_save_task_queue # noqa: F401 from .io_utils import _open_file_buffer # noqa: F401 from .io_utils import is_parameter # noqa: F401 diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 4c0f7e52c0266606dd8dabafccee57ac3cb3f608..4f8f958d187ac9eaa7d20541dd3840dc83330773 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -17,6 +17,7 @@ import copyreg import os import pickle import sys +import threading import warnings from collections.abc import Iterable @@ -48,6 +49,81 @@ from .io_utils import ( ) __all__ = [] +async_save_queue = [] + + +def clear_async_save_task_queue(): + ''' + wait until all async save task to be done. + ''' + while len(async_save_queue) > 0: + task = async_save_queue.pop() + if task and task.is_alive(): + task.join() + + +def async_save(obj, path, protocol=4, sync_other_task=False, **configs): + ''' + async version of paddle.save. + Note: + currently only support dygraph mode. + Note: + any argument passed through configs will be overrided by default setting. + Args: + obj(Object) : The object to be saved. + path(str|BytesIO) : The path/buffer of the object to be saved. + If saved in the current directory, the input path string will be used as the file name. + protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. + Default: 4 + sync_other_task(bool) : Determine whether to wait other async save task to be finished before this one be put in queue. + **configs(dict, optional): compatible argument to paddle.save, but will be overrided by default setting. + Examples: + .. code-block:: python + :name: code-example-1 + + import paddle + emb = paddle.nn.Embedding(10, 10) + layer_state_dict = emb.state_dict() + + # call paddle.async_save with the same style of paddle.save + paddle.async_save(layer_state_dict, "emb.pdparams") + for i in range(10): + # do some calculations here + # wait if any async_save task has not been done + paddle.clear_async_task_queue() + ''' + if not _non_static_mode(): + raise ValueError( + "async_save currently is not supported in static mode." + ) + if len(configs) > 0: + warnings.warn( + "configs are not supported in async mode, will be overided by default settings." + ) + + # TODO: make this part async + def move_state_dict_to_cpu(sd): + for k, v in sd.items(): + if isinstance(v, dict): + move_state_dict_to_cpu(v) + elif isinstance(v, core.eager.Tensor): + sd[k] = v.pin_memory() if core.is_compiled_with_cuda() else v + return + + if isinstance(obj, dict): + move_state_dict_to_cpu(obj) + elif isinstance(obj, core.eager.Tensor): + obj = obj.pin_memory() if core.is_compiled_with_cuda() else obj + else: + # other types are currently not supported + raise TypeError( + f"currently async_save does not support this type: {type(obj)}" + ) + if sync_other_task: + clear_async_save_task_queue() + t = threading.Thread(target=save, args=(obj, path, protocol)) + t.start() + async_save_queue.append(t) def _build_saved_state_dict(state_dict):