提交 f6f80fae 编写于 作者: L luopengting

fix calculation of lineage dataset_num, get model name

上级 169c580a
...@@ -37,7 +37,7 @@ from mindinsight.lineagemgr.collection.model.base import Metadata ...@@ -37,7 +37,7 @@ from mindinsight.lineagemgr.collection.model.base import Metadata
try: try:
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
from mindspore.nn import Cell, Optimizer, WithLossCell, TrainOneStepWithLossScaleCell from mindspore.nn import Cell, Optimizer
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss
from mindspore.dataset.engine import Dataset, MindDataset from mindspore.dataset.engine import Dataset, MindDataset
import mindspore.dataset as ds import mindspore.dataset as ds
...@@ -412,27 +412,28 @@ class AnalyzeObject: ...@@ -412,27 +412,28 @@ class AnalyzeObject:
Returns: Returns:
str, the name of the backbone network. str, the name of the backbone network.
""" """
with_loss_cell = False backbone_name = None
backbone = None has_network = False
network_key = 'network'
backbone_key = '_backbone'
net_args = vars(network) if network else {} net_args = vars(network) if network else {}
net_cell = net_args.get('_cells') if net_args else {} net_cell = net_args.get('_cells') if net_args else {}
for _, value in net_cell.items(): for key, value in net_cell.items():
if isinstance(value, WithLossCell): if key == network_key:
backbone = getattr(value, '_backbone') network = value
with_loss_cell = True has_network = True
break break
if with_loss_cell: if has_network:
backbone_name = type(backbone).__name__ \ while hasattr(network, network_key):
if backbone else None network = getattr(network, network_key)
elif isinstance(network, TrainOneStepWithLossScaleCell): if hasattr(network, backbone_key):
backbone = getattr(network, 'network') backbone = getattr(network, backbone_key)
backbone_name = type(backbone).__name__ \ backbone_name = type(backbone).__name__
if backbone else None elif network is not None:
else: backbone_name = type(network).__name__
backbone_name = type(network).__name__ \
if network else None
return backbone_name return backbone_name
@staticmethod @staticmethod
...@@ -489,26 +490,24 @@ class AnalyzeObject: ...@@ -489,26 +490,24 @@ class AnalyzeObject:
Returns: Returns:
dict, the lineage metadata. dict, the lineage metadata.
""" """
dataset_batch_size = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()
if dataset_batch_size is not None: batch_size = dataset.get_batch_size()
validate_int_params(dataset_batch_size, 'dataset_batch_size') if batch_num is not None:
log.debug('dataset_batch_size: %d', dataset_batch_size) validate_int_params(batch_num, 'dataset_batch_num')
validate_int_params(batch_num, 'dataset_batch_size')
log.debug('dataset_batch_num: %d', batch_num)
log.debug('dataset_batch_size: %d', batch_size)
dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
if dataset_path: if dataset_path:
dataset_path = '/'.join(dataset_path.split('/')[:-1]) dataset_path = '/'.join(dataset_path.split('/')[:-1])
step_num = lineage_dict.get('step_num') dataset_size = int(batch_num * batch_size)
validate_int_params(step_num, 'step_num')
log.debug('step_num: %d', step_num)
if dataset_type == 'train': if dataset_type == 'train':
lineage_dict[Metadata.train_dataset_path] = dataset_path lineage_dict[Metadata.train_dataset_path] = dataset_path
epoch = lineage_dict.get('epoch') lineage_dict[Metadata.train_dataset_size] = dataset_size
train_dataset_size = dataset_batch_size * (step_num / epoch)
lineage_dict[Metadata.train_dataset_size] = int(train_dataset_size)
elif dataset_type == 'valid': elif dataset_type == 'valid':
lineage_dict[Metadata.valid_dataset_path] = dataset_path lineage_dict[Metadata.valid_dataset_path] = dataset_path
lineage_dict[Metadata.valid_dataset_size] = dataset_batch_size * step_num lineage_dict[Metadata.valid_dataset_size] = dataset_size
return lineage_dict return lineage_dict
......
...@@ -82,7 +82,6 @@ class LineageObj: ...@@ -82,7 +82,6 @@ class LineageObj:
self._lineage_info = { self._lineage_info = {
self._name_summary_dir: summary_dir self._name_summary_dir: summary_dir
} }
self._filtration_result = None
self._init_lineage() self._init_lineage()
self.parse_and_update_lineage(**kwargs) self.parse_and_update_lineage(**kwargs)
......
...@@ -50,10 +50,10 @@ LINEAGE_INFO_RUN1 = { ...@@ -50,10 +50,10 @@ LINEAGE_INFO_RUN1 = {
'network': 'ResNet' 'network': 'ResNet'
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_size': 731 'train_dataset_size': 1024
}, },
'valid_dataset': { 'valid_dataset': {
'valid_dataset_size': 10240 'valid_dataset_size': 1024
}, },
'model': { 'model': {
'path': '{"ckpt": "' 'path': '{"ckpt": "'
...@@ -89,9 +89,9 @@ LINEAGE_FILTRATION_RUN1 = { ...@@ -89,9 +89,9 @@ LINEAGE_FILTRATION_RUN1 = {
'model_lineage': { 'model_lineage': {
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None, 'train_dataset_path': None,
'train_dataset_count': 731, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 10240, 'test_dataset_count': 1024,
'user_defined': {}, 'user_defined': {},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
...@@ -115,7 +115,7 @@ LINEAGE_FILTRATION_RUN2 = { ...@@ -115,7 +115,7 @@ LINEAGE_FILTRATION_RUN2 = {
'train_dataset_path': None, 'train_dataset_path': None,
'train_dataset_count': 1024, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 10240, 'test_dataset_count': 1024,
'user_defined': {}, 'user_defined': {},
'network': "ResNet", 'network': "ResNet",
'optimizer': "Momentum", 'optimizer': "Momentum",
......
...@@ -334,12 +334,14 @@ class TestAnalyzer(TestCase): ...@@ -334,12 +334,14 @@ class TestAnalyzer(TestCase):
) )
res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train')
res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid')
# batch_size is mocked as 32.
assert res1 == {'step_num': 10, assert res1 == {'step_num': 10,
'train_dataset_path': '/path/to', 'train_dataset_path': '/path/to',
'train_dataset_size': 50, 'train_dataset_size': 320,
'epoch': 2} 'epoch': 2}
assert res2 == {'step_num': 5, 'valid_dataset_path': '/path/to', assert res2 == {'step_num': 5, 'valid_dataset_path': '/path/to',
'valid_dataset_size': 50} 'valid_dataset_size': 320}
def test_get_dataset_path_dataset(self): def test_get_dataset_path_dataset(self):
"""Test get_dataset_path method with Dataset.""" """Test get_dataset_path method with Dataset."""
......
...@@ -27,6 +27,10 @@ class Dataset: ...@@ -27,6 +27,10 @@ class Dataset:
"""Mocked get_dataset_size.""" """Mocked get_dataset_size."""
return self.dataset_size return self.dataset_size
def get_batch_size(self):
"""Mocked get_batch_size"""
return 32
class MindDataset(Dataset): class MindDataset(Dataset):
"""Mock the MindSpore MindDataset class.""" """Mock the MindSpore MindDataset class."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册