未验证 提交 8d42540f 编写于 作者: T Tian 提交者: GitHub

add paddle.async_save to reduce time cost by checkpoint saving (#55115)

* add paddle.async_save to reduce time cost by checkpoint saving

* adapt save_for_auto_inference to paddle.async_save

* modify UT

* modify UT

* fix on cpu only version

* revert commit on save_auto_inference

* fix threading
上级 64f25adf
......@@ -345,6 +345,7 @@ from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
from .framework import load # noqa: F401
from .framework import async_save, clear_async_save_task_queue # noqa: F401
from .distributed import DataParallel # noqa: F401
from .framework import set_default_dtype # noqa: F401
......
......@@ -1049,6 +1049,88 @@ class TestSaveLoad(unittest.TestCase):
)
class TestAsyncSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
# config seed
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def build_and_train_model(self):
# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
# create data loader
# TODO: using new DataLoader cause unknown Timeout on windows, replace it
loader = random_batch_reader()
# train
train(layer, loader, loss_fn, adam)
return layer, adam
def check_load_state_dict(self, orig_dict, load_dict):
for var_name, value in orig_dict.items():
load_value = (
load_dict[var_name].numpy()
if hasattr(load_dict[var_name], 'numpy')
else np.array(load_dict[var_name])
)
np.testing.assert_array_equal(value.numpy(), load_value)
def test_async_save_load(self):
layer, opt = self.build_and_train_model()
# save
layer_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.linear.pdparams"
)
opt_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.linear.pdopt"
)
layer_state_dict = layer.state_dict()
opt_state_dict = opt.state_dict()
paddle.async_save(
layer_state_dict, layer_save_path, sync_other_task=True
)
paddle.async_save(opt_state_dict, opt_save_path)
paddle.clear_async_save_task_queue()
# load
load_layer_state_dict = paddle.load(layer_save_path)
load_opt_state_dict = paddle.load(opt_save_path)
self.check_load_state_dict(layer_state_dict, load_layer_state_dict)
self.check_load_state_dict(opt_state_dict, load_opt_state_dict)
# test assertion on illegal object
some_tuple_obj = (1, 2, 3)
tuple_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.tuple.pdparams"
)
with self.assertRaises(TypeError):
paddle.async_save(some_tuple_obj, tuple_save_path)
# test assertion on static graph
paddle.enable_static()
static_save_path = os.path.join(
self.temp_dir.name,
"static_mode_test/test_paddle_async_save_load.linear.pdparams",
)
with self.assertRaises(ValueError):
paddle.async_save(layer_state_dict, static_save_path)
class TestSaveLoadProgram(unittest.TestCase):
def test_save_load_program(self):
paddle.enable_static()
......
......@@ -33,6 +33,7 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..fluid.dygraph.base import grad # noqa: F401
from .io import save # noqa: F401
from .io import load # noqa: F401
from .io import async_save, clear_async_save_task_queue # noqa: F401
from .io_utils import _open_file_buffer # noqa: F401
from .io_utils import is_parameter # noqa: F401
......
......@@ -17,6 +17,7 @@ import copyreg
import os
import pickle
import sys
import threading
import warnings
from collections.abc import Iterable
......@@ -48,6 +49,81 @@ from .io_utils import (
)
__all__ = []
async_save_queue = []
def clear_async_save_task_queue():
'''
wait until all async save task to be done.
'''
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join()
def async_save(obj, path, protocol=4, sync_other_task=False, **configs):
'''
async version of paddle.save.
Note:
currently only support dygraph mode.
Note:
any argument passed through configs will be overrided by default setting.
Args:
obj(Object) : The object to be saved.
path(str|BytesIO) : The path/buffer of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 4
sync_other_task(bool) : Determine whether to wait other async save task to be finished before this one be put in queue.
**configs(dict, optional): compatible argument to paddle.save, but will be overrided by default setting.
Examples:
.. code-block:: python
:name: code-example-1
import paddle
emb = paddle.nn.Embedding(10, 10)
layer_state_dict = emb.state_dict()
# call paddle.async_save with the same style of paddle.save
paddle.async_save(layer_state_dict, "emb.pdparams")
for i in range(10):
# do some calculations here
# wait if any async_save task has not been done
paddle.clear_async_task_queue()
'''
if not _non_static_mode():
raise ValueError(
"async_save currently is not supported in static mode."
)
if len(configs) > 0:
warnings.warn(
"configs are not supported in async mode, will be overided by default settings."
)
# TODO: make this part async
def move_state_dict_to_cpu(sd):
for k, v in sd.items():
if isinstance(v, dict):
move_state_dict_to_cpu(v)
elif isinstance(v, core.eager.Tensor):
sd[k] = v.pin_memory() if core.is_compiled_with_cuda() else v
return
if isinstance(obj, dict):
move_state_dict_to_cpu(obj)
elif isinstance(obj, core.eager.Tensor):
obj = obj.pin_memory() if core.is_compiled_with_cuda() else obj
else:
# other types are currently not supported
raise TypeError(
f"currently async_save does not support this type: {type(obj)}"
)
if sync_other_task:
clear_async_save_task_queue()
t = threading.Thread(target=save, args=(obj, path, protocol))
t.start()
async_save_queue.append(t)
def _build_saved_state_dict(state_dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册