Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8e5d8a1c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e5d8a1c
编写于
6月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2305 Enhance callback module and strongly check if callbacks is list or not
Merge pull request !2305 from 李鸿章/callback
上级
10195c0d
4c0d12fd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
19 addition
and
9 deletion
+19
-9
mindspore/train/callback/__init__.py
mindspore/train/callback/__init__.py
+2
-0
mindspore/train/callback/_callback.py
mindspore/train/callback/_callback.py
+13
-4
tests/ut/python/utils/test_callback.py
tests/ut/python/utils/test_callback.py
+4
-5
未找到文件。
mindspore/train/callback/__init__.py
浏览文件 @
8e5d8a1c
...
...
@@ -18,6 +18,8 @@ from ._callback import Callback
from
._callback
import
CallbackManager
as
_CallbackManager
from
._callback
import
InternalCallbackParam
as
_InternalCallbackParam
from
._callback
import
RunContext
from
._callback
import
checkpoint_cb_for_save_op
as
_checkpoint_cb_for_save_op
from
._callback
import
set_cur_net
as
_set_cur_net
from
._checkpoint
import
CheckpointConfig
from
._checkpoint
import
CheckpointManager
as
_CheckpointManager
from
._checkpoint
import
ModelCheckpoint
...
...
mindspore/train/callback/_callback.py
浏览文件 @
8e5d8a1c
...
...
@@ -160,16 +160,25 @@ class CallbackManager(Callback):
self
.
_callbacks
,
self
.
_stack
=
[],
None
if
isinstance
(
callbacks
,
Callback
):
self
.
_callbacks
.
append
(
callbacks
)
elif
callbacks
is
not
None
:
elif
isinstance
(
callbacks
,
list
)
:
for
cb
in
callbacks
:
if
not
isinstance
(
cb
,
Callback
):
raise
TypeError
(
"
%r is not an instance of %r"
%
(
cb
,
Callback
)
)
raise
TypeError
(
"
The 'callbacks' contains not-a-Callback item."
)
self
.
_callbacks
.
append
(
cb
)
elif
callbacks
is
not
None
:
raise
TypeError
(
"The 'callbacks' is not a Callback or a list of Callback."
)
def
__enter__
(
self
):
if
self
.
_stack
is
None
:
self
.
_stack
=
ExitStack
().
__enter__
()
self
.
_callbacks
=
[
self
.
_stack
.
enter_context
(
cb
)
for
cb
in
self
.
_callbacks
]
callbacks
,
self
.
_stack
=
[],
ExitStack
().
__enter__
()
for
callback
in
self
.
_callbacks
:
target
=
self
.
_stack
.
enter_context
(
callback
)
if
not
isinstance
(
target
,
Callback
):
logger
.
warning
(
"Please return 'self' or a Callback as the enter target."
)
callbacks
.
append
(
callback
)
else
:
callbacks
.
append
(
target
)
self
.
_callbacks
=
callbacks
return
self
def
__exit__
(
self
,
*
err
):
...
...
tests/ut/python/utils/test_callback.py
浏览文件 @
8e5d8a1c
...
...
@@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback
import
ModelCheckpoint
,
RunContext
,
LossMonitor
,
_InternalCallbackParam
,
\
_CallbackManager
,
Callback
,
CheckpointConfig
from
mindspore.train.callback._callback
import
set_cur_net
,
checkpoint_cb_for_save_op
_CallbackManager
,
Callback
,
CheckpointConfig
,
_set_cur_net
,
_checkpoint_cb_for_save_op
from
mindspore.train.callback._checkpoint
import
_check_file_name_prefix
,
_chg_ckpt_file_name_if_same_exist
class
Net
(
nn
.
Cell
):
...
...
@@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op():
one_param
[
'name'
]
=
"conv1.weight"
one_param
[
'data'
]
=
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
224
,
224
]),
dtype
=
mstype
.
float32
)
parameter_list
.
append
(
one_param
)
checkpoint_cb_for_save_op
(
parameter_list
)
_
checkpoint_cb_for_save_op
(
parameter_list
)
def
test_checkpoint_cb_for_save_op_update_net
():
...
...
@@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net():
one_param
[
'data'
]
=
Tensor
(
np
.
ones
(
shape
=
(
64
,
3
,
3
,
3
)),
dtype
=
mstype
.
float32
)
parameter_list
.
append
(
one_param
)
net
=
Net
()
set_cur_net
(
net
)
checkpoint_cb_for_save_op
(
parameter_list
)
_
set_cur_net
(
net
)
_
checkpoint_cb_for_save_op
(
parameter_list
)
assert
net
.
conv
.
weight
.
default_input
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录