提交 3dc6f6f2 编写于 作者: O ougongchang

add more ut and st for SummaryCollector

Has fixed collecting optimizer error when mode is eval
上级 d6d93f16
......@@ -161,7 +161,7 @@ class SummaryCollector(Callback):
self._check_custom_lineage_data(custom_lineage_data)
self._custom_lineage_data = custom_lineage_data
self._optimizer = None
self._temp_optimizer = None
self._has_saved_train_network = False
self._has_saved_custom_data = False
self._is_parse_loss_success = True
......@@ -369,15 +369,15 @@ class SummaryCollector(Callback):
input_data = getattr(cb_params, 'train_dataset_element', None)
if input_data is None:
self._collect_specified_data['collect_input_data'] = False
logger.info("There is not a `train_dataset_element` in cb_params.")
logger.info("The 'train_dataset_element' in cb_params is None, maybe there is dataset sink mode.")
return
if isinstance(input_data, (list, tuple)):
input_data = input_data[0]
try:
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
except ValueError as ex:
logger.warning(str(ex))
except ValueError:
logger.warning('The input data of network are not image, so will not collect by SummaryCollector.')
self._collect_specified_data['collect_input_data'] = False
return
......@@ -418,8 +418,8 @@ class SummaryCollector(Callback):
try:
self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss)
except ValueError as exc:
logger.warning(str(exc))
except ValueError:
logger.warning("The output of network is not a scalar, so will not collect loss in SummaryCollector.")
self._collect_specified_data['collect_metric'] = False
def _get_loss(self, cb_params):
......@@ -438,7 +438,7 @@ class SummaryCollector(Callback):
output = cb_params.net_outputs
if output is None:
logger.warning("Can not find any output by this network.")
logger.warning("Can not find any output by this network, so will not collect loss in SummaryCollector.")
self._is_parse_loss_success = False
return None
......@@ -448,7 +448,7 @@ class SummaryCollector(Callback):
# If the output is a list, since the default network returns loss first,
# we assume that the first one is loss.
loss = output[0]
elif isinstance(output, Tensor) and (not output.shape or output.shape == [1]):
elif isinstance(output, Tensor) and (not output.shape or output.shape == (1,)):
loss_numpy = output.asnumpy()
loss = float(np.atleast_1d(loss_numpy)[0])
else:
......@@ -473,15 +473,15 @@ class SummaryCollector(Callback):
"""
# 'optimizer_failed' means find optimizer failed, so we will not collect data about optimizer.
optimizer_failed = 'Failed'
if self._optimizer == optimizer_failed:
if self._temp_optimizer == optimizer_failed:
return None
if self._optimizer is not None:
return self._optimizer
if self._temp_optimizer is not None:
return self._temp_optimizer
optimizer = cb_params.optimizer
if optimizer is None:
network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_work
network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_network
optimizer = self._parse_optimizer_by_network(network)
if optimizer is None or not isinstance(optimizer, Optimizer):
......@@ -489,7 +489,7 @@ class SummaryCollector(Callback):
"optimizer, so we will not collect data about optimizer in SummaryCollector.")
optimizer = None
self._optimizer = optimizer if optimizer is not None else optimizer_failed
self._temp_optimizer = optimizer if optimizer is not None else optimizer_failed
return optimizer
......@@ -765,7 +765,7 @@ class SummaryCollector(Callback):
cb_params (_InternalCallbackParam): Callback parameters.
Returns:
Union[Loss_fn, None], a Cell object, if parse failed, will return None.
Union[Cell, None], a Cell object, if parse failed, will return None.
"""
loss_fn = cb_params.loss_fn
if loss_fn is not None:
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test model train """
import os
import numpy as np
from apply_momentum import ApplyMomentum
import mindspore.context as context
import mindspore.nn as nn
from mindspore.nn import wrap
from mindspore import Tensor, Model
from mindspore.common.api import ms_function
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.ops import operations as P
from mindspore.train.summary.summary_record import SummaryRecord
CUR_DIR = os.getcwd()
SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
context.set_context(device_target="Ascend")
class MsWrapper(nn.Cell):
def __init__(self, network):
super(MsWrapper, self).__init__(auto_prefix=False)
self._network = network
@ms_function
def construct(self, *args):
return self._network(*args)
def me_train_tensor(net, input_np, label_np, epoch_size=2):
context.set_context(mode=context.GRAPH_MODE)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = ApplyMomentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])),
filter(lambda x: x.requires_grad, net.get_parameters()))
Model(net, loss, opt)
_network = wrap.WithLossCell(net, loss)
_train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt))
_train_net.set_train()
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) as summary_writer:
for epoch in range(0, epoch_size):
print(f"epoch %d" % (epoch))
output = _train_net(Tensor(input_np), Tensor(label_np))
summary_writer.record(i)
print("********output***********")
print(output.asnumpy())
def me_infer_tensor(net, input_np):
net.set_train()
net = MsWrapper(net)
output = net(Tensor(input_np))
return output
def test_net():
class Net(nn.Cell):
def __init__(self, cin, cout):
super(Net, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.conv = nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same")
self.bn = nn.BatchNorm2d(cin, momentum=0.1, eps=0.0001)
self.add = P.TensorAdd()
self.relu = P.ReLU()
self.mean = P.ReduceMean(keep_dims=True)
self.reshape = P.Reshape()
self.dense = nn.Dense(cin, cout)
def construct(self, input_x):
output = input_x
output = self.maxpool(output)
identity = output
output = self.conv(output)
output = self.bn(output)
output = self.add(output, identity)
output = self.relu(output)
output = self.mean(output, (-2, -1))
output = self.reshape(output, (32, -1))
output = self.dense(output)
return output
net = Net(2048, 1001)
input_np = np.ones([32, 2048, 14, 14]).astype(np.float32) * 0.01
label_np = np.ones([32]).astype(np.int32)
me_train_tensor(net, input_np, label_np)
# me_infer_tensor(net, input_np)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Summary gpu st."""
import os
import random
import tempfile
import shutil
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.train.summary.summary_record import SummaryRecord
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class SummaryNet(nn.Cell):
"""Summary net."""
def __init__(self, tag_tuple=None, scalar=1):
super(SummaryNet, self).__init__()
self.summary_s = P.ScalarSummary()
self.summary_i = P.ImageSummary()
self.summary_t = P.TensorSummary()
self.histogram_summary = P.HistogramSummary()
self.add = P.TensorAdd()
self.tag_tuple = tag_tuple
self.scalar = scalar
def construct(self, x, y, image):
"""Run summary net."""
self.summary_i("image", image)
self.summary_s("x1", x)
z = self.add(x, y)
self.summary_t("z1", z)
self.histogram_summary("histogram", z)
return z
def train_summary_record(test_writer, steps):
"""Train and record summary."""
net = SummaryNet()
out_me_dict = {}
for i in range(0, steps):
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
image = Tensor(np.array([[[[1.2]]]]).astype(np.float32))
out_put = net(x, y, image)
test_writer.record(i)
out_me_dict[i] = out_put.asnumpy()
return out_me_dict
class TestGpuSummary:
"""Test Gpu summary."""
summary_dir = tempfile.mkdtemp(suffix='_gpu_summary')
def setup_method(self):
"""Run before method."""
if not os.path.exists(self.summary_dir):
os.mkdir(self.summary_dir)
def teardown_method(self):
"""Run after method."""
if os.path.exists(self.summary_dir):
shutil.rmtree(self.summary_dir)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_summary_step10_summaryrecord1(self):
"""Test record 10 step summary."""
with SummaryRecord(self.summary_dir) as test_writer:
train_summary_record(test_writer, steps=10)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test model train """
import os
import re
import tempfile
import shutil
import pytest
from mindspore import dataset as ds
from mindspore import nn, Tensor, context
from mindspore.nn.metrics import Accuracy
from mindspore.nn.optim import Momentum
from mindspore.dataset.transforms import c_transforms as C
from mindspore.dataset.transforms.vision import c_transforms as CV
from mindspore.dataset.transforms.vision import Inter
from mindspore.common import dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.callback import SummaryCollector
from tests.summary_utils import SummaryReader
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""Define LeNet5 network."""
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.scalar_summary = P.ScalarSummary()
self.image_summary = P.ImageSummary()
self.histogram_summary = P.HistogramSummary()
self.tensor_summary = P.TensorSummary()
self.channel = Tensor(channel)
def construct(self, data):
"""define construct."""
self.image_summary('image', data)
output = self.conv1(data)
self.histogram_summary('histogram', output)
output = self.relu(output)
self.tensor_summary('tensor', output)
output = self.max_pool2d(output)
output = self.conv2(output)
output = self.relu(output)
output = self.max_pool2d(output)
output = self.flatten(output)
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
self.scalar_summary('scalar', self.channel)
return output
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1):
"""create dataset for train or test"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift=0.0)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.shuffle(buffer_size=10000) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
class TestSummary:
"""Test summary collector the basic function."""
base_summary_dir = ''
mnist_path = '/home/workspace/mindspore_dataset/mnist'
@classmethod
def setup_class(cls):
"""Run before test this class."""
cls.base_summary_dir = tempfile.mkdtemp(suffix='summary')
@classmethod
def teardown_class(cls):
"""Run after test this class."""
if os.path.exists(cls.base_summary_dir):
shutil.rmtree(cls.base_summary_dir)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_summary_ascend(self):
"""Test summary ascend."""
context.set_context(mode=context.GRAPH_MODE)
self._run_network()
def _run_network(self, dataset_sink_mode=True):
lenet = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Accuracy()})
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=1)
ds_train = create_dataset(os.path.join(self.mnist_path, "train"))
model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode)
ds_eval = create_dataset(os.path.join(self.mnist_path, "test"))
model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector])
self._check_summary_result(summary_dir)
@staticmethod
def _check_summary_result(summary_dir):
summary_file_path = ''
for file in os.listdir(summary_dir):
if re.search("_MS", file):
summary_file_path = os.path.join(summary_dir, file)
break
assert not summary_file_path
with SummaryReader(summary_file_path) as summary_reader:
tags = set()
# Read the event that record by SummaryCollector.begin
summary_reader.read_event()
summary_event = summary_reader.read_event()
for value in summary_event.summary.value:
tags.add(value.tag)
# There will not record input data when dataset sink mode is True
expected_tags = ['conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
'fc2.weight/auto', 'histogram', 'image', 'scalar', 'tensor']
assert set(expected_tags) == tags
......@@ -38,6 +38,7 @@ class SummaryReader:
def __init__(self, canonical_file_path, ignore_version_event=True):
self._file_path = canonical_file_path
self._ignore_version_event = ignore_version_event
self._file_handler = None
def __enter__(self):
self._file_handler = open(self._file_path, "rb")
......
......@@ -16,9 +16,50 @@
import os
import tempfile
import shutil
from importlib import import_module
from unittest import mock
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import Parameter
from mindspore.train.callback import SummaryCollector
from mindspore.train.callback import _InternalCallbackParam
from mindspore.train.summary.enum import ModeEnum, PluginEnum
from mindspore.train.summary import SummaryRecord
from mindspore.nn import Cell
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops.operations import TensorAdd
_VALUE_CACHE = list()
def add_value(plugin, name, value):
"""This function is mock the function in SummaryRecord."""
global _VALUE_CACHE
_VALUE_CACHE.append((plugin, name, value))
def get_value():
"""Get the value which is added by add_value function."""
global _VALUE_CACHE
value = _VALUE_CACHE
_VALUE_CACHE = list()
return value
class CustomNet(Cell):
"""Define custom netwrok."""
def __init__(self):
super(CustomNet, self).__init__()
self.add = TensorAdd
self.optimizer = Optimizer(learning_rate=1, parameters=[Parameter(Tensor(1), 'weight')])
def construct(self, data):
return data
class TestSummaryCollector:
......@@ -34,6 +75,10 @@ class TestSummaryCollector:
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
def teardown_method(self):
"""Run after each test function."""
get_value()
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
def test_params_with_summary_dir_value_error(self, summary_dir):
"""Test the exception scenario for summary dir."""
......@@ -182,3 +227,151 @@ class TestSummaryCollector:
f'bug got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
def test_check_callback_with_multi_instances(self):
"""Use multi SummaryCollector instances to test check_callback function."""
cb_params = _InternalCallbackParam()
cb_params.list_callback = [
SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)),
SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir))
]
with pytest.raises(ValueError) as exc:
SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
def test_collect_input_data_with_train_dataset_element_none(self):
"""Test the param 'train_dataset_element' in cb_params is none."""
cb_params = _InternalCallbackParam()
cb_params.train_dataset_element = None
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
summary_collector._collect_input_data(cb_params)
assert not summary_collector._collect_specified_data['collect_input_data']
@mock.patch.object(SummaryRecord, 'add_value')
def test_collect_input_data_success(self, mock_add_value):
"""Mock a image data, and collect image data success."""
mock_add_value.side_effect = add_value
cb_params = _InternalCallbackParam()
image_data = Tensor(np.random.randint(0, 255, size=(1, 1, 1, 1)).astype(np.uint8))
cb_params.train_dataset_element = image_data
with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
summary_collector._collect_input_data(cb_params)
# Note Here need to asssert the result and expected data
@mock.patch.object(SummaryRecord, 'add_value')
def test_collect_dataset_graph_success(self, mock_add_value):
"""Test collect dataset graph."""
dataset = import_module('mindspore.dataset')
mock_add_value.side_effect = add_value
cb_params = _InternalCallbackParam()
cb_params.train_dataset = dataset.MnistDataset(dataset_dir=tempfile.mkdtemp(dir=self.base_summary_dir))
cb_params.mode = ModeEnum.TRAIN.value
with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
summary_collector._collect_dataset_graph(cb_params)
plugin, name, _ = get_value()[0]
assert plugin == 'dataset_graph'
assert name == 'train_dataset'
@pytest.mark.parametrize("net_output, expected_loss", [
(1, Tensor(1)),
([1], Tensor(1)),
([Tensor(1)], Tensor(1)),
(Tensor([1]), Tensor(1)),
(tuple([1]), Tensor(1)),
(None, None)
])
def test_get_loss(self, net_output, expected_loss):
"""Test get loss success and failed."""
cb_params = _InternalCallbackParam()
cb_params.net_outputs = net_output
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
assert summary_collector._is_parse_loss_success
assert summary_collector._get_loss(cb_params) == expected_loss
if expected_loss is None:
assert not summary_collector._is_parse_loss_success
def test_get_optimizer_from_cb_params_success(self):
"""Test get optimizer success from cb params."""
cb_params = _InternalCallbackParam()
cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=[Parameter(Tensor(1), 'weight')])
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
optimizer = summary_collector._get_optimizer(cb_params)
assert optimizer == cb_params.optimizer
# Test get optimizer again
assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer
@pytest.mark.parametrize('mode', [ModeEnum.TRAIN.value, ModeEnum.EVAL.value])
def test_get_optimizer_from_network(self, mode):
"""Get optimizer from train network"""
cb_params = _InternalCallbackParam()
cb_params.optimizer = None
cb_params.mode = mode
if mode == ModeEnum.TRAIN.value:
cb_params.train_network = CustomNet()
else:
cb_params.eval_network = CustomNet()
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
optimizer = summary_collector._get_optimizer(cb_params)
assert isinstance(optimizer, Optimizer)
def test_get_optimizer_failed(self):
"""Test get optimizer failed."""
class Net(Cell):
"""Define net."""
def __init__(self):
super(Net, self).__init__()
self.add = TensorAdd()
def construct(self, data):
return data
cb_params = _InternalCallbackParam()
cb_params.optimizer = None
cb_params.train_network = Net()
cb_params.mode = ModeEnum.TRAIN.value
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
optimizer = summary_collector._get_optimizer(cb_params)
assert optimizer is None
assert summary_collector._temp_optimizer == 'Failed'
# Test get optimizer again
optimizer = summary_collector._get_optimizer(cb_params)
assert optimizer is None
assert summary_collector._temp_optimizer == 'Failed'
@pytest.mark.parametrize("histogram_regular, expected_names, expected_values", [
(
'conv1|conv2',
['conv1.weight1/auto', 'conv2.weight2/auto', 'conv1.bias1/auto'],
[1, 2, 3]
),
(
None,
['conv1.weight1/auto', 'conv2.weight2/auto', 'conv1.bias1/auto', 'conv3.bias/auto', 'conv5.bias/auto'],
[1, 2, 3, 4, 5]
)
])
@mock.patch.object(SummaryRecord, 'add_value')
def test_collect_histogram_from_regular(self, mock_add_value, histogram_regular, expected_names, expected_values):
"""Test collect histogram from regular success."""
mock_add_value.side_effect = add_value
cb_params = _InternalCallbackParam()
parameters = [
Parameter(Tensor(1), 'conv1.weight1'),
Parameter(Tensor(2), 'conv2.weight2'),
Parameter(Tensor(3), 'conv1.bias1'),
Parameter(Tensor(4), 'conv3.bias'),
Parameter(Tensor(5), 'conv5.bias'),
Parameter(Tensor(6), 'conv6.bias'),
]
cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=parameters)
with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
summary_collector._collect_specified_data['histogram_regular'] = histogram_regular
summary_collector._collect_histogram(cb_params)
result = get_value()
assert PluginEnum.HISTOGRAM.value == result[0][0]
assert expected_names == [data[1] for data in result]
assert expected_values == [data[2] for data in result]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册