Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea4836e1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea4836e1
编写于
5月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!713 Use a resident process to write summary files
Merge pull request !713 from 李鸿章/summary_record
上级
95d4665d
69d3abfd
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
196 addition
and
591 deletion
+196
-591
mindspore/train/summary/_event_writer.py
mindspore/train/summary/_event_writer.py
+61
-75
mindspore/train/summary/_summary_adapter.py
mindspore/train/summary/_summary_adapter.py
+56
-147
mindspore/train/summary/_summary_scheduler.py
mindspore/train/summary/_summary_scheduler.py
+0
-308
mindspore/train/summary/summary_record.py
mindspore/train/summary/summary_record.py
+79
-61
未找到文件。
mindspore/train/summary/_event_writer.py
浏览文件 @
ea4836e1
...
@@ -14,91 +14,77 @@
...
@@ -14,91 +14,77 @@
# ============================================================================
# ============================================================================
"""Writes events to disk in a logdir."""
"""Writes events to disk in a logdir."""
import
os
import
os
import
time
import
stat
import
stat
from
mindspore
import
log
as
logger
from
collections
import
deque
from
multiprocessing
import
Pool
,
Process
,
Queue
,
cpu_count
from
..._c_expression
import
EventWriter_
from
..._c_expression
import
EventWriter_
from
._summary_adapter
import
package_init_event
from
._summary_adapter
import
package_summary_event
def
_pack
(
result
,
step
):
summary_event
=
package_summary_event
(
result
,
step
)
return
summary_event
.
SerializeToString
()
class
_WrapEventWriter
(
EventWriter_
):
class
EventWriter
(
Process
):
"""
"""
Wrap the c++ EventWriter object
.
Creates a `EventWriter` and write event to file
.
Args:
Args:
full_file_name (str): Include directory and file name.
filepath (str): Summary event file path and file name.
flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
"""
"""
def
__init__
(
self
,
full_file_name
):
if
full_file_name
is
not
None
:
EventWriter_
.
__init__
(
self
,
full_file_name
)
def
__init__
(
self
,
filepath
:
str
,
flush_interval
:
int
)
->
None
:
class
EventRecord
:
super
().
__init__
()
with
open
(
filepath
,
'w'
):
os
.
chmod
(
filepath
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
self
.
_writer
=
EventWriter_
(
filepath
)
self
.
_queue
=
Queue
(
cpu_count
()
*
2
)
self
.
start
()
def
run
(
self
):
with
Pool
()
as
pool
:
deq
=
deque
()
while
True
:
while
deq
and
deq
[
0
].
ready
():
self
.
_writer
.
Write
(
deq
.
popleft
().
get
())
if
not
self
.
_queue
.
empty
():
action
,
data
=
self
.
_queue
.
get
()
if
action
==
'WRITE'
:
if
not
isinstance
(
data
,
(
str
,
bytes
)):
deq
.
append
(
pool
.
apply_async
(
_pack
,
data
))
else
:
self
.
_writer
.
Write
(
data
)
elif
action
==
'FLUSH'
:
self
.
_writer
.
Flush
()
elif
action
==
'END'
:
break
for
res
in
deq
:
self
.
_writer
.
Write
(
res
.
get
())
self
.
_writer
.
Shut
()
def
write
(
self
,
data
)
->
None
:
"""
"""
Creates a `EventFileWriter` and writ
e event to file.
Write th
e event to file.
Args:
Args:
full_file_name (str): Summary event file path and file name.
data (Optional[str, Tuple[list, int]]): The data to write.
flush_time (int): The flush seconds to flush the pending events to disk. Default: 120.
"""
"""
def
__init__
(
self
,
full_file_name
:
str
,
flush_time
:
int
=
120
):
self
.
_queue
.
put
((
'WRITE'
,
data
))
self
.
full_file_name
=
full_file_name
# The first event will be flushed immediately.
self
.
flush_time
=
flush_time
self
.
next_flush_time
=
0
# create event write object
self
.
event_writer
=
self
.
_create_event_file
()
self
.
_init_event_file
()
# count the events
self
.
event_count
=
0
def
_create_event_file
(
self
):
"""Create the event write file."""
with
open
(
self
.
full_file_name
,
'w'
):
os
.
chmod
(
self
.
full_file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
# create c++ event write object
event_writer
=
_WrapEventWriter
(
self
.
full_file_name
)
return
event_writer
def
_init_event_file
(
self
):
"""Send the init event to file."""
self
.
event_writer
.
Write
((
package_init_event
()).
SerializeToString
())
self
.
flush
()
return
True
def
write_event_to_file
(
self
,
event_str
):
"""Write the event to file."""
self
.
event_writer
.
Write
(
event_str
)
def
get_data_count
(
self
):
"""Return the event count."""
return
self
.
event_count
def
flush_cycle
(
self
):
"""Flush file by timer."""
self
.
event_count
=
self
.
event_count
+
1
# Flush the event writer every so often.
now
=
int
(
time
.
time
())
if
now
>
self
.
next_flush_time
:
self
.
flush
()
# update the flush time
self
.
next_flush_time
=
now
+
self
.
flush_time
def
count_event
(
self
):
"""Count event."""
logger
.
debug
(
"Write the event count is %r"
,
self
.
event_count
)
self
.
event_count
=
self
.
event_count
+
1
return
self
.
event_count
def
flush
(
self
):
def
flush
(
self
):
"""Flush the event file to disk."""
"""Flush the writer."""
self
.
event_writer
.
Flush
()
self
.
_queue
.
put
((
'FLUSH'
,
None
))
def
close
(
self
)
->
None
:
"""Close the writer."""
self
.
_queue
.
put
((
'END'
,
None
))
self
.
join
()
def
close
(
self
):
def
__del__
(
self
)
->
None
:
"""Flush the event file to disk and close the file."""
self
.
close
()
self
.
flush
()
self
.
event_writer
.
Shut
()
mindspore/train/summary/_summary_adapter.py
浏览文件 @
ea4836e1
...
@@ -13,17 +13,17 @@
...
@@ -13,17 +13,17 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""Generate the summary event which conform to proto format."""
"""Generate the summary event which conform to proto format."""
import
time
import
socket
import
socket
import
math
import
time
from
enum
import
Enum
,
unique
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
..summary_pb2
import
Event
from
..anf_ir_pb2
import
ModelProto
,
DataType
from
..._checkparam
import
_check_str_by_regular
from
..._checkparam
import
_check_str_by_regular
from
..anf_ir_pb2
import
DataType
,
ModelProto
from
..summary_pb2
import
Event
# define the MindSpore image format
# define the MindSpore image format
MS_IMAGE_TENSOR_FORMAT
=
'NCHW'
MS_IMAGE_TENSOR_FORMAT
=
'NCHW'
...
@@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary."
...
@@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary."
# Set the init event of version and mark
# Set the init event of version and mark
EVENT_FILE_INIT_VERSION_MARK
=
"Mindspore.Event:"
EVENT_FILE_INIT_VERSION_MARK
=
"Mindspore.Event:"
EVENT_FILE_INIT_VERSION
=
1
EVENT_FILE_INIT_VERSION
=
1
# cache the summary data dict
# {id: SummaryData}
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_dict
=
{}
def
save_summary_data
(
data_id
,
data
):
"""Save the global summary cache."""
global
g_summary_data_dict
g_summary_data_dict
[
data_id
]
=
data
def
del_summary_data
(
data_id
):
"""Save the global summary cache."""
global
g_summary_data_dict
if
data_id
in
g_summary_data_dict
:
del
g_summary_data_dict
[
data_id
]
else
:
logger
.
warning
(
"Can't del the data because data_id(%r) "
"does not have data in g_summary_data_dict"
,
data_id
)
def
get_summary_data
(
data_id
):
"""Save the global summary cache."""
ret
=
None
global
g_summary_data_dict
if
data_id
in
g_summary_data_dict
:
ret
=
g_summary_data_dict
.
get
(
data_id
)
else
:
logger
.
warning
(
"The data_id(%r) does not have data in g_summary_data_dict"
,
data_id
)
return
ret
@
unique
class
SummaryType
(
Enum
):
"""
Summary type.
Args:
SCALAR (Number): Summary Scalar enum.
TENSOR (Number): Summary TENSOR enum.
IMAGE (Number): Summary image enum.
GRAPH (Number): Summary graph enum.
HISTOGRAM (Number): Summary histogram enum.
INVALID (Number): Unknow type.
"""
SCALAR
=
1
# Scalar summary
TENSOR
=
2
# Tensor summary
IMAGE
=
3
# Image summary
GRAPH
=
4
# graph
HISTOGRAM
=
5
# Histogram Summary
INVALID
=
0xFF
# unknow type
def
get_event_file_name
(
prefix
,
suffix
):
def
get_event_file_name
(
prefix
,
suffix
):
...
@@ -138,7 +89,7 @@ def package_graph_event(data):
...
@@ -138,7 +89,7 @@ def package_graph_event(data):
return
graph_event
return
graph_event
def
package_summary_event
(
data_
id
,
step
):
def
package_summary_event
(
data_
list
,
step
):
"""
"""
Package the summary to event protobuffer.
Package the summary to event protobuffer.
...
@@ -149,50 +100,37 @@ def package_summary_event(data_id, step):
...
@@ -149,50 +100,37 @@ def package_summary_event(data_id, step):
Returns:
Returns:
Summary, the summary event.
Summary, the summary event.
"""
"""
data_list
=
get_summary_data
(
data_id
)
if
data_list
is
None
:
logger
.
error
(
"The step(%r) does not have record data."
,
step
)
del_summary_data
(
data_id
)
# create the event of summary
# create the event of summary
summary_event
=
Event
()
summary_event
=
Event
()
summary
=
summary_event
.
summary
summary
=
summary_event
.
summary
summary_event
.
wall_time
=
time
.
time
()
summary_event
.
step
=
int
(
step
)
for
value
in
data_list
:
for
value
in
data_list
:
tag
=
value
[
"nam
e"
]
summary_type
=
value
[
"_typ
e"
]
data
=
value
[
"data"
]
data
=
value
[
"data"
]
summary_type
=
value
[
"type"
]
tag
=
value
[
"name"
]
logger
.
debug
(
"Now process %r summary, tag = %r"
,
summary_type
,
tag
)
# get the summary type and parse the tag
if
summary_type
is
SummaryType
.
SCALAR
:
logger
.
debug
(
"Now process Scalar summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
summary_value
.
tag
=
tag
# get the summary type and parse the tag
if
summary_type
==
'Scalar'
:
summary_value
.
scalar_value
=
_get_scalar_summary
(
tag
,
data
)
summary_value
.
scalar_value
=
_get_scalar_summary
(
tag
,
data
)
elif
summary_type
is
SummaryType
.
TENSOR
:
elif
summary_type
==
'Tensor'
:
logger
.
debug
(
"Now process Tensor summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
summary_tensor
=
summary_value
.
tensor
summary_tensor
=
summary_value
.
tensor
_get_tensor_summary
(
tag
,
data
,
summary_tensor
)
_get_tensor_summary
(
tag
,
data
,
summary_tensor
)
elif
summary_type
is
SummaryType
.
IMAGE
:
elif
summary_type
==
'Image'
:
logger
.
debug
(
"Now process Image summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
summary_image
=
summary_value
.
image
summary_image
=
summary_value
.
image
_get_image_summary
(
tag
,
data
,
summary_image
,
MS_IMAGE_TENSOR_FORMAT
)
_get_image_summary
(
tag
,
data
,
summary_image
,
MS_IMAGE_TENSOR_FORMAT
)
elif
summary_type
is
SummaryType
.
HISTOGRAM
:
elif
summary_type
==
'Histogram'
:
logger
.
debug
(
"Now process Histogram summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
summary_histogram
=
summary_value
.
histogram
summary_histogram
=
summary_value
.
histogram
_fill_histogram_summary
(
tag
,
data
,
summary_histogram
)
_fill_histogram_summary
(
tag
,
data
,
summary_histogram
)
else
:
else
:
# The data is invalid ,jump the data
# The data is invalid ,jump the data
logger
.
error
(
"Summary type is error, tag = %r"
,
tag
)
logger
.
error
(
"Summary type(%r) is error, tag = %r"
,
summary_type
,
tag
)
continue
summary_event
.
wall_time
=
time
.
time
()
summary_event
.
step
=
int
(
step
)
return
summary_event
return
summary_event
...
@@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value):
...
@@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value):
# So consider the dim = 1, shape = (1,) tensor is scalar
# So consider the dim = 1, shape = (1,) tensor is scalar
scalar_value
=
np_value
[
0
]
scalar_value
=
np_value
[
0
]
if
np_value
.
shape
!=
(
1
,):
if
np_value
.
shape
!=
(
1
,):
logger
.
error
(
"The tensor is not Scalar, tag = %r,
Value = %r"
,
tag
,
np_valu
e
)
logger
.
error
(
"The tensor is not Scalar, tag = %r,
Shape = %r"
,
tag
,
np_value
.
shap
e
)
else
:
else
:
np_list
=
np_value
.
reshape
(
-
1
).
tolist
()
np_list
=
np_value
.
reshape
(
-
1
).
tolist
()
scalar_value
=
np_list
[
0
]
scalar_value
=
np_list
[
0
]
logger
.
error
(
"The value is not Scalar, tag = %r,
Value = %r"
,
tag
,
np_value
)
logger
.
error
(
"The value is not Scalar, tag = %r,
ndim = %r"
,
tag
,
np_value
.
ndim
)
logger
.
debug
(
"The tag(%r) value is: %r"
,
tag
,
scalar_value
)
logger
.
debug
(
"The tag(%r) value is: %r"
,
tag
,
scalar_value
)
return
scalar_value
return
scalar_value
...
@@ -307,8 +245,7 @@ def _calc_histogram_bins(count):
...
@@ -307,8 +245,7 @@ def _calc_histogram_bins(count):
Returns:
Returns:
int, number of histogram bins.
int, number of histogram bins.
"""
"""
number_per_bucket
=
10
max_bins
,
max_per_bin
=
90
,
10
max_bins
=
90
if
not
count
:
if
not
count
:
return
1
return
1
...
@@ -318,78 +255,50 @@ def _calc_histogram_bins(count):
...
@@ -318,78 +255,50 @@ def _calc_histogram_bins(count):
return
3
return
3
if
count
<=
880
:
if
count
<=
880
:
# note that math.ceil(881/10) + 1 equals 90
# note that math.ceil(881/10) + 1 equals 90
return
int
(
math
.
ceil
(
count
/
number_per_bucket
)
+
1
)
return
count
//
max_per_bin
+
1
return
max_bins
return
max_bins
def
_fill_histogram_summary
(
tag
:
str
,
np_value
:
np
.
array
,
summary_histogram
)
->
None
:
def
_fill_histogram_summary
(
tag
:
str
,
np_value
:
np
.
ndarray
,
summary
)
->
None
:
"""
"""
Package the histogram summary.
Package the histogram summary.
Args:
Args:
tag (str): Summary tag describe.
tag (str): Summary tag describe.
np_value (np.array): Summary data.
np_value (np.
nd
array): Summary data.
summary
_histogram
(summary_pb2.Summary.Histogram): Summary histogram data.
summary (summary_pb2.Summary.Histogram): Summary histogram data.
"""
"""
logger
.
debug
(
"Set(%r) the histogram summary value"
,
tag
)
logger
.
debug
(
"Set(%r) the histogram summary value"
,
tag
)
# Default bucket for tensor with no valid data.
# Default bucket for tensor with no valid data.
default_bucket_left
=
-
0.5
ma_value
=
np
.
ma
.
masked_invalid
(
np_value
)
default_bucket_width
=
1.0
total
,
valid
=
np_value
.
size
,
ma_value
.
count
()
invalids
=
[]
if
np_value
.
size
==
0
:
for
isfn
in
np
.
isnan
,
np
.
isposinf
,
np
.
isneginf
:
bucket
=
summary_histogram
.
buckets
.
add
()
if
total
-
valid
>
sum
(
invalids
):
bucket
.
left
=
default_bucket_left
count
=
np
.
count_nonzero
(
isfn
(
np_value
))
bucket
.
width
=
default_bucket_width
invalids
.
append
(
count
)
bucket
.
count
=
0
else
:
invalids
.
append
(
0
)
summary_histogram
.
nan_count
=
0
summary_histogram
.
pos_inf_count
=
0
summary_histogram
.
neg_inf_count
=
0
summary_histogram
.
max
=
0
summary_histogram
.
min
=
0
summary_histogram
.
sum
=
0
summary_histogram
.
count
=
0
return
summary_histogram
.
nan_count
=
np
.
count_nonzero
(
np
.
isnan
(
np_value
))
summary_histogram
.
pos_inf_count
=
np
.
count_nonzero
(
np
.
isposinf
(
np_value
))
summary_histogram
.
neg_inf_count
=
np
.
count_nonzero
(
np
.
isneginf
(
np_value
))
summary_histogram
.
count
=
np_value
.
size
masked_value
=
np
.
ma
.
masked_invalid
(
np_value
)
tensor_max
=
masked_value
.
max
()
tensor_min
=
masked_value
.
min
()
tensor_sum
=
masked_value
.
sum
()
# No valid value in tensor.
if
tensor_max
is
np
.
ma
.
masked
:
bucket
=
summary_histogram
.
buckets
.
add
()
bucket
.
left
=
default_bucket_left
bucket
.
width
=
default_bucket_width
bucket
.
count
=
0
summary_histogram
.
max
=
np
.
nan
summary_histogram
.
min
=
np
.
nan
summary_histogram
.
sum
=
0
return
bin_number
=
_calc_histogram_bins
(
masked_value
.
count
())
counts
,
edges
=
np
.
histogram
(
np_value
,
bins
=
bin_number
,
range
=
(
tensor_min
,
tensor_max
))
for
ind
,
count
in
enumerate
(
counts
):
summary
.
count
=
total
bucket
=
summary_histogram
.
buckets
.
add
()
summary
.
nan_count
,
summary
.
pos_inf_count
,
summary
.
neg_inf_count
=
invalids
bucket
.
left
=
edges
[
ind
]
if
not
valid
:
bucket
.
width
=
edges
[
ind
+
1
]
-
edges
[
ind
]
logger
.
warning
(
'There are no valid values in the ndarray(size=%d, shape=%d)'
,
total
,
np_value
.
shape
)
bucket
.
count
=
count
# summary.{min, max, sum} are 0s by default, no need to explicitly set
else
:
summary
.
min
=
ma_value
.
min
()
summary
.
max
=
ma_value
.
max
()
summary
.
sum
=
ma_value
.
sum
()
bins
=
_calc_histogram_bins
(
valid
)
range_
=
summary
.
min
,
summary
.
max
hists
,
edges
=
np
.
histogram
(
np_value
,
bins
=
bins
,
range
=
range_
)
summary_histogram
.
max
=
tensor_max
for
hist
,
edge1
,
edge2
in
zip
(
hists
,
edges
,
edges
[
1
:]):
summary_histogram
.
min
=
tensor_min
bucket
=
summary
.
buckets
.
add
()
summary_histogram
.
sum
=
tensor_sum
bucket
.
width
=
edge2
-
edge1
bucket
.
count
=
hist
bucket
.
left
=
edge1
def
_get_image_summary
(
tag
:
str
,
np_value
,
summary_image
,
input_format
=
'NCHW'
):
def
_get_image_summary
(
tag
:
str
,
np_value
,
summary_image
,
input_format
=
'NCHW'
):
...
@@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
...
@@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
"""
"""
logger
.
debug
(
"Set(%r) the image summary value"
,
tag
)
logger
.
debug
(
"Set(%r) the image summary value"
,
tag
)
if
np_value
.
ndim
!=
4
:
if
np_value
.
ndim
!=
4
:
logger
.
error
(
"The value is not Image, tag = %r,
Value = %r"
,
tag
,
np_value
)
logger
.
error
(
"The value is not Image, tag = %r,
ndim = %r"
,
tag
,
np_value
.
ndim
)
# convert the tensor format
# convert the tensor format
tensor
=
_convert_image_format
(
np_value
,
input_format
)
tensor
=
_convert_image_format
(
np_value
,
input_format
)
...
@@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'):
...
@@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'):
"""
"""
out_tensor
=
None
out_tensor
=
None
if
np_tensor
.
ndim
!=
len
(
input_format
):
if
np_tensor
.
ndim
!=
len
(
input_format
):
logger
.
error
(
"The tensor
(%r) can't convert the format(%r) because dim not same"
,
logger
.
error
(
"The tensor
with dim(%r) can't convert the format(%r) because dim not same"
,
np_tensor
.
ndim
,
np_tensor
,
input_format
)
input_format
)
return
out_tensor
return
out_tensor
input_format
=
input_format
.
upper
()
input_format
=
input_format
.
upper
()
...
@@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8):
...
@@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8):
# check the tensor format
# check the tensor format
if
tensor
.
ndim
!=
4
or
tensor
.
shape
[
1
]
!=
3
:
if
tensor
.
ndim
!=
4
or
tensor
.
shape
[
1
]
!=
3
:
logger
.
error
(
"The image tensor
(%r) is not 'NCHW' format"
,
tensor
)
logger
.
error
(
"The image tensor
with ndim(%r) and shape(%r) is not 'NCHW' format"
,
tensor
.
ndim
,
tensor
.
shape
)
return
out_canvas
return
out_canvas
# expand the N
# expand the N
...
...
mindspore/train/summary/_summary_scheduler.py
已删除
100644 → 0
浏览文件 @
95d4665d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Schedule the event writer process."""
import
multiprocessing
as
mp
from
enum
import
Enum
,
unique
from
mindspore
import
log
as
logger
from
..._c_expression
import
Tensor
from
._summary_adapter
import
SummaryType
,
package_summary_event
,
save_summary_data
# define the type of summary
FORMAT_SCALAR_STR
=
"Scalar"
FORMAT_TENSOR_STR
=
"Tensor"
FORMAT_IMAGE_STR
=
"Image"
FORMAT_HISTOGRAM_STR
=
"Histogram"
FORMAT_BEGIN_SLICE
=
"[:"
FORMAT_END_SLICE
=
"]"
# cache the summary data dict
# {id: SummaryData}
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_id
=
0
g_summary_data_dict
=
{}
# cache the summary data file
g_summary_writer_id
=
0
g_summary_file
=
{}
@
unique
class
ScheduleMethod
(
Enum
):
"""Schedule method type."""
FORMAL_WORKER
=
0
# use the formal worker that receive small size data by queue
TEMP_WORKER
=
1
# use the Temp worker that receive big size data by the global value(avoid copy)
CACHE_DATA
=
2
# Cache data util have idle worker to process it
@
unique
class
WorkerStatus
(
Enum
):
"""Worker status."""
WORKER_INIT
=
0
# data is exist but not process
WORKER_PROCESSING
=
1
# data is processing
WORKER_PROCESSED
=
2
# data already processed
def
_parse_tag_format
(
tag
:
str
):
"""
Parse the tag.
Args:
tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
Returns:
Tuple, (SummaryType, summary_tag).
"""
summary_type
=
SummaryType
.
INVALID
summary_tag
=
tag
if
tag
is
None
:
logger
.
error
(
"The tag is None"
)
return
summary_type
,
summary_tag
# search the slice
slice_begin
=
FORMAT_BEGIN_SLICE
slice_end
=
FORMAT_END_SLICE
index
=
tag
.
rfind
(
slice_begin
)
if
index
is
-
1
:
logger
.
error
(
"The tag(%s) have not the key slice."
,
tag
)
return
summary_type
,
summary_tag
# slice the tag
summary_tag
=
tag
[:
index
]
# check the slice end
if
tag
[
-
1
:]
!=
slice_end
:
logger
.
error
(
"The tag(%s) end format is error"
,
tag
)
return
summary_type
,
summary_tag
# check the type
type_str
=
tag
[
index
+
2
:
-
1
]
logger
.
debug
(
"The summary_tag is = %r"
,
summary_tag
)
logger
.
debug
(
"The type_str value is = %r"
,
type_str
)
if
type_str
==
FORMAT_SCALAR_STR
:
summary_type
=
SummaryType
.
SCALAR
elif
type_str
==
FORMAT_TENSOR_STR
:
summary_type
=
SummaryType
.
TENSOR
elif
type_str
==
FORMAT_IMAGE_STR
:
summary_type
=
SummaryType
.
IMAGE
elif
type_str
==
FORMAT_HISTOGRAM_STR
:
summary_type
=
SummaryType
.
HISTOGRAM
else
:
logger
.
error
(
"The tag(%s) type is invalid."
,
tag
)
summary_type
=
SummaryType
.
INVALID
return
summary_type
,
summary_tag
class
SummaryDataManager
:
"""Manage the summary global data cache."""
def
__init__
(
self
):
global
g_summary_data_dict
self
.
size
=
len
(
g_summary_data_dict
)
@
classmethod
def
summary_data_save
(
cls
,
data
):
"""Save the global summary cache."""
global
g_summary_data_id
data_id
=
g_summary_data_id
save_summary_data
(
data_id
,
data
)
g_summary_data_id
+=
1
return
data_id
@
classmethod
def
summary_file_set
(
cls
,
event_writer
):
"""Support the many event_writer."""
global
g_summary_file
,
g_summary_writer_id
g_summary_writer_id
+=
1
g_summary_file
[
g_summary_writer_id
]
=
event_writer
return
g_summary_writer_id
@
classmethod
def
summary_file_get
(
cls
,
writer_id
=
1
):
ret
=
None
global
g_summary_file
if
writer_id
in
g_summary_file
:
ret
=
g_summary_file
.
get
(
writer_id
)
return
ret
class
WorkerScheduler
:
"""
Create worker and schedule data to worker.
Args:
writer_id (int): The index of writer.
"""
def
__init__
(
self
,
writer_id
):
# Create the process of write event file
self
.
write_lock
=
mp
.
Lock
()
# Schedule info for all worker
# Format: {worker: (step, WorkerStatus)}
self
.
schedule_table
=
{}
# write id
self
.
writer_id
=
writer_id
self
.
has_graph
=
False
def
dispatch
(
self
,
step
,
data
):
"""
Select schedule strategy and dispatch data.
Args:
step (Number): The number of step index.
data (Object): The data of recode for summary.
Retruns:
bool, run successfully or not.
"""
# save the data to global cache , convert the tensor to numpy
result
,
size
,
data
=
self
.
_data_convert
(
data
)
if
result
is
False
:
logger
.
error
(
"The step(%r) summary data(%r) is invalid."
,
step
,
size
)
return
False
data_id
=
SummaryDataManager
.
summary_data_save
(
data
)
self
.
_start_worker
(
step
,
data_id
)
return
True
def
_start_worker
(
self
,
step
,
data_id
):
"""
Start worker.
Args:
step (Number): The index of recode.
data_id (str): The id of work.
Return:
bool, run successfully or not.
"""
# assign the worker
policy
=
self
.
_make_policy
()
if
policy
==
ScheduleMethod
.
TEMP_WORKER
:
worker
=
SummaryDataProcess
(
step
,
data_id
,
self
.
write_lock
,
self
.
writer_id
)
# update the schedule table
self
.
schedule_table
[
worker
]
=
(
step
,
data_id
,
WorkerStatus
.
WORKER_INIT
)
# start the worker
worker
.
start
()
else
:
logger
.
error
(
"Do not support the other scheduler policy now."
)
# update the scheduler infor
self
.
_update_scheduler
()
return
True
def
_data_convert
(
self
,
data_list
):
"""Convert the data."""
if
data_list
is
None
:
logger
.
warning
(
"The step does not have record data."
)
return
False
,
0
,
None
# convert the summary to numpy
size
=
0
for
v_dict
in
data_list
:
tag
=
v_dict
[
"name"
]
data
=
v_dict
[
"data"
]
# confirm the data is valid
summary_type
,
summary_tag
=
_parse_tag_format
(
tag
)
if
summary_type
==
SummaryType
.
INVALID
:
logger
.
error
(
"The data type is invalid, tag = %r, tensor = %r"
,
tag
,
data
)
return
False
,
0
,
None
if
isinstance
(
data
,
Tensor
):
# get the summary type and parse the tag
v_dict
[
"name"
]
=
summary_tag
v_dict
[
"type"
]
=
summary_type
v_dict
[
"data"
]
=
data
.
asnumpy
()
size
+=
v_dict
[
"data"
].
size
else
:
logger
.
error
(
"The data type is invalid, tag = %r, tensor = %r"
,
tag
,
data
)
return
False
,
0
,
None
return
True
,
size
,
data_list
def
_update_scheduler
(
self
):
"""Check the worker status and update schedule table."""
workers
=
list
(
self
.
schedule_table
.
keys
())
for
worker
in
workers
:
if
not
worker
.
is_alive
():
# update the table
worker
.
join
()
del
self
.
schedule_table
[
worker
]
def
close
(
self
):
"""Confirm all worker is end."""
workers
=
self
.
schedule_table
.
keys
()
for
worker
in
workers
:
if
worker
.
is_alive
():
worker
.
join
()
def
_make_policy
(
self
):
"""Select the schedule strategy by data."""
# now only support the temp worker
return
ScheduleMethod
.
TEMP_WORKER
class
SummaryDataProcess
(
mp
.
Process
):
"""
Process that consume the summarydata.
Args:
step (int): The index of step.
data_id (int): The index of summary data.
write_lock (Lock): The process lock for writer same file.
writer_id (int): The index of writer.
"""
def
__init__
(
self
,
step
,
data_id
,
write_lock
,
writer_id
):
super
(
SummaryDataProcess
,
self
).
__init__
()
self
.
daemon
=
True
self
.
writer_id
=
writer_id
self
.
writer
=
SummaryDataManager
.
summary_file_get
(
self
.
writer_id
)
if
self
.
writer
is
None
:
logger
.
error
(
"The writer_id(%r) does not have writer"
,
writer_id
)
self
.
step
=
step
self
.
data_id
=
data_id
self
.
write_lock
=
write_lock
self
.
name
=
"SummaryDataConsumer_"
+
str
(
self
.
step
)
def
run
(
self
):
"""The consumer is process the step data and exit."""
# convert the data to event
# All exceptions need to be caught and end the queue
try
:
logger
.
debug
(
"process(%r) process a data(%r)"
,
self
.
name
,
self
.
step
)
# package the summary event
summary_event
=
package_summary_event
(
self
.
data_id
,
self
.
step
)
# send the event to file
self
.
_write_summary
(
summary_event
)
except
Exception
as
e
:
logger
.
error
(
"Summary data mq consumer exception occurred, value = %r"
,
e
)
def
_write_summary
(
self
,
summary_event
):
"""
Write the summary to event file.
Note:
The write record format:
1 uint64 : data length.
2 uint32 : mask crc value of data length.
3 bytes : data.
4 uint32 : mask crc value of data.
Args:
summary_event (Event): The summary event of proto.
"""
event_str
=
summary_event
.
SerializeToString
()
self
.
write_lock
.
acquire
()
self
.
writer
.
write_event_to_file
(
event_str
)
self
.
writer
.
flush
()
self
.
write_lock
.
release
()
mindspore/train/summary/summary_record.py
浏览文件 @
ea4836e1
...
@@ -14,17 +14,22 @@
...
@@ -14,17 +14,22 @@
# ============================================================================
# ============================================================================
"""Record the summary event."""
"""Record the summary event."""
import
os
import
os
import
re
import
threading
import
threading
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
._summary_scheduler
import
WorkerScheduler
,
SummaryDataManager
from
._summary_adapter
import
get_event_file_name
,
package_graph_event
from
..._c_expression
import
Tensor
from
._event_writer
import
EventRecord
from
.._utils
import
_make_directory
from
..._checkparam
import
_check_str_by_regular
from
..._checkparam
import
_check_str_by_regular
from
.._utils
import
_make_directory
from
._event_writer
import
EventWriter
from
._summary_adapter
import
get_event_file_name
,
package_graph_event
,
package_init_event
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
_summary_lock
=
threading
.
Lock
()
# cache the summary data
# cache the summary data
_summary_tensor_cache
=
{}
_summary_tensor_cache
=
{}
_summary_lock
=
threading
.
Lock
()
def
_cache_summary_tensor_data
(
summary
):
def
_cache_summary_tensor_data
(
summary
):
...
@@ -34,16 +39,20 @@ def _cache_summary_tensor_data(summary):
...
@@ -34,16 +39,20 @@ def _cache_summary_tensor_data(summary):
Args:
Args:
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
"""
"""
_summary_lock
.
acquire
()
with
_summary_lock
:
if
"SummaryRecord"
in
_summary_tensor_cache
:
for
item
in
summary
:
for
record
in
summary
:
_summary_tensor_cache
[
item
[
'name'
]]
=
item
[
'data'
]
_summary_tensor_cache
[
"SummaryRecord"
].
append
(
record
)
else
:
_summary_tensor_cache
[
"SummaryRecord"
]
=
summary
_summary_lock
.
release
()
return
True
return
True
def
_get_summary_tensor_data
():
global
_summary_tensor_cache
with
_summary_lock
:
data
=
_summary_tensor_cache
_summary_tensor_cache
=
{}
return
data
class
SummaryRecord
:
class
SummaryRecord
:
"""
"""
SummaryRecord is used to record the summary value.
SummaryRecord is used to record the summary value.
...
@@ -71,6 +80,7 @@ class SummaryRecord:
...
@@ -71,6 +80,7 @@ class SummaryRecord:
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
>>> file_prefix="xxx_", file_suffix="_yyy")
>>> file_prefix="xxx_", file_suffix="_yyy")
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
log_dir
,
log_dir
,
queue_max_size
=
0
,
queue_max_size
=
0
,
...
@@ -101,26 +111,18 @@ class SummaryRecord:
...
@@ -101,26 +111,18 @@ class SummaryRecord:
self
.
prefix
=
file_prefix
self
.
prefix
=
file_prefix
self
.
suffix
=
file_suffix
self
.
suffix
=
file_suffix
self
.
network
=
network
self
.
has_graph
=
False
self
.
_closed
=
False
# create the summary writer file
# create the summary writer file
self
.
event_file_name
=
get_event_file_name
(
self
.
prefix
,
self
.
suffix
)
self
.
event_file_name
=
get_event_file_name
(
self
.
prefix
,
self
.
suffix
)
if
self
.
log_path
[
-
1
:]
==
'/'
:
self
.
full_file_name
=
self
.
log_path
+
self
.
event_file_name
else
:
self
.
full_file_name
=
self
.
log_path
+
'/'
+
self
.
event_file_name
try
:
try
:
self
.
full_file_name
=
os
.
path
.
realpath
(
self
.
full
_file_name
)
self
.
full_file_name
=
os
.
path
.
join
(
self
.
log_path
,
self
.
event
_file_name
)
except
Exception
as
ex
:
except
Exception
as
ex
:
raise
RuntimeError
(
ex
)
raise
RuntimeError
(
ex
)
self
.
event_writer
=
EventRecord
(
self
.
full_file_name
,
self
.
flush_time
)
self
.
event_writer
=
EventWriter
(
self
.
full_file_name
,
self
.
flush_time
)
self
.
writer_id
=
SummaryDataManager
.
summary_file_set
(
self
.
event_writer
)
self
.
event_writer
.
write
(
package_init_event
().
SerializeToString
())
self
.
worker_scheduler
=
WorkerScheduler
(
self
.
writer_id
)
self
.
step
=
0
self
.
_closed
=
False
self
.
network
=
network
self
.
has_graph
=
False
def
record
(
self
,
step
,
train_network
=
None
):
def
record
(
self
,
step
,
train_network
=
None
):
"""
"""
...
@@ -145,42 +147,34 @@ class SummaryRecord:
...
@@ -145,42 +147,34 @@ class SummaryRecord:
if
not
isinstance
(
step
,
int
)
or
isinstance
(
step
,
bool
):
if
not
isinstance
(
step
,
int
)
or
isinstance
(
step
,
bool
):
raise
ValueError
(
"`step` should be int"
)
raise
ValueError
(
"`step` should be int"
)
# Set the current summary of train step
# Set the current summary of train step
self
.
step
=
step
if
self
.
network
is
not
None
and
self
.
has_graph
is
False
:
if
self
.
network
is
not
None
and
not
self
.
has_graph
:
graph_proto
=
self
.
network
.
get_func_graph_proto
()
graph_proto
=
self
.
network
.
get_func_graph_proto
()
if
graph_proto
is
None
and
train_network
is
not
None
:
if
graph_proto
is
None
and
train_network
is
not
None
:
graph_proto
=
train_network
.
get_func_graph_proto
()
graph_proto
=
train_network
.
get_func_graph_proto
()
if
graph_proto
is
None
:
if
graph_proto
is
None
:
logger
.
error
(
"Failed to get proto for graph"
)
logger
.
error
(
"Failed to get proto for graph"
)
else
:
else
:
self
.
event_writer
.
write_event_to_file
(
self
.
event_writer
.
write
(
package_graph_event
(
graph_proto
).
SerializeToString
())
package_graph_event
(
graph_proto
).
SerializeToString
())
self
.
event_writer
.
flush
()
self
.
has_graph
=
True
self
.
has_graph
=
True
data
=
_summary_tensor_cache
.
get
(
"SummaryRecord"
)
if
not
_summary_tensor_cache
:
if
data
is
None
:
return
True
return
True
data
=
_
summary_tensor_cache
.
get
(
"SummaryRecord"
)
data
=
_
get_summary_tensor_data
(
)
if
data
is
None
:
if
not
data
:
logger
.
error
(
"The step(%r) does not have record data."
,
s
elf
.
s
tep
)
logger
.
error
(
"The step(%r) does not have record data."
,
step
)
return
False
return
False
if
self
.
queue_max_size
>
0
and
len
(
data
)
>
self
.
queue_max_size
:
if
self
.
queue_max_size
>
0
and
len
(
data
)
>
self
.
queue_max_size
:
logger
.
error
(
"The size of data record is %r, which is greater than queue_max_size %r."
,
len
(
data
),
logger
.
error
(
"The size of data record is %r, which is greater than queue_max_size %r."
,
len
(
data
),
self
.
queue_max_size
)
self
.
queue_max_size
)
# clean the data of cache
del
_summary_tensor_cache
[
"SummaryRecord"
]
# process the data
# process the data
self
.
worker_scheduler
.
dispatch
(
self
.
step
,
data
)
result
=
self
.
_data_convert
(
data
)
if
not
result
:
# count & flush
logger
.
error
(
"The step(%r) summary data is invalid."
,
step
)
self
.
event_writer
.
count_event
()
return
False
self
.
event_writer
.
flush_cycle
()
self
.
event_writer
.
write
((
result
,
step
))
logger
.
debug
(
"Send the summary data to scheduler for saving, step = %d"
,
step
)
logger
.
debug
(
"Send the summary data to scheduler for saving, step = %d"
,
self
.
step
)
return
True
return
True
@
property
@
property
...
@@ -196,7 +190,7 @@ class SummaryRecord:
...
@@ -196,7 +190,7 @@ class SummaryRecord:
Returns:
Returns:
String, the full path of log file.
String, the full path of log file.
"""
"""
return
self
.
event_writer
.
full_file_name
return
self
.
full_file_name
def
flush
(
self
):
def
flush
(
self
):
"""
"""
...
@@ -224,20 +218,44 @@ class SummaryRecord:
...
@@ -224,20 +218,44 @@ class SummaryRecord:
>>> summary_record.close()
>>> summary_record.close()
"""
"""
if
not
self
.
_closed
:
if
not
self
.
_closed
:
self
.
_check_data_before_close
()
self
.
worker_scheduler
.
close
()
# event writer flush and close
# event writer flush and close
self
.
event_writer
.
close
()
self
.
event_writer
.
close
()
self
.
_closed
=
True
self
.
_closed
=
True
def
__del__
(
self
):
def
_data_convert
(
self
,
summary
):
"""Process exit is called."""
"""Convert the data."""
if
hasattr
(
self
,
"worker_scheduler"
):
# convert the summary to numpy
if
self
.
worker_scheduler
:
result
=
[]
self
.
close
()
for
name
,
data
in
summary
.
items
():
# confirm the data is valid
def
_check_data_before_close
(
self
):
summary_tag
,
summary_type
=
SummaryRecord
.
_parse_from
(
name
)
"Check whether there is any data in the cache, and if so, call record"
if
summary_tag
is
None
:
data
=
_summary_tensor_cache
.
get
(
"SummaryRecord"
)
logger
.
error
(
"The data type is invalid, name = %r, tensor = %r"
,
name
,
data
)
if
data
is
not
None
:
return
None
self
.
record
(
self
.
step
)
if
isinstance
(
data
,
Tensor
):
result
.
append
({
'name'
:
summary_tag
,
'data'
:
data
.
asnumpy
(),
'_type'
:
summary_type
})
else
:
logger
.
error
(
"The data type is invalid, name = %r, tensor = %r"
,
name
,
data
)
return
None
return
result
@
staticmethod
def
_parse_from
(
name
:
str
=
None
):
"""
Parse the tag and type from name.
Args:
name (str): Format: TAG[:TYPE].
Returns:
Tuple, (summary_tag, summary_type).
"""
if
name
is
None
:
logger
.
error
(
"The name is None"
)
return
None
,
None
match
=
re
.
match
(
r
'(.+)\[:(.+)\]'
,
name
)
if
match
:
return
match
.
groups
()
logger
.
error
(
"The name(%r) format is invalid, expected 'TAG[:TYPE]'."
,
name
)
return
None
,
None
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录