提交 99a2ab4b 编写于 作者: C changzherui

modify asyn save checkpoint bug

上级 28f873e9
......@@ -18,6 +18,7 @@ import os
import stat
import time
import threading
import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
......@@ -245,6 +246,12 @@ class ModelCheckpoint(Callback):
_to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt)
thread_list = threading.enumerate()
if len(thread_list) > 1:
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
......
......@@ -160,7 +160,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
data_list[key].append(data)
if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt")
thr.start()
else:
_exec_save(ckpt_file_name, data_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册