提交 a432241f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!151 fix calculation of lineage dataset_num, get model name

Merge pull request !151 from luopengting/lineage_fix
......@@ -37,7 +37,7 @@ from mindinsight.lineagemgr.collection.model.base import Metadata
try:
from mindspore.common.tensor import Tensor
from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
from mindspore.nn import Cell, Optimizer, WithLossCell, TrainOneStepWithLossScaleCell
from mindspore.nn import Cell, Optimizer
from mindspore.nn.loss.loss import _Loss
from mindspore.dataset.engine import Dataset, MindDataset
import mindspore.dataset as ds
......@@ -412,27 +412,28 @@ class AnalyzeObject:
Returns:
str, the name of the backbone network.
"""
with_loss_cell = False
backbone = None
backbone_name = None
has_network = False
network_key = 'network'
backbone_key = '_backbone'
net_args = vars(network) if network else {}
net_cell = net_args.get('_cells') if net_args else {}
for _, value in net_cell.items():
if isinstance(value, WithLossCell):
backbone = getattr(value, '_backbone')
with_loss_cell = True
for key, value in net_cell.items():
if key == network_key:
network = value
has_network = True
break
if with_loss_cell:
backbone_name = type(backbone).__name__ \
if backbone else None
elif isinstance(network, TrainOneStepWithLossScaleCell):
backbone = getattr(network, 'network')
backbone_name = type(backbone).__name__ \
if backbone else None
else:
backbone_name = type(network).__name__ \
if network else None
if has_network:
while hasattr(network, network_key):
network = getattr(network, network_key)
if hasattr(network, backbone_key):
backbone = getattr(network, backbone_key)
backbone_name = type(backbone).__name__
elif network is not None:
backbone_name = type(network).__name__
return backbone_name
@staticmethod
......@@ -489,26 +490,24 @@ class AnalyzeObject:
Returns:
dict, the lineage metadata.
"""
dataset_batch_size = dataset.get_dataset_size()
if dataset_batch_size is not None:
validate_int_params(dataset_batch_size, 'dataset_batch_size')
log.debug('dataset_batch_size: %d', dataset_batch_size)
batch_num = dataset.get_dataset_size()
batch_size = dataset.get_batch_size()
if batch_num is not None:
validate_int_params(batch_num, 'dataset_batch_num')
validate_int_params(batch_num, 'dataset_batch_size')
log.debug('dataset_batch_num: %d', batch_num)
log.debug('dataset_batch_size: %d', batch_size)
dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
if dataset_path:
dataset_path = '/'.join(dataset_path.split('/')[:-1])
step_num = lineage_dict.get('step_num')
validate_int_params(step_num, 'step_num')
log.debug('step_num: %d', step_num)
dataset_size = int(batch_num * batch_size)
if dataset_type == 'train':
lineage_dict[Metadata.train_dataset_path] = dataset_path
epoch = lineage_dict.get('epoch')
train_dataset_size = dataset_batch_size * (step_num / epoch)
lineage_dict[Metadata.train_dataset_size] = int(train_dataset_size)
lineage_dict[Metadata.train_dataset_size] = dataset_size
elif dataset_type == 'valid':
lineage_dict[Metadata.valid_dataset_path] = dataset_path
lineage_dict[Metadata.valid_dataset_size] = dataset_batch_size * step_num
lineage_dict[Metadata.valid_dataset_size] = dataset_size
return lineage_dict
......
......@@ -82,7 +82,6 @@ class LineageObj:
self._lineage_info = {
self._name_summary_dir: summary_dir
}
self._filtration_result = None
self._init_lineage()
self.parse_and_update_lineage(**kwargs)
......
......@@ -50,10 +50,10 @@ LINEAGE_INFO_RUN1 = {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
'train_dataset_size': 1024
},
'valid_dataset': {
'valid_dataset_size': 10240
'valid_dataset_size': 1024
},
'model': {
'path': '{"ckpt": "'
......@@ -89,9 +89,9 @@ LINEAGE_FILTRATION_RUN1 = {
'model_lineage': {
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': 10240,
'test_dataset_count': 1024,
'user_defined': {},
'network': 'ResNet',
'optimizer': 'Momentum',
......@@ -115,7 +115,7 @@ LINEAGE_FILTRATION_RUN2 = {
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': 10240,
'test_dataset_count': 1024,
'user_defined': {},
'network': "ResNet",
'optimizer': "Momentum",
......
......@@ -334,12 +334,14 @@ class TestAnalyzer(TestCase):
)
res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train')
res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid')
# batch_size is mocked as 32.
assert res1 == {'step_num': 10,
'train_dataset_path': '/path/to',
'train_dataset_size': 50,
'train_dataset_size': 320,
'epoch': 2}
assert res2 == {'step_num': 5, 'valid_dataset_path': '/path/to',
'valid_dataset_size': 50}
'valid_dataset_size': 320}
def test_get_dataset_path_dataset(self):
"""Test get_dataset_path method with Dataset."""
......
......@@ -27,6 +27,10 @@ class Dataset:
"""Mocked get_dataset_size."""
return self.dataset_size
def get_batch_size(self):
"""Mocked get_batch_size"""
return 32
class MindDataset(Dataset):
"""Mock the MindSpore MindDataset class."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册