diff --git a/example/resnet50_imagenet2012_THOR/model/model_thor.py b/example/resnet50_imagenet2012_THOR/model/model_thor.py index b8cd27470c2c83b9bbf1021709b73b28a57ffeba..3106b044530fa6b3e0f23325713881aed71cc645 100644 --- a/example/resnet50_imagenet2012_THOR/model/model_thor.py +++ b/example/resnet50_imagenet2012_THOR/model/model_thor.py @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check from mindspore.train import amp -from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks +from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.parallel_utils import ParallelMode from model.dataset_helper import DatasetHelper @@ -374,7 +374,6 @@ class Model: self._train_network.set_broadcast_flag() # build callback list - list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch @@ -385,17 +384,17 @@ class Model: cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset - cb_params.list_callback = list_callback + cb_params.list_callback = callbacks - if dataset_sink_mode: - if context.get_context("mode") == context.PYNATIVE_MODE: + with _CallbackManager(callbacks) as list_callback: + if not dataset_sink_mode: + self._train_process(epoch, train_dataset, list_callback, cb_params) + elif context.get_context("mode") == context.PYNATIVE_MODE: logger.warning("The pynative mode cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink.") self._train_process(epoch, train_dataset, list_callback, cb_params) else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) - else: - self._train_process(epoch, train_dataset, list_callback, cb_params) def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ @@ -408,7 +407,7 @@ class Model: returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. - list_callback (_ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ iter_first_order = self._frequency - 1 @@ -473,7 +472,7 @@ class Model: returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. - list_callback (_ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ dataset_helper, _ = self._exec_preprocess(self._train_network, @@ -580,7 +579,7 @@ class Model: Args: valid_dataset (Dataset): Dataset to evaluate the model. - list_callback (ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: @@ -619,7 +618,7 @@ class Model: Args: valid_dataset (Dataset): Dataset to evaluate the model. - list_callback (ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: @@ -678,7 +677,6 @@ class Model: if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") - list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network cb_params.valid_dataset = valid_dataset @@ -691,9 +689,10 @@ class Model: self._clear_metrics() - if dataset_sink_mode: - return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) - return self._eval_process(valid_dataset, list_callback, cb_params) + with _CallbackManager(callbacks) as list_callback: + if dataset_sink_mode: + return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) + return self._eval_process(valid_dataset, list_callback, cb_params) def predict(self, *predict_data): """ diff --git a/mindspore/train/callback/callback.py b/mindspore/train/callback/callback.py index 7df804af0f0eca1a3d871d9b0ee7d92fee0d9a81..822fa3cb6686ef973c69f6237d95362e3b304c2b 100644 --- a/mindspore/train/callback/callback.py +++ b/mindspore/train/callback/callback.py @@ -18,6 +18,7 @@ import os import stat import shutil import time +from contextlib import ExitStack import numpy as np import mindspore.context as context @@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list): return ret -def _build_callbacks(callbacks): - """ - Contain a list of callback. - - Args: - callbacks (list): Callback functions list, Support None, a single Callback object, or a list. - - Returns: - List, a list of callback functions. - """ - if callbacks: - if isinstance(callbacks, tuple): - raise TypeError("Callbacks cannot be a tuple. Please check it.") - if not isinstance(callbacks, list): - callbacks = [callbacks] - else: - callbacks = [] - - excute_callbacks = [] - for cb in callbacks: - if cb is None or not isinstance(cb, Callback): - raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.") - excute_callbacks.append(cb) - - return _ListCallback(excute_callbacks) - - -class _ListCallback: - """ - Sequential execution of callback functions. - - Execute Callback functions at certain points. - - Args: - callbacks (list): Callback functions list. - """ - def __init__(self, callbacks): - super(_ListCallback, self).__init__() - self._callbacks = callbacks - - def begin(self, run_context): - """Called once before network training.""" - for cb in self._callbacks: - cb.begin(run_context) - - def epoch_begin(self, run_context): - """Called before each epoch begin.""" - for cb in self._callbacks: - cb.epoch_begin(run_context) - - def epoch_end(self, run_context): - """Called after each epoch finished.""" - for cb in self._callbacks: - cb.epoch_end(run_context) - - def step_begin(self, run_context): - """Called before each epoch begin.""" - for cb in self._callbacks: - cb.step_begin(run_context) - - def step_end(self, run_context): - """Called after each step finished.""" - for cb in self._callbacks: - cb.step_end(run_context) - - def end(self, run_context): - """Called once after network training.""" - for cb in self._callbacks: - cb.end(run_context) - - class Callback: """ - Abstract base class used to build a callback function. + Abstract base class used to build a callback class. Callbacks are context managers + which will be entered and exited when passing into the Model. + You can leverage this mechanism to init and release resources automatically. Callback function will execution some operating to the current step or epoch. @@ -369,8 +301,13 @@ class Callback: >>> print_cb = Print_info() >>> model.train(epoch, dataset, callbacks=print_cb) """ - def __init__(self): - pass + + def __enter__(self): + """Return the enter target.""" + return self + + def __exit__(self, *err): + """Release resources here if have any.""" def begin(self, run_context): """ @@ -421,6 +358,67 @@ class Callback: """ +class _CallbackManager(Callback): + """ + Sequential execution of callback functions. + + Execute Callback functions at certain points. + + Args: + callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. + """ + + def __init__(self, callbacks): + self._callbacks, self._stack = [], None + if isinstance(callbacks, Callback): + self._callbacks.append(callbacks) + elif callbacks is not None: + for cb in callbacks: + if not isinstance(cb, Callback): + raise TypeError("%r is not an instance of %r" % (cb, Callback)) + self._callbacks.append(cb) + + def __enter__(self): + if self._stack is None: + self._stack = ExitStack().__enter__() + self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] + return self + + def __exit__(self, *err): + return self._stack.__exit__(*err) + + def begin(self, run_context): + """Called once before network training.""" + for cb in self._callbacks: + cb.begin(run_context) + + def epoch_begin(self, run_context): + """Called before each epoch begin.""" + for cb in self._callbacks: + cb.epoch_begin(run_context) + + def epoch_end(self, run_context): + """Called after each epoch finished.""" + for cb in self._callbacks: + cb.epoch_end(run_context) + + def step_begin(self, run_context): + """Called before each epoch begin.""" + for cb in self._callbacks: + cb.step_begin(run_context) + + def step_end(self, run_context): + """Called after each step finished.""" + for cb in self._callbacks: + cb.step_end(run_context) + + def end(self, run_context): + """Called once after network training.""" + for cb in self._callbacks: + cb.end(run_context) + + + class SummaryStep(Callback): """ The summary callback class. @@ -435,6 +433,13 @@ class SummaryStep(Callback): raise ValueError("`flush_step` should be int and greater than 0") self._summary = summary self._flush_step = flush_step + def __enter__(self): + self._summary.__enter__() + return self + + def __exit__(self, *err): + return self._summary.__exit__(*err) + def step_end(self, run_context): """ diff --git a/mindspore/train/model.py b/mindspore/train/model.py index cf6a91216cdace00db17d74e6c6b516babf6d6bc..68042d8d0ae323523e0acfddeb61bedc0b13d596 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -19,7 +19,7 @@ from mindspore import log as logger from ..common.tensor import Tensor from ..nn.metrics import get_metrics from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool -from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks +from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check @@ -334,8 +334,6 @@ class Model: if self._parameter_broadcast: self._train_network.set_broadcast_flag() - # build callback list - list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch @@ -346,17 +344,18 @@ class Model: cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset - cb_params.list_callback = list_callback + cb_params.list_callback = callbacks - if dataset_sink_mode: - if context.get_context("mode") == context.PYNATIVE_MODE: + # build callback list + with _CallbackManager(callbacks) as list_callback: + if not dataset_sink_mode: + self._train_process(epoch, train_dataset, list_callback, cb_params) + elif context.get_context("mode") == context.PYNATIVE_MODE: logger.warning("The pynative mode cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink.") self._train_process(epoch, train_dataset, list_callback, cb_params) else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) - else: - self._train_process(epoch, train_dataset, list_callback, cb_params) def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ @@ -369,7 +368,7 @@ class Model: returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. - list_callback (_ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ dataset_helper, train_network = self._exec_preprocess(self._train_network, @@ -417,7 +416,7 @@ class Model: returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. - list_callback (_ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ dataset_helper, _ = self._exec_preprocess(self._train_network, @@ -524,7 +523,7 @@ class Model: Args: valid_dataset (Dataset): Dataset to evaluate the model. - list_callback (ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: @@ -563,7 +562,7 @@ class Model: Args: valid_dataset (Dataset): Dataset to evaluate the model. - list_callback (ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: @@ -622,7 +621,6 @@ class Model: if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") - list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network cb_params.valid_dataset = valid_dataset @@ -635,9 +633,10 @@ class Model: self._clear_metrics() - if dataset_sink_mode: - return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) - return self._eval_process(valid_dataset, list_callback, cb_params) + with _CallbackManager(callbacks) as list_callback: + if dataset_sink_mode: + return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) + return self._eval_process(valid_dataset, list_callback, cb_params) def predict(self, *predict_data): """ diff --git a/tests/st/networks/models/resnet50/src_thor/model_thor.py b/tests/st/networks/models/resnet50/src_thor/model_thor.py index c633d913aca53c17cb9f514c4d57a3fb45e700c4..ee799c4b740291c7df4a8341fe8c875b7f229ea0 100644 --- a/tests/st/networks/models/resnet50/src_thor/model_thor.py +++ b/tests/st/networks/models/resnet50/src_thor/model_thor.py @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check from mindspore.train import amp -from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks +from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.parallel_utils import ParallelMode from .dataset_helper import DatasetHelper @@ -392,7 +392,6 @@ class Model: self._train_network.set_broadcast_flag() # build callback list - list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch @@ -403,17 +402,17 @@ class Model: cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset - cb_params.list_callback = list_callback + cb_params.list_callback = callbacks - if dataset_sink_mode: - if context.get_context("mode") == context.PYNATIVE_MODE: + with _CallbackManager(callbacks) as list_callback: + if not dataset_sink_mode: + self._train_process(epoch, train_dataset, list_callback, cb_params) + elif context.get_context("mode") == context.PYNATIVE_MODE: logger.warning("The pynative mode cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink.") self._train_process(epoch, train_dataset, list_callback, cb_params) else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) - else: - self._train_process(epoch, train_dataset, list_callback, cb_params) def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ @@ -426,7 +425,7 @@ class Model: returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. - list_callback (_ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ iter_first_order = self._frequency - 1 @@ -490,7 +489,7 @@ class Model: returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. - list_callback (_ListCallback): Executor of callback list. Default: None. + list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ dataset_helper, _ = self._exec_preprocess(self._train_network, @@ -695,7 +694,6 @@ class Model: if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") - list_callback = _build_callbacks(callbacks) cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network cb_params.valid_dataset = valid_dataset @@ -708,9 +706,10 @@ class Model: self._clear_metrics() - if dataset_sink_mode: - return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) - return self._eval_process(valid_dataset, list_callback, cb_params) + with _CallbackManager(callbacks) as list_callback: + if dataset_sink_mode: + return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) + return self._eval_process(valid_dataset, list_callback, cb_params) def predict(self, *predict_data): """ diff --git a/tests/ut/python/train/summary/test_image_summary.py b/tests/ut/python/train/summary/test_image_summary.py index 5e5bf2b3c3a9125247187f5cb30d771561f0628e..e650442cd3f023792698d273263028086f2a9626 100644 --- a/tests/ut/python/train/summary/test_image_summary.py +++ b/tests/ut/python/train/summary/test_image_summary.py @@ -156,12 +156,19 @@ def get_dataset(): class ImageSummaryCallback: - def __init__(self, summaryRecord): - self._summaryRecord = summaryRecord + + def __init__(self, summary_record): + self._summary_record = summary_record + + def __enter__(self): + return self + + def __exit__(self, *err): + pass def record(self, step, train_network=None): - self._summaryRecord.record(step, train_network) - self._summaryRecord.flush() + self._summary_record.record(step, train_network) + self._summary_record.flush() def test_image_summary_train(): diff --git a/tests/ut/python/train/test_training.py b/tests/ut/python/train/test_training.py index 92625e54f970e4c57063705011ef92be15a95c98..a007d18571b20d7fd389725b5ced26267ec87308 100644 --- a/tests/ut/python/train/test_training.py +++ b/tests/ut/python/train/test_training.py @@ -180,6 +180,12 @@ class CallbackTest: def __init__(self): pass + def __enter__(self): + return self + + def __exit__(self, *err): + pass + def record(self, step, *args): print(step, args) diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index da564e3f9c2cabfa69529f825236614b185772c6..b0879ebc0ed6ddefda250fcf6ec6f02256fcebec 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -15,6 +15,7 @@ """test callback function.""" import os import stat +from unittest import mock import numpy as np import pytest @@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ - _build_callbacks, CheckpointConfig, _set_cur_net + _CallbackManager, Callback, CheckpointConfig, _set_cur_net class Net(nn.Cell): @@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode(): run_context = RunContext(cb_params) loss_cb = LossMonitor(1) callbacks = [loss_cb] - callbacklist = _build_callbacks(callbacks) - callbacklist.begin(run_context) - callbacklist.epoch_begin(run_context) - callbacklist.step_begin(run_context) - callbacklist.step_end(run_context) - callbacklist.epoch_end(run_context) - callbacklist.end(run_context) + with _CallbackManager(callbacks) as callbacklist: + callbacklist.begin(run_context) + callbacklist.epoch_begin(run_context) + callbacklist.step_begin(run_context) + callbacklist.step_end(run_context) + callbacklist.epoch_end(run_context) + callbacklist.end(run_context) def test_loss_monitor_normal_mode(): @@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds(): ckpt_cb2.step_end(run_context) -def test_build_callbacks(): - """Test_build_callbacks.""" +def test_CallbackManager(): + """TestCallbackManager.""" ck_obj = ModelCheckpoint() loss_cb_1 = LossMonitor(1) callbacks = [None] with pytest.raises(TypeError): - callbacks = _build_callbacks(callbacks) + _CallbackManager(callbacks) callbacks = ['Error'] with pytest.raises(TypeError): - callbacks = _build_callbacks(callbacks) + _CallbackManager(callbacks) callbacks = [ck_obj, loss_cb_1, 'Error', None] with pytest.raises(TypeError): - _ = _build_callbacks(callbacks) + _CallbackManager(callbacks) + + +def test_CallbackManager_exit_called(): + with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: + cb1, cb2 = Callback(), Callback() + with _CallbackManager([cb1, cb2]): + pass + for call_args in mock_exit.call_args_list: + assert call_args == mock.call(mock.ANY, None, None, None) + assert mock_exit.call_count == 2 + + +def test_CallbackManager_exit_called_when_raises(): + with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: + cb1, cb2 = Callback(), Callback() + with pytest.raises(ValueError): + with _CallbackManager([cb1, cb2]): + raise ValueError() + for call_args in mock_exit.call_args_list: + assert call_args == mock.call(*[mock.ANY] * 4) + assert mock_exit.call_count == 2 + + +def test_CallbackManager_begin_called(): + context = dict() + with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin: + cb1, cb2 = Callback(), Callback() + with _CallbackManager([cb1, cb2]) as cm: + cm.begin(context) + for call_args in mock_begin.call_args_list: + assert call_args == mock.call(context) + assert mock_begin.call_count == 2 def test_RunContext(): """Test RunContext.""" context_err = 666 with pytest.raises(TypeError): - _ = RunContext(context_err) + RunContext(context_err) cb_params = _InternalCallbackParam() cb_params.member1 = 1