Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
08a496d0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
08a496d0
编写于
6月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1276 Callbacks as context managers
Merge pull request !1276 from 李鸿章/context_manager
上级
8cb3859b
ee438aaf
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
184 addition
and
136 deletion
+184
-136
example/resnet50_imagenet2012_THOR/model/model_thor.py
example/resnet50_imagenet2012_THOR/model/model_thor.py
+14
-15
mindspore/train/callback/callback.py
mindspore/train/callback/callback.py
+79
-74
mindspore/train/model.py
mindspore/train/model.py
+15
-16
tests/st/networks/models/resnet50/src_thor/model_thor.py
tests/st/networks/models/resnet50/src_thor/model_thor.py
+12
-13
tests/ut/python/train/summary/test_image_summary.py
tests/ut/python/train/summary/test_image_summary.py
+11
-4
tests/ut/python/train/test_training.py
tests/ut/python/train/test_training.py
+6
-0
tests/ut/python/utils/test_callback.py
tests/ut/python/utils/test_callback.py
+47
-14
未找到文件。
example/resnet50_imagenet2012_THOR/model/model_thor.py
浏览文件 @
08a496d0
...
...
@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from
mindspore.parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
from
mindspore.train
import
amp
from
mindspore.train.callback.callback
import
_InternalCallbackParam
,
RunContext
,
_
build_callbacks
from
mindspore.train.callback.callback
import
_InternalCallbackParam
,
RunContext
,
_
CallbackManager
from
mindspore.train.parallel_utils
import
ParallelMode
from
model.dataset_helper
import
DatasetHelper
...
...
@@ -374,7 +374,6 @@ class Model:
self
.
_train_network
.
set_broadcast_flag
()
# build callback list
list_callback
=
_build_callbacks
(
callbacks
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
epoch_num
=
epoch
...
...
@@ -385,17 +384,17 @@ class Model:
cb_params
.
parallel_mode
=
self
.
_parallel_mode
cb_params
.
device_number
=
self
.
_device_number
cb_params
.
train_dataset
=
train_dataset
cb_params
.
list_callback
=
list_callback
cb_params
.
list_callback
=
callbacks
if
dataset_sink_mode
:
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
with
_CallbackManager
(
callbacks
)
as
list_callback
:
if
not
dataset_sink_mode
:
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
elif
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
logger
.
warning
(
"The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink."
)
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
"""
...
...
@@ -408,7 +407,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (
_List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
iter_first_order
=
self
.
_frequency
-
1
...
...
@@ -473,7 +472,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (
_List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper
,
_
=
self
.
_exec_preprocess
(
self
.
_train_network
,
...
...
@@ -580,7 +579,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (
List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
...
...
@@ -619,7 +618,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (
List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
...
...
@@ -678,7 +677,6 @@ class Model:
if
not
self
.
_metric_fns
:
raise
ValueError
(
"metric fn can not be None or empty."
)
list_callback
=
_build_callbacks
(
callbacks
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
eval_network
=
self
.
_eval_network
cb_params
.
valid_dataset
=
valid_dataset
...
...
@@ -691,9 +689,10 @@ class Model:
self
.
_clear_metrics
()
if
dataset_sink_mode
:
return
self
.
_eval_dataset_sink_process
(
valid_dataset
,
list_callback
,
cb_params
)
return
self
.
_eval_process
(
valid_dataset
,
list_callback
,
cb_params
)
with
_CallbackManager
(
callbacks
)
as
list_callback
:
if
dataset_sink_mode
:
return
self
.
_eval_dataset_sink_process
(
valid_dataset
,
list_callback
,
cb_params
)
return
self
.
_eval_process
(
valid_dataset
,
list_callback
,
cb_params
)
def
predict
(
self
,
*
predict_data
):
"""
...
...
mindspore/train/callback/callback.py
浏览文件 @
08a496d0
...
...
@@ -18,6 +18,7 @@ import os
import
stat
import
shutil
import
time
from
contextlib
import
ExitStack
import
numpy
as
np
import
mindspore.context
as
context
...
...
@@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list):
return
ret
def
_build_callbacks
(
callbacks
):
"""
Contain a list of callback.
Args:
callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
Returns:
List, a list of callback functions.
"""
if
callbacks
:
if
isinstance
(
callbacks
,
tuple
):
raise
TypeError
(
"Callbacks cannot be a tuple. Please check it."
)
if
not
isinstance
(
callbacks
,
list
):
callbacks
=
[
callbacks
]
else
:
callbacks
=
[]
excute_callbacks
=
[]
for
cb
in
callbacks
:
if
cb
is
None
or
not
isinstance
(
cb
,
Callback
):
raise
TypeError
(
"Callback must inheriting base class Callback. Some callback is Wrong. Please check it."
)
excute_callbacks
.
append
(
cb
)
return
_ListCallback
(
excute_callbacks
)
class
_ListCallback
:
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (list): Callback functions list.
"""
def
__init__
(
self
,
callbacks
):
super
(
_ListCallback
,
self
).
__init__
()
self
.
_callbacks
=
callbacks
def
begin
(
self
,
run_context
):
"""Called once before network training."""
for
cb
in
self
.
_callbacks
:
cb
.
begin
(
run_context
)
def
epoch_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_begin
(
run_context
)
def
epoch_end
(
self
,
run_context
):
"""Called after each epoch finished."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_end
(
run_context
)
def
step_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
step_begin
(
run_context
)
def
step_end
(
self
,
run_context
):
"""Called after each step finished."""
for
cb
in
self
.
_callbacks
:
cb
.
step_end
(
run_context
)
def
end
(
self
,
run_context
):
"""Called once after network training."""
for
cb
in
self
.
_callbacks
:
cb
.
end
(
run_context
)
class
Callback
:
"""
Abstract base class used to build a callback function.
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
Callback function will execution some operating to the current step or epoch.
...
...
@@ -369,8 +301,13 @@ class Callback:
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)
"""
def
__init__
(
self
):
pass
def
__enter__
(
self
):
"""Return the enter target."""
return
self
def
__exit__
(
self
,
*
err
):
"""Release resources here if have any."""
def
begin
(
self
,
run_context
):
"""
...
...
@@ -421,6 +358,67 @@ class Callback:
"""
class
_CallbackManager
(
Callback
):
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
"""
def
__init__
(
self
,
callbacks
):
self
.
_callbacks
,
self
.
_stack
=
[],
None
if
isinstance
(
callbacks
,
Callback
):
self
.
_callbacks
.
append
(
callbacks
)
elif
callbacks
is
not
None
:
for
cb
in
callbacks
:
if
not
isinstance
(
cb
,
Callback
):
raise
TypeError
(
"%r is not an instance of %r"
%
(
cb
,
Callback
))
self
.
_callbacks
.
append
(
cb
)
def
__enter__
(
self
):
if
self
.
_stack
is
None
:
self
.
_stack
=
ExitStack
().
__enter__
()
self
.
_callbacks
=
[
self
.
_stack
.
enter_context
(
cb
)
for
cb
in
self
.
_callbacks
]
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_stack
.
__exit__
(
*
err
)
def
begin
(
self
,
run_context
):
"""Called once before network training."""
for
cb
in
self
.
_callbacks
:
cb
.
begin
(
run_context
)
def
epoch_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_begin
(
run_context
)
def
epoch_end
(
self
,
run_context
):
"""Called after each epoch finished."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_end
(
run_context
)
def
step_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
step_begin
(
run_context
)
def
step_end
(
self
,
run_context
):
"""Called after each step finished."""
for
cb
in
self
.
_callbacks
:
cb
.
step_end
(
run_context
)
def
end
(
self
,
run_context
):
"""Called once after network training."""
for
cb
in
self
.
_callbacks
:
cb
.
end
(
run_context
)
class
SummaryStep
(
Callback
):
"""
The summary callback class.
...
...
@@ -435,6 +433,13 @@ class SummaryStep(Callback):
raise
ValueError
(
"`flush_step` should be int and greater than 0"
)
self
.
_summary
=
summary
self
.
_flush_step
=
flush_step
def
__enter__
(
self
):
self
.
_summary
.
__enter__
()
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_summary
.
__exit__
(
*
err
)
def
step_end
(
self
,
run_context
):
"""
...
...
mindspore/train/model.py
浏览文件 @
08a496d0
...
...
@@ -19,7 +19,7 @@ from mindspore import log as logger
from
..common.tensor
import
Tensor
from
..nn.metrics
import
get_metrics
from
.._checkparam
import
check_input_data
,
check_output_data
,
check_int_positive
,
check_bool
from
.callback.callback
import
_InternalCallbackParam
,
RunContext
,
_
build_callbacks
from
.callback.callback
import
_InternalCallbackParam
,
RunContext
,
_
CallbackManager
from
..
import
context
from
..parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
...
...
@@ -334,8 +334,6 @@ class Model:
if
self
.
_parameter_broadcast
:
self
.
_train_network
.
set_broadcast_flag
()
# build callback list
list_callback
=
_build_callbacks
(
callbacks
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
epoch_num
=
epoch
...
...
@@ -346,17 +344,18 @@ class Model:
cb_params
.
parallel_mode
=
self
.
_parallel_mode
cb_params
.
device_number
=
self
.
_device_number
cb_params
.
train_dataset
=
train_dataset
cb_params
.
list_callback
=
list_callback
cb_params
.
list_callback
=
callbacks
if
dataset_sink_mode
:
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
# build callback list
with
_CallbackManager
(
callbacks
)
as
list_callback
:
if
not
dataset_sink_mode
:
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
elif
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
logger
.
warning
(
"The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink."
)
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
"""
...
...
@@ -369,7 +368,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (
_List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper
,
train_network
=
self
.
_exec_preprocess
(
self
.
_train_network
,
...
...
@@ -417,7 +416,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (
_List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper
,
_
=
self
.
_exec_preprocess
(
self
.
_train_network
,
...
...
@@ -524,7 +523,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (
List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
...
...
@@ -563,7 +562,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (
List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
...
...
@@ -622,7 +621,6 @@ class Model:
if
not
self
.
_metric_fns
:
raise
ValueError
(
"metric fn can not be None or empty."
)
list_callback
=
_build_callbacks
(
callbacks
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
eval_network
=
self
.
_eval_network
cb_params
.
valid_dataset
=
valid_dataset
...
...
@@ -635,9 +633,10 @@ class Model:
self
.
_clear_metrics
()
if
dataset_sink_mode
:
return
self
.
_eval_dataset_sink_process
(
valid_dataset
,
list_callback
,
cb_params
)
return
self
.
_eval_process
(
valid_dataset
,
list_callback
,
cb_params
)
with
_CallbackManager
(
callbacks
)
as
list_callback
:
if
dataset_sink_mode
:
return
self
.
_eval_dataset_sink_process
(
valid_dataset
,
list_callback
,
cb_params
)
return
self
.
_eval_process
(
valid_dataset
,
list_callback
,
cb_params
)
def
predict
(
self
,
*
predict_data
):
"""
...
...
tests/st/networks/models/resnet50/src_thor/model_thor.py
浏览文件 @
08a496d0
...
...
@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from
mindspore.parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
from
mindspore.train
import
amp
from
mindspore.train.callback.callback
import
_InternalCallbackParam
,
RunContext
,
_
build_callbacks
from
mindspore.train.callback.callback
import
_InternalCallbackParam
,
RunContext
,
_
CallbackManager
from
mindspore.train.parallel_utils
import
ParallelMode
from
.dataset_helper
import
DatasetHelper
...
...
@@ -392,7 +392,6 @@ class Model:
self
.
_train_network
.
set_broadcast_flag
()
# build callback list
list_callback
=
_build_callbacks
(
callbacks
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
epoch_num
=
epoch
...
...
@@ -403,17 +402,17 @@ class Model:
cb_params
.
parallel_mode
=
self
.
_parallel_mode
cb_params
.
device_number
=
self
.
_device_number
cb_params
.
train_dataset
=
train_dataset
cb_params
.
list_callback
=
list_callback
cb_params
.
list_callback
=
callbacks
if
dataset_sink_mode
:
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
with
_CallbackManager
(
callbacks
)
as
list_callback
:
if
not
dataset_sink_mode
:
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
elif
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
logger
.
warning
(
"The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink."
)
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
"""
...
...
@@ -426,7 +425,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (
_List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
iter_first_order
=
self
.
_frequency
-
1
...
...
@@ -490,7 +489,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (
_List
Callback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper
,
_
=
self
.
_exec_preprocess
(
self
.
_train_network
,
...
...
@@ -695,7 +694,6 @@ class Model:
if
not
self
.
_metric_fns
:
raise
ValueError
(
"metric fn can not be None or empty."
)
list_callback
=
_build_callbacks
(
callbacks
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
eval_network
=
self
.
_eval_network
cb_params
.
valid_dataset
=
valid_dataset
...
...
@@ -708,9 +706,10 @@ class Model:
self
.
_clear_metrics
()
if
dataset_sink_mode
:
return
self
.
_eval_dataset_sink_process
(
valid_dataset
,
list_callback
,
cb_params
)
return
self
.
_eval_process
(
valid_dataset
,
list_callback
,
cb_params
)
with
_CallbackManager
(
callbacks
)
as
list_callback
:
if
dataset_sink_mode
:
return
self
.
_eval_dataset_sink_process
(
valid_dataset
,
list_callback
,
cb_params
)
return
self
.
_eval_process
(
valid_dataset
,
list_callback
,
cb_params
)
def
predict
(
self
,
*
predict_data
):
"""
...
...
tests/ut/python/train/summary/test_image_summary.py
浏览文件 @
08a496d0
...
...
@@ -156,12 +156,19 @@ def get_dataset():
class
ImageSummaryCallback
:
def
__init__
(
self
,
summaryRecord
):
self
.
_summaryRecord
=
summaryRecord
def
__init__
(
self
,
summary_record
):
self
.
_summary_record
=
summary_record
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
err
):
pass
def
record
(
self
,
step
,
train_network
=
None
):
self
.
_summary
R
ecord
.
record
(
step
,
train_network
)
self
.
_summary
R
ecord
.
flush
()
self
.
_summary
_r
ecord
.
record
(
step
,
train_network
)
self
.
_summary
_r
ecord
.
flush
()
def
test_image_summary_train
():
...
...
tests/ut/python/train/test_training.py
浏览文件 @
08a496d0
...
...
@@ -180,6 +180,12 @@ class CallbackTest:
def
__init__
(
self
):
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
err
):
pass
def
record
(
self
,
step
,
*
args
):
print
(
step
,
args
)
...
...
tests/ut/python/utils/test_callback.py
浏览文件 @
08a496d0
...
...
@@ -15,6 +15,7 @@
"""test callback function."""
import
os
import
stat
from
unittest
import
mock
import
numpy
as
np
import
pytest
...
...
@@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback.callback
import
ModelCheckpoint
,
_check_file_name_prefix
,
RunContext
,
\
_checkpoint_cb_for_save_op
,
LossMonitor
,
_InternalCallbackParam
,
_chg_ckpt_file_name_if_same_exist
,
\
_
build_callbacks
,
CheckpointConfig
,
_set_cur_net
_
CallbackManager
,
Callback
,
CheckpointConfig
,
_set_cur_net
class
Net
(
nn
.
Cell
):
...
...
@@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode():
run_context
=
RunContext
(
cb_params
)
loss_cb
=
LossMonitor
(
1
)
callbacks
=
[
loss_cb
]
callbacklist
=
_build_callbacks
(
callbacks
)
callbacklist
.
begin
(
run_context
)
callbacklist
.
epoch_begin
(
run_context
)
callbacklist
.
step_begin
(
run_context
)
callbacklist
.
step_end
(
run_context
)
callbacklist
.
epoch_end
(
run_context
)
callbacklist
.
end
(
run_context
)
with
_CallbackManager
(
callbacks
)
as
callbacklist
:
callbacklist
.
begin
(
run_context
)
callbacklist
.
epoch_begin
(
run_context
)
callbacklist
.
step_begin
(
run_context
)
callbacklist
.
step_end
(
run_context
)
callbacklist
.
epoch_end
(
run_context
)
callbacklist
.
end
(
run_context
)
def
test_loss_monitor_normal_mode
():
...
...
@@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds():
ckpt_cb2
.
step_end
(
run_context
)
def
test_
build_callbacks
():
"""Test
_build_callbacks
."""
def
test_
CallbackManager
():
"""Test
CallbackManager
."""
ck_obj
=
ModelCheckpoint
()
loss_cb_1
=
LossMonitor
(
1
)
callbacks
=
[
None
]
with
pytest
.
raises
(
TypeError
):
callbacks
=
_build_callbacks
(
callbacks
)
_CallbackManager
(
callbacks
)
callbacks
=
[
'Error'
]
with
pytest
.
raises
(
TypeError
):
callbacks
=
_build_callbacks
(
callbacks
)
_CallbackManager
(
callbacks
)
callbacks
=
[
ck_obj
,
loss_cb_1
,
'Error'
,
None
]
with
pytest
.
raises
(
TypeError
):
_
=
_build_callbacks
(
callbacks
)
_CallbackManager
(
callbacks
)
def
test_CallbackManager_exit_called
():
with
mock
.
patch
.
object
(
Callback
,
'__exit__'
,
return_value
=
None
)
as
mock_exit
:
cb1
,
cb2
=
Callback
(),
Callback
()
with
_CallbackManager
([
cb1
,
cb2
]):
pass
for
call_args
in
mock_exit
.
call_args_list
:
assert
call_args
==
mock
.
call
(
mock
.
ANY
,
None
,
None
,
None
)
assert
mock_exit
.
call_count
==
2
def
test_CallbackManager_exit_called_when_raises
():
with
mock
.
patch
.
object
(
Callback
,
'__exit__'
,
return_value
=
None
)
as
mock_exit
:
cb1
,
cb2
=
Callback
(),
Callback
()
with
pytest
.
raises
(
ValueError
):
with
_CallbackManager
([
cb1
,
cb2
]):
raise
ValueError
()
for
call_args
in
mock_exit
.
call_args_list
:
assert
call_args
==
mock
.
call
(
*
[
mock
.
ANY
]
*
4
)
assert
mock_exit
.
call_count
==
2
def
test_CallbackManager_begin_called
():
context
=
dict
()
with
mock
.
patch
.
object
(
Callback
,
'begin'
,
return_value
=
None
)
as
mock_begin
:
cb1
,
cb2
=
Callback
(),
Callback
()
with
_CallbackManager
([
cb1
,
cb2
])
as
cm
:
cm
.
begin
(
context
)
for
call_args
in
mock_begin
.
call_args_list
:
assert
call_args
==
mock
.
call
(
context
)
assert
mock_begin
.
call_count
==
2
def
test_RunContext
():
"""Test RunContext."""
context_err
=
666
with
pytest
.
raises
(
TypeError
):
_
=
RunContext
(
context_err
)
RunContext
(
context_err
)
cb_params
=
_InternalCallbackParam
()
cb_params
.
member1
=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录