Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0c3d96a9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c3d96a9
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2236 Refactor the callback module in an encapsulated way
Merge pull request !2236 from 李鸿章/callback
上级
b106c220
ecc45915
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
504 addition
and
418 deletion
+504
-418
example/resnet50_imagenet2012_THOR/model/model_thor.py
example/resnet50_imagenet2012_THOR/model/model_thor.py
+1
-1
mindspore/ccsrc/utils/callbacks.cc
mindspore/ccsrc/utils/callbacks.cc
+3
-3
mindspore/ccsrc/utils/callbacks_ge.cc
mindspore/ccsrc/utils/callbacks_ge.cc
+3
-3
mindspore/train/callback/__init__.py
mindspore/train/callback/__init__.py
+11
-3
mindspore/train/callback/_callback.py
mindspore/train/callback/_callback.py
+260
-0
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+63
-398
mindspore/train/callback/_loss_monitor.py
mindspore/train/callback/_loss_monitor.py
+62
-0
mindspore/train/callback/_summary_step.py
mindspore/train/callback/_summary_step.py
+56
-0
mindspore/train/callback/_time_monitor.py
mindspore/train/callback/_time_monitor.py
+35
-0
mindspore/train/model.py
mindspore/train/model.py
+1
-1
tests/st/networks/models/resnet50/src_thor/model_thor.py
tests/st/networks/models/resnet50/src_thor/model_thor.py
+1
-1
tests/ut/python/utils/test_callback.py
tests/ut/python/utils/test_callback.py
+7
-7
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+1
-1
未找到文件。
example/resnet50_imagenet2012_THOR/model/model_thor.py
浏览文件 @
0c3d96a9
...
@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
...
@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from
mindspore.parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
from
mindspore.parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
from
mindspore.train
import
amp
from
mindspore.train
import
amp
from
mindspore.train.callback
.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
mindspore.train.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
mindspore.train.parallel_utils
import
ParallelMode
from
mindspore.train.parallel_utils
import
ParallelMode
from
model.dataset_helper
import
DatasetHelper
from
model.dataset_helper
import
DatasetHelper
...
...
mindspore/ccsrc/utils/callbacks.cc
浏览文件 @
0c3d96a9
...
@@ -26,9 +26,9 @@
...
@@ -26,9 +26,9 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
callbacks
{
namespace
callbacks
{
const
char
PYTHON_MOD_CALLBACK_MODULE
[]
=
"mindspore.train.callback.callback"
;
const
char
PYTHON_MOD_CALLBACK_MODULE
[]
=
"mindspore.train.callback.
_
callback"
;
const
char
PYTHON_FUN_PROCESS_CHECKPOINT
[]
=
"
_
checkpoint_cb_for_save_op"
;
const
char
PYTHON_FUN_PROCESS_CHECKPOINT
[]
=
"checkpoint_cb_for_save_op"
;
const
char
PYTHON_FUN_PROCESS_SUMMARY
[]
=
"
_
summary_cb_for_save_op"
;
const
char
PYTHON_FUN_PROCESS_SUMMARY
[]
=
"summary_cb_for_save_op"
;
const
char
kSummary
[]
=
"Summary"
;
const
char
kSummary
[]
=
"Summary"
;
const
char
kCheckPoint
[]
=
"Save"
;
const
char
kCheckPoint
[]
=
"Save"
;
const
int
ONE_SHAPE
=
1
;
const
int
ONE_SHAPE
=
1
;
...
...
mindspore/ccsrc/utils/callbacks_ge.cc
浏览文件 @
0c3d96a9
...
@@ -25,9 +25,9 @@
...
@@ -25,9 +25,9 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
callbacks
{
namespace
callbacks
{
const
char
PYTHON_MOD_CALLBACK_MODULE
[]
=
"mindspore.train.callback.callback"
;
const
char
PYTHON_MOD_CALLBACK_MODULE
[]
=
"mindspore.train.callback.
_
callback"
;
const
char
PYTHON_FUN_PROCESS_CHECKPOINT
[]
=
"
_
checkpoint_cb_for_save_op"
;
const
char
PYTHON_FUN_PROCESS_CHECKPOINT
[]
=
"checkpoint_cb_for_save_op"
;
const
char
PYTHON_FUN_PROCESS_SUMMARY
[]
=
"
_
summary_cb_for_save_op"
;
const
char
PYTHON_FUN_PROCESS_SUMMARY
[]
=
"summary_cb_for_save_op"
;
const
char
kSummary
[]
=
"Summary"
;
const
char
kSummary
[]
=
"Summary"
;
const
char
kCheckPoint
[]
=
"Save"
;
const
char
kCheckPoint
[]
=
"Save"
;
const
int
ONE_SHAPE
=
1
;
const
int
ONE_SHAPE
=
1
;
...
...
mindspore/train/callback/__init__.py
浏览文件 @
0c3d96a9
...
@@ -14,7 +14,15 @@
...
@@ -14,7 +14,15 @@
# ============================================================================
# ============================================================================
"""Callback related classes and functions."""
"""Callback related classes and functions."""
from
.callback
import
Callback
,
LossMonitor
,
TimeMonitor
,
ModelCheckpoint
,
SummaryStep
,
CheckpointConfig
,
RunContext
from
._callback
import
Callback
from
._callback
import
CallbackManager
as
_CallbackManager
from
._callback
import
InternalCallbackParam
as
_InternalCallbackParam
from
._callback
import
RunContext
from
._checkpoint
import
CheckpointConfig
from
._checkpoint
import
CheckpointManager
as
_CheckpointManager
from
._checkpoint
import
ModelCheckpoint
from
._loss_monitor
import
LossMonitor
from
._summary_step
import
SummaryStep
from
._time_monitor
import
TimeMonitor
__all__
=
[
"Callback"
,
"LossMonitor"
,
"TimeMonitor"
,
"ModelCheckpoint"
,
__all__
=
[
"Callback"
,
"LossMonitor"
,
"TimeMonitor"
,
"ModelCheckpoint"
,
"SummaryStep"
,
"CheckpointConfig"
,
"RunContext"
]
"SummaryStep"
,
"CheckpointConfig"
,
"RunContext"
]
mindspore/train/callback/_callback.py
0 → 100644
浏览文件 @
0c3d96a9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""
from
contextlib
import
ExitStack
from
mindspore
import
log
as
logger
from
mindspore.train.serialization
import
_fill_param_into_net
from
mindspore.train.summary.summary_record
import
_cache_summary_tensor_data
_cur_net
=
None
def
set_cur_net
(
net
):
"""
Set current net for which we are using to save checkpoint.
Args:
net (Cell): train network
"""
global
_cur_net
_cur_net
=
net
def
checkpoint_cb_for_save_op
(
parameter_list
):
"""
The checkpoint callback function for MindSpore.
Will be executed by checkpoint save op.
Args:
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
Returns:
bool, true: means save checkpoint success.
"""
if
_cur_net
is
None
:
logger
.
warning
(
"_cur_net is None. parameters are not updated."
)
return
False
logger
.
info
(
"update parameters in the net."
)
_fill_param_into_net
(
_cur_net
,
parameter_list
)
set_cur_net
(
None
)
return
True
def
summary_cb_for_save_op
(
summary_list
):
"""
The summary callback function for MindSpore.
Will be executed by summary op.
Args:
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
Returns:
bool, true: means save summary success.
"""
ret
=
_cache_summary_tensor_data
(
summary_list
)
return
ret
class
Callback
:
"""
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
Callback function will execution some operating to the current step or epoch.
Examples:
>>> class Print_info(Callback):
>>> def step_end(self, run_context):
>>> cb_params = run_context.original_args()
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)
"""
def
__enter__
(
self
):
"""Return the enter target."""
return
self
def
__exit__
(
self
,
*
err
):
"""Release resources here if have any."""
def
begin
(
self
,
run_context
):
"""
Called once before the network executing.
Args:
run_context (RunContext): Include some information of the model.
"""
def
epoch_begin
(
self
,
run_context
):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def
epoch_end
(
self
,
run_context
):
"""
Called after each epoch finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def
step_begin
(
self
,
run_context
):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def
step_end
(
self
,
run_context
):
"""
Called after each step finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def
end
(
self
,
run_context
):
"""
Called once after network training.
Args:
run_context (RunContext): Include some information of the model.
"""
class
CallbackManager
(
Callback
):
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
"""
def
__init__
(
self
,
callbacks
):
self
.
_callbacks
,
self
.
_stack
=
[],
None
if
isinstance
(
callbacks
,
Callback
):
self
.
_callbacks
.
append
(
callbacks
)
elif
callbacks
is
not
None
:
for
cb
in
callbacks
:
if
not
isinstance
(
cb
,
Callback
):
raise
TypeError
(
"%r is not an instance of %r"
%
(
cb
,
Callback
))
self
.
_callbacks
.
append
(
cb
)
def
__enter__
(
self
):
if
self
.
_stack
is
None
:
self
.
_stack
=
ExitStack
().
__enter__
()
self
.
_callbacks
=
[
self
.
_stack
.
enter_context
(
cb
)
for
cb
in
self
.
_callbacks
]
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_stack
.
__exit__
(
*
err
)
def
begin
(
self
,
run_context
):
"""Called once before network training."""
for
cb
in
self
.
_callbacks
:
cb
.
begin
(
run_context
)
def
epoch_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_begin
(
run_context
)
def
epoch_end
(
self
,
run_context
):
"""Called after each epoch finished."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_end
(
run_context
)
def
step_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
step_begin
(
run_context
)
def
step_end
(
self
,
run_context
):
"""Called after each step finished."""
for
cb
in
self
.
_callbacks
:
cb
.
step_end
(
run_context
)
def
end
(
self
,
run_context
):
"""Called once after network training."""
for
cb
in
self
.
_callbacks
:
cb
.
end
(
run_context
)
class
InternalCallbackParam
(
dict
):
"""Internal callback object's parameters."""
def
__getattr__
(
self
,
key
):
return
self
[
key
]
def
__setattr__
(
self
,
key
,
value
):
self
[
key
]
=
value
class
RunContext
:
"""
Provides information about the model.
Run call being made. Provides information about original request to model function.
callback objects can stop the loop by calling request_stop() of run_context.
Args:
original_args (dict): Holding the related information of model etc.
"""
def
__init__
(
self
,
original_args
):
if
not
isinstance
(
original_args
,
dict
):
raise
TypeError
(
"The arg of RunContext should be dict type."
)
self
.
_original_args
=
original_args
self
.
_stop_requested
=
False
def
original_args
(
self
):
"""
Get the _original_args object.
Returns:
Dict, a object holding the original arguments of model.
"""
return
self
.
_original_args
def
request_stop
(
self
):
"""
Sets stop requested during training.
Callbacks can use this function to request stop of iterations.
model.train() checks whether this is called or not.
"""
self
.
_stop_requested
=
True
def
get_stop_requested
(
self
):
"""
Returns whether a stop is requested or not.
Returns:
bool, if true, model.train() stops iterations.
"""
return
self
.
_stop_requested
mindspore/train/callback/
callback
.py
→
mindspore/train/callback/
_checkpoint
.py
浏览文件 @
0c3d96a9
...
@@ -12,93 +12,25 @@
...
@@ -12,93 +12,25 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""C
allback
related classes and functions."""
"""C
heckpoint
related classes and functions."""
import
os
import
os
import
stat
import
shutil
import
shutil
import
stat
import
time
import
time
from
contextlib
import
ExitStack
import
numpy
as
np
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_fill_param_into_net
,
_save_graph
from
mindspore.train._utils
import
_make_directory
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore._checkparam
import
check_
int_non_negative
,
check_bool
from
mindspore._checkparam
import
check_
bool
,
check_int_non_negative
from
mindspore.
common.tensor
import
Tensor
from
mindspore.
train._utils
import
_make_directory
from
mindspore.train.s
ummary.summary_record
import
_cache_summary_tensor_data
from
mindspore.train.s
erialization
import
_exec_save_checkpoint
,
_save_graph
from
._callback
import
Callback
,
set_cur_net
_cur_dir
=
os
.
getcwd
()
_cur_dir
=
os
.
getcwd
()
_cur_net
=
None
_save_dir
=
_cur_dir
_save_dir
=
_cur_dir
class
_CheckpointManager
:
"""Manage checkpoint files according to train_config of checkpoint."""
def
__init__
(
self
):
self
.
_ckpoint_filelist
=
[]
@
property
def
ckpoint_filelist
(
self
):
"""Get all the related checkpoint files managed here."""
return
self
.
_ckpoint_filelist
@
property
def
ckpoint_num
(
self
):
"""Get the number of the related checkpoint files managed here."""
return
len
(
self
.
_ckpoint_filelist
)
def
update_ckpoint_filelist
(
self
,
directory
,
prefix
):
"""Update the checkpoint file list."""
self
.
_ckpoint_filelist
=
[]
files
=
os
.
listdir
(
directory
)
for
filename
in
files
:
if
os
.
path
.
splitext
(
filename
)[
-
1
]
==
".ckpt"
and
filename
.
startswith
(
prefix
):
mid_name
=
filename
[
len
(
prefix
):
-
5
]
flag
=
True
for
char
in
mid_name
:
if
char
.
isalpha
():
flag
=
False
if
flag
:
self
.
_ckpoint_filelist
.
append
(
directory
+
'/'
+
filename
)
def
remove_ckpoint_file
(
self
,
file_name
):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try
:
os
.
chmod
(
file_name
,
stat
.
S_IWRITE
)
os
.
remove
(
file_name
)
self
.
_ckpoint_filelist
.
remove
(
file_name
)
except
OSError
:
logger
.
warning
(
"OSError, failed to remove the older ckpt file %s."
,
file_name
)
except
ValueError
:
logger
.
warning
(
"ValueError, failed to remove the older ckpt file %s."
,
file_name
)
def
remove_oldest_ckpoint_file
(
self
):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
ckpoint_files
=
sorted
(
self
.
_ckpoint_filelist
,
key
=
os
.
path
.
getmtime
)
self
.
remove_ckpoint_file
(
ckpoint_files
[
0
])
def
keep_one_ckpoint_per_minutes
(
self
,
minutes
,
cur_time
):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
movs
=
[]
oldest_file
=
''
oldest_time
=
cur_time
for
ck_file
in
self
.
_ckpoint_filelist
:
modify_time
=
os
.
path
.
getmtime
(
ck_file
)
if
cur_time
-
modify_time
<
60
*
minutes
:
movs
.
append
(
ck_file
)
if
modify_time
<
oldest_time
:
oldest_time
=
modify_time
oldest_file
=
ck_file
for
mv_file
in
movs
:
if
mv_file
==
oldest_file
:
continue
self
.
remove_ckpoint_file
(
mv_file
)
def
_check_file_name_prefix
(
file_name_prefix
):
def
_check_file_name_prefix
(
file_name_prefix
):
"""
"""
...
@@ -234,282 +166,6 @@ class CheckpointConfig:
...
@@ -234,282 +166,6 @@ class CheckpointConfig:
return
checkpoint_policy
return
checkpoint_policy
def
_set_cur_net
(
net
):
"""
Set current net for which we are using to save checkpoint.
Args:
net (Cell): train network
"""
global
_cur_net
_cur_net
=
net
def
_checkpoint_cb_for_save_op
(
parameter_list
):
"""
The checkpoint callback function for MindSpore.
Will be executed by checkpoint save op.
Args:
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
Returns:
bool, true: means save checkpoint success.
"""
if
_cur_net
is
None
:
logger
.
warning
(
"_cur_net is None. parameters are not updated."
)
return
False
logger
.
info
(
"update parameters in the net."
)
_fill_param_into_net
(
_cur_net
,
parameter_list
)
_set_cur_net
(
None
)
return
True
def
_summary_cb_for_save_op
(
summary_list
):
"""
The summary callback function for MindSpore.
Will be executed by summary op.
Args:
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
Returns:
bool, true: means save summary success.
"""
ret
=
_cache_summary_tensor_data
(
summary_list
)
return
ret
class
Callback
:
"""
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
Callback function will execution some operating to the current step or epoch.
Examples:
>>> class Print_info(Callback):
>>> def step_end(self, run_context):
>>> cb_params = run_context.original_args()
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)
"""
def
__enter__
(
self
):
"""Return the enter target."""
return
self
def
__exit__
(
self
,
*
err
):
"""Release resources here if have any."""
def
begin
(
self
,
run_context
):
"""
Called once before the network executing.
Args:
run_context (RunContext): Include some information of the model.
"""
def
epoch_begin
(
self
,
run_context
):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def
epoch_end
(
self
,
run_context
):
"""
Called after each epoch finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def
step_begin
(
self
,
run_context
):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def
step_end
(
self
,
run_context
):
"""
Called after each step finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def
end
(
self
,
run_context
):
"""
Called once after network training.
Args:
run_context (RunContext): Include some information of the model.
"""
class
_CallbackManager
(
Callback
):
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
"""
def
__init__
(
self
,
callbacks
):
self
.
_callbacks
,
self
.
_stack
=
[],
None
if
isinstance
(
callbacks
,
Callback
):
self
.
_callbacks
.
append
(
callbacks
)
elif
callbacks
is
not
None
:
for
cb
in
callbacks
:
if
not
isinstance
(
cb
,
Callback
):
raise
TypeError
(
"%r is not an instance of %r"
%
(
cb
,
Callback
))
self
.
_callbacks
.
append
(
cb
)
def
__enter__
(
self
):
if
self
.
_stack
is
None
:
self
.
_stack
=
ExitStack
().
__enter__
()
self
.
_callbacks
=
[
self
.
_stack
.
enter_context
(
cb
)
for
cb
in
self
.
_callbacks
]
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_stack
.
__exit__
(
*
err
)
def
begin
(
self
,
run_context
):
"""Called once before network training."""
for
cb
in
self
.
_callbacks
:
cb
.
begin
(
run_context
)
def
epoch_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_begin
(
run_context
)
def
epoch_end
(
self
,
run_context
):
"""Called after each epoch finished."""
for
cb
in
self
.
_callbacks
:
cb
.
epoch_end
(
run_context
)
def
step_begin
(
self
,
run_context
):
"""Called before each epoch begin."""
for
cb
in
self
.
_callbacks
:
cb
.
step_begin
(
run_context
)
def
step_end
(
self
,
run_context
):
"""Called after each step finished."""
for
cb
in
self
.
_callbacks
:
cb
.
step_end
(
run_context
)
def
end
(
self
,
run_context
):
"""Called once after network training."""
for
cb
in
self
.
_callbacks
:
cb
.
end
(
run_context
)
class
SummaryStep
(
Callback
):
"""
The summary callback class.
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def
__init__
(
self
,
summary
,
flush_step
=
10
):
super
(
SummaryStep
,
self
).
__init__
()
if
not
isinstance
(
flush_step
,
int
)
or
isinstance
(
flush_step
,
bool
)
or
flush_step
<=
0
:
raise
ValueError
(
"`flush_step` should be int and greater than 0"
)
self
.
_summary
=
summary
self
.
_flush_step
=
flush_step
def
__enter__
(
self
):
self
.
_summary
.
__enter__
()
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_summary
.
__exit__
(
*
err
)
def
step_end
(
self
,
run_context
):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params
=
run_context
.
original_args
()
if
cb_params
.
cur_step_num
%
self
.
_flush_step
==
0
:
self
.
_summary
.
record
(
cb_params
.
cur_step_num
,
cb_params
.
train_network
)
@
property
def
summary_file_name
(
self
):
return
self
.
_summary
.
full_file_name
class
_InternalCallbackParam
(
dict
):
"""Internal callback object's parameters."""
def
__getattr__
(
self
,
key
):
return
self
[
key
]
def
__setattr__
(
self
,
key
,
value
):
self
[
key
]
=
value
class
RunContext
:
"""
Provides information about the model.
Run call being made. Provides information about original request to model function.
callback objects can stop the loop by calling request_stop() of run_context.
Args:
original_args (dict): Holding the related information of model etc.
"""
def
__init__
(
self
,
original_args
):
if
not
isinstance
(
original_args
,
dict
):
raise
TypeError
(
"The arg of RunContext should be dict type."
)
self
.
_original_args
=
original_args
self
.
_stop_requested
=
False
def
original_args
(
self
):
"""
Get the _original_args object.
Returns:
Dict, a object holding the original arguments of model.
"""
return
self
.
_original_args
def
request_stop
(
self
):
"""
Sets stop requested during training.
Callbacks can use this function to request stop of iterations.
model.train() checks whether this is called or not.
"""
self
.
_stop_requested
=
True
def
get_stop_requested
(
self
):
"""
Returns whether a stop is requested or not.
Returns:
bool, if true, model.train() stops iterations.
"""
return
self
.
_stop_requested
class
ModelCheckpoint
(
Callback
):
class
ModelCheckpoint
(
Callback
):
"""
"""
...
@@ -553,7 +209,7 @@ class ModelCheckpoint(Callback):
...
@@ -553,7 +209,7 @@ class ModelCheckpoint(Callback):
self
.
_config
=
config
self
.
_config
=
config
# get existing checkpoint files
# get existing checkpoint files
self
.
_manager
=
_
CheckpointManager
()
self
.
_manager
=
CheckpointManager
()
self
.
_prefix
=
_chg_ckpt_file_name_if_same_exist
(
self
.
_directory
,
self
.
_prefix
)
self
.
_prefix
=
_chg_ckpt_file_name_if_same_exist
(
self
.
_directory
,
self
.
_prefix
)
self
.
_graph_saved
=
False
self
.
_graph_saved
=
False
...
@@ -633,7 +289,7 @@ class ModelCheckpoint(Callback):
...
@@ -633,7 +289,7 @@ class ModelCheckpoint(Callback):
self
.
_last_triggered_step
=
cb_params
.
cur_step_num
self
.
_last_triggered_step
=
cb_params
.
cur_step_num
if
context
.
get_context
(
"enable_ge"
):
if
context
.
get_context
(
"enable_ge"
):
_
set_cur_net
(
cb_params
.
train_network
)
set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
...
@@ -648,57 +304,66 @@ class ModelCheckpoint(Callback):
...
@@ -648,57 +304,66 @@ class ModelCheckpoint(Callback):
return
self
.
_latest_ckpt_file_name
return
self
.
_latest_ckpt_file_name
class
LossMonitor
(
Callback
):
class
CheckpointManager
:
"""
"""Manage checkpoint files according to train_config of checkpoint."""
Monitor the loss in training.
def
__init__
(
self
):
self
.
_ckpoint_filelist
=
[]
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
Raises:
ValueError: If print_step is not int or less than zero.
"""
def
__init__
(
self
,
per_print_times
=
1
):
super
(
LossMonitor
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"print_step must be int and >= 0."
)
self
.
_per_print_times
=
per_print_times
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
loss
=
cb_params
.
net_outputs
if
isinstance
(
loss
,
(
tuple
,
list
)):
@
property
if
isinstance
(
loss
[
0
],
Tensor
)
and
isinstance
(
loss
[
0
].
asnumpy
(),
np
.
ndarray
):
def
ckpoint_filelist
(
self
):
loss
=
loss
[
0
]
"""Get all the related checkpoint files managed here."""
return
self
.
_ckpoint_filelist
if
isinstance
(
loss
,
Tensor
)
and
isinstance
(
loss
.
asnumpy
(),
np
.
ndarray
):
@
property
loss
=
np
.
mean
(
loss
.
asnumpy
())
def
ckpoint_num
(
self
):
"""Get the number of the related checkpoint files managed here."""
return
len
(
self
.
_ckpoint_filelist
)
cur_step_in_epoch
=
(
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
+
1
def
update_ckpoint_filelist
(
self
,
directory
,
prefix
):
"""Update the checkpoint file list."""
self
.
_ckpoint_filelist
=
[]
files
=
os
.
listdir
(
directory
)
for
filename
in
files
:
if
os
.
path
.
splitext
(
filename
)[
-
1
]
==
".ckpt"
and
filename
.
startswith
(
prefix
):
mid_name
=
filename
[
len
(
prefix
):
-
5
]
flag
=
True
for
char
in
mid_name
:
if
char
.
isalpha
():
flag
=
False
if
flag
:
self
.
_ckpoint_filelist
.
append
(
directory
+
'/'
+
filename
)
if
isinstance
(
loss
,
float
)
and
(
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
)):
def
remove_ckpoint_file
(
self
,
file_name
):
raise
ValueError
(
"epoch: {} step: {}. Invalid loss, terminating training."
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
))
try
:
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
os
.
chmod
(
file_name
,
stat
.
S_IWRITE
)
print
(
"epoch: %s step: %s, loss is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
loss
),
flush
=
True
)
os
.
remove
(
file_name
)
self
.
_ckpoint_filelist
.
remove
(
file_name
)
except
OSError
:
logger
.
warning
(
"OSError, failed to remove the older ckpt file %s."
,
file_name
)
except
ValueError
:
logger
.
warning
(
"ValueError, failed to remove the older ckpt file %s."
,
file_name
)
def
remove_oldest_ckpoint_file
(
self
):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
ckpoint_files
=
sorted
(
self
.
_ckpoint_filelist
,
key
=
os
.
path
.
getmtime
)
self
.
remove_ckpoint_file
(
ckpoint_files
[
0
])
class
TimeMonitor
(
Callback
):
def
keep_one_ckpoint_per_minutes
(
self
,
minutes
,
cur_time
):
"""Time Monitor."""
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
def
__init__
(
self
,
data_size
):
movs
=
[]
super
(
TimeMonitor
,
self
).
__init__
()
oldest_file
=
''
self
.
data_size
=
data_size
oldest_time
=
cur_time
for
ck_file
in
self
.
_ckpoint_filelist
:
modify_time
=
os
.
path
.
getmtime
(
ck_file
)
if
cur_time
-
modify_time
<
60
*
minutes
:
movs
.
append
(
ck_file
)
def
epoch_begin
(
self
,
run_context
):
if
modify_time
<
oldest_time
:
self
.
epoch_time
=
time
.
time
()
oldest_time
=
modify_time
oldest_file
=
ck_file
def
epoch_end
(
self
,
run_context
)
:
for
mv_file
in
movs
:
epoch_mseconds
=
(
time
.
time
()
-
self
.
epoch_time
)
*
1000
if
mv_file
==
oldest_file
:
per_step_mseconds
=
epoch_mseconds
/
self
.
data_siz
e
continu
e
print
(
"epoch time: {0}, per step time: {1}"
.
format
(
epoch_mseconds
,
per_step_mseconds
),
flush
=
Tru
e
)
self
.
remove_ckpoint_file
(
mv_fil
e
)
mindspore/train/callback/_loss_monitor.py
0 → 100644
浏览文件 @
0c3d96a9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LossMonitor Callback class."""
import
numpy
as
np
from
mindspore.common.tensor
import
Tensor
from
._callback
import
Callback
class
LossMonitor
(
Callback
):
"""
Monitor the loss in training.
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
Raises:
ValueError: If print_step is not int or less than zero.
"""
def
__init__
(
self
,
per_print_times
=
1
):
super
(
LossMonitor
,
self
).
__init__
()
if
not
isinstance
(
per_print_times
,
int
)
or
per_print_times
<
0
:
raise
ValueError
(
"print_step must be int and >= 0."
)
self
.
_per_print_times
=
per_print_times
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
loss
=
cb_params
.
net_outputs
if
isinstance
(
loss
,
(
tuple
,
list
)):
if
isinstance
(
loss
[
0
],
Tensor
)
and
isinstance
(
loss
[
0
].
asnumpy
(),
np
.
ndarray
):
loss
=
loss
[
0
]
if
isinstance
(
loss
,
Tensor
)
and
isinstance
(
loss
.
asnumpy
(),
np
.
ndarray
):
loss
=
np
.
mean
(
loss
.
asnumpy
())
cur_step_in_epoch
=
(
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
+
1
if
isinstance
(
loss
,
float
)
and
(
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
)):
raise
ValueError
(
"epoch: {} step: {}. Invalid loss, terminating training."
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
))
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
print
(
"epoch: %s step: %s, loss is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
loss
),
flush
=
True
)
mindspore/train/callback/_summary_step.py
0 → 100644
浏览文件 @
0c3d96a9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SummaryStep Callback class."""
from
._callback
import
Callback
class
SummaryStep
(
Callback
):
"""
The summary callback class.
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def
__init__
(
self
,
summary
,
flush_step
=
10
):
super
(
SummaryStep
,
self
).
__init__
()
if
not
isinstance
(
flush_step
,
int
)
or
isinstance
(
flush_step
,
bool
)
or
flush_step
<=
0
:
raise
ValueError
(
"`flush_step` should be int and greater than 0"
)
self
.
_summary
=
summary
self
.
_flush_step
=
flush_step
def
__enter__
(
self
):
self
.
_summary
.
__enter__
()
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_summary
.
__exit__
(
*
err
)
def
step_end
(
self
,
run_context
):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params
=
run_context
.
original_args
()
if
cb_params
.
cur_step_num
%
self
.
_flush_step
==
0
:
self
.
_summary
.
record
(
cb_params
.
cur_step_num
,
cb_params
.
train_network
)
@
property
def
summary_file_name
(
self
):
return
self
.
_summary
.
full_file_name
mindspore/train/callback/_time_monitor.py
0 → 100644
浏览文件 @
0c3d96a9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TimeMonitor Callback class."""
import
time
from
._callback
import
Callback
class
TimeMonitor
(
Callback
):
"""Time Monitor."""
def
__init__
(
self
,
data_size
):
super
(
TimeMonitor
,
self
).
__init__
()
self
.
data_size
=
data_size
def
epoch_begin
(
self
,
run_context
):
self
.
epoch_time
=
time
.
time
()
def
epoch_end
(
self
,
run_context
):
epoch_mseconds
=
(
time
.
time
()
-
self
.
epoch_time
)
*
1000
per_step_mseconds
=
epoch_mseconds
/
self
.
data_size
print
(
"epoch time: {0}, per step time: {1}"
.
format
(
epoch_mseconds
,
per_step_mseconds
),
flush
=
True
)
mindspore/train/model.py
浏览文件 @
0c3d96a9
...
@@ -19,7 +19,7 @@ from mindspore import log as logger
...
@@ -19,7 +19,7 @@ from mindspore import log as logger
from
..common.tensor
import
Tensor
from
..common.tensor
import
Tensor
from
..nn.metrics
import
get_metrics
from
..nn.metrics
import
get_metrics
from
.._checkparam
import
check_input_data
,
check_output_data
,
check_int_positive
,
check_bool
from
.._checkparam
import
check_input_data
,
check_output_data
,
check_int_positive
,
check_bool
from
.callback
.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
..
import
context
from
..
import
context
from
..parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
from
..parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
...
...
tests/st/networks/models/resnet50/src_thor/model_thor.py
浏览文件 @
0c3d96a9
...
@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
...
@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from
mindspore.parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
from
mindspore.parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
from
mindspore.train
import
amp
from
mindspore.train
import
amp
from
mindspore.train.callback
.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
mindspore.train.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
mindspore.train.parallel_utils
import
ParallelMode
from
mindspore.train.parallel_utils
import
ParallelMode
from
.dataset_helper
import
DatasetHelper
from
.dataset_helper
import
DatasetHelper
...
...
tests/ut/python/utils/test_callback.py
浏览文件 @
0c3d96a9
...
@@ -26,10 +26,10 @@ from mindspore.common.api import ms_function
...
@@ -26,10 +26,10 @@ from mindspore.common.api import ms_function
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback
.callback
import
ModelCheckpoint
,
_check_file_name_prefix
,
RunContext
,
\
from
mindspore.train.callback
import
ModelCheckpoint
,
RunContext
,
LossMonitor
,
_InternalCallbackParam
,
\
_
checkpoint_cb_for_save_op
,
LossMonitor
,
_InternalCallbackParam
,
_chg_ckpt_file_name_if_same_exist
,
\
_
CallbackManager
,
Callback
,
CheckpointConfig
_CallbackManager
,
Callback
,
CheckpointConfig
,
_set_cur_net
from
mindspore.train.callback._callback
import
set_cur_net
,
checkpoint_cb_for_save_op
from
mindspore.train.callback._checkpoint
import
_check_file_name_prefix
,
_chg_ckpt_file_name_if_same_exist
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
"""Net definition."""
"""Net definition."""
...
@@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op():
...
@@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op():
one_param
[
'name'
]
=
"conv1.weight"
one_param
[
'name'
]
=
"conv1.weight"
one_param
[
'data'
]
=
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
224
,
224
]),
dtype
=
mstype
.
float32
)
one_param
[
'data'
]
=
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
224
,
224
]),
dtype
=
mstype
.
float32
)
parameter_list
.
append
(
one_param
)
parameter_list
.
append
(
one_param
)
_
checkpoint_cb_for_save_op
(
parameter_list
)
checkpoint_cb_for_save_op
(
parameter_list
)
def
test_checkpoint_cb_for_save_op_update_net
():
def
test_checkpoint_cb_for_save_op_update_net
():
...
@@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net():
...
@@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net():
one_param
[
'data'
]
=
Tensor
(
np
.
ones
(
shape
=
(
64
,
3
,
3
,
3
)),
dtype
=
mstype
.
float32
)
one_param
[
'data'
]
=
Tensor
(
np
.
ones
(
shape
=
(
64
,
3
,
3
,
3
)),
dtype
=
mstype
.
float32
)
parameter_list
.
append
(
one_param
)
parameter_list
.
append
(
one_param
)
net
=
Net
()
net
=
Net
()
_
set_cur_net
(
net
)
set_cur_net
(
net
)
_
checkpoint_cb_for_save_op
(
parameter_list
)
checkpoint_cb_for_save_op
(
parameter_list
)
assert
net
.
conv
.
weight
.
default_input
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
1
assert
net
.
conv
.
weight
.
default_input
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
1
...
...
tests/ut/python/utils/test_serialize.py
浏览文件 @
0c3d96a9
...
@@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
...
@@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from
mindspore.nn
import
WithLossCell
,
TrainOneStepCell
from
mindspore.nn
import
WithLossCell
,
TrainOneStepCell
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.train.callback
.callback
import
_CheckpointManager
from
mindspore.train.callback
import
_CheckpointManager
from
mindspore.train.serialization
import
save_checkpoint
,
load_checkpoint
,
load_param_into_net
,
\
from
mindspore.train.serialization
import
save_checkpoint
,
load_checkpoint
,
load_param_into_net
,
\
_exec_save_checkpoint
,
export
,
_save_graph
_exec_save_checkpoint
,
export
,
_save_graph
from
..ut_filter
import
non_graph_engine
from
..ut_filter
import
non_graph_engine
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录