提交 9514b52a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2147 Add a callback named SummaryCollector and delete SummaryStep callback

Merge pull request !2147 from ougongchang/master
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
# ============================================================================ # ============================================================================
"""Train utility.""" """Train utility."""
import os import os
from collections.abc import Iterable
import numpy as np import numpy as np
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
...@@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor): ...@@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor):
raise ValueError('The tensor should not be empty.') raise ValueError('The tensor should not be empty.')
return np_value return np_value
def _check_lineage_value(plugin, value): def _check_lineage_value(plugin, value):
"""Check the lineage value.""" """Check the lineage value."""
def raises(plugin, prototype): def raises(plugin, prototype):
...@@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value): ...@@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value):
if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
raises(plugin, UserDefinedInfo) raises(plugin, UserDefinedInfo)
def check_value_type(arg_name, arg_value, valid_types):
"""Checks whether a value is instance of some types."""
valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
is_valid = True
# bool is subclass of int, so for a bool value, we need to extra check
if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
is_valid = False
if not isinstance(arg_value, valid_types):
is_valid = False
if not is_valid:
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
f'bug got {type(arg_value).__name__}.')
...@@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig ...@@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig
from ._checkpoint import CheckpointManager as _CheckpointManager from ._checkpoint import CheckpointManager as _CheckpointManager
from ._checkpoint import ModelCheckpoint from ._checkpoint import ModelCheckpoint
from ._loss_monitor import LossMonitor from ._loss_monitor import LossMonitor
from ._summary_step import SummaryStep
from ._time_monitor import TimeMonitor from ._time_monitor import TimeMonitor
from ._summary_collector import SummaryCollector
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"] __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryCollector", "CheckpointConfig", "RunContext"]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define dataset graph related operations."""
import json
from importlib import import_module
from mindspore.train import lineage_pb2
class DatasetGraph:
"""Handle the data graph and packages it into binary data."""
def package_dataset_graph(self, dataset):
"""
packages dataset graph into binary data
Args:
dataset (MindData): refer to MindDataset
Returns:
DatasetGraph, a object of lineage_pb2.DatasetGraph.
"""
dataset_package = import_module('mindspore.dataset')
dataset_dict = dataset_package.serialize(dataset)
json_str = json.dumps(dataset_dict, indent=2)
dataset_dict = json.loads(json_str)
dataset_graph_proto = lineage_pb2.DatasetGraph()
if "children" in dataset_dict:
children = dataset_dict.pop("children")
if children:
self._package_children(children=children, message=dataset_graph_proto)
self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto)
return dataset_graph_proto
def _package_children(self, children, message):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for child in children:
if child:
child_graph_message = getattr(message, "children").add()
grandson = child.pop("children")
if grandson:
self._package_children(children=grandson, message=child_graph_message)
# package other parameters
self._package_current_dataset(operation=child, message=child_graph_message)
def _package_current_dataset(self, operation, message):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for key, value in operation.items():
if value and key == "operations":
for operator in value:
self._package_enhancement_operation(
operator,
message.operations.add()
)
elif value and key == "sampler":
self._package_enhancement_operation(
value,
message.sampler
)
else:
self._package_parameter(key, value, message.parameter)
def _package_enhancement_operation(self, operation, message):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for key, value in operation.items():
if isinstance(value, list):
if all(isinstance(ele, int) for ele in value):
message.size.extend(value)
else:
message.weights.extend(value)
else:
self._package_parameter(key, value, message.operationParam)
@staticmethod
def _package_parameter(key, value, message):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if isinstance(value, str):
message.mapStr[key] = value
elif isinstance(value, bool):
message.mapBool[key] = value
elif isinstance(value, int):
message.mapInt[key] = value
elif isinstance(value, float):
message.mapDouble[key] = value
elif isinstance(value, list) and key != "operations":
if value:
replace_value_list = list(map(lambda x: "" if x is None else x, value))
message.mapStrList[key].strValue.extend(replace_value_list)
elif value is None:
message.mapStr[key] = "None"
else:
raise ValueError(f"Parameter {key} is not supported in event package.")
此差异已折叠。
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Model.""" """Model."""
from collections.abc import Iterable
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
...@@ -345,7 +347,8 @@ class Model: ...@@ -345,7 +347,8 @@ class Model:
cb_params.parallel_mode = self._parallel_mode cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset cb_params.train_dataset = train_dataset
cb_params.list_callback = callbacks cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
# build callback list # build callback list
with _CallbackManager(callbacks) as list_callback: with _CallbackManager(callbacks) as list_callback:
...@@ -358,6 +361,17 @@ class Model: ...@@ -358,6 +361,17 @@ class Model:
else: else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
@staticmethod
def _transform_callbacks(callbacks):
"""Transform callback to a list."""
if callbacks is None:
return []
if isinstance(callbacks, Iterable):
return list(callbacks)
return [callbacks]
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
""" """
Training process. The data would be passed to network through dataset channel. Training process. The data would be passed to network through dataset channel.
...@@ -449,6 +463,7 @@ class Model: ...@@ -449,6 +463,7 @@ class Model:
scaling_sens = self._get_scaling_sens() scaling_sens = self._get_scaling_sens()
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element
outputs = self._train_network(*next_element) outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
...@@ -628,6 +643,7 @@ class Model: ...@@ -628,6 +643,7 @@ class Model:
cb_params.batch_num = valid_dataset.get_dataset_size() cb_params.batch_num = valid_dataset.get_dataset_size()
cb_params.mode = "eval" cb_params.mode = "eval"
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
cb_params.list_callback = self._transform_callbacks(callbacks)
self._eval_network.set_train(mode=False) self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval' self._eval_network.phase = 'eval'
......
...@@ -12,45 +12,32 @@ ...@@ -12,45 +12,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""SummaryStep Callback class.""" """Summary's enumeration file."""
from enum import Enum
from ._callback import Callback
class BaseEnum(Enum):
"""The base enum class."""
class SummaryStep(Callback): @classmethod
""" def to_list(cls):
The summary callback class. """Converts the enumeration into a list."""
return [member.value for member in cls.__members__.values()]
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def __init__(self, summary, flush_step=10): class PluginEnum(BaseEnum):
super(SummaryStep, self).__init__() """The list of plugins currently supported by the summary."""
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: GRAPH = 'graph'
raise ValueError("`flush_step` should be int and greater than 0") SCALAR = 'scalar'
self._summary = summary IMAGE = 'image'
self._flush_step = flush_step TENSOR = 'tensor'
HISTOGRAM = 'histogram'
TRAIN_LINEAGE = 'train_lineage'
EVAL_LINEAGE = 'eval_lineage'
DATASET_GRAPH = 'dataset_graph'
def __enter__(self):
self._summary.__enter__()
return self
def __exit__(self, *err): class ModeEnum(BaseEnum):
return self._summary.__exit__(*err) """The modes currently supported by the summary."""
TRAIN = 'train'
def step_end(self, run_context): EVAL = 'eval'
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self._flush_step == 0:
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
@property
def summary_file_name(self):
return self._summary.full_file_name
...@@ -75,7 +75,7 @@ class TestGpuSummary: ...@@ -75,7 +75,7 @@ class TestGpuSummary:
if not os.path.exists(self.summary_dir): if not os.path.exists(self.summary_dir):
os.mkdir(self.summary_dir) os.mkdir(self.summary_dir)
def teardown_emthod(self): def teardown_method(self):
"""Run after method.""" """Run after method."""
if os.path.exists(self.summary_dir): if os.path.exists(self.summary_dir):
shutil.rmtree(self.summary_dir) shutil.rmtree(self.summary_dir)
......
...@@ -20,8 +20,8 @@ import numpy as np ...@@ -20,8 +20,8 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Model, context from mindspore import Model, context
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.train.callback import SummaryStep from mindspore.train.summary import SummaryRecord
from mindspore.train.summary.summary_record import SummaryRecord from mindspore.train.callback import SummaryCollector
from .....dataset_mock import MindData from .....dataset_mock import MindData
CUR_DIR = os.getcwd() CUR_DIR = os.getcwd()
...@@ -107,16 +107,9 @@ def test_graph_summary_sample(): ...@@ -107,16 +107,9 @@ def test_graph_summary_sample():
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
model.train(2, dataset) model.train(2, dataset)
# step 2: create the Event
for i in range(1, 5): for i in range(1, 5):
test_writer.record(i) test_writer.record(i)
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_graph_summary_sample")
def test_graph_summary_callback(): def test_graph_summary_callback():
dataset = get_dataset() dataset = get_dataset()
...@@ -125,18 +118,8 @@ def test_graph_summary_callback(): ...@@ -125,18 +118,8 @@ def test_graph_summary_callback():
optim = Momentum(net.trainable_params(), 0.1, 0.9) optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: summary_collector = SummaryCollector(SUMMARY_DIR,
summary_cb = SummaryStep(test_writer, 1) collect_freq=1,
model.train(2, dataset, callbacks=summary_cb) keep_default_action=False,
collect_specified_data={'collect_graph': True})
model.train(1, dataset, callbacks=[summary_collector])
def test_graph_summary_callback2():
dataset = get_dataset()
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits()
optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer:
summary_cb = SummaryStep(test_writer, 1)
model.train(2, dataset, callbacks=summary_cb)
...@@ -26,9 +26,8 @@ import mindspore.nn as nn ...@@ -26,9 +26,8 @@ import mindspore.nn as nn
from mindspore import Model, context from mindspore import Model, context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.train.callback import SummaryStep from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from mindspore.train.summary.summary_record import SummaryRecord, \ from mindspore.train.callback import Callback
_cache_summary_tensor_data
from .....dataset_mock import MindData from .....dataset_mock import MindData
CUR_DIR = os.getcwd() CUR_DIR = os.getcwd()
...@@ -155,7 +154,8 @@ def get_dataset(): ...@@ -155,7 +154,8 @@ def get_dataset():
return dataset return dataset
class ImageSummaryCallback: class ImageSummaryCallback(Callback):
"""Image summary callback."""
def __init__(self, summary_record): def __init__(self, summary_record):
self._summary_record = summary_record self._summary_record = summary_record
...@@ -164,9 +164,10 @@ class ImageSummaryCallback: ...@@ -164,9 +164,10 @@ class ImageSummaryCallback:
return self return self
def __exit__(self, *err): def __exit__(self, *err):
pass self._summary_record.close()
def record(self, step, train_network=None): def record(self, step, train_network=None):
"""record data."""
self._summary_record.record(step, train_network) self._summary_record.record(step, train_network)
self._summary_record.flush() self._summary_record.flush()
...@@ -183,9 +184,8 @@ def test_image_summary_train(): ...@@ -183,9 +184,8 @@ def test_image_summary_train():
# step 2: create the Event # step 2: create the Event
model = get_model() model = get_model()
fn = ImageSummaryCallback(test_writer) callback = ImageSummaryCallback(test_writer)
summary_recode = SummaryStep(fn, 1) model.train(2, dataset, callbacks=[callback])
model.train(2, dataset, callbacks=summary_recode)
# step 3: send the event to mq # step 3: send the event to mq
......
...@@ -24,11 +24,9 @@ import random ...@@ -24,11 +24,9 @@ import random
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.callback import SummaryStep
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
CUR_DIR = os.getcwd() CUR_DIR = os.getcwd()
...@@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2(): ...@@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2():
def test_validate(): def test_validate():
with SummaryRecord(SUMMARY_DIR) as sr: with SummaryRecord(SUMMARY_DIR) as sr:
with pytest.raises(ValueError):
SummaryStep(sr, 0)
with pytest.raises(ValueError):
SummaryStep(sr, -1)
with pytest.raises(ValueError):
SummaryStep(sr, 1.2)
with pytest.raises(ValueError):
SummaryStep(sr, True)
with pytest.raises(ValueError):
SummaryStep(sr, "str")
sr.record(1) sr.record(1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
sr.record(False) sr.record(False)
...@@ -215,17 +203,3 @@ def test_validate(): ...@@ -215,17 +203,3 @@ def test_validate():
sr.record("str") sr.record("str")
with pytest.raises(ValueError): with pytest.raises(ValueError):
sr.record(sr) sr.record(sr)
SummaryStep(sr, 1)
with pytest.raises(ValueError):
SummaryStep(sr, 1.2)
with pytest.raises(ValueError):
SummaryStep(sr, False)
with pytest.raises(ValueError):
SummaryStep(sr, "str")
with pytest.raises(ValueError):
SummaryStep(sr, (1, 2))
with pytest.raises(ValueError):
SummaryStep(sr, [3, 4])
with pytest.raises(ValueError):
SummaryStep(sr, sr)
...@@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string(): ...@@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string():
log.debug("begin test_summaryrecord_input_null_string") log.debug("begin test_summaryrecord_input_null_string")
# step 0: create the thread # step 0: create the thread
try: try:
SummaryRecord("") with SummaryRecord(""):
pass
except: except:
assert True assert True
else: else:
...@@ -71,7 +72,8 @@ def test_summaryrecord_input_None(): ...@@ -71,7 +72,8 @@ def test_summaryrecord_input_None():
log.debug("begin test_summaryrecord_input_None") log.debug("begin test_summaryrecord_input_None")
# step 0: create the thread # step 0: create the thread
try: try:
SummaryRecord(None) with SummaryRecord(None):
pass
except: except:
assert True assert True
else: else:
...@@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1(): ...@@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1():
log.debug("begin test_summaryrecord_input_relative_dir_1") log.debug("begin test_summaryrecord_input_relative_dir_1")
# step 0: create the thread # step 0: create the thread
try: try:
SummaryRecord("./test_temp_summary_event_file/") with SummaryRecord("./test_temp_summary_event_file/"):
pass
except: except:
assert False assert False
else: else:
...@@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2(): ...@@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2():
log.debug("begin test_summaryrecord_input_relative_dir_2") log.debug("begin test_summaryrecord_input_relative_dir_2")
# step 0: create the thread # step 0: create the thread
try: try:
SummaryRecord("../summary/") with SummaryRecord("../summary/"):
pass
except: except:
assert False assert False
else: else:
...@@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir(): ...@@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir():
log.debug("begin test_summaryrecord_input_invalid_type_dir") log.debug("begin test_summaryrecord_input_invalid_type_dir")
# step 0: create the thread # step 0: create the thread
try: try:
SummaryRecord(32) with SummaryRecord(32):
pass
except: except:
assert True assert True
else: else:
...@@ -119,7 +124,8 @@ def test_mulit_layer_directory(): ...@@ -119,7 +124,8 @@ def test_mulit_layer_directory():
log.debug("begin test_mulit_layer_directory") log.debug("begin test_mulit_layer_directory")
# step 0: create the thread # step 0: create the thread
try: try:
SummaryRecord("./test_temp_summary_event_file/test/t1/") with SummaryRecord("./test_temp_summary_event_file/test/t1/"):
pass
except: except:
assert False assert False
else: else:
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the exception parameter scenario for summary collector."""
import os
import tempfile
import shutil
import pytest
from mindspore.train.callback import SummaryCollector
class TestSummaryCollector:
"""Test the exception parameter for summary collector."""
base_summary_dir = ''
def setup_class(self):
"""Run before test this class."""
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
def teardown_class(self):
"""Run after test this class."""
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
def test_params_with_summary_dir_value_error(self, summary_dir):
"""Test the exception scenario for summary dir."""
if isinstance(summary_dir, str):
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
'but got empty string.'
else:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
def test_params_with_summary_dir_not_dir(self):
"""Test the given summary dir parameter is not a directory."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_file = os.path.join(summary_dir, 'temp_file.txt')
with open(summary_file, 'w') as file_handle:
file_handle.write('temp')
print(os.path.isfile(summary_file))
with pytest.raises(NotADirectoryError):
SummaryCollector(summary_dir=summary_file)
@pytest.mark.parametrize("collect_freq", [None, 0, 0.01])
def test_params_with_collect_freq_exception(self, collect_freq):
"""Test the exception scenario for collect freq."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
if isinstance(collect_freq, int):
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.'
assert expected_msg == str(exc.value)
else:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
f'bug got {type(collect_freq).__name__}.'
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("action", [None, 123, '', '123'])
def test_params_with_action_exception(self, action):
"""Test the exception scenario for action."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
f"bug got {type(action).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [123])
def test_params_with_collect_specified_data_type_error(self, collect_specified_data):
"""Test type error scenario for collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(collect_specified_data).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [
{
123: 123
},
{
None: True
}
])
def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data):
"""Test the key of collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data)[0]
expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
f"bug got {type(param_name).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [
{
'collect_metric': None
},
{
'collect_graph': 123
},
{
'histogram_regular': 123
},
])
def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data):
"""Test the value of collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data)[0]
param_value = collect_specified_data[param_name]
expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'bug got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
def test_params_with_collect_specified_data_unexpected_key(self):
"""Test the collect_specified_data parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
data = {'unexpected_key': True}
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, collect_specified_data=data)
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("custom_lineage_data", [
123,
{
'custom': {}
},
{
'custom': None
},
{
123: 'custom'
}
])
def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data):
"""Test the custom lineage data parameter type error."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data)
if not isinstance(custom_lineage_data, dict):
expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(custom_lineage_data).__name__}."
else:
param_name = list(custom_lineage_data)[0]
param_value = custom_lineage_data[param_name]
if not isinstance(param_name, str):
arg_name = f'custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
f'bug got {type(param_name).__name__}.'
else:
arg_name = f'the value of custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
f'bug got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
...@@ -20,8 +20,8 @@ import pytest ...@@ -20,8 +20,8 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Model, context from mindspore import Model, context
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.callback import Callback
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.train.callback import SummaryStep
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
from ....dataset_mock import MindData from ....dataset_mock import MindData
...@@ -174,7 +174,7 @@ class TestGraphMode: ...@@ -174,7 +174,7 @@ class TestGraphMode:
model.train(1, dataset) model.train(1, dataset)
class CallbackTest: class CallbackTest(Callback):
""" CallbackTest definition """ """ CallbackTest definition """
def __init__(self): def __init__(self):
...@@ -186,19 +186,19 @@ class CallbackTest: ...@@ -186,19 +186,19 @@ class CallbackTest:
def __exit__(self, *err): def __exit__(self, *err):
pass pass
def record(self, step, *args): def step_end(self, run_context):
print(step, args) cb_params = run_context.original_args()
print(cb_params.cur_epoch_num, cb_params.cur_step_num)
def test_train_callback(test_with_simu): def test_train_callback(test_with_simu):
""" test_train_callback """ """ test_train_callback """
dataset = get_dataset() dataset = get_dataset()
model = get_model() model = get_model()
fn = CallbackTest() callback = CallbackTest()
summary_recode = SummaryStep(fn, 2)
if test_with_simu: if test_with_simu:
return return
model.train(2, dataset, callbacks=summary_recode) model.train(2, dataset, callbacks=callback)
log = logging.getLogger("test") log = logging.getLogger("test")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册