提交 4c0d12fd 编写于 作者: L Li Hongzhang

enhance callback module and strongly check callbacks is list or not

上级 932b7649
......@@ -18,6 +18,8 @@ from ._callback import Callback
from ._callback import CallbackManager as _CallbackManager
from ._callback import InternalCallbackParam as _InternalCallbackParam
from ._callback import RunContext
from ._callback import checkpoint_cb_for_save_op as _checkpoint_cb_for_save_op
from ._callback import set_cur_net as _set_cur_net
from ._checkpoint import CheckpointConfig
from ._checkpoint import CheckpointManager as _CheckpointManager
from ._checkpoint import ModelCheckpoint
......
......@@ -160,16 +160,25 @@ class CallbackManager(Callback):
self._callbacks, self._stack = [], None
if isinstance(callbacks, Callback):
self._callbacks.append(callbacks)
elif callbacks is not None:
elif isinstance(callbacks, list):
for cb in callbacks:
if not isinstance(cb, Callback):
raise TypeError("%r is not an instance of %r" % (cb, Callback))
raise TypeError("The 'callbacks' contains not-a-Callback item.")
self._callbacks.append(cb)
elif callbacks is not None:
raise TypeError("The 'callbacks' is not a Callback or a list of Callback.")
def __enter__(self):
if self._stack is None:
self._stack = ExitStack().__enter__()
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
callbacks, self._stack = [], ExitStack().__enter__()
for callback in self._callbacks:
target = self._stack.enter_context(callback)
if not isinstance(target, Callback):
logger.warning("Please return 'self' or a Callback as the enter target.")
callbacks.append(callback)
else:
callbacks.append(target)
self._callbacks = callbacks
return self
def __exit__(self, *err):
......
......@@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
_CallbackManager, Callback, CheckpointConfig
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
class Net(nn.Cell):
......@@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op():
one_param['name'] = "conv1.weight"
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
parameter_list.append(one_param)
checkpoint_cb_for_save_op(parameter_list)
_checkpoint_cb_for_save_op(parameter_list)
def test_checkpoint_cb_for_save_op_update_net():
......@@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net():
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
parameter_list.append(one_param)
net = Net()
set_cur_net(net)
checkpoint_cb_for_save_op(parameter_list)
_set_cur_net(net)
_checkpoint_cb_for_save_op(parameter_list)
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册