Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea4836e1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea4836e1
编写于
5月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!713 Use a resident process to write summary files
Merge pull request !713 from 李鸿章/summary_record
上级
95d4665d
69d3abfd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
196 addition
and
591 deletion
+196
-591
mindspore/train/summary/_event_writer.py
mindspore/train/summary/_event_writer.py
+61
-75
mindspore/train/summary/_summary_adapter.py
mindspore/train/summary/_summary_adapter.py
+56
-147
mindspore/train/summary/_summary_scheduler.py
mindspore/train/summary/_summary_scheduler.py
+0
-308
mindspore/train/summary/summary_record.py
mindspore/train/summary/summary_record.py
+79
-61
未找到文件。
mindspore/train/summary/_event_writer.py
浏览文件 @
ea4836e1
...
...
@@ -14,91 +14,77 @@
# ============================================================================
"""Writes events to disk in a logdir."""
import
os
import
time
import
stat
from
mindspore
import
log
as
logger
from
collections
import
deque
from
multiprocessing
import
Pool
,
Process
,
Queue
,
cpu_count
from
..._c_expression
import
EventWriter_
from
._summary_adapter
import
package_
init
_event
from
._summary_adapter
import
package_
summary
_event
class
_WrapEventWriter
(
EventWriter_
):
"""
Wrap the c++ EventWriter object.
def
_pack
(
result
,
step
):
summary_event
=
package_summary_event
(
result
,
step
)
return
summary_event
.
SerializeToString
()
Args:
full_file_name (str): Include directory and file name.
"""
def
__init__
(
self
,
full_file_name
):
if
full_file_name
is
not
None
:
EventWriter_
.
__init__
(
self
,
full_file_name
)
class
EventRecord
:
class
EventWriter
(
Process
):
"""
Creates a `Event
File
Writer` and write event to file.
Creates a `EventWriter` and write event to file.
Args:
f
ull_file_name
(str): Summary event file path and file name.
flush_
time
(int): The flush seconds to flush the pending events to disk. Default: 120.
f
ilepath
(str): Summary event file path and file name.
flush_
interval
(int): The flush seconds to flush the pending events to disk. Default: 120.
"""
def
__init__
(
self
,
full_file_name
:
str
,
flush_time
:
int
=
120
):
self
.
full_file_name
=
full_file_name
# The first event will be flushed immediately.
self
.
flush_time
=
flush_time
self
.
next_flush_time
=
0
# create event write object
self
.
event_writer
=
self
.
_create_event_file
()
self
.
_init_event_file
()
# count the events
self
.
event_count
=
0
def
_create_event_file
(
self
):
"""Create the event write file."""
with
open
(
self
.
full_file_name
,
'w'
):
os
.
chmod
(
self
.
full_file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
# create c++ event write object
event_writer
=
_WrapEventWriter
(
self
.
full_file_name
)
return
event_writer
def
_init_event_file
(
self
):
"""Send the init event to file."""
self
.
event_writer
.
Write
((
package_init_event
()).
SerializeToString
())
self
.
flush
()
return
True
def
write_event_to_file
(
self
,
event_str
):
"""Write the event to file."""
self
.
event_writer
.
Write
(
event_str
)
def
get_data_count
(
self
):
"""Return the event count."""
return
self
.
event_count
def
flush_cycle
(
self
):
"""Flush file by timer."""
self
.
event_count
=
self
.
event_count
+
1
# Flush the event writer every so often.
now
=
int
(
time
.
time
())
if
now
>
self
.
next_flush_time
:
self
.
flush
()
# update the flush time
self
.
next_flush_time
=
now
+
self
.
flush_time
def
count_event
(
self
):
"""Count event."""
logger
.
debug
(
"Write the event count is %r"
,
self
.
event_count
)
self
.
event_count
=
self
.
event_count
+
1
return
self
.
event_count
def
__init__
(
self
,
filepath
:
str
,
flush_interval
:
int
)
->
None
:
super
().
__init__
()
with
open
(
filepath
,
'w'
):
os
.
chmod
(
filepath
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
self
.
_writer
=
EventWriter_
(
filepath
)
self
.
_queue
=
Queue
(
cpu_count
()
*
2
)
self
.
start
()
def
run
(
self
):
with
Pool
()
as
pool
:
deq
=
deque
()
while
True
:
while
deq
and
deq
[
0
].
ready
():
self
.
_writer
.
Write
(
deq
.
popleft
().
get
())
if
not
self
.
_queue
.
empty
():
action
,
data
=
self
.
_queue
.
get
()
if
action
==
'WRITE'
:
if
not
isinstance
(
data
,
(
str
,
bytes
)):
deq
.
append
(
pool
.
apply_async
(
_pack
,
data
))
else
:
self
.
_writer
.
Write
(
data
)
elif
action
==
'FLUSH'
:
self
.
_writer
.
Flush
()
elif
action
==
'END'
:
break
for
res
in
deq
:
self
.
_writer
.
Write
(
res
.
get
())
self
.
_writer
.
Shut
()
def
write
(
self
,
data
)
->
None
:
"""
Write the event to file.
Args:
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self
.
_queue
.
put
((
'WRITE'
,
data
))
def
flush
(
self
):
"""Flush the event file to disk."""
self
.
event_writer
.
Flush
()
"""Flush the writer."""
self
.
_queue
.
put
((
'FLUSH'
,
None
))
def
close
(
self
)
->
None
:
"""Close the writer."""
self
.
_queue
.
put
((
'END'
,
None
))
self
.
join
()
def
close
(
self
):
"""Flush the event file to disk and close the file."""
self
.
flush
()
self
.
event_writer
.
Shut
()
def
__del__
(
self
)
->
None
:
self
.
close
()
mindspore/train/summary/_summary_adapter.py
浏览文件 @
ea4836e1
...
...
@@ -13,17 +13,17 @@
# limitations under the License.
# ============================================================================
"""Generate the summary event which conform to proto format."""
import
time
import
socket
import
math
from
enum
import
Enum
,
unique
import
time
import
numpy
as
np
from
PIL
import
Image
from
mindspore
import
log
as
logger
from
..summary_pb2
import
Event
from
..anf_ir_pb2
import
ModelProto
,
DataType
from
..._checkparam
import
_check_str_by_regular
from
..anf_ir_pb2
import
DataType
,
ModelProto
from
..summary_pb2
import
Event
# define the MindSpore image format
MS_IMAGE_TENSOR_FORMAT
=
'NCHW'
...
...
@@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary."
# Set the init event of version and mark
EVENT_FILE_INIT_VERSION_MARK
=
"Mindspore.Event:"
EVENT_FILE_INIT_VERSION
=
1
# cache the summary data dict
# {id: SummaryData}
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_dict
=
{}
def
save_summary_data
(
data_id
,
data
):
"""Save the global summary cache."""
global
g_summary_data_dict
g_summary_data_dict
[
data_id
]
=
data
def
del_summary_data
(
data_id
):
"""Save the global summary cache."""
global
g_summary_data_dict
if
data_id
in
g_summary_data_dict
:
del
g_summary_data_dict
[
data_id
]
else
:
logger
.
warning
(
"Can't del the data because data_id(%r) "
"does not have data in g_summary_data_dict"
,
data_id
)
def
get_summary_data
(
data_id
):
"""Save the global summary cache."""
ret
=
None
global
g_summary_data_dict
if
data_id
in
g_summary_data_dict
:
ret
=
g_summary_data_dict
.
get
(
data_id
)
else
:
logger
.
warning
(
"The data_id(%r) does not have data in g_summary_data_dict"
,
data_id
)
return
ret
@
unique
class
SummaryType
(
Enum
):
"""
Summary type.
Args:
SCALAR (Number): Summary Scalar enum.
TENSOR (Number): Summary TENSOR enum.
IMAGE (Number): Summary image enum.
GRAPH (Number): Summary graph enum.
HISTOGRAM (Number): Summary histogram enum.
INVALID (Number): Unknow type.
"""
SCALAR
=
1
# Scalar summary
TENSOR
=
2
# Tensor summary
IMAGE
=
3
# Image summary
GRAPH
=
4
# graph
HISTOGRAM
=
5
# Histogram Summary
INVALID
=
0xFF
# unknow type
def
get_event_file_name
(
prefix
,
suffix
):
...
...
@@ -138,7 +89,7 @@ def package_graph_event(data):
return
graph_event
def
package_summary_event
(
data_
id
,
step
):
def
package_summary_event
(
data_
list
,
step
):
"""
Package the summary to event protobuffer.
...
...
@@ -149,50 +100,37 @@ def package_summary_event(data_id, step):
Returns:
Summary, the summary event.
"""
data_list
=
get_summary_data
(
data_id
)
if
data_list
is
None
:
logger
.
error
(
"The step(%r) does not have record data."
,
step
)
del_summary_data
(
data_id
)
# create the event of summary
summary_event
=
Event
()
summary
=
summary_event
.
summary
summary_event
.
wall_time
=
time
.
time
()
summary_event
.
step
=
int
(
step
)
for
value
in
data_list
:
tag
=
value
[
"nam
e"
]
summary_type
=
value
[
"_typ
e"
]
data
=
value
[
"data"
]
summary_type
=
value
[
"typ
e"
]
tag
=
value
[
"nam
e"
]
logger
.
debug
(
"Now process %r summary, tag = %r"
,
summary_type
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
# get the summary type and parse the tag
if
summary_type
is
SummaryType
.
SCALAR
:
logger
.
debug
(
"Now process Scalar summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
if
summary_type
==
'Scalar'
:
summary_value
.
scalar_value
=
_get_scalar_summary
(
tag
,
data
)
elif
summary_type
is
SummaryType
.
TENSOR
:
logger
.
debug
(
"Now process Tensor summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
elif
summary_type
==
'Tensor'
:
summary_tensor
=
summary_value
.
tensor
_get_tensor_summary
(
tag
,
data
,
summary_tensor
)
elif
summary_type
is
SummaryType
.
IMAGE
:
logger
.
debug
(
"Now process Image summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
elif
summary_type
==
'Image'
:
summary_image
=
summary_value
.
image
_get_image_summary
(
tag
,
data
,
summary_image
,
MS_IMAGE_TENSOR_FORMAT
)
elif
summary_type
is
SummaryType
.
HISTOGRAM
:
logger
.
debug
(
"Now process Histogram summary, tag = %r"
,
tag
)
summary_value
=
summary
.
value
.
add
()
summary_value
.
tag
=
tag
elif
summary_type
==
'Histogram'
:
summary_histogram
=
summary_value
.
histogram
_fill_histogram_summary
(
tag
,
data
,
summary_histogram
)
else
:
# The data is invalid ,jump the data
logger
.
error
(
"Summary type is error, tag = %r"
,
tag
)
continue
logger
.
error
(
"Summary type(%r) is error, tag = %r"
,
summary_type
,
tag
)
summary_event
.
wall_time
=
time
.
time
()
summary_event
.
step
=
int
(
step
)
return
summary_event
...
...
@@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value):
# So consider the dim = 1, shape = (1,) tensor is scalar
scalar_value
=
np_value
[
0
]
if
np_value
.
shape
!=
(
1
,):
logger
.
error
(
"The tensor is not Scalar, tag = %r,
Value = %r"
,
tag
,
np_valu
e
)
logger
.
error
(
"The tensor is not Scalar, tag = %r,
Shape = %r"
,
tag
,
np_value
.
shap
e
)
else
:
np_list
=
np_value
.
reshape
(
-
1
).
tolist
()
scalar_value
=
np_list
[
0
]
logger
.
error
(
"The value is not Scalar, tag = %r,
Value = %r"
,
tag
,
np_value
)
logger
.
error
(
"The value is not Scalar, tag = %r,
ndim = %r"
,
tag
,
np_value
.
ndim
)
logger
.
debug
(
"The tag(%r) value is: %r"
,
tag
,
scalar_value
)
return
scalar_value
...
...
@@ -307,8 +245,7 @@ def _calc_histogram_bins(count):
Returns:
int, number of histogram bins.
"""
number_per_bucket
=
10
max_bins
=
90
max_bins
,
max_per_bin
=
90
,
10
if
not
count
:
return
1
...
...
@@ -318,78 +255,50 @@ def _calc_histogram_bins(count):
return
3
if
count
<=
880
:
# note that math.ceil(881/10) + 1 equals 90
return
int
(
math
.
ceil
(
count
/
number_per_bucket
)
+
1
)
return
count
//
max_per_bin
+
1
return
max_bins
def
_fill_histogram_summary
(
tag
:
str
,
np_value
:
np
.
array
,
summary_histogram
)
->
None
:
def
_fill_histogram_summary
(
tag
:
str
,
np_value
:
np
.
ndarray
,
summary
)
->
None
:
"""
Package the histogram summary.
Args:
tag (str): Summary tag describe.
np_value (np.array): Summary data.
summary
_histogram
(summary_pb2.Summary.Histogram): Summary histogram data.
np_value (np.
nd
array): Summary data.
summary (summary_pb2.Summary.Histogram): Summary histogram data.
"""
logger
.
debug
(
"Set(%r) the histogram summary value"
,
tag
)
# Default bucket for tensor with no valid data.
default_bucket_left
=
-
0.5
default_bucket_width
=
1.0
if
np_value
.
size
==
0
:
bucket
=
summary_histogram
.
buckets
.
add
()
bucket
.
left
=
default_bucket_left
bucket
.
width
=
default_bucket_width
bucket
.
count
=
0
summary_histogram
.
nan_count
=
0
summary_histogram
.
pos_inf_count
=
0
summary_histogram
.
neg_inf_count
=
0
summary_histogram
.
max
=
0
summary_histogram
.
min
=
0
summary_histogram
.
sum
=
0
summary_histogram
.
count
=
0
return
summary_histogram
.
nan_count
=
np
.
count_nonzero
(
np
.
isnan
(
np_value
))
summary_histogram
.
pos_inf_count
=
np
.
count_nonzero
(
np
.
isposinf
(
np_value
))
summary_histogram
.
neg_inf_count
=
np
.
count_nonzero
(
np
.
isneginf
(
np_value
))
summary_histogram
.
count
=
np_value
.
size
masked_value
=
np
.
ma
.
masked_invalid
(
np_value
)
tensor_max
=
masked_value
.
max
()
tensor_min
=
masked_value
.
min
()
tensor_sum
=
masked_value
.
sum
()
# No valid value in tensor.
if
tensor_max
is
np
.
ma
.
masked
:
bucket
=
summary_histogram
.
buckets
.
add
()
bucket
.
left
=
default_bucket_left
bucket
.
width
=
default_bucket_width
bucket
.
count
=
0
summary_histogram
.
max
=
np
.
nan
summary_histogram
.
min
=
np
.
nan
summary_histogram
.
sum
=
0
return
bin_number
=
_calc_histogram_bins
(
masked_value
.
count
())
counts
,
edges
=
np
.
histogram
(
np_value
,
bins
=
bin_number
,
range
=
(
tensor_min
,
tensor_max
))
ma_value
=
np
.
ma
.
masked_invalid
(
np_value
)
total
,
valid
=
np_value
.
size
,
ma_value
.
count
()
invalids
=
[]
for
isfn
in
np
.
isnan
,
np
.
isposinf
,
np
.
isneginf
:
if
total
-
valid
>
sum
(
invalids
):
count
=
np
.
count_nonzero
(
isfn
(
np_value
))
invalids
.
append
(
count
)
else
:
invalids
.
append
(
0
)
for
ind
,
count
in
enumerate
(
counts
):
bucket
=
summary_histogram
.
buckets
.
add
()
bucket
.
left
=
edges
[
ind
]
bucket
.
width
=
edges
[
ind
+
1
]
-
edges
[
ind
]
bucket
.
count
=
count
summary
.
count
=
total
summary
.
nan_count
,
summary
.
pos_inf_count
,
summary
.
neg_inf_count
=
invalids
if
not
valid
:
logger
.
warning
(
'There are no valid values in the ndarray(size=%d, shape=%d)'
,
total
,
np_value
.
shape
)
# summary.{min, max, sum} are 0s by default, no need to explicitly set
else
:
summary
.
min
=
ma_value
.
min
()
summary
.
max
=
ma_value
.
max
()
summary
.
sum
=
ma_value
.
sum
()
bins
=
_calc_histogram_bins
(
valid
)
range_
=
summary
.
min
,
summary
.
max
hists
,
edges
=
np
.
histogram
(
np_value
,
bins
=
bins
,
range
=
range_
)
summary_histogram
.
max
=
tensor_max
summary_histogram
.
min
=
tensor_min
summary_histogram
.
sum
=
tensor_sum
for
hist
,
edge1
,
edge2
in
zip
(
hists
,
edges
,
edges
[
1
:]):
bucket
=
summary
.
buckets
.
add
()
bucket
.
width
=
edge2
-
edge1
bucket
.
count
=
hist
bucket
.
left
=
edge1
def
_get_image_summary
(
tag
:
str
,
np_value
,
summary_image
,
input_format
=
'NCHW'
):
...
...
@@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
"""
logger
.
debug
(
"Set(%r) the image summary value"
,
tag
)
if
np_value
.
ndim
!=
4
:
logger
.
error
(
"The value is not Image, tag = %r,
Value = %r"
,
tag
,
np_value
)
logger
.
error
(
"The value is not Image, tag = %r,
ndim = %r"
,
tag
,
np_value
.
ndim
)
# convert the tensor format
tensor
=
_convert_image_format
(
np_value
,
input_format
)
...
...
@@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'):
"""
out_tensor
=
None
if
np_tensor
.
ndim
!=
len
(
input_format
):
logger
.
error
(
"The tensor
(%r) can't convert the format(%r) because dim not same"
,
np_tensor
,
input_format
)
logger
.
error
(
"The tensor
with dim(%r) can't convert the format(%r) because dim not same"
,
np_tensor
.
ndim
,
input_format
)
return
out_tensor
input_format
=
input_format
.
upper
()
...
...
@@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8):
# check the tensor format
if
tensor
.
ndim
!=
4
or
tensor
.
shape
[
1
]
!=
3
:
logger
.
error
(
"The image tensor
(%r) is not 'NCHW' format"
,
tensor
)
logger
.
error
(
"The image tensor
with ndim(%r) and shape(%r) is not 'NCHW' format"
,
tensor
.
ndim
,
tensor
.
shape
)
return
out_canvas
# expand the N
...
...
mindspore/train/summary/_summary_scheduler.py
已删除
100644 → 0
浏览文件 @
95d4665d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Schedule the event writer process."""
import
multiprocessing
as
mp
from
enum
import
Enum
,
unique
from
mindspore
import
log
as
logger
from
..._c_expression
import
Tensor
from
._summary_adapter
import
SummaryType
,
package_summary_event
,
save_summary_data
# define the type of summary
FORMAT_SCALAR_STR
=
"Scalar"
FORMAT_TENSOR_STR
=
"Tensor"
FORMAT_IMAGE_STR
=
"Image"
FORMAT_HISTOGRAM_STR
=
"Histogram"
FORMAT_BEGIN_SLICE
=
"[:"
FORMAT_END_SLICE
=
"]"
# cache the summary data dict
# {id: SummaryData}
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_id
=
0
g_summary_data_dict
=
{}
# cache the summary data file
g_summary_writer_id
=
0
g_summary_file
=
{}
@
unique
class
ScheduleMethod
(
Enum
):
"""Schedule method type."""
FORMAL_WORKER
=
0
# use the formal worker that receive small size data by queue
TEMP_WORKER
=
1
# use the Temp worker that receive big size data by the global value(avoid copy)
CACHE_DATA
=
2
# Cache data util have idle worker to process it
@
unique
class
WorkerStatus
(
Enum
):
"""Worker status."""
WORKER_INIT
=
0
# data is exist but not process
WORKER_PROCESSING
=
1
# data is processing
WORKER_PROCESSED
=
2
# data already processed
def
_parse_tag_format
(
tag
:
str
):
"""
Parse the tag.
Args:
tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
Returns:
Tuple, (SummaryType, summary_tag).
"""
summary_type
=
SummaryType
.
INVALID
summary_tag
=
tag
if
tag
is
None
:
logger
.
error
(
"The tag is None"
)
return
summary_type
,
summary_tag
# search the slice
slice_begin
=
FORMAT_BEGIN_SLICE
slice_end
=
FORMAT_END_SLICE
index
=
tag
.
rfind
(
slice_begin
)
if
index
is
-
1
:
logger
.
error
(
"The tag(%s) have not the key slice."
,
tag
)
return
summary_type
,
summary_tag
# slice the tag
summary_tag
=
tag
[:
index
]
# check the slice end
if
tag
[
-
1
:]
!=
slice_end
:
logger
.
error
(
"The tag(%s) end format is error"
,
tag
)
return
summary_type
,
summary_tag
# check the type
type_str
=
tag
[
index
+
2
:
-
1
]
logger
.
debug
(
"The summary_tag is = %r"
,
summary_tag
)
logger
.
debug
(
"The type_str value is = %r"
,
type_str
)
if
type_str
==
FORMAT_SCALAR_STR
:
summary_type
=
SummaryType
.
SCALAR
elif
type_str
==
FORMAT_TENSOR_STR
:
summary_type
=
SummaryType
.
TENSOR
elif
type_str
==
FORMAT_IMAGE_STR
:
summary_type
=
SummaryType
.
IMAGE
elif
type_str
==
FORMAT_HISTOGRAM_STR
:
summary_type
=
SummaryType
.
HISTOGRAM
else
:
logger
.
error
(
"The tag(%s) type is invalid."
,
tag
)
summary_type
=
SummaryType
.
INVALID
return
summary_type
,
summary_tag
class
SummaryDataManager
:
"""Manage the summary global data cache."""
def
__init__
(
self
):
global
g_summary_data_dict
self
.
size
=
len
(
g_summary_data_dict
)
@
classmethod
def
summary_data_save
(
cls
,
data
):
"""Save the global summary cache."""
global
g_summary_data_id
data_id
=
g_summary_data_id
save_summary_data
(
data_id
,
data
)
g_summary_data_id
+=
1
return
data_id
@
classmethod
def
summary_file_set
(
cls
,
event_writer
):
"""Support the many event_writer."""
global
g_summary_file
,
g_summary_writer_id
g_summary_writer_id
+=
1
g_summary_file
[
g_summary_writer_id
]
=
event_writer
return
g_summary_writer_id
@
classmethod
def
summary_file_get
(
cls
,
writer_id
=
1
):
ret
=
None
global
g_summary_file
if
writer_id
in
g_summary_file
:
ret
=
g_summary_file
.
get
(
writer_id
)
return
ret
class
WorkerScheduler
:
"""
Create worker and schedule data to worker.
Args:
writer_id (int): The index of writer.
"""
def
__init__
(
self
,
writer_id
):
# Create the process of write event file
self
.
write_lock
=
mp
.
Lock
()
# Schedule info for all worker
# Format: {worker: (step, WorkerStatus)}
self
.
schedule_table
=
{}
# write id
self
.
writer_id
=
writer_id
self
.
has_graph
=
False
def
dispatch
(
self
,
step
,
data
):
"""
Select schedule strategy and dispatch data.
Args:
step (Number): The number of step index.
data (Object): The data of recode for summary.
Retruns:
bool, run successfully or not.
"""
# save the data to global cache , convert the tensor to numpy
result
,
size
,
data
=
self
.
_data_convert
(
data
)
if
result
is
False
:
logger
.
error
(
"The step(%r) summary data(%r) is invalid."
,
step
,
size
)
return
False
data_id
=
SummaryDataManager
.
summary_data_save
(
data
)
self
.
_start_worker
(
step
,
data_id
)
return
True
def
_start_worker
(
self
,
step
,
data_id
):
"""
Start worker.
Args:
step (Number): The index of recode.
data_id (str): The id of work.
Return:
bool, run successfully or not.
"""
# assign the worker
policy
=
self
.
_make_policy
()
if
policy
==
ScheduleMethod
.
TEMP_WORKER
:
worker
=
SummaryDataProcess
(
step
,
data_id
,
self
.
write_lock
,
self
.
writer_id
)
# update the schedule table
self
.
schedule_table
[
worker
]
=
(
step
,
data_id
,
WorkerStatus
.
WORKER_INIT
)
# start the worker
worker
.
start
()
else
:
logger
.
error
(
"Do not support the other scheduler policy now."
)
# update the scheduler infor
self
.
_update_scheduler
()
return
True
def
_data_convert
(
self
,
data_list
):
"""Convert the data."""
if
data_list
is
None
:
logger
.
warning
(
"The step does not have record data."
)
return
False
,
0
,
None
# convert the summary to numpy
size
=
0
for
v_dict
in
data_list
:
tag
=
v_dict
[
"name"
]
data
=
v_dict
[
"data"
]
# confirm the data is valid
summary_type
,
summary_tag
=
_parse_tag_format
(
tag
)
if
summary_type
==
SummaryType
.
INVALID
:
logger
.
error
(
"The data type is invalid, tag = %r, tensor = %r"
,
tag
,
data
)
return
False
,
0
,
None
if
isinstance
(
data
,
Tensor
):
# get the summary type and parse the tag
v_dict
[
"name"
]
=
summary_tag
v_dict
[
"type"
]
=
summary_type
v_dict
[
"data"
]
=
data
.
asnumpy
()
size
+=
v_dict
[
"data"
].
size
else
:
logger
.
error
(
"The data type is invalid, tag = %r, tensor = %r"
,
tag
,
data
)
return
False
,
0
,
None
return
True
,
size
,
data_list
def
_update_scheduler
(
self
):
"""Check the worker status and update schedule table."""
workers
=
list
(
self
.
schedule_table
.
keys
())
for
worker
in
workers
:
if
not
worker
.
is_alive
():
# update the table
worker
.
join
()
del
self
.
schedule_table
[
worker
]
def
close
(
self
):
"""Confirm all worker is end."""
workers
=
self
.
schedule_table
.
keys
()
for
worker
in
workers
:
if
worker
.
is_alive
():
worker
.
join
()
def
_make_policy
(
self
):
"""Select the schedule strategy by data."""
# now only support the temp worker
return
ScheduleMethod
.
TEMP_WORKER
class
SummaryDataProcess
(
mp
.
Process
):
"""
Process that consume the summarydata.
Args:
step (int): The index of step.
data_id (int): The index of summary data.
write_lock (Lock): The process lock for writer same file.
writer_id (int): The index of writer.
"""
def
__init__
(
self
,
step
,
data_id
,
write_lock
,
writer_id
):
super
(
SummaryDataProcess
,
self
).
__init__
()
self
.
daemon
=
True
self
.
writer_id
=
writer_id
self
.
writer
=
SummaryDataManager
.
summary_file_get
(
self
.
writer_id
)
if
self
.
writer
is
None
:
logger
.
error
(
"The writer_id(%r) does not have writer"
,
writer_id
)
self
.
step
=
step
self
.
data_id
=
data_id
self
.
write_lock
=
write_lock
self
.
name
=
"SummaryDataConsumer_"
+
str
(
self
.
step
)
def
run
(
self
):
"""The consumer is process the step data and exit."""
# convert the data to event
# All exceptions need to be caught and end the queue
try
:
logger
.
debug
(
"process(%r) process a data(%r)"
,
self
.
name
,
self
.
step
)
# package the summary event
summary_event
=
package_summary_event
(
self
.
data_id
,
self
.
step
)
# send the event to file
self
.
_write_summary
(
summary_event
)
except
Exception
as
e
:
logger
.
error
(
"Summary data mq consumer exception occurred, value = %r"
,
e
)
def
_write_summary
(
self
,
summary_event
):
"""
Write the summary to event file.
Note:
The write record format:
1 uint64 : data length.
2 uint32 : mask crc value of data length.
3 bytes : data.
4 uint32 : mask crc value of data.
Args:
summary_event (Event): The summary event of proto.
"""
event_str
=
summary_event
.
SerializeToString
()
self
.
write_lock
.
acquire
()
self
.
writer
.
write_event_to_file
(
event_str
)
self
.
writer
.
flush
()
self
.
write_lock
.
release
()
mindspore/train/summary/summary_record.py
浏览文件 @
ea4836e1
...
...
@@ -14,17 +14,22 @@
# ============================================================================
"""Record the summary event."""
import
os
import
re
import
threading
from
mindspore
import
log
as
logger
from
._summary_scheduler
import
WorkerScheduler
,
SummaryDataManager
from
._summary_adapter
import
get_event_file_name
,
package_graph_event
from
._event_writer
import
EventRecord
from
.._utils
import
_make_directory
from
..._c_expression
import
Tensor
from
..._checkparam
import
_check_str_by_regular
from
.._utils
import
_make_directory
from
._event_writer
import
EventWriter
from
._summary_adapter
import
get_event_file_name
,
package_graph_event
,
package_init_event
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
_summary_lock
=
threading
.
Lock
()
# cache the summary data
_summary_tensor_cache
=
{}
_summary_lock
=
threading
.
Lock
()
def
_cache_summary_tensor_data
(
summary
):
...
...
@@ -34,14 +39,18 @@ def _cache_summary_tensor_data(summary):
Args:
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
"""
_summary_lock
.
acquire
()
if
"SummaryRecord"
in
_summary_tensor_cache
:
for
record
in
summary
:
_summary_tensor_cache
[
"SummaryRecord"
].
append
(
record
)
else
:
_summary_tensor_cache
[
"SummaryRecord"
]
=
summary
_summary_lock
.
release
()
return
True
with
_summary_lock
:
for
item
in
summary
:
_summary_tensor_cache
[
item
[
'name'
]]
=
item
[
'data'
]
return
True
def
_get_summary_tensor_data
():
global
_summary_tensor_cache
with
_summary_lock
:
data
=
_summary_tensor_cache
_summary_tensor_cache
=
{}
return
data
class
SummaryRecord
:
...
...
@@ -71,6 +80,7 @@ class SummaryRecord:
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
>>> file_prefix="xxx_", file_suffix="_yyy")
"""
def
__init__
(
self
,
log_dir
,
queue_max_size
=
0
,
...
...
@@ -101,26 +111,18 @@ class SummaryRecord:
self
.
prefix
=
file_prefix
self
.
suffix
=
file_suffix
self
.
network
=
network
self
.
has_graph
=
False
self
.
_closed
=
False
# create the summary writer file
self
.
event_file_name
=
get_event_file_name
(
self
.
prefix
,
self
.
suffix
)
if
self
.
log_path
[
-
1
:]
==
'/'
:
self
.
full_file_name
=
self
.
log_path
+
self
.
event_file_name
else
:
self
.
full_file_name
=
self
.
log_path
+
'/'
+
self
.
event_file_name
try
:
self
.
full_file_name
=
os
.
path
.
realpath
(
self
.
full
_file_name
)
self
.
full_file_name
=
os
.
path
.
join
(
self
.
log_path
,
self
.
event
_file_name
)
except
Exception
as
ex
:
raise
RuntimeError
(
ex
)
self
.
event_writer
=
EventRecord
(
self
.
full_file_name
,
self
.
flush_time
)
self
.
writer_id
=
SummaryDataManager
.
summary_file_set
(
self
.
event_writer
)
self
.
worker_scheduler
=
WorkerScheduler
(
self
.
writer_id
)
self
.
step
=
0
self
.
_closed
=
False
self
.
network
=
network
self
.
has_graph
=
False
self
.
event_writer
=
EventWriter
(
self
.
full_file_name
,
self
.
flush_time
)
self
.
event_writer
.
write
(
package_init_event
().
SerializeToString
())
def
record
(
self
,
step
,
train_network
=
None
):
"""
...
...
@@ -145,42 +147,34 @@ class SummaryRecord:
if
not
isinstance
(
step
,
int
)
or
isinstance
(
step
,
bool
):
raise
ValueError
(
"`step` should be int"
)
# Set the current summary of train step
self
.
step
=
step
if
self
.
network
is
not
None
and
self
.
has_graph
is
False
:
if
self
.
network
is
not
None
and
not
self
.
has_graph
:
graph_proto
=
self
.
network
.
get_func_graph_proto
()
if
graph_proto
is
None
and
train_network
is
not
None
:
graph_proto
=
train_network
.
get_func_graph_proto
()
if
graph_proto
is
None
:
logger
.
error
(
"Failed to get proto for graph"
)
else
:
self
.
event_writer
.
write_event_to_file
(
package_graph_event
(
graph_proto
).
SerializeToString
())
self
.
event_writer
.
flush
()
self
.
event_writer
.
write
(
package_graph_event
(
graph_proto
).
SerializeToString
())
self
.
has_graph
=
True
data
=
_summary_tensor_cache
.
get
(
"SummaryRecord"
)
if
data
is
None
:
if
not
_summary_tensor_cache
:
return
True
data
=
_
summary_tensor_cache
.
get
(
"SummaryRecord"
)
if
data
is
None
:
logger
.
error
(
"The step(%r) does not have record data."
,
s
elf
.
s
tep
)
data
=
_
get_summary_tensor_data
(
)
if
not
data
:
logger
.
error
(
"The step(%r) does not have record data."
,
step
)
return
False
if
self
.
queue_max_size
>
0
and
len
(
data
)
>
self
.
queue_max_size
:
logger
.
error
(
"The size of data record is %r, which is greater than queue_max_size %r."
,
len
(
data
),
self
.
queue_max_size
)
# clean the data of cache
del
_summary_tensor_cache
[
"SummaryRecord"
]
# process the data
self
.
worker_scheduler
.
dispatch
(
self
.
step
,
data
)
# count & flush
self
.
event_writer
.
count_event
()
self
.
event_writer
.
flush_cycle
()
logger
.
debug
(
"Send the summary data to scheduler for saving, step = %d"
,
self
.
step
)
result
=
self
.
_data_convert
(
data
)
if
not
result
:
logger
.
error
(
"The step(%r) summary data is invalid."
,
step
)
return
False
self
.
event_writer
.
write
((
result
,
step
))
logger
.
debug
(
"Send the summary data to scheduler for saving, step = %d"
,
step
)
return
True
@
property
...
...
@@ -196,7 +190,7 @@ class SummaryRecord:
Returns:
String, the full path of log file.
"""
return
self
.
event_writer
.
full_file_name
return
self
.
full_file_name
def
flush
(
self
):
"""
...
...
@@ -224,20 +218,44 @@ class SummaryRecord:
>>> summary_record.close()
"""
if
not
self
.
_closed
:
self
.
_check_data_before_close
()
self
.
worker_scheduler
.
close
()
# event writer flush and close
self
.
event_writer
.
close
()
self
.
_closed
=
True
def
__del__
(
self
):
"""Process exit is called."""
if
hasattr
(
self
,
"worker_scheduler"
):
if
self
.
worker_scheduler
:
self
.
close
()
def
_check_data_before_close
(
self
):
"Check whether there is any data in the cache, and if so, call record"
data
=
_summary_tensor_cache
.
get
(
"SummaryRecord"
)
if
data
is
not
None
:
self
.
record
(
self
.
step
)
def
_data_convert
(
self
,
summary
):
"""Convert the data."""
# convert the summary to numpy
result
=
[]
for
name
,
data
in
summary
.
items
():
# confirm the data is valid
summary_tag
,
summary_type
=
SummaryRecord
.
_parse_from
(
name
)
if
summary_tag
is
None
:
logger
.
error
(
"The data type is invalid, name = %r, tensor = %r"
,
name
,
data
)
return
None
if
isinstance
(
data
,
Tensor
):
result
.
append
({
'name'
:
summary_tag
,
'data'
:
data
.
asnumpy
(),
'_type'
:
summary_type
})
else
:
logger
.
error
(
"The data type is invalid, name = %r, tensor = %r"
,
name
,
data
)
return
None
return
result
@
staticmethod
def
_parse_from
(
name
:
str
=
None
):
"""
Parse the tag and type from name.
Args:
name (str): Format: TAG[:TYPE].
Returns:
Tuple, (summary_tag, summary_type).
"""
if
name
is
None
:
logger
.
error
(
"The name is None"
)
return
None
,
None
match
=
re
.
match
(
r
'(.+)\[:(.+)\]'
,
name
)
if
match
:
return
match
.
groups
()
logger
.
error
(
"The name(%r) format is invalid, expected 'TAG[:TYPE]'."
,
name
)
return
None
,
None
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录