提交 08a496d0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1276 Callbacks as context managers

Merge pull request !1276 from 李鸿章/context_manager
......@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from model.dataset_helper import DatasetHelper
......@@ -374,7 +374,6 @@ class Model:
self._train_network.set_broadcast_flag()
# build callback list
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
......@@ -385,17 +384,17 @@ class Model:
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = list_callback
cb_params.list_callback = callbacks
if dataset_sink_mode:
if context.get_context("mode") == context.PYNATIVE_MODE:
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_process(epoch, train_dataset, list_callback, cb_params)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
......@@ -408,7 +407,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
iter_first_order = self._frequency - 1
......@@ -473,7 +472,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, _ = self._exec_preprocess(self._train_network,
......@@ -580,7 +579,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
......@@ -619,7 +618,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
......@@ -678,7 +677,6 @@ class Model:
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
......@@ -691,9 +689,10 @@ class Model:
self._clear_metrics()
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def predict(self, *predict_data):
"""
......
......@@ -18,6 +18,7 @@ import os
import stat
import shutil
import time
from contextlib import ExitStack
import numpy as np
import mindspore.context as context
......@@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list):
return ret
def _build_callbacks(callbacks):
"""
Contain a list of callback.
Args:
callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
Returns:
List, a list of callback functions.
"""
if callbacks:
if isinstance(callbacks, tuple):
raise TypeError("Callbacks cannot be a tuple. Please check it.")
if not isinstance(callbacks, list):
callbacks = [callbacks]
else:
callbacks = []
excute_callbacks = []
for cb in callbacks:
if cb is None or not isinstance(cb, Callback):
raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.")
excute_callbacks.append(cb)
return _ListCallback(excute_callbacks)
class _ListCallback:
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (list): Callback functions list.
"""
def __init__(self, callbacks):
super(_ListCallback, self).__init__()
self._callbacks = callbacks
def begin(self, run_context):
"""Called once before network training."""
for cb in self._callbacks:
cb.begin(run_context)
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.epoch_begin(run_context)
def epoch_end(self, run_context):
"""Called after each epoch finished."""
for cb in self._callbacks:
cb.epoch_end(run_context)
def step_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.step_begin(run_context)
def step_end(self, run_context):
"""Called after each step finished."""
for cb in self._callbacks:
cb.step_end(run_context)
def end(self, run_context):
"""Called once after network training."""
for cb in self._callbacks:
cb.end(run_context)
class Callback:
"""
Abstract base class used to build a callback function.
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
Callback function will execution some operating to the current step or epoch.
......@@ -369,8 +301,13 @@ class Callback:
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)
"""
def __init__(self):
pass
def __enter__(self):
"""Return the enter target."""
return self
def __exit__(self, *err):
"""Release resources here if have any."""
def begin(self, run_context):
"""
......@@ -421,6 +358,67 @@ class Callback:
"""
class _CallbackManager(Callback):
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
"""
def __init__(self, callbacks):
self._callbacks, self._stack = [], None
if isinstance(callbacks, Callback):
self._callbacks.append(callbacks)
elif callbacks is not None:
for cb in callbacks:
if not isinstance(cb, Callback):
raise TypeError("%r is not an instance of %r" % (cb, Callback))
self._callbacks.append(cb)
def __enter__(self):
if self._stack is None:
self._stack = ExitStack().__enter__()
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
return self
def __exit__(self, *err):
return self._stack.__exit__(*err)
def begin(self, run_context):
"""Called once before network training."""
for cb in self._callbacks:
cb.begin(run_context)
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.epoch_begin(run_context)
def epoch_end(self, run_context):
"""Called after each epoch finished."""
for cb in self._callbacks:
cb.epoch_end(run_context)
def step_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.step_begin(run_context)
def step_end(self, run_context):
"""Called after each step finished."""
for cb in self._callbacks:
cb.step_end(run_context)
def end(self, run_context):
"""Called once after network training."""
for cb in self._callbacks:
cb.end(run_context)
class SummaryStep(Callback):
"""
The summary callback class.
......@@ -435,6 +433,13 @@ class SummaryStep(Callback):
raise ValueError("`flush_step` should be int and greater than 0")
self._summary = summary
self._flush_step = flush_step
def __enter__(self):
self._summary.__enter__()
return self
def __exit__(self, *err):
return self._summary.__exit__(*err)
def step_end(self, run_context):
"""
......
......@@ -19,7 +19,7 @@ from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
......@@ -334,8 +334,6 @@ class Model:
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
# build callback list
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
......@@ -346,17 +344,18 @@ class Model:
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = list_callback
cb_params.list_callback = callbacks
if dataset_sink_mode:
if context.get_context("mode") == context.PYNATIVE_MODE:
# build callback list
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_process(epoch, train_dataset, list_callback, cb_params)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
......@@ -369,7 +368,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, train_network = self._exec_preprocess(self._train_network,
......@@ -417,7 +416,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, _ = self._exec_preprocess(self._train_network,
......@@ -524,7 +523,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
......@@ -563,7 +562,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
......@@ -622,7 +621,6 @@ class Model:
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
......@@ -635,9 +633,10 @@ class Model:
self._clear_metrics()
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def predict(self, *predict_data):
"""
......
......@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper
......@@ -392,7 +392,6 @@ class Model:
self._train_network.set_broadcast_flag()
# build callback list
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
......@@ -403,17 +402,17 @@ class Model:
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = list_callback
cb_params.list_callback = callbacks
if dataset_sink_mode:
if context.get_context("mode") == context.PYNATIVE_MODE:
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_process(epoch, train_dataset, list_callback, cb_params)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
......@@ -426,7 +425,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
iter_first_order = self._frequency - 1
......@@ -490,7 +489,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, _ = self._exec_preprocess(self._train_network,
......@@ -695,7 +694,6 @@ class Model:
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
......@@ -708,9 +706,10 @@ class Model:
self._clear_metrics()
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def predict(self, *predict_data):
"""
......
......@@ -156,12 +156,19 @@ def get_dataset():
class ImageSummaryCallback:
def __init__(self, summaryRecord):
self._summaryRecord = summaryRecord
def __init__(self, summary_record):
self._summary_record = summary_record
def __enter__(self):
return self
def __exit__(self, *err):
pass
def record(self, step, train_network=None):
self._summaryRecord.record(step, train_network)
self._summaryRecord.flush()
self._summary_record.record(step, train_network)
self._summary_record.flush()
def test_image_summary_train():
......
......@@ -180,6 +180,12 @@ class CallbackTest:
def __init__(self):
pass
def __enter__(self):
return self
def __exit__(self, *err):
pass
def record(self, step, *args):
print(step, args)
......
......@@ -15,6 +15,7 @@
"""test callback function."""
import os
import stat
from unittest import mock
import numpy as np
import pytest
......@@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
_build_callbacks, CheckpointConfig, _set_cur_net
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
class Net(nn.Cell):
......@@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode():
run_context = RunContext(cb_params)
loss_cb = LossMonitor(1)
callbacks = [loss_cb]
callbacklist = _build_callbacks(callbacks)
callbacklist.begin(run_context)
callbacklist.epoch_begin(run_context)
callbacklist.step_begin(run_context)
callbacklist.step_end(run_context)
callbacklist.epoch_end(run_context)
callbacklist.end(run_context)
with _CallbackManager(callbacks) as callbacklist:
callbacklist.begin(run_context)
callbacklist.epoch_begin(run_context)
callbacklist.step_begin(run_context)
callbacklist.step_end(run_context)
callbacklist.epoch_end(run_context)
callbacklist.end(run_context)
def test_loss_monitor_normal_mode():
......@@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds():
ckpt_cb2.step_end(run_context)
def test_build_callbacks():
"""Test_build_callbacks."""
def test_CallbackManager():
"""TestCallbackManager."""
ck_obj = ModelCheckpoint()
loss_cb_1 = LossMonitor(1)
callbacks = [None]
with pytest.raises(TypeError):
callbacks = _build_callbacks(callbacks)
_CallbackManager(callbacks)
callbacks = ['Error']
with pytest.raises(TypeError):
callbacks = _build_callbacks(callbacks)
_CallbackManager(callbacks)
callbacks = [ck_obj, loss_cb_1, 'Error', None]
with pytest.raises(TypeError):
_ = _build_callbacks(callbacks)
_CallbackManager(callbacks)
def test_CallbackManager_exit_called():
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
cb1, cb2 = Callback(), Callback()
with _CallbackManager([cb1, cb2]):
pass
for call_args in mock_exit.call_args_list:
assert call_args == mock.call(mock.ANY, None, None, None)
assert mock_exit.call_count == 2
def test_CallbackManager_exit_called_when_raises():
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
cb1, cb2 = Callback(), Callback()
with pytest.raises(ValueError):
with _CallbackManager([cb1, cb2]):
raise ValueError()
for call_args in mock_exit.call_args_list:
assert call_args == mock.call(*[mock.ANY] * 4)
assert mock_exit.call_count == 2
def test_CallbackManager_begin_called():
context = dict()
with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
cb1, cb2 = Callback(), Callback()
with _CallbackManager([cb1, cb2]) as cm:
cm.begin(context)
for call_args in mock_begin.call_args_list:
assert call_args == mock.call(context)
assert mock_begin.call_count == 2
def test_RunContext():
"""Test RunContext."""
context_err = 666
with pytest.raises(TypeError):
_ = RunContext(context_err)
RunContext(context_err)
cb_params = _InternalCallbackParam()
cb_params.member1 = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册