提交 99a2ab4b 编写于 作者: C changzherui

modify asyn save checkpoint bug

上级 28f873e9
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import stat import stat
import time import time
import threading
import mindspore.context as context import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative from mindspore._checkparam import check_bool, check_int_non_negative
...@@ -245,6 +246,12 @@ class ModelCheckpoint(Callback): ...@@ -245,6 +246,12 @@ class ModelCheckpoint(Callback):
_to_save_last_ckpt = True _to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt) self._save_ckpt(cb_params, _to_save_last_ckpt)
thread_list = threading.enumerate()
if len(thread_list) > 1:
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell() destroy_allgather_cell()
......
...@@ -160,7 +160,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): ...@@ -160,7 +160,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
data_list[key].append(data) data_list[key].append(data)
if async_save: if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list)) thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt")
thr.start() thr.start()
else: else:
_exec_save(ckpt_file_name, data_list) _exec_save(ckpt_file_name, data_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册