Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9514b52a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9514b52a
编写于
6月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2147 Add a callback named SummaryCollector and delete SummaryStep callback
Merge pull request !2147 from ougongchang/master
上级
5dac9c4c
939cd29d
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
1217 addition
and
75 deletion
+1217
-75
mindspore/train/_utils.py
mindspore/train/_utils.py
+21
-0
mindspore/train/callback/__init__.py
mindspore/train/callback/__init__.py
+3
-2
mindspore/train/callback/_dataset_graph.py
mindspore/train/callback/_dataset_graph.py
+128
-0
mindspore/train/callback/_summary_collector.py
mindspore/train/callback/_summary_collector.py
+786
-0
mindspore/train/model.py
mindspore/train/model.py
+17
-1
mindspore/train/summary/enum.py
mindspore/train/summary/enum.py
+43
-0
tests/st/summary/test_gpu_summary.py
tests/st/summary/test_gpu_summary.py
+1
-1
tests/ut/python/train/summary/test_graph_summary.py
tests/ut/python/train/summary/test_graph_summary.py
+7
-24
tests/ut/python/train/summary/test_image_summary.py
tests/ut/python/train/summary/test_image_summary.py
+8
-8
tests/ut/python/train/summary/test_summary.py
tests/ut/python/train/summary/test_summary.py
+0
-26
tests/ut/python/train/summary/test_summary_abnormal_input.py
tests/ut/python/train/summary/test_summary_abnormal_input.py
+12
-6
tests/ut/python/train/summary/test_summary_collector.py
tests/ut/python/train/summary/test_summary_collector.py
+184
-0
tests/ut/python/train/test_training.py
tests/ut/python/train/test_training.py
+7
-7
未找到文件。
mindspore/train/_utils.py
浏览文件 @
9514b52a
...
...
@@ -14,7 +14,10 @@
# ============================================================================
"""Train utility."""
import
os
from
collections.abc
import
Iterable
import
numpy
as
np
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.dtype
import
dtype_to_nptype
,
pytype_to_dtype
from
mindspore.common
import
dtype
as
mstype
...
...
@@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor):
raise
ValueError
(
'The tensor should not be empty.'
)
return
np_value
def
_check_lineage_value
(
plugin
,
value
):
"""Check the lineage value."""
def
raises
(
plugin
,
prototype
):
...
...
@@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value):
if
plugin
==
'custom_lineage_data'
and
not
isinstance
(
value
,
UserDefinedInfo
):
raises
(
plugin
,
UserDefinedInfo
)
def
check_value_type
(
arg_name
,
arg_value
,
valid_types
):
"""Checks whether a value is instance of some types."""
valid_types
=
tuple
(
valid_types
)
if
isinstance
(
valid_types
,
Iterable
)
else
(
valid_types
,)
is_valid
=
True
# bool is subclass of int, so for a bool value, we need to extra check
if
isinstance
(
arg_value
,
int
)
and
isinstance
(
arg_value
,
bool
)
and
bool
not
in
valid_types
:
is_valid
=
False
if
not
isinstance
(
arg_value
,
valid_types
):
is_valid
=
False
if
not
is_valid
:
raise
TypeError
(
f
'For `
{
arg_name
}
` the type should be a valid type of
{
[
t
.
__name__
for
t
in
valid_types
]
}
, '
f
'bug got
{
type
(
arg_value
).
__name__
}
.'
)
mindspore/train/callback/__init__.py
浏览文件 @
9514b52a
...
...
@@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig
from
._checkpoint
import
CheckpointManager
as
_CheckpointManager
from
._checkpoint
import
ModelCheckpoint
from
._loss_monitor
import
LossMonitor
from
._summary_step
import
SummaryStep
from
._time_monitor
import
TimeMonitor
from
._summary_collector
import
SummaryCollector
__all__
=
[
"Callback"
,
"LossMonitor"
,
"TimeMonitor"
,
"ModelCheckpoint"
,
"SummaryStep"
,
"CheckpointConfig"
,
"RunContext"
]
__all__
=
[
"Callback"
,
"LossMonitor"
,
"TimeMonitor"
,
"ModelCheckpoint"
,
"SummaryCollector"
,
"CheckpointConfig"
,
"RunContext"
]
mindspore/train/callback/_dataset_graph.py
0 → 100644
浏览文件 @
9514b52a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define dataset graph related operations."""
import
json
from
importlib
import
import_module
from
mindspore.train
import
lineage_pb2
class
DatasetGraph
:
"""Handle the data graph and packages it into binary data."""
def
package_dataset_graph
(
self
,
dataset
):
"""
packages dataset graph into binary data
Args:
dataset (MindData): refer to MindDataset
Returns:
DatasetGraph, a object of lineage_pb2.DatasetGraph.
"""
dataset_package
=
import_module
(
'mindspore.dataset'
)
dataset_dict
=
dataset_package
.
serialize
(
dataset
)
json_str
=
json
.
dumps
(
dataset_dict
,
indent
=
2
)
dataset_dict
=
json
.
loads
(
json_str
)
dataset_graph_proto
=
lineage_pb2
.
DatasetGraph
()
if
"children"
in
dataset_dict
:
children
=
dataset_dict
.
pop
(
"children"
)
if
children
:
self
.
_package_children
(
children
=
children
,
message
=
dataset_graph_proto
)
self
.
_package_current_dataset
(
operation
=
dataset_dict
,
message
=
dataset_graph_proto
)
return
dataset_graph_proto
def
_package_children
(
self
,
children
,
message
):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for
child
in
children
:
if
child
:
child_graph_message
=
getattr
(
message
,
"children"
).
add
()
grandson
=
child
.
pop
(
"children"
)
if
grandson
:
self
.
_package_children
(
children
=
grandson
,
message
=
child_graph_message
)
# package other parameters
self
.
_package_current_dataset
(
operation
=
child
,
message
=
child_graph_message
)
def
_package_current_dataset
(
self
,
operation
,
message
):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
value
and
key
==
"operations"
:
for
operator
in
value
:
self
.
_package_enhancement_operation
(
operator
,
message
.
operations
.
add
()
)
elif
value
and
key
==
"sampler"
:
self
.
_package_enhancement_operation
(
value
,
message
.
sampler
)
else
:
self
.
_package_parameter
(
key
,
value
,
message
.
parameter
)
def
_package_enhancement_operation
(
self
,
operation
,
message
):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
isinstance
(
value
,
list
):
if
all
(
isinstance
(
ele
,
int
)
for
ele
in
value
):
message
.
size
.
extend
(
value
)
else
:
message
.
weights
.
extend
(
value
)
else
:
self
.
_package_parameter
(
key
,
value
,
message
.
operationParam
)
@
staticmethod
def
_package_parameter
(
key
,
value
,
message
):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if
isinstance
(
value
,
str
):
message
.
mapStr
[
key
]
=
value
elif
isinstance
(
value
,
bool
):
message
.
mapBool
[
key
]
=
value
elif
isinstance
(
value
,
int
):
message
.
mapInt
[
key
]
=
value
elif
isinstance
(
value
,
float
):
message
.
mapDouble
[
key
]
=
value
elif
isinstance
(
value
,
list
)
and
key
!=
"operations"
:
if
value
:
replace_value_list
=
list
(
map
(
lambda
x
:
""
if
x
is
None
else
x
,
value
))
message
.
mapStrList
[
key
].
strValue
.
extend
(
replace_value_list
)
elif
value
is
None
:
message
.
mapStr
[
key
]
=
"None"
else
:
raise
ValueError
(
f
"Parameter
{
key
}
is not supported in event package."
)
mindspore/train/callback/_summary_collector.py
0 → 100644
浏览文件 @
9514b52a
此差异已折叠。
点击以展开。
mindspore/train/model.py
浏览文件 @
9514b52a
...
...
@@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Model."""
from
collections.abc
import
Iterable
import
numpy
as
np
from
mindspore
import
log
as
logger
...
...
@@ -345,7 +347,8 @@ class Model:
cb_params
.
parallel_mode
=
self
.
_parallel_mode
cb_params
.
device_number
=
self
.
_device_number
cb_params
.
train_dataset
=
train_dataset
cb_params
.
list_callback
=
callbacks
cb_params
.
list_callback
=
self
.
_transform_callbacks
(
callbacks
)
cb_params
.
train_dataset_element
=
None
# build callback list
with
_CallbackManager
(
callbacks
)
as
list_callback
:
...
...
@@ -358,6 +361,17 @@ class Model:
else
:
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
@
staticmethod
def
_transform_callbacks
(
callbacks
):
"""Transform callback to a list."""
if
callbacks
is
None
:
return
[]
if
isinstance
(
callbacks
,
Iterable
):
return
list
(
callbacks
)
return
[
callbacks
]
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
"""
Training process. The data would be passed to network through dataset channel.
...
...
@@ -449,6 +463,7 @@ class Model:
scaling_sens
=
self
.
_get_scaling_sens
()
next_element
=
tuple
(
next_element
)
+
(
Tensor
(
scaling_sens
,
mstype
.
float32
),)
cb_params
.
train_dataset_element
=
next_element
outputs
=
self
.
_train_network
(
*
next_element
)
cb_params
.
net_outputs
=
outputs
if
self
.
_loss_scale_manager
and
self
.
_loss_scale_manager
.
get_drop_overflow_update
():
...
...
@@ -628,6 +643,7 @@ class Model:
cb_params
.
batch_num
=
valid_dataset
.
get_dataset_size
()
cb_params
.
mode
=
"eval"
cb_params
.
cur_step_num
=
0
cb_params
.
list_callback
=
self
.
_transform_callbacks
(
callbacks
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
phase
=
'eval'
...
...
mindspore/train/
callback/_summary_step
.py
→
mindspore/train/
summary/enum
.py
浏览文件 @
9514b52a
...
...
@@ -12,45 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SummaryStep Callback class."""
"""Summary's enumeration file."""
from
enum
import
Enum
from
._callback
import
Callback
class
BaseEnum
(
Enum
):
"""The base enum class."""
class
SummaryStep
(
Callback
):
"""
The summary callback class.
@
classmethod
def
to_list
(
cls
):
"""Converts the enumeration into a list."""
return
[
member
.
value
for
member
in
cls
.
__members__
.
values
()]
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def
__init__
(
self
,
summary
,
flush_step
=
10
):
super
(
SummaryStep
,
self
).
__init__
()
if
not
isinstance
(
flush_step
,
int
)
or
isinstance
(
flush_step
,
bool
)
or
flush_step
<=
0
:
raise
ValueError
(
"`flush_step` should be int and greater than 0"
)
self
.
_summary
=
summary
self
.
_flush_step
=
flush_step
class
PluginEnum
(
BaseEnum
):
"""The list of plugins currently supported by the summary."""
GRAPH
=
'graph'
SCALAR
=
'scalar'
IMAGE
=
'image'
TENSOR
=
'tensor'
HISTOGRAM
=
'histogram'
TRAIN_LINEAGE
=
'train_lineage'
EVAL_LINEAGE
=
'eval_lineage'
DATASET_GRAPH
=
'dataset_graph'
def
__enter__
(
self
):
self
.
_summary
.
__enter__
()
return
self
def
__exit__
(
self
,
*
err
):
return
self
.
_summary
.
__exit__
(
*
err
)
def
step_end
(
self
,
run_context
):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params
=
run_context
.
original_args
()
if
cb_params
.
cur_step_num
%
self
.
_flush_step
==
0
:
self
.
_summary
.
record
(
cb_params
.
cur_step_num
,
cb_params
.
train_network
)
@
property
def
summary_file_name
(
self
):
return
self
.
_summary
.
full_file_name
class
ModeEnum
(
BaseEnum
):
"""The modes currently supported by the summary."""
TRAIN
=
'train'
EVAL
=
'eval'
tests/st/summary/test_gpu_summary.py
浏览文件 @
9514b52a
...
...
@@ -75,7 +75,7 @@ class TestGpuSummary:
if
not
os
.
path
.
exists
(
self
.
summary_dir
):
os
.
mkdir
(
self
.
summary_dir
)
def
teardown_
em
thod
(
self
):
def
teardown_
me
thod
(
self
):
"""Run after method."""
if
os
.
path
.
exists
(
self
.
summary_dir
):
shutil
.
rmtree
(
self
.
summary_dir
)
...
...
tests/ut/python/train/summary/test_graph_summary.py
浏览文件 @
9514b52a
...
...
@@ -20,8 +20,8 @@ import numpy as np
import
mindspore.nn
as
nn
from
mindspore
import
Model
,
context
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.
callback
import
SummaryStep
from
mindspore.train.
summary.summary_record
import
SummaryRecord
from
mindspore.train.
summary
import
SummaryRecord
from
mindspore.train.
callback
import
SummaryCollector
from
.....dataset_mock
import
MindData
CUR_DIR
=
os
.
getcwd
()
...
...
@@ -107,16 +107,9 @@ def test_graph_summary_sample():
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
model
.
_train_network
)
as
test_writer
:
model
.
train
(
2
,
dataset
)
# step 2: create the Event
for
i
in
range
(
1
,
5
):
test_writer
.
record
(
i
)
# step 3: send the event to mq
# step 4: accept the event and write the file
log
.
debug
(
"finished test_graph_summary_sample"
)
def
test_graph_summary_callback
():
dataset
=
get_dataset
()
...
...
@@ -125,18 +118,8 @@ def test_graph_summary_callback():
optim
=
Momentum
(
net
.
trainable_params
(),
0.1
,
0.9
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
model
.
_train_network
)
as
test_writer
:
summary_cb
=
SummaryStep
(
test_writer
,
1
)
model
.
train
(
2
,
dataset
,
callbacks
=
summary_cb
)
def
test_graph_summary_callback2
():
dataset
=
get_dataset
()
net
=
Net
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optim
=
Momentum
(
net
.
trainable_params
(),
0.1
,
0.9
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
with
SummaryRecord
(
SUMMARY_DIR
,
file_suffix
=
"_MS_GRAPH"
,
network
=
net
)
as
test_writer
:
summary_cb
=
SummaryStep
(
test_writer
,
1
)
model
.
train
(
2
,
dataset
,
callbacks
=
summary_cb
)
summary_collector
=
SummaryCollector
(
SUMMARY_DIR
,
collect_freq
=
1
,
keep_default_action
=
False
,
collect_specified_data
=
{
'collect_graph'
:
True
})
model
.
train
(
1
,
dataset
,
callbacks
=
[
summary_collector
])
tests/ut/python/train/summary/test_image_summary.py
浏览文件 @
9514b52a
...
...
@@ -26,9 +26,8 @@ import mindspore.nn as nn
from
mindspore
import
Model
,
context
from
mindspore
import
Tensor
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback
import
SummaryStep
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
\
_cache_summary_tensor_data
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
_cache_summary_tensor_data
from
mindspore.train.callback
import
Callback
from
.....dataset_mock
import
MindData
CUR_DIR
=
os
.
getcwd
()
...
...
@@ -155,7 +154,8 @@ def get_dataset():
return
dataset
class
ImageSummaryCallback
:
class
ImageSummaryCallback
(
Callback
):
"""Image summary callback."""
def
__init__
(
self
,
summary_record
):
self
.
_summary_record
=
summary_record
...
...
@@ -164,9 +164,10 @@ class ImageSummaryCallback:
return
self
def
__exit__
(
self
,
*
err
):
pass
self
.
_summary_record
.
close
()
def
record
(
self
,
step
,
train_network
=
None
):
"""record data."""
self
.
_summary_record
.
record
(
step
,
train_network
)
self
.
_summary_record
.
flush
()
...
...
@@ -183,9 +184,8 @@ def test_image_summary_train():
# step 2: create the Event
model
=
get_model
()
fn
=
ImageSummaryCallback
(
test_writer
)
summary_recode
=
SummaryStep
(
fn
,
1
)
model
.
train
(
2
,
dataset
,
callbacks
=
summary_recode
)
callback
=
ImageSummaryCallback
(
test_writer
)
model
.
train
(
2
,
dataset
,
callbacks
=
[
callback
])
# step 3: send the event to mq
...
...
tests/ut/python/train/summary/test_summary.py
浏览文件 @
9514b52a
...
...
@@ -24,11 +24,9 @@ import random
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.train.callback
import
SummaryStep
from
mindspore.train.summary.summary_record
import
SummaryRecord
,
_cache_summary_tensor_data
CUR_DIR
=
os
.
getcwd
()
...
...
@@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2():
def
test_validate
():
with
SummaryRecord
(
SUMMARY_DIR
)
as
sr
:
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
0
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
-
1
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
1.2
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
True
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
"str"
)
sr
.
record
(
1
)
with
pytest
.
raises
(
ValueError
):
sr
.
record
(
False
)
...
...
@@ -215,17 +203,3 @@ def test_validate():
sr
.
record
(
"str"
)
with
pytest
.
raises
(
ValueError
):
sr
.
record
(
sr
)
SummaryStep
(
sr
,
1
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
1.2
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
False
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
"str"
)
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
(
1
,
2
))
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
[
3
,
4
])
with
pytest
.
raises
(
ValueError
):
SummaryStep
(
sr
,
sr
)
tests/ut/python/train/summary/test_summary_abnormal_input.py
浏览文件 @
9514b52a
...
...
@@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string():
log
.
debug
(
"begin test_summaryrecord_input_null_string"
)
# step 0: create the thread
try
:
SummaryRecord
(
""
)
with
SummaryRecord
(
""
):
pass
except
:
assert
True
else
:
...
...
@@ -71,7 +72,8 @@ def test_summaryrecord_input_None():
log
.
debug
(
"begin test_summaryrecord_input_None"
)
# step 0: create the thread
try
:
SummaryRecord
(
None
)
with
SummaryRecord
(
None
):
pass
except
:
assert
True
else
:
...
...
@@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1():
log
.
debug
(
"begin test_summaryrecord_input_relative_dir_1"
)
# step 0: create the thread
try
:
SummaryRecord
(
"./test_temp_summary_event_file/"
)
with
SummaryRecord
(
"./test_temp_summary_event_file/"
):
pass
except
:
assert
False
else
:
...
...
@@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2():
log
.
debug
(
"begin test_summaryrecord_input_relative_dir_2"
)
# step 0: create the thread
try
:
SummaryRecord
(
"../summary/"
)
with
SummaryRecord
(
"../summary/"
):
pass
except
:
assert
False
else
:
...
...
@@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir():
log
.
debug
(
"begin test_summaryrecord_input_invalid_type_dir"
)
# step 0: create the thread
try
:
SummaryRecord
(
32
)
with
SummaryRecord
(
32
):
pass
except
:
assert
True
else
:
...
...
@@ -119,7 +124,8 @@ def test_mulit_layer_directory():
log
.
debug
(
"begin test_mulit_layer_directory"
)
# step 0: create the thread
try
:
SummaryRecord
(
"./test_temp_summary_event_file/test/t1/"
)
with
SummaryRecord
(
"./test_temp_summary_event_file/test/t1/"
):
pass
except
:
assert
False
else
:
...
...
tests/ut/python/train/summary/test_summary_collector.py
0 → 100644
浏览文件 @
9514b52a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the exception parameter scenario for summary collector."""
import
os
import
tempfile
import
shutil
import
pytest
from
mindspore.train.callback
import
SummaryCollector
class
TestSummaryCollector
:
"""Test the exception parameter for summary collector."""
base_summary_dir
=
''
def
setup_class
(
self
):
"""Run before test this class."""
self
.
base_summary_dir
=
tempfile
.
mkdtemp
(
suffix
=
'summary'
)
def
teardown_class
(
self
):
"""Run after test this class."""
if
os
.
path
.
exists
(
self
.
base_summary_dir
):
shutil
.
rmtree
(
self
.
base_summary_dir
)
@
pytest
.
mark
.
parametrize
(
"summary_dir"
,
[
1234
,
None
,
True
,
''
])
def
test_params_with_summary_dir_value_error
(
self
,
summary_dir
):
"""Test the exception scenario for summary dir."""
if
isinstance
(
summary_dir
,
str
):
with
pytest
.
raises
(
ValueError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
)
assert
str
(
exc
.
value
)
==
'For `summary_dir` the value should be a valid string of path, '
\
'but got empty string.'
else
:
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
)
assert
'For `summary_dir` the type should be a valid type'
in
str
(
exc
.
value
)
def
test_params_with_summary_dir_not_dir
(
self
):
"""Test the given summary dir parameter is not a directory."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
summary_file
=
os
.
path
.
join
(
summary_dir
,
'temp_file.txt'
)
with
open
(
summary_file
,
'w'
)
as
file_handle
:
file_handle
.
write
(
'temp'
)
print
(
os
.
path
.
isfile
(
summary_file
))
with
pytest
.
raises
(
NotADirectoryError
):
SummaryCollector
(
summary_dir
=
summary_file
)
@
pytest
.
mark
.
parametrize
(
"collect_freq"
,
[
None
,
0
,
0.01
])
def
test_params_with_collect_freq_exception
(
self
,
collect_freq
):
"""Test the exception scenario for collect freq."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
if
isinstance
(
collect_freq
,
int
):
with
pytest
.
raises
(
ValueError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
,
collect_freq
=
collect_freq
)
expected_msg
=
f
'For `collect_freq` the value should be greater than 0, but got `
{
collect_freq
}
`.'
assert
expected_msg
==
str
(
exc
.
value
)
else
:
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
,
collect_freq
=
collect_freq
)
expected_msg
=
f
"For `collect_freq` the type should be a valid type of ['int'], "
\
f
'bug got
{
type
(
collect_freq
).
__name__
}
.'
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"action"
,
[
None
,
123
,
''
,
'123'
])
def
test_params_with_action_exception
(
self
,
action
):
"""Test the exception scenario for action."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
=
summary_dir
,
keep_default_action
=
action
)
expected_msg
=
f
"For `keep_default_action` the type should be a valid type of ['bool'], "
\
f
"bug got
{
type
(
action
).
__name__
}
."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"collect_specified_data"
,
[
123
])
def
test_params_with_collect_specified_data_type_error
(
self
,
collect_specified_data
):
"""Test type error scenario for collect specified data param."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
collect_specified_data
)
expected_msg
=
f
"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], "
\
f
"bug got
{
type
(
collect_specified_data
).
__name__
}
."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"collect_specified_data"
,
[
{
123
:
123
},
{
None
:
True
}
])
def
test_params_with_collect_specified_data_key_type_error
(
self
,
collect_specified_data
):
"""Test the key of collect specified data param."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
collect_specified_data
)
param_name
=
list
(
collect_specified_data
)[
0
]
expected_msg
=
f
"For `
{
param_name
}
` the type should be a valid type of ['str'], "
\
f
"bug got
{
type
(
param_name
).
__name__
}
."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"collect_specified_data"
,
[
{
'collect_metric'
:
None
},
{
'collect_graph'
:
123
},
{
'histogram_regular'
:
123
},
])
def
test_params_with_collect_specified_data_value_type_error
(
self
,
collect_specified_data
):
"""Test the value of collect specified data param."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
collect_specified_data
)
param_name
=
list
(
collect_specified_data
)[
0
]
param_value
=
collect_specified_data
[
param_name
]
expected_type
=
"['bool']"
if
param_name
!=
'histogram_regular'
else
"['str', 'NoneType']"
expected_msg
=
f
'For `
{
param_name
}
` the type should be a valid type of
{
expected_type
}
, '
\
f
'bug got
{
type
(
param_value
).
__name__
}
.'
assert
expected_msg
==
str
(
exc
.
value
)
def
test_params_with_collect_specified_data_unexpected_key
(
self
):
"""Test the collect_specified_data parameter with unexpected key."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
data
=
{
'unexpected_key'
:
True
}
with
pytest
.
raises
(
ValueError
)
as
exc
:
SummaryCollector
(
summary_dir
,
collect_specified_data
=
data
)
expected_msg
=
f
"For `collect_specified_data` the keys
{
set
(
data
)
}
are unsupported."
assert
expected_msg
==
str
(
exc
.
value
)
@
pytest
.
mark
.
parametrize
(
"custom_lineage_data"
,
[
123
,
{
'custom'
:
{}
},
{
'custom'
:
None
},
{
123
:
'custom'
}
])
def
test_params_with_custom_lineage_data_type_error
(
self
,
custom_lineage_data
):
"""Test the custom lineage data parameter type error."""
summary_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
base_summary_dir
)
with
pytest
.
raises
(
TypeError
)
as
exc
:
SummaryCollector
(
summary_dir
,
custom_lineage_data
=
custom_lineage_data
)
if
not
isinstance
(
custom_lineage_data
,
dict
):
expected_msg
=
f
"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], "
\
f
"bug got
{
type
(
custom_lineage_data
).
__name__
}
."
else
:
param_name
=
list
(
custom_lineage_data
)[
0
]
param_value
=
custom_lineage_data
[
param_name
]
if
not
isinstance
(
param_name
,
str
):
arg_name
=
f
'custom_lineage_data ->
{
param_name
}
'
expected_msg
=
f
"For `
{
arg_name
}
` the type should be a valid type of ['str'], "
\
f
'bug got
{
type
(
param_name
).
__name__
}
.'
else
:
arg_name
=
f
'the value of custom_lineage_data ->
{
param_name
}
'
expected_msg
=
f
"For `
{
arg_name
}
` the type should be a valid type of ['int', 'str', 'float'], "
\
f
'bug got
{
type
(
param_value
).
__name__
}
.'
assert
expected_msg
==
str
(
exc
.
value
)
tests/ut/python/train/test_training.py
浏览文件 @
9514b52a
...
...
@@ -20,8 +20,8 @@ import pytest
import
mindspore.nn
as
nn
from
mindspore
import
Model
,
context
from
mindspore
import
Tensor
from
mindspore.train.callback
import
Callback
from
mindspore.nn.optim
import
Momentum
from
mindspore.train.callback
import
SummaryStep
from
..ut_filter
import
non_graph_engine
from
....dataset_mock
import
MindData
...
...
@@ -174,7 +174,7 @@ class TestGraphMode:
model
.
train
(
1
,
dataset
)
class
CallbackTest
:
class
CallbackTest
(
Callback
)
:
""" CallbackTest definition """
def
__init__
(
self
):
...
...
@@ -186,19 +186,19 @@ class CallbackTest:
def
__exit__
(
self
,
*
err
):
pass
def
record
(
self
,
step
,
*
args
):
print
(
step
,
args
)
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
print
(
cb_params
.
cur_epoch_num
,
cb_params
.
cur_step_num
)
def
test_train_callback
(
test_with_simu
):
""" test_train_callback """
dataset
=
get_dataset
()
model
=
get_model
()
fn
=
CallbackTest
()
summary_recode
=
SummaryStep
(
fn
,
2
)
callback
=
CallbackTest
()
if
test_with_simu
:
return
model
.
train
(
2
,
dataset
,
callbacks
=
summary_recode
)
model
.
train
(
2
,
dataset
,
callbacks
=
callback
)
log
=
logging
.
getLogger
(
"test"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录