Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8867c67d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8867c67d
编写于
6月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1935 Summary callback as collector for summary and lineage
Merge pull request !1935 from 李鸿章/policy_writer
上级
230963d0
0921c1e5
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
523 addition
and
90 deletion
+523
-90
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+1
-0
mindspore/ccsrc/utils/lineage.proto
mindspore/ccsrc/utils/lineage.proto
+129
-0
mindspore/train/_utils.py
mindspore/train/_utils.py
+36
-0
mindspore/train/summary/_lineage_adapter.py
mindspore/train/summary/_lineage_adapter.py
+39
-0
mindspore/train/summary/_summary_adapter.py
mindspore/train/summary/_summary_adapter.py
+2
-2
mindspore/train/summary/_summary_writer.py
mindspore/train/summary/_summary_writer.py
+79
-0
mindspore/train/summary/_writer_pool.py
mindspore/train/summary/_writer_pool.py
+114
-0
mindspore/train/summary/summary_record.py
mindspore/train/summary/summary_record.py
+123
-73
tests/ut/python/train/summary/test_histogram_summary.py
tests/ut/python/train/summary/test_histogram_summary.py
+0
-15
未找到文件。
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
8867c67d
...
...
@@ -79,6 +79,7 @@ if (ENABLE_DUMP_PROTO)
file
(
GLOB_RECURSE PROTO_PY RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"utils/anf_ir.proto"
"utils/summary.proto"
"utils/lineage.proto"
"utils/checkpoint.proto"
)
ms_protobuf_generate_py
(
PY_SRCS PY_HDRS PY_PYS
${
PROTO_PY
}
)
...
...
mindspore/ccsrc/utils/lineage.proto
0 → 100644
浏览文件 @
8867c67d
// Copyright 2020 Huawei Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
mindspore
.
irpb
;
option
cc_enable_arenas
=
true
;
// Event Protocol buffer, Top define
message
LineageEvent
{
// Timestamp
required
double
wall_time
=
1
;
// The step of train.
optional
int64
step
=
2
;
oneof
what
{
// An event file was started, with the specified version.
// Now version is "Mindspore.Event:1"
string
version
=
3
;
// Train lineage
TrainLineage
train_lineage
=
6
;
// Evaluation lineage
EvaluationLineage
evaluation_lineage
=
7
;
// Dataset graph
DatasetGraph
dataset_graph
=
9
;
// User defined info
UserDefinedInfo
user_defined_info
=
10
;
}
}
// User defined info
message
UserDefinedInfo
{
// repeated user defined info
repeated
UserDefinedInfo
user_info
=
1
;
// key/value which contains both scalar and dict
map
<
string
,
UserDefinedInfo
>
map_dict
=
2
;
map
<
string
,
int32
>
map_int32
=
3
;
map
<
string
,
string
>
map_str
=
4
;
map
<
string
,
double
>
map_double
=
5
;
}
// TrainLineage records infos of a train.
message
TrainLineage
{
message
HyperParameters
{
optional
string
optimizer
=
1
;
optional
float
learning_rate
=
2
;
optional
string
loss_function
=
3
;
optional
int32
epoch
=
4
;
optional
string
parallel_mode
=
5
;
optional
int32
device_num
=
6
;
optional
int32
batch_size
=
8
;
}
message
TrainDataset
{
optional
string
train_dataset_path
=
1
;
optional
int32
train_dataset_size
=
2
;
}
message
Algorithm
{
optional
string
network
=
1
;
optional
float
loss
=
2
;
}
message
Model
{
optional
string
path
=
3
;
optional
int64
size
=
4
;
}
optional
HyperParameters
hyper_parameters
=
1
;
optional
TrainDataset
train_dataset
=
2
;
optional
Algorithm
algorithm
=
3
;
optional
Model
model
=
4
;
}
//EvalLineage records infos of evaluation.
message
EvaluationLineage
{
message
ValidDataset
{
optional
string
valid_dataset_path
=
1
;
optional
int32
valid_dataset_size
=
2
;
}
optional
string
metric
=
2
;
optional
ValidDataset
valid_dataset
=
3
;
}
// DatasetGraph
message
DatasetGraph
{
repeated
DatasetGraph
children
=
1
;
optional
OperationParameter
parameter
=
2
;
repeated
Operation
operations
=
3
;
optional
Operation
sampler
=
4
;
}
message
Operation
{
optional
OperationParameter
operationParam
=
1
;
repeated
int32
size
=
2
;
repeated
float
weights
=
3
;
}
message
OperationParameter
{
map
<
string
,
string
>
mapStr
=
1
;
map
<
string
,
StrList
>
mapStrList
=
2
;
map
<
string
,
bool
>
mapBool
=
3
;
map
<
string
,
int32
>
mapInt
=
4
;
map
<
string
,
double
>
mapDouble
=
5
;
}
message
StrList
{
repeated
string
strValue
=
1
;
}
mindspore/train/_utils.py
浏览文件 @
8867c67d
...
...
@@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype
from
mindspore
import
log
as
logger
from
mindspore.common.api
import
_executor
from
.lineage_pb2
import
DatasetGraph
,
TrainLineage
,
EvaluationLineage
,
UserDefinedInfo
def
_convert_type
(
types
):
"""
...
...
@@ -193,3 +194,38 @@ def _to_full_shapes(shapes, device_num):
new_shape
+=
(
item
,)
new_shapes
.
append
(
new_shape
)
return
new_shapes
def
_check_to_numpy
(
plugin
,
tensor
):
"""Check the tensor and return a numpy.ndarray."""
np_value
=
tensor
.
asnumpy
()
if
plugin
==
'scalar'
:
if
np_value
.
size
==
1
:
return
np_value
raise
ValueError
(
'The tensor holds more than one value, but the scalar plugin expects on value.'
)
if
plugin
==
'image'
:
if
np_value
.
ndim
==
4
:
return
np_value
raise
ValueError
(
'The tensor seems not to hold a valid image.'
)
if
plugin
in
(
'tensor'
,
'histogram'
):
if
np_value
.
ndim
>
0
:
return
np_value
raise
ValueError
(
'The tensor should not be empty.'
)
return
np_value
def
_check_lineage_value
(
plugin
,
value
):
"""Check the lineage value."""
def
raises
(
plugin
,
prototype
):
raise
TypeError
(
f
'Plugin
{
repr
(
plugin
)
}
expects a
{
prototype
.
__name__
}
value.'
)
if
plugin
==
'dataset_graph'
and
not
isinstance
(
value
,
DatasetGraph
):
raises
(
plugin
,
DatasetGraph
)
if
plugin
==
'eval_lineage'
and
not
isinstance
(
value
,
EvaluationLineage
):
raises
(
plugin
,
EvaluationLineage
)
if
plugin
==
'train_lineage'
and
not
isinstance
(
value
,
TrainLineage
):
raises
(
plugin
,
TrainLineage
)
if
plugin
==
'custom_lineage_data'
and
not
isinstance
(
value
,
UserDefinedInfo
):
raises
(
plugin
,
UserDefinedInfo
)
mindspore/train/summary/_lineage_adapter.py
0 → 100644
浏览文件 @
8867c67d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate the lineage event which conform to proto format."""
import
time
from
..lineage_pb2
import
LineageEvent
def
serialize_to_lineage_event
(
name
,
value
):
"""Serialize value to lineage event."""
event
=
LineageEvent
()
event
.
wall_time
=
time
.
time
()
content
=
_get_lineage_content
(
name
,
event
)
content
.
ParseFromString
(
value
)
return
event
.
SerializeToString
()
def
_get_lineage_content
(
name
,
event
):
if
name
==
'dataset_graph'
:
return
event
.
dataset_graph
if
name
==
'eval_lineage'
:
return
event
.
evaluation_lineage
if
name
==
'train_lineage'
:
return
event
.
train_lineage
if
name
==
'custom_lineage_data'
:
return
event
.
user_defined_info
raise
KeyError
(
f
'No such field in LineageEvent'
)
mindspore/train/summary/_summary_adapter.py
浏览文件 @
8867c67d
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Generate the summary event which conform to proto format."""
import
socket
import
platform
import
time
import
numpy
as
np
...
...
@@ -51,7 +51,7 @@ def get_event_file_name(prefix, suffix):
_check_str_by_regular
(
suffix
)
file_name
=
""
time_second
=
str
(
int
(
time
.
time
()))
hostname
=
socket
.
gethostnam
e
()
hostname
=
platform
.
nod
e
()
if
prefix
is
not
None
:
file_name
=
file_name
+
prefix
...
...
mindspore/train/summary/_summary_writer.py
0 → 100644
浏览文件 @
8867c67d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Writes events to disk in a logdir."""
import
os
import
stat
from
..._c_expression
import
EventWriter_
from
._summary_adapter
import
package_init_event
class
BaseWriter
:
"""BaseWriter to be subclass."""
def
__init__
(
self
,
filepath
)
->
None
:
self
.
_filepath
=
filepath
self
.
_writer
:
EventWriter_
=
None
def
init_writer
(
self
):
"""Write some metadata etc."""
@
property
def
writer
(
self
)
->
EventWriter_
:
"""Get the writer."""
if
self
.
_writer
is
not
None
:
return
self
.
_writer
with
open
(
self
.
_filepath
,
'w'
):
os
.
chmod
(
self
.
_filepath
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
self
.
_writer
=
EventWriter_
(
self
.
_filepath
)
self
.
init_writer
()
return
self
.
_writer
def
write
(
self
,
plugin
,
mode
,
data
):
"""Write data to file."""
raise
NotImplementedError
()
def
flush
(
self
):
"""Flush the writer."""
if
self
.
_writer
is
not
None
:
self
.
_writer
.
Flush
()
def
close
(
self
):
"""Close the writer."""
if
self
.
_writer
is
not
None
:
self
.
_writer
.
Shut
()
class
SummaryWriter
(
BaseWriter
):
"""SummaryWriter for write summaries."""
def
init_writer
(
self
):
"""Write some metadata etc."""
self
.
writer
.
Write
(
package_init_event
().
SerializeToString
())
def
write
(
self
,
plugin
,
mode
,
data
):
"""Write data to file."""
if
plugin
in
(
'summary'
,
'graph'
):
self
.
writer
.
Write
(
data
)
class
LineageWriter
(
BaseWriter
):
"""LineageWriter for write lineage."""
def
write
(
self
,
plugin
,
mode
,
data
):
"""Write data to file."""
if
plugin
in
(
'dataset_graph'
,
'train_lineage'
,
'eval_lineage'
,
'custom_lineage_data'
):
self
.
writer
.
Write
(
data
)
mindspore/train/summary/_
event_writer
.py
→
mindspore/train/summary/_
writer_pool
.py
浏览文件 @
8867c67d
...
...
@@ -12,74 +12,100 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Write
s events to disk in a logdir
."""
"""Write
events to disk in a base directory
."""
import
os
import
stat
from
collections
import
deque
from
multiprocessing
import
Pool
,
Process
,
Queue
,
cpu_count
from
..._c_expression
import
EventWriter_
from
._summary_adapter
import
package_summary_event
from
._lineage_adapter
import
serialize_to_lineage_event
from
._summary_adapter
import
package_graph_event
,
package_summary_event
from
._summary_writer
import
SummaryWriter
,
LineageWriter
def
_pack
(
result
,
step
):
summary_event
=
package_summary_event
(
result
,
step
)
return
summary_event
.
SerializeToString
()
def
_pack_data
(
datadict
):
"""Pack data according to which plugin."""
result
=
[]
summaries
,
step
,
mode
=
[],
None
,
None
for
plugin
,
datalist
in
datadict
.
items
():
for
data
in
datalist
:
if
plugin
==
'graph'
:
result
.
append
([
plugin
,
data
.
get
(
'mode'
),
package_graph_event
(
data
.
get
(
'value'
)).
SerializeToString
()])
elif
plugin
in
(
'train_lineage'
,
'eval_lineage'
,
'custom_lineage_data'
,
'dataset_graph'
):
result
.
append
([
plugin
,
data
.
get
(
'mode'
),
serialize_to_lineage_event
(
plugin
,
data
.
get
(
'value'
))])
elif
plugin
in
(
'scalar'
,
'tensor'
,
'histogram'
,
'image'
):
summaries
.
append
({
'_type'
:
plugin
.
title
(),
'name'
:
data
.
get
(
'tag'
),
'data'
:
data
.
get
(
'value'
)})
step
=
data
.
get
(
'step'
)
mode
=
data
.
get
(
'mode'
)
if
summaries
:
result
.
append
([
'summary'
,
mode
,
package_summary_event
(
summaries
,
step
).
SerializeToString
()])
return
result
class
EventWriter
(
Process
):
class
WriterPool
(
Process
):
"""
Creates a `EventWriter` and write event to
file.
Use a set of pooled resident processes for writing a list of
file.
Args:
filepath (str): Summary event file path and file name
.
f
lush_interval (int): The flush seconds to flush the pending events to disk. Default: 120
.
base_dir (str): The base directory to hold all the files
.
f
ilelist (str): The mapping from short name to long filename
.
"""
def
__init__
(
self
,
filepath
:
str
,
flush_interval
:
in
t
)
->
None
:
def
__init__
(
self
,
base_dir
,
**
filedic
t
)
->
None
:
super
().
__init__
()
_
=
flush_interval
with
open
(
filepath
,
'w'
):
os
.
chmod
(
filepath
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
self
.
_writer
=
EventWriter_
(
filepath
)
self
.
_base_dir
,
self
.
_filedict
=
base_dir
,
filedict
self
.
_queue
=
Queue
(
cpu_count
()
*
2
)
self
.
start
()
def
run
(
self
):
writers
=
self
.
_get_writers
()
with
Pool
(
min
(
cpu_count
(),
32
)
)
as
pool
:
with
Pool
()
as
pool
:
deq
=
deque
()
while
True
:
while
deq
and
deq
[
0
].
ready
():
self
.
_writer
.
Write
(
deq
.
popleft
().
get
())
for
plugin
,
mode
,
data
in
deq
.
popleft
().
get
():
for
writer
in
writers
:
writer
.
write
(
plugin
,
mode
,
data
)
if
not
self
.
_queue
.
empty
():
action
,
data
=
self
.
_queue
.
get
()
if
action
==
'WRITE'
:
if
not
isinstance
(
data
,
(
str
,
bytes
)):
deq
.
append
(
pool
.
apply_async
(
_pack
,
data
))
else
:
self
.
_writer
.
Write
(
data
)
deq
.
append
(
pool
.
apply_async
(
_pack_data
,
(
data
,)))
elif
action
==
'FLUSH'
:
self
.
_writer
.
Flush
()
for
writer
in
writers
:
writer
.
flush
()
elif
action
==
'END'
:
break
for
res
in
deq
:
self
.
_writer
.
Write
(
res
.
get
())
for
result
in
deq
:
for
plugin
,
mode
,
data
in
result
.
get
():
for
writer
in
writers
:
writer
.
write
(
plugin
,
mode
,
data
)
self
.
_writer
.
Shut
()
for
writer
in
writers
:
writer
.
close
()
def
_get_writers
(
self
):
writers
=
[]
for
plugin
,
filename
in
self
.
_filedict
.
items
():
filepath
=
os
.
path
.
join
(
self
.
_base_dir
,
filename
)
if
plugin
==
'summary'
:
writers
.
append
(
SummaryWriter
(
filepath
))
elif
plugin
==
'lineage'
:
writers
.
append
(
LineageWriter
(
filepath
))
return
writers
def
write
(
self
,
data
)
->
None
:
"""
Write the event to file.
Args:
name (str): The key of a specified file.
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self
.
_queue
.
put
((
'WRITE'
,
data
))
def
flush
(
self
):
"""Flush the writer."""
"""Flush the writer
and sync data to disk
."""
self
.
_queue
.
put
((
'FLUSH'
,
None
))
def
close
(
self
)
->
None
:
...
...
mindspore/train/summary/summary_record.py
浏览文件 @
8867c67d
...
...
@@ -21,9 +21,9 @@ from mindspore import log as logger
from
..._c_expression
import
Tensor
from
..._checkparam
import
_check_str_by_regular
from
.._utils
import
_make_directory
from
._
event_writer
import
EventWriter
from
._
summary_adapter
import
get_event_file_name
,
package_graph_event
,
package_init_event
from
.._utils
import
_make_directory
,
_check_to_numpy
,
_check_lineage_value
from
._
summary_adapter
import
get_event_file_name
,
package_graph_event
from
._
writer_pool
import
WriterPool
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
...
...
@@ -53,16 +53,20 @@ def _get_summary_tensor_data():
return
data
def
_dictlist
():
from
collections
import
defaultdict
return
defaultdict
(
list
)
class
SummaryRecord
:
"""
SummaryRecord is used to record the summary
value
.
SummaryRecord is used to record the summary
data and lineage data
.
Note:
The API will create an event file in a given directory and add summaries and events to it.
It writes the event log to a file by executing the record method. In addition,
if the SummaryRecord object is created and the summary operator is used in the network,
even if the record method is not called, the event in the cache will be written to the
file at the end of execution. Make sure to close the SummaryRecord object at the end.
The API will create a summary file and a lineage file lazily in a given directory and writes data to them.
It writes the data to files by executing the record method. In addition to record the data bubbled up from
the network by defining the summary operators, SummaryRecord also supports to record extra data which
can be added by calling add_value. Finally, make sure to close the SummaryRecord object at the end.
Args:
log_dir (str): The log_dir is a directory location to save the summary.
...
...
@@ -89,10 +93,12 @@ class SummaryRecord:
file_suffix
=
"_MS"
,
network
=
None
):
self
.
_event_writer
,
self
.
_closed
=
None
,
False
self
.
_closed
,
self
.
_mode
=
False
,
'train'
self
.
_data_pool
=
_dictlist
()
_check_str_by_regular
(
file_prefix
)
_check_str_by_regular
(
file_suffix
)
self
.
log_path
=
_make_directory
(
log_dir
)
if
not
isinstance
(
queue_max_size
,
int
)
or
not
isinstance
(
flush_time
,
int
):
...
...
@@ -123,16 +129,12 @@ class SummaryRecord:
except
Exception
as
ex
:
raise
RuntimeError
(
ex
)
def
_init_event_writer
(
self
):
"""Init event writer and write metadata."""
event_writer
=
EventWriter
(
self
.
full_file_name
,
self
.
flush_time
)
event_writer
.
write
(
package_init_event
().
SerializeToString
())
return
event_writer
self
.
_event_writer
=
WriterPool
(
log_dir
,
summary
=
self
.
full_file_name
,
lineage
=
get_event_file_name
(
'events'
,
'_lineage'
))
def
__enter__
(
self
):
"""Enter the context manager."""
if
not
self
.
_event_writer
:
self
.
_event_writer
=
self
.
_init_event_writer
()
if
self
.
_closed
:
raise
ValueError
(
'SummaryRecord has been closed.'
)
return
self
...
...
@@ -141,6 +143,76 @@ class SummaryRecord:
"""Exit the context manager."""
self
.
close
()
def
set_mode
(
self
,
mode
):
"""
Set the mode for the recorder to be aware. The mode is set 'train' by default.
Args:
mode (str): The mode to set, which should be 'train' or 'eval'.
Raises:
ValueError: When the mode is not recognized.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> summary_record.set_mode('eval')
"""
mode_spec
=
'train'
,
'eval'
if
mode
not
in
mode_spec
:
raise
ValueError
(
f
'
{
repr
(
mode
)
}
is not a recognized mode.'
)
self
.
_mode
=
mode
def
add_value
(
self
,
plugin
,
name
,
value
):
"""
Add value to be record later on.
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
the name should be the tag name, and the value should be a Tensor.
When the plugin plugin is 'graph', the value should be a GraphProto.
When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage',
or 'custom_lineage_data', the value should be a proto message.
Args:
plugin (str): The plugin for the value.
name (str): The name for the value.
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]):
\
The value to store.
- GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'.
- Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'.
- TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'.
- EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'.
- DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'.
- UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'.
Raises:
ValueError: When the name is not valid.
TypeError: When the value is not a Tensor.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> summary_record.add_value('scalar', 'loss', Tensor(0.1))
"""
if
plugin
in
(
'tensor'
,
'scalar'
,
'image'
,
'histogram'
):
if
not
name
or
not
isinstance
(
name
,
str
):
raise
ValueError
(
f
'
{
repr
(
name
)
}
is not a valid tag name.'
)
if
not
isinstance
(
value
,
Tensor
):
raise
TypeError
(
f
'Expect the value to be Tensor, but got
{
type
(
value
).
__name__
}
'
)
np_value
=
_check_to_numpy
(
plugin
,
value
)
self
.
_data_pool
[
plugin
].
append
(
dict
(
tag
=
name
,
mode
=
self
.
_mode
,
value
=
np_value
))
elif
plugin
in
(
'train_lineage'
,
'eval_lineage'
,
'dataset_graph'
,
'custom_lineage_data'
):
_check_lineage_value
(
plugin
,
value
)
self
.
_data_pool
[
plugin
].
append
(
dict
(
mode
=
self
.
_mode
,
value
=
value
.
SerializeToString
()))
elif
plugin
==
'graph'
:
package_graph_event
(
value
)
self
.
_data_pool
[
plugin
].
append
(
dict
(
mode
=
self
.
_mode
,
value
=
value
))
else
:
raise
ValueError
(
f
'No such plugin of
{
repr
(
plugin
)
}
'
)
def
record
(
self
,
step
,
train_network
=
None
):
"""
Record the summary.
...
...
@@ -149,12 +221,12 @@ class SummaryRecord:
step (int): Represents training step number.
train_network (Cell): The network that called the callback.
Returns:
bool, whether the record process is successful or not.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> summary_record.record(step=2)
Returns:
bool, whether the record process is successful or not.
"""
logger
.
info
(
"SummaryRecord step is %r."
,
step
)
if
self
.
_closed
:
...
...
@@ -163,10 +235,6 @@ class SummaryRecord:
if
not
isinstance
(
step
,
int
)
or
isinstance
(
step
,
bool
):
raise
ValueError
(
"`step` should be int"
)
# Set the current summary of train step
if
not
self
.
_event_writer
:
self
.
_event_writer
=
self
.
_init_event_writer
()
logger
.
warning
(
'SummaryRecord should be used as context manager for a with statement.'
)
if
self
.
network
is
not
None
and
not
self
.
has_graph
:
graph_proto
=
self
.
network
.
get_func_graph_proto
()
if
graph_proto
is
None
and
train_network
is
not
None
:
...
...
@@ -174,39 +242,48 @@ class SummaryRecord:
if
graph_proto
is
None
:
logger
.
error
(
"Failed to get proto for graph"
)
else
:
self
.
_event_writer
.
write
(
package_graph_event
(
graph_proto
).
SerializeToString
()
)
self
.
_event_writer
.
write
(
{
'graph'
:
[{
'step'
:
step
,
'value'
:
graph_proto
}]}
)
self
.
has_graph
=
True
if
not
_summary_tensor_cache
:
return
True
data
=
_get_summary_tensor_data
()
if
not
data
:
logger
.
info
(
"The step(%r) does not have record data."
,
step
)
return
False
if
self
.
queue_max_size
>
0
and
len
(
data
)
>
self
.
queue_max_size
:
logger
.
error
(
"The size of data record is %r, which is greater than queue_max_size %r."
,
len
(
data
),
self
.
queue_max_size
)
# process the data
result
=
self
.
_data_convert
(
data
)
if
not
result
:
logger
.
error
(
"The step(%r) summary data is invalid."
,
step
)
return
False
self
.
_event_writer
.
write
((
result
,
step
))
logger
.
debug
(
"Send the summary data to scheduler for saving, step = %d"
,
step
)
if
self
.
_mode
==
'train'
:
self
.
_add_summary_tensor_data
()
self
.
_event_writer
.
write
(
self
.
_consume_data_pool
(
step
))
return
True
def
_add_summary_tensor_data
(
self
):
summary_data
=
_get_summary_tensor_data
()
if
not
summary_data
:
logger
.
debug
(
f
'No summary data bubbled from the network.'
)
for
name
,
tensor
in
summary_data
.
items
():
tag
,
plugin
=
SummaryRecord
.
_parse_from
(
name
)
if
(
tag
,
plugin
)
==
(
None
,
None
):
logger
.
warning
(
"The name(%r) is invalid, expected 'TAG[:TYPE]'."
,
name
)
else
:
self
.
add_value
(
plugin
.
lower
(),
tag
,
tensor
)
def
_consume_data_pool
(
self
,
step
):
try
:
for
values
in
self
.
_data_pool
.
values
():
for
value
in
values
:
value
[
'step'
]
=
step
return
self
.
_data_pool
finally
:
self
.
_data_pool
=
_dictlist
()
@
property
def
log_dir
(
self
):
"""
Get the full path of the log file.
Returns:
str, the full path of log file.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> print(summary_record.log_dir)
Returns:
String, the full path of log file.
"""
return
self
.
full_file_name
...
...
@@ -235,46 +312,19 @@ class SummaryRecord:
"""
if
not
self
.
_closed
and
self
.
_event_writer
:
# event writer flush and close
logger
.
info
(
'Please wait it may take quite some time to finish writing and closing.'
)
self
.
_event_writer
.
close
()
self
.
_closed
=
True
def
__del__
(
self
)
->
None
:
self
.
close
()
def
_data_convert
(
self
,
summary
):
"""Convert the data."""
# convert the summary to numpy
result
=
[]
for
name
,
data
in
summary
.
items
():
# confirm the data is valid
summary_tag
,
summary_type
=
SummaryRecord
.
_parse_from
(
name
)
if
summary_tag
is
None
:
logger
.
error
(
"The data type is invalid, name = %r, tensor = %r"
,
name
,
data
)
return
None
if
isinstance
(
data
,
Tensor
):
result
.
append
({
'name'
:
summary_tag
,
'data'
:
data
.
asnumpy
(),
'_type'
:
summary_type
})
else
:
logger
.
error
(
"The data type is invalid, name = %r, tensor = %r"
,
name
,
data
)
return
None
return
result
@
staticmethod
def
_parse_from
(
name
:
str
=
None
):
"""
Parse the tag and type from name.
Args:
name (str): Format: TAG[:TYPE].
Returns:
Tuple, (summary_tag, summary_type).
"""
if
name
is
None
:
logger
.
error
(
"The name is None"
)
"""Parse the tag and type from name."""
if
not
isinstance
(
name
,
str
):
return
None
,
None
match
=
re
.
match
(
r
'(.+)\[:(.+)\]'
,
name
)
if
match
:
return
match
.
groups
()
logger
.
error
(
"The name(%r) format is invalid, expected 'TAG[:TYPE]'."
,
name
)
return
None
,
None
tests/ut/python/train/summary/test_histogram_summary.py
浏览文件 @
8867c67d
...
...
@@ -84,21 +84,6 @@ def test_histogram_multi_summary():
event
=
reader
.
read_event
()
assert
event
.
summary
.
value
[
0
].
histogram
.
count
==
size
def
test_histogram_summary_scalar_tensor
():
"""Test histogram summary, input is a scalar tensor."""
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
SummaryRecord
(
tmp_dir
,
file_suffix
=
"_MS_HISTOGRAM"
)
as
test_writer
:
test_data
=
_wrap_test_data
(
Tensor
(
1
))
_cache_summary_tensor_data
(
test_data
)
test_writer
.
record
(
step
=
1
)
file_name
=
os
.
path
.
join
(
tmp_dir
,
test_writer
.
event_file_name
)
with
SummaryReader
(
file_name
)
as
reader
:
event
=
reader
.
read_event
()
assert
event
.
summary
.
value
[
0
].
histogram
.
count
==
1
def
test_histogram_summary_empty_tensor
():
"""Test histogram summary, input is an empty tensor."""
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录