Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9514b52a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9514b52a
编写于
6月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2147 Add a callback named SummaryCollector and delete SummaryStep callback
Merge pull request !2147 from ougongchang/master
上级
5dac9c4c
939cd29d
变更
13
展开全部
显示空白变更内容
内联
并排
Showing
13 changed file
with
1217 addition
and
75 deletion
+1217
-75
mindspore/train/_utils.py
mindspore/train/_utils.py
+21
-0
mindspore/train/callback/__init__.py
mindspore/train/callback/__init__.py
+3
-2
mindspore/train/callback/_dataset_graph.py
mindspore/train/callback/_dataset_graph.py
+128
-0
mindspore/train/callback/_summary_collector.py
mindspore/train/callback/_summary_collector.py
+786
-0
mindspore/train/model.py
mindspore/train/model.py
+17
-1
mindspore/train/summary/enum.py
mindspore/train/summary/enum.py
+43
-0
tests/st/summary/test_gpu_summary.py
tests/st/summary/test_gpu_summary.py
+1
-1
tests/ut/python/train/summary/test_graph_summary.py
tests/ut/python/train/summary/test_graph_summary.py
+7
-24
tests/ut/python/train/summary/test_image_summary.py
tests/ut/python/train/summary/test_image_summary.py
+8
-8
tests/ut/python/train/summary/test_summary.py
tests/ut/python/train/summary/test_summary.py
+0
-26
tests/ut/python/train/summary/test_summary_abnormal_input.py
tests/ut/python/train/summary/test_summary_abnormal_input.py
+12
-6
tests/ut/python/train/summary/test_summary_collector.py
tests/ut/python/train/summary/test_summary_collector.py
+184
-0
tests/ut/python/train/test_training.py
tests/ut/python/train/test_training.py
+7
-7
未找到文件。
mindspore/train/_utils.py
浏览文件 @
9514b52a
...
@@ -14,7 +14,10 @@
...
@@ -14,7 +14,10 @@
# ============================================================================
# ============================================================================
"""Train utility."""
"""Train utility."""
import
os
import
os
from
collections.abc
import
Iterable
import
numpy
as
np
import
numpy
as
np
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.dtype
import
dtype_to_nptype
,
pytype_to_dtype
from
mindspore.common.dtype
import
dtype_to_nptype
,
pytype_to_dtype
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
...
@@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor):
...
@@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor):
raise
ValueError
(
'The tensor should not be empty.'
)
raise
ValueError
(
'The tensor should not be empty.'
)
return
np_value
return
np_value
def
_check_lineage_value
(
plugin
,
value
):
def
_check_lineage_value
(
plugin
,
value
):
"""Check the lineage value."""
"""Check the lineage value."""
def
raises
(
plugin
,
prototype
):
def
raises
(
plugin
,
prototype
):
...
@@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value):
...
@@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value):
if
plugin
==
'custom_lineage_data'
and
not
isinstance
(
value
,
UserDefinedInfo
):
if
plugin
==
'custom_lineage_data'
and
not
isinstance
(
value
,
UserDefinedInfo
):
raises
(
plugin
,
UserDefinedInfo
)
raises
(
plugin
,
UserDefinedInfo
)
def
check_value_type
(
arg_name
,
arg_value
,
valid_types
):
"""Checks whether a value is instance of some types."""
valid_types
=
tuple
(
valid_types
)
if
isinstance
(
valid_types
,
Iterable
)
else
(
valid_types
,)
is_valid
=
True
# bool is subclass of int, so for a bool value, we need to extra check
if
isinstance
(
arg_value
,
int
)
and
isinstance
(
arg_value
,
bool
)
and
bool
not
in
valid_types
:
is_valid
=
False
if
not
isinstance
(
arg_value
,
valid_types
):
is_valid
=
False
if
not
is_valid
:
raise
TypeError
(
f
'For `
{
arg_name
}
` the type should be a valid type of
{
[
t
.
__name__
for
t
in
valid_types
]
}
, '
f
'bug got
{
type
(
arg_value
).
__name__
}
.'
)
mindspore/train/callback/__init__.py
浏览文件 @
9514b52a
...
@@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig
...
@@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig
from
._checkpoint
import
CheckpointManager
as
_CheckpointManager
from
._checkpoint
import
CheckpointManager
as
_CheckpointManager
from
._checkpoint
import
ModelCheckpoint
from
._checkpoint
import
ModelCheckpoint
from
._loss_monitor
import
LossMonitor
from
._loss_monitor
import
LossMonitor
from
._summary_step
import
SummaryStep
from
._time_monitor
import
TimeMonitor
from
._time_monitor
import
TimeMonitor
from
._summary_collector
import
SummaryCollector
__all__
=
[
"Callback"
,
"LossMonitor"
,
"TimeMonitor"
,
"ModelCheckpoint"
,
"SummaryStep"
,
"CheckpointConfig"
,
"RunContext"
]
__all__
=
[
"Callback"
,
"LossMonitor"
,
"TimeMonitor"
,
"ModelCheckpoint"
,
"SummaryCollector"
,
"CheckpointConfig"
,
"RunContext"
]
mindspore/train/callback/_dataset_graph.py
0 → 100644
浏览文件 @
9514b52a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define dataset graph related operations."""
import
json
from
importlib
import
import_module
from
mindspore.train
import
lineage_pb2
class
DatasetGraph
:
"""Handle the data graph and packages it into binary data."""
def
package_dataset_graph
(
self
,
dataset
):
"""
packages dataset graph into binary data
Args:
dataset (MindData): refer to MindDataset
Returns:
DatasetGraph, a object of lineage_pb2.DatasetGraph.
"""
dataset_package
=
import_module
(
'mindspore.dataset'
)
dataset_dict
=
dataset_package
.
serialize
(
dataset
)
json_str
=
json
.
dumps
(
dataset_dict
,
indent
=
2
)
dataset_dict
=
json
.
loads
(
json_str
)
dataset_graph_proto
=
lineage_pb2
.
DatasetGraph
()
if
"children"
in
dataset_dict
:
children
=
dataset_dict
.
pop
(
"children"
)
if
children
:
self
.
_package_children
(
children
=
children
,
message
=
dataset_graph_proto
)
self
.
_package_current_dataset
(
operation
=
dataset_dict
,
message
=
dataset_graph_proto
)
return
dataset_graph_proto
def
_package_children
(
self
,
children
,
message
):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for
child
in
children
:
if
child
:
child_graph_message
=
getattr
(
message
,
"children"
).
add
()
grandson
=
child
.
pop
(
"children"
)
if
grandson
:
self
.
_package_children
(
children
=
grandson
,
message
=
child_graph_message
)
# package other parameters
self
.
_package_current_dataset
(
operation
=
child
,
message
=
child_graph_message
)
def
_package_current_dataset
(
self
,
operation
,
message
):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
value
and
key
==
"operations"
:
for
operator
in
value
:
self
.
_package_enhancement_operation
(
operator
,
message
.
operations
.
add
()
)
elif
value
and
key
==
"sampler"
:
self
.
_package_enhancement_operation
(
value
,
message
.
sampler
)
else
:
self
.
_package_parameter
(
key
,
value
,
message
.
parameter
)
def
_package_enhancement_operation
(
self
,
operation
,
message
):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
isinstance
(
value
,
list
):
if
all
(
isinstance
(
ele
,
int
)
for
ele
in
value
):
message
.
size
.
extend
(
value
)
else
:
message
.
weights
.
extend
(
value
)
else
:
self
.
_package_parameter
(
key
,
value
,
message
.
operationParam
)
@
staticmethod
def
_package_parameter
(
key
,
value
,
message
):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if
isinstance
(
value
,
str
):
message
.
mapStr
[
key
]
=
value
elif
isinstance
(
value
,
bool
):
message
.
mapBool
[
key
]
=
value
elif
isinstance
(
value
,
int
):
message
.
mapInt
[
key
]
=
value
elif
isinstance
(
value
,
float
):
message
.
mapDouble
[
key
]
=
value
elif
isinstance
(
value
,
list
)
and
key
!=
"operations"
:
if
value
:
replace_value_list
=
list
(
map
(
lambda
x
:
""
if
x
is
None
else
x
,
value
))
message
.
mapStrList
[
key
].
strValue
.
extend
(
replace_value_list
)
elif
value
is
None
:
message
.
mapStr
[
key
]
=
"None"
else
:
raise
ValueError
(
f
"Parameter
{
key
}
is not supported in event package."
)
mindspore/train/callback/_summary_collector.py
0 → 100644
浏览文件 @
9514b52a
此差异已折叠。
点击以展开。
mindspore/train/model.py
浏览文件 @
9514b52a
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""Model."""
"""Model."""
from
collections.abc
import
Iterable
import
numpy
as
np
import
numpy
as
np
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
@@ -345,7 +347,8 @@ class Model:
...
@@ -345,7 +347,8 @@ class Model:
cb_params
.
parallel_mode
=
self
.
_parallel_mode
cb_params
.
parallel_mode
=
self
.
_parallel_mode
cb_params
.
device_number
=
self
.
_device_number
cb_params
.
device_number
=
self
.
_device_number
cb_params
.
train_dataset
=
train_dataset
cb_params
.
train_dataset
=
train_dataset
cb_params
.
list_callback
=
callbacks
cb_params
.
list_callback
=
self
.
_transform_callbacks
(
callbacks
)
cb_params
.
train_dataset_element
=
None
# build callback list
# build callback list
with
_CallbackManager
(
callbacks
)
as
list_callback
:
with
_CallbackManager
(
callbacks
)
as
list_callback
:
...
@@ -358,6 +361,17 @@ class Model:
...
@@ -358,6 +361,17 @@ class Model:
else
:
else
:
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
@
staticmethod
def
_transform_callbacks
(
callbacks
):
"""Transform callback to a list."""
if
callbacks
is
None
:
return
[]
if
isinstance
(
callbacks
,
Iterable
):
return
list
(
callbacks
)
return
[
callbacks
]
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
"""
"""
Training process. The data would be passed to network through dataset channel.
Training process. The data would be passed to network through dataset channel.
...
@@ -449,6 +463,7 @@ class Model:
...
@@ -449,6 +463,7 @@ class Model:
scaling_sens
=
self
.
_get_scaling_sens
()
scaling_sens
=
self
.
_get_scaling_sens
()
next_element
=
tuple
(
next_element
)
+
(
Tensor
(
scaling_sens
,
mstype
.
float32
),)
next_element
=
tuple
(
next_element
)
+
(
Tensor
(
scaling_sens
,
mstype
.
float32
),)
cb_params
.
train_dataset_element
=
next_element
outputs
=
self
.
_train_network
(
*
next_element
)
outputs
=
self
.
_train_network
(
*
next_element
)
cb_params
.
net_outputs
=
outputs
cb_params
.
net_outputs
=
outputs
if
self
.
_loss_scale_manager
and
self
.
_loss_scale_manager
.
get_drop_overflow_update
():
if
self
.
_loss_scale_manager
and
self
.
_loss_scale_manager
.
get_drop_overflow_update
():
...
@@ -628,6 +643,7 @@ class Model:
...
@@ -628,6 +643,7 @@ class Model:
cb_params
.
batch_num
=
valid_dataset
.
get_dataset_size
()
cb_params
.
batch_num
=
valid_dataset
.
get_dataset_size
()
cb_params
.
mode
=
"eval"
cb_params
.
mode
=
"eval"
cb_params
.
cur_step_num
=
0
cb_params
.
cur_step_num
=
0
cb_params
.
list_callback
=
self
.
_transform_callbacks
(
callbacks
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
phase
=
'eval'
self
.
_eval_network
.
phase
=
'eval'
...
...
mindspore/train/
callback/_summary_step
.py
→
mindspore/train/
summary/enum
.py
浏览文件 @
9514b52a
...
@@ -12,45 +12,32 @@
...
@@ -12,45 +12,32 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""SummaryStep Callback class."""
"""Summary's enumeration file."""
from
enum
import
Enum
from
._callback
import
Callback
class
BaseEnum
(
Enum
):
"""The base enum class."""
class
SummaryStep
(
Callback
):
@
classmethod
"""
def
to_list
(
cls
):
The summary callback class.
"""Converts the enumeration into a list."""
return
[
member
.
value
for
member
in
cls
.
__members__
.
values
()]
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def
__init__
(
self
,
summary
,
flush_step
=
10
):
class
PluginEnum
(
BaseEnum
):
super
(
SummaryStep
,
self
).
__init__
()
"""The list of plugins currently supported by the summary."""
if
not
isinstance
(
flush_step
,
int
)
or
isinstance
(
flush_step
,
bool
)
or
flush_step
<=
0
:
GRAPH
=
'graph'
raise
ValueError
(
"`flush_step` should be int and greater than 0"
)
SCALAR
=
'scalar'
self
.
_summary
=
summary
IMAGE
=
'image'
self
.
_flush_step
=
flush_step
TENSOR
=
'tensor'
HISTOGRAM
=
'histogram'
TRAIN_LINEAGE
=
'train_lineage'
EVAL_LINEAGE
=
'eval_lineage'
DATASET_GRAPH
=
'dataset_graph'
def
__enter__
(
self
):
self
.
_summary
.
__enter__
()
return
self
def
__exit__
(
self
,
*
err
):
class
ModeEnum
(
BaseEnum
):
return
self
.
_summary
.
__exit__
(
*
err
)
"""The modes currently supported by the summary."""
TRAIN
=
'train'
def
step_end
(
self
,
run_context
):
EVAL
=
'eval'
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params
=
run_context
.
original_args
()
if
cb_params
.
cur_step_num
%
self
.
_flush_step
==
0
:
self
.
_summary
.
record
(
cb_params
.
cur_step_num
,
cb_params
.
train_network
)
@
property
def
summary_file_name
(
self
):
return
self
.
_summary
.
full_file_name
tests/st/summary/test_gpu_summary.py
浏览文件 @
9514b52a
...
@@ -75,7 +75,7 @@ class TestGpuSummary:
...
@@ -75,7 +75,7 @@ class TestGpuSummary:
if
not
os
.
path
.
exists
(
self
.
summary_dir
):
if
not
os
.
path
.
exists
(
self
.
summary_dir
):
os
.
mkdir
(
self
.
summary_dir
)
os
.
mkdir
(
self
.
summary_dir
)
def
teardown_
em
thod
(
self
):
def
teardown_
me
thod
(
self
):
"""Run after method."""
"""Run after method."""
if
os
.
path
.
exists
(
self
.
summary_dir
):
if
os
.
path
.
exists
(
self
.
summary_dir
):
shutil
.
rmtree
(
self
.
summary_dir
)
shutil
.
rmtree
(
self
.
summary_dir
)
...
...
tests/ut/python/train/summary/test_graph_summary.py
浏览文件 @
9514b52a
...
@@ -20,8 +20,8 @@ import numpy as np
...
@@ -20,8 +20,8 @@ import numpy as np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Model
,
context
from
mindspore
import
Model
,
context
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.
callback
import
SummaryStep
from
mindspore.train.
summary
import
SummaryRecord
from
mindspore.train.
summary.summary_record
import
SummaryRecord
from
mindspore.train.
callback
import
SummaryCollector
from
.....dataset_mock
import
MindData
from
.....dataset_mock
import
MindData
CUR_DIR
=
os
.
getcwd
()
CUR_DIR
=
os
.
getcwd
()
...
@@ -107,16 +107,9 @@ def test_graph_summary_sample():
...
@@ -107,16 +107,9 @@ def test_graph_summary_sample():
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
model
.
_train_network
)
as
test_writer
:
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
model
.
_train_network
)
as
test_writer
:
model
.
train
(
2
,
dataset
)
model
.
train
(
2
,
dataset
)
# step 2: create the Event
for
i
in
range
(
1
,
5
):
for
i
in
range
(
1
,
5
):
test_writer
.
record
(
i
)
test_writer
.
record
(
i
)
# step 3: send the event to mq
# step 4: accept the event and write the file
log
.
debug
(
"finished test_graph_summary_sample"
)
def
test_graph_summary_callback
():
def
test_graph_summary_callback
():
dataset
=
get_dataset
()
dataset
=
get_dataset
()
...
@@ -125,18 +118,8 @@ def test_graph_summary_callback():
...
@@ -125,18 +118,8 @@ def test_graph_summary_callback():
optim
=
Momentum
(
net
.
trainable_params
(),
0.1
,
0.9
)
optim
=
Momentum
(
net
.
trainable_params
(),
0.1
,
0.9
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
model
.
_train_network
)
as
test_writer
:
summary_collector
=
SummaryCollector
(
SUMMARY_DIR
,
summary_cb
=
SummaryStep
(
test_writer
,
1
)
collect_freq
=
1
,
model
.
train
(
2
,
dataset
,
callbacks
=
summary_cb
)
keep_default_action
=
False
,
collect_specified_data
=
{
'collect_graph'
:
True
})
model
.
train
(
1
,
dataset
,
callbacks
=
[
summary_collector
])
def
test_graph_summary_callback2
():
dataset
=
get_dataset
()
net
=
Net
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optim
=
Momentum
(
net
.
trainable_params
(),
0.1
,
0.9
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
net
)
as
test_writer
:
summary_cb
=
SummaryStep
(
test_writer
,
1
)
model
.
train
(
2
,
dataset
,
callbacks
=
summary_cb
)
tests/ut/python/train/summary/test_image_summary.py
浏览文件 @
9514b52a
...
@@ -26,9 +26,8 @@ import mindspore.nn as nn
...
@@ -26,9 +26,8 @@ import mindspore.nn as nn
from
mindspore
import
Model
,
context
from
mindspore
import
Model
,
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback
import
SummaryStep
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
_cache_summary_tensor_data
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
\
from
mindspore.train.callback
import
Callback
_cache_summary_tensor_data
from
.....dataset_mock
import
MindData
from
.....dataset_mock
import
MindData
CUR_DIR
=
os
.
getcwd
()
CUR_DIR
=
os
.
getcwd
()
...
@@ -155,7 +154,8 @@ def get_dataset():
...
@@ -155,7 +154,8 @@ def get_dataset():
return
dataset
return
dataset
class
ImageSummaryCallback
:
class
ImageSummaryCallback
(
Callback
):
"""Image summary callback."""
def
__init__
(
self
,
summary_record
):
def
__init__
(
self
,
summary_record
):
self
.
_summary_record
=
summary_record
self
.
_summary_record
=
summary_record
...
@@ -164,9 +164,10 @@ class ImageSummaryCallback:
...
@@ -164,9 +164,10 @@ class ImageSummaryCallback:
return
self
return
self
def
__exit__
(
self
,
*
err
):
def
__exit__
(
self
,
*
err
):
pass
self
.
_summary_record
.
close
()
def
record
(
self
,
step
,
train_network
=
None
):
def
record
(
self
,
step
,
train_network
=
None
):
"""record data."""
self
.
_summary_record
.
record
(
step
,
train_network
)
self
.
_summary_record
.
record
(
step
,
train_network
)
self
.
_summary_record
.
flush
()
self
.
_summary_record
.
flush
()
...
@@ -183,9 +184,8 @@ def test_image_summary_train():
...
@@ -183,9 +184,8 @@ def test_image_summary_train():
# step 2: create the Event
# step 2: create the Event
model
=
get_model
()
model
=
get_model
()
fn
=
ImageSummaryCallback
(
test_writer
)
callback
=
ImageSummaryCallback
(
test_writer
)
summary_recode
=
SummaryStep
(
fn
,
1
)
model
.
train
(
2
,
dataset
,
callbacks
=
[
callback
])
model
.
train
(
2
,
dataset
,
callbacks
=
summary_recode
)
# step 3: send the event to mq
# step 3: send the event to mq
...
...
tests/ut/python/train/summary/test_summary.py
浏览文件 @
9514b52a
...
@@ -24,11 +24,9 @@ import random
...
@@ -24,11 +24,9 @@ import random
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.train.callback
import
SummaryStep
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
_cache_summary_tensor_data
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
_cache_summary_tensor_data
CUR_DIR
=
os
.
getcwd
()
CUR_DIR
=
os
.
getcwd
()
...
@@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2():
...
@@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2():
def
test_validate
():
def
test_validate
():
with
SummaryRecord
(
SUMMARY_DIR
)
as
sr
:
with
SummaryRecord
(
SUMMARY_DIR
)
as
sr
:
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
0
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
-
1
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
1.2
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
True
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
"str"
)
sr
.
record
(
1
)
sr
.
record
(
1
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
sr
.
record
(
False
)
sr
.
record
(
False
)
...
@@ -215,17 +203,3 @@ def test_validate():
...
@@ -215,17 +203,3 @@ def test_validate():
sr
.
record
(
"str"
)
sr
.
record
(
"str"
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
sr
.
record
(
sr
)
sr
.
record
(
sr
)
SummaryStep
(
sr
,
1
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
1.2
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
False
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
"str"
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
(
1
,
2
))
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
[
3
,
4
])
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
sr
)
tests/ut/python/train/summary/test_summary_abnormal_input.py
浏览文件 @
9514b52a
...
@@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string():
...
@@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string():
log
.
debug
(
"begin test_summaryrecord_input_null_string"
)
log
.
debug
(
"begin test_summaryrecord_input_null_string"
)
# step 0: create the thread
# step 0: create the thread
try
:
try
:
SummaryRecord
(
""
)
with
SummaryRecord
(
""
):
pass
except
:
except
:
assert
True
assert
True
else
:
else
:
...
@@ -71,7 +72,8 @@ def test_summaryrecord_input_None():
...
@@ -71,7 +72,8 @@ def test_summaryrecord_input_None():
log
.
debug
(
"begin test_summaryrecord_input_None"
)
log
.
debug
(
"begin test_summaryrecord_input_None"
)
# step 0: create the thread
# step 0: create the thread
try
:
try
:
SummaryRecord
(
None
)
with
SummaryRecord
(
None
):
pass
except
:
except
:
assert
True
assert
True
else
:
else
:
...
@@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1():
...
@@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1():
log
.
debug
(
"begin test_summaryrecord_input_relative_dir_1"
)
log
.
debug
(
"begin test_summaryrecord_input_relative_dir_1"
)
# step 0: create the thread
# step 0: create the thread
try
:
try
:
SummaryRecord
(
"./test_temp_summary_event_file/"
)
with
SummaryRecord
(
"./test_temp_summary_event_file/"
):
pass
except
:
except
:
assert
False
assert
False
else
:
else
:
...
@@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2():
...
@@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2():
log
.
debug
(
"begin test_summaryrecord_input_relative_dir_2"
)
log
.
debug
(
"begin test_summaryrecord_input_relative_dir_2"
)
# step 0: create the thread
# step 0: create the thread
try
:
try
:
SummaryRecord
(
"../summary/"
)
with
SummaryRecord
(
"../summary/"
):
pass
except
:
except
:
assert
False
assert
False
else
:
else
:
...
@@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir():
...
@@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir():
log
.
debug
(
"begin test_summaryrecord_input_invalid_type_dir"
)
log
.
debug
(
"begin test_summaryrecord_input_invalid_type_dir"
)
# step 0: create the thread
# step 0: create the thread
try
:
try
:
SummaryRecord
(
32
)
with
SummaryRecord
(
32
):
pass
except
:
except
:
assert
True
assert
True
else
:
else
:
...
@@ -119,7 +124,8 @@ def test_mulit_layer_directory():
...
@@ -119,7 +124,8 @@ def test_mulit_layer_directory():
log
.
debug
(
"begin test_mulit_layer_directory"
)
log
.
debug
(
"begin test_mulit_layer_directory"
)
# step 0: create the thread
# step 0: create the thread
try
:
try
:
SummaryRecord
(
"./test_temp_summary_event_file/test/t1/"
)
with
SummaryRecord
(
"./test_temp_summary_event_file/test/t1/"
):
pass
except
:
except
:
assert
False
assert
False
else
:
else
:
...
...
tests/ut/python/train/summary/test_summary_collector.py
0 → 100644
浏览文件 @
9514b52a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the exception parameter scenario for summary collector."""
import
os
import
tempfile
import
shutil
import
pytest
from
mindspore.train.callback
import
SummaryCollector
class
TestSummaryCollector
:
"""Test the exception parameter for summary collector."""
base_summary_dir
=
''
def
setup_class
(
self
):
"""Run before test this class."""
self
.
base_summary_dir
=
tempfile
.
mkdtemp
(
suffix
=
'summary'
)
def
teardown_class
(
self
):
"""Run after test this class."""
if
os
.
path
.
exists
(
self
.
base_summary_dir
):
shutil
.
rmtree
(
self
.
base_summary_dir
)
@
pytest
.
mark
.
parametrize
(
"summary_dir"
,
[
1234
,
None
,
True
,
''
])
def
test_params_with_summary_dir_value_error
(
self
,
summary_dir
):
"""Test the exception scenario for summary dir."""
if
isinstance
(
summary_dir
,
str
):
with
pytest
.
raises
(
ValueError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
)
assert
str
(
exc
.
value
)
==
'For `summary_dir` the value should be a valid string of path, '
\
'but got empty string.'
else
:
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
)
assert
'For `summary_dir` the type should be a valid type'
in
str
(
exc
.
value
)
def
test_params_with_summary_dir_not_dir
(
self
):
"""Test the given summary dir parameter is not a directory."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
summary_file
=
os
.
path
.
join
(
summary_dir
,
'temp_file.txt'
)
with
open
(
summary_file
,
'w'
)
as
file_handle
:
file_handle
.
write
(
'temp'
)
print
(
os
.
path
.
isfile
(
summary_file
))
with
pytest
.
raises
(
NotADirectoryError
):
SummaryCollector
(
summary_dir
=
summary_file
)
@
pytest
.
mark
.
parametrize
(
"collect_freq"
,
[
None
,
0
,
0.01
])
def
test_params_with_collect_freq_exception
(
self
,
collect_freq
):
"""Test the exception scenario for collect freq."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
if
isinstance
(
collect_freq
,
int
):
with
pytest
.
raises
(
ValueError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
,
collect_freq
=
collect_freq
)
expected_msg
=
f
'For `collect_freq` the value should be greater than 0, but got `
{
collect_freq
}
`.'
assert
expected_msg
==
str
(
exc
.
value
)
else
:
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
,
collect_freq
=
collect_freq
)
expected_msg
=
f
"For `collect_freq` the type should be a valid type of ['int'], "
\
f
'bug got
{
type
(
collect_freq
).
__name__
}
.'
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"action"
,
[
None
,
123
,
''
,
'123'
])
def
test_params_with_action_exception
(
self
,
action
):
"""Test the exception scenario for action."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
,
keep_default_action
=
action
)
expected_msg
=
f
"For `keep_default_action` the type should be a valid type of ['bool'], "
\
f
"bug got
{
type
(
action
).
__name__
}
."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"collect_specified_data"
,
[
123
])
def
test_params_with_collect_specified_data_type_error
(
self
,
collect_specified_data
):
"""Test type error scenario for collect specified data param."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
collect_specified_data
)
expected_msg
=
f
"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], "
\
f
"bug got
{
type
(
collect_specified_data
).
__name__
}
."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"collect_specified_data"
,
[
{
123
:
123
},
{
None
:
True
}
])
def
test_params_with_collect_specified_data_key_type_error
(
self
,
collect_specified_data
):
"""Test the key of collect specified data param."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
collect_specified_data
)
param_name
=
list
(
collect_specified_data
)[
0
]
expected_msg
=
f
"For `
{
param_name
}
` the type should be a valid type of ['str'], "
\
f
"bug got
{
type
(
param_name
).
__name__
}
."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"collect_specified_data"
,
[
{
'collect_metric'
:
None
},
{
'collect_graph'
:
123
},
{
'histogram_regular'
:
123
},
])
def
test_params_with_collect_specified_data_value_type_error
(
self
,
collect_specified_data
):
"""Test the value of collect specified data param."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
collect_specified_data
)
param_name
=
list
(
collect_specified_data
)[
0
]
param_value
=
collect_specified_data
[
param_name
]
expected_type
=
"['bool']"
if
param_name
!=
'histogram_regular'
else
"['str', 'NoneType']"
expected_msg
=
f
'For `
{
param_name
}
` the type should be a valid type of
{
expected_type
}
, '
\
f
'bug got
{
type
(
param_value
).
__name__
}
.'
assert
expected_msg
==
str
(
exc
.
value
)
def
test_params_with_collect_specified_data_unexpected_key
(
self
):
"""Test the collect_specified_data parameter with unexpected key."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
data
=
{
'unexpected_key'
:
True
}
with
pytest
.
raises
(
ValueError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
data
)
expected_msg
=
f
"For `collect_specified_data` the keys
{
set
(
data
)
}
are unsupported."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"custom_lineage_data"
,
[
123
,
{
'custom'
:
{}
},
{
'custom'
:
None
},
{
123
:
'custom'
}
])
def
test_params_with_custom_lineage_data_type_error
(
self
,
custom_lineage_data
):
"""Test the custom lineage data parameter type error."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
custom_lineage_data
=
custom_lineage_data
)
if
not
isinstance
(
custom_lineage_data
,
dict
):
expected_msg
=
f
"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], "
\
f
"bug got
{
type
(
custom_lineage_data
).
__name__
}
."
else
:
param_name
=
list
(
custom_lineage_data
)[
0
]
param_value
=
custom_lineage_data
[
param_name
]
if
not
isinstance
(
param_name
,
str
):
arg_name
=
f
'custom_lineage_data ->
{
param_name
}
'
expected_msg
=
f
"For `
{
arg_name
}
` the type should be a valid type of ['str'], "
\
f
'bug got
{
type
(
param_name
).
__name__
}
.'
else
:
arg_name
=
f
'the value of custom_lineage_data ->
{
param_name
}
'
expected_msg
=
f
"For `
{
arg_name
}
` the type should be a valid type of ['int', 'str', 'float'], "
\
f
'bug got
{
type
(
param_value
).
__name__
}
.'
assert
expected_msg
==
str
(
exc
.
value
)
tests/ut/python/train/test_training.py
浏览文件 @
9514b52a
...
@@ -20,8 +20,8 @@ import pytest
...
@@ -20,8 +20,8 @@ import pytest
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Model
,
context
from
mindspore
import
Model
,
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.train.callback
import
Callback
from
mindspore.nn.optim
import
Momentum
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback
import
SummaryStep
from
..ut_filter
import
non_graph_engine
from
..ut_filter
import
non_graph_engine
from
....dataset_mock
import
MindData
from
....dataset_mock
import
MindData
...
@@ -174,7 +174,7 @@ class TestGraphMode:
...
@@ -174,7 +174,7 @@ class TestGraphMode:
model
.
train
(
1
,
dataset
)
model
.
train
(
1
,
dataset
)
class
CallbackTest
:
class
CallbackTest
(
Callback
)
:
""" CallbackTest definition """
""" CallbackTest definition """
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -186,19 +186,19 @@ class CallbackTest:
...
@@ -186,19 +186,19 @@ class CallbackTest:
def
__exit__
(
self
,
*
err
):
def
__exit__
(
self
,
*
err
):
pass
pass
def
record
(
self
,
step
,
*
args
):
def
step_end
(
self
,
run_context
):
print
(
step
,
args
)
cb_params
=
run_context
.
original_args
()
print
(
cb_params
.
cur_epoch_num
,
cb_params
.
cur_step_num
)
def
test_train_callback
(
test_with_simu
):
def
test_train_callback
(
test_with_simu
):
""" test_train_callback """
""" test_train_callback """
dataset
=
get_dataset
()
dataset
=
get_dataset
()
model
=
get_model
()
model
=
get_model
()
fn
=
CallbackTest
()
callback
=
CallbackTest
()
summary_recode
=
SummaryStep
(
fn
,
2
)
if
test_with_simu
:
if
test_with_simu
:
return
return
model
.
train
(
2
,
dataset
,
callbacks
=
summary_recode
)
model
.
train
(
2
,
dataset
,
callbacks
=
callback
)
log
=
logging
.
getLogger
(
"test"
)
log
=
logging
.
getLogger
(
"test"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录