Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8333aea5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8333aea5
编写于
8月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4061 Fix several minor issues
Merge pull request !4061 from LiHongzhang/fix_summary
上级
daefafbe
de43c11e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
32 addition
and
14 deletion
+32
-14
mindspore/train/callback/_summary_collector.py
mindspore/train/callback/_summary_collector.py
+26
-11
mindspore/train/summary/_summary_writer.py
mindspore/train/summary/_summary_writer.py
+5
-3
mindspore/train/summary/summary_record.py
mindspore/train/summary/summary_record.py
+1
-0
未找到文件。
mindspore/train/callback/_summary_collector.py
浏览文件 @
8333aea5
...
@@ -111,10 +111,10 @@ class SummaryCollector(Callback):
...
@@ -111,10 +111,10 @@ class SummaryCollector(Callback):
Default: None, it means there is no custom data.
Default: None, it means there is no custom data.
collect_tensor_freq (Optional[int]): Same semantic as the `collect_freq`, but controls TensorSummary only.
collect_tensor_freq (Optional[int]): Same semantic as the `collect_freq`, but controls TensorSummary only.
Because TensorSummary data is too large compared to other summary data, this parameter is used to reduce
Because TensorSummary data is too large compared to other summary data, this parameter is used to reduce
its collection. By default, TensorSummary data will be collected at most 2
1
steps, but not more than how
its collection. By default, TensorSummary data will be collected at most 2
0
steps, but not more than how
many steps other summary data will be collected.
many steps other summary data will be collected.
Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`,
Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`,
when the total steps is 600, TensorSummary will be collected 2
1
steps, while other summary data 61 steps,
when the total steps is 600, TensorSummary will be collected 2
0
steps, while other summary data 61 steps,
but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps.
but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps.
Also note that when in parallel mode, the total steps will be splitted evenly, which will
Also note that when in parallel mode, the total steps will be splitted evenly, which will
affect how many steps TensorSummary will be collected.
affect how many steps TensorSummary will be collected.
...
@@ -176,6 +176,7 @@ class SummaryCollector(Callback):
...
@@ -176,6 +176,7 @@ class SummaryCollector(Callback):
self
.
_check_positive
(
'collect_tensor_freq'
,
collect_tensor_freq
,
allow_none
=
True
)
self
.
_check_positive
(
'collect_tensor_freq'
,
collect_tensor_freq
,
allow_none
=
True
)
self
.
_collect_tensor_freq
=
collect_tensor_freq
self
.
_collect_tensor_freq
=
collect_tensor_freq
self
.
_tensor_collect_range
=
None
self
.
_check_positive
(
'max_file_size'
,
max_file_size
,
allow_none
=
True
)
self
.
_check_positive
(
'max_file_size'
,
max_file_size
,
allow_none
=
True
)
self
.
_max_file_size
=
max_file_size
self
.
_max_file_size
=
max_file_size
...
@@ -296,12 +297,6 @@ class SummaryCollector(Callback):
...
@@ -296,12 +297,6 @@ class SummaryCollector(Callback):
self
.
_record
.
set_mode
(
cb_params
.
mode
)
self
.
_record
.
set_mode
(
cb_params
.
mode
)
if
cb_params
.
mode
==
ModeEnum
.
TRAIN
.
value
:
if
self
.
_collect_tensor_freq
is
None
:
default_tensor_summary_limit
=
20
total_step
=
cb_params
.
epoch_num
*
cb_params
.
batch_num
self
.
_collect_tensor_freq
=
max
(
self
.
_collect_freq
,
total_step
//
default_tensor_summary_limit
)
def
step_end
(
self
,
run_context
):
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
cb_params
=
run_context
.
original_args
()
if
cb_params
.
mode
!=
ModeEnum
.
TRAIN
.
value
:
if
cb_params
.
mode
!=
ModeEnum
.
TRAIN
.
value
:
...
@@ -322,17 +317,36 @@ class SummaryCollector(Callback):
...
@@ -322,17 +317,36 @@ class SummaryCollector(Callback):
if
self
.
_first_step
:
if
self
.
_first_step
:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self
.
_dataset_sink_mode
=
cb_params
.
cur_step_num
==
cb_params
.
batch_num
self
.
_dataset_sink_mode
=
cb_params
.
cur_step_num
==
cb_params
.
batch_num
self
.
_tensor_collect_range
=
self
.
_get_tensor_collect_range
(
cb_params
,
self
.
_dataset_sink_mode
)
self
.
_collect_at_step_end
(
cb_params
,
plugin_filter
=
None
)
self
.
_collect_at_step_end
(
cb_params
,
plugin_filter
=
None
)
self
.
_first_step
=
False
self
.
_first_step
=
False
else
:
else
:
current
=
cb_params
.
cur_epoch_num
if
self
.
_dataset_sink_mode
else
cb_params
.
cur_step_num
current
=
cb_params
.
cur_epoch_num
if
self
.
_dataset_sink_mode
else
cb_params
.
cur_step_num
if
current
%
self
.
_collect_freq
==
0
and
current
%
self
.
_collect_tensor_freq
==
0
:
if
current
%
self
.
_collect_freq
==
0
and
current
in
self
.
_tensor_collect_range
:
self
.
_collect_at_step_end
(
cb_params
,
plugin_filter
=
None
)
self
.
_collect_at_step_end
(
cb_params
,
plugin_filter
=
None
)
elif
current
%
self
.
_collect_tensor_freq
==
0
:
elif
current
in
self
.
_tensor_collect_range
:
self
.
_collect_at_step_end
(
cb_params
,
lambda
plugin
:
plugin
==
PluginEnum
.
TENSOR
.
value
)
self
.
_collect_at_step_end
(
cb_params
,
lambda
plugin
:
plugin
==
PluginEnum
.
TENSOR
.
value
)
elif
current
%
self
.
_collect_freq
==
0
:
elif
current
%
self
.
_collect_freq
==
0
:
self
.
_collect_at_step_end
(
cb_params
,
lambda
plugin
:
plugin
!=
PluginEnum
.
TENSOR
.
value
)
self
.
_collect_at_step_end
(
cb_params
,
lambda
plugin
:
plugin
!=
PluginEnum
.
TENSOR
.
value
)
def
_get_tensor_collect_range
(
self
,
cb_params
,
dataset_sink_mode
):
"""Get tensor collect range."""
total_step
=
cb_params
.
epoch_num
if
not
dataset_sink_mode
:
total_step
*=
cb_params
.
batch_num
if
self
.
_collect_tensor_freq
is
not
None
:
# `total_step + 1`: `total_step` would be a value of `cb_params.cur_step_num`.
return
range
(
0
,
total_step
+
1
,
self
.
_collect_tensor_freq
)
summary_to_collect
=
len
(
range
(
0
,
total_step
+
1
,
self
.
_collect_freq
))
default_tensor_summary_limit
=
20
if
summary_to_collect
>
default_tensor_summary_limit
:
tensor_freq
=
total_step
//
(
default_tensor_summary_limit
-
1
)
if
tensor_freq
>
1
:
return
range
(
0
,
total_step
+
1
,
tensor_freq
)[:
default_tensor_summary_limit
]
# `cb_params.cur_step_num` counting from `1`, when `1` is in the range, take `1` more steps.
return
range
(
0
,
total_step
+
1
)[:
default_tensor_summary_limit
+
1
]
return
range
(
0
,
total_step
+
1
,
self
.
_collect_freq
)
def
_collect_at_step_end
(
self
,
cb_params
,
plugin_filter
):
def
_collect_at_step_end
(
self
,
cb_params
,
plugin_filter
):
self
.
_collect_input_data
(
cb_params
)
self
.
_collect_input_data
(
cb_params
)
self
.
_collect_metric
(
cb_params
)
self
.
_collect_metric
(
cb_params
)
...
@@ -577,7 +591,8 @@ class SummaryCollector(Callback):
...
@@ -577,7 +591,8 @@ class SummaryCollector(Callback):
"""
"""
learning_rate
=
optimizer
.
learning_rate
learning_rate
=
optimizer
.
learning_rate
if
not
isinstance
(
learning_rate
,
Parameter
):
if
not
isinstance
(
learning_rate
,
Parameter
):
logger
.
info
(
"The learning rate detected in the optimizer is not a Parameter type, so it is not recorded."
)
logger
.
warning
(
"The learning rate detected in the optimizer "
"is not a Parameter type, so it is not recorded."
)
return
None
return
None
return
learning_rate
.
data
return
learning_rate
.
data
...
...
mindspore/train/summary/_summary_writer.py
浏览文件 @
8333aea5
...
@@ -20,6 +20,8 @@ from shutil import disk_usage
...
@@ -20,6 +20,8 @@ from shutil import disk_usage
from
..._c_expression
import
EventWriter_
from
..._c_expression
import
EventWriter_
from
._summary_adapter
import
package_init_event
from
._summary_adapter
import
package_init_event
FREE_DISK_SPACE_TIMES
=
32
class
BaseWriter
:
class
BaseWriter
:
"""BaseWriter to be subclass."""
"""BaseWriter to be subclass."""
...
@@ -45,13 +47,13 @@ class BaseWriter:
...
@@ -45,13 +47,13 @@ class BaseWriter:
def
write
(
self
,
plugin
,
data
):
def
write
(
self
,
plugin
,
data
):
"""Write data to file."""
"""Write data to file."""
if
self
.
writer
and
disk_usage
(
self
.
_filepath
).
free
<
len
(
data
)
*
32
:
raise
RuntimeError
(
f
"The disk space may be soon exhausted by the '
{
self
.
_filepath
}
'."
)
# 8: data length
# 8: data length
# 4: crc32 of data length
# 4: crc32 of data length
# 4: crc32 of data
# 4: crc32 of data
metadata_length
=
8
+
4
+
4
metadata_length
=
8
+
4
+
4
required_length
=
len
(
data
)
+
metadata_length
required_length
=
len
(
data
)
+
metadata_length
if
self
.
writer
and
disk_usage
(
self
.
_filepath
).
free
<
required_length
*
FREE_DISK_SPACE_TIMES
:
raise
RuntimeError
(
f
"The disk space may be soon exhausted by the '
{
self
.
_filepath
}
'."
)
if
self
.
_max_file_size
is
None
:
if
self
.
_max_file_size
is
None
:
self
.
writer
.
Write
(
data
)
self
.
writer
.
Write
(
data
)
elif
self
.
_max_file_size
>=
required_length
:
elif
self
.
_max_file_size
>=
required_length
:
...
@@ -77,7 +79,7 @@ class SummaryWriter(BaseWriter):
...
@@ -77,7 +79,7 @@ class SummaryWriter(BaseWriter):
def
init_writer
(
self
):
def
init_writer
(
self
):
"""Write some metadata etc."""
"""Write some metadata etc."""
self
.
write
r
.
Write
(
package_init_event
().
SerializeToString
())
self
.
write
(
'summary'
,
package_init_event
().
SerializeToString
())
def
write
(
self
,
plugin
,
data
):
def
write
(
self
,
plugin
,
data
):
"""Write data to file."""
"""Write data to file."""
...
...
mindspore/train/summary/summary_record.py
浏览文件 @
8333aea5
...
@@ -156,6 +156,7 @@ class SummaryRecord:
...
@@ -156,6 +156,7 @@ class SummaryRecord:
max_file_size
,
max_file_size
,
summary
=
self
.
full_file_name
,
summary
=
self
.
full_file_name
,
lineage
=
get_event_file_name
(
'events'
,
'_lineage'
))
lineage
=
get_event_file_name
(
'events'
,
'_lineage'
))
_get_summary_tensor_data
()
atexit
.
register
(
self
.
close
)
atexit
.
register
(
self
.
close
)
def
__enter__
(
self
):
def
__enter__
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录