提交 bd65d62e 编写于 作者: L luopengting

fix lineage collection for network, loss_fn and dataset path

上级 faa697ce
...@@ -39,7 +39,8 @@ try: ...@@ -39,7 +39,8 @@ try:
from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
from mindspore.nn import Cell, Optimizer from mindspore.nn import Cell, Optimizer
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss
from mindspore.dataset.engine import Dataset, MindDataset from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \
VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset
import mindspore.dataset as ds import mindspore.dataset as ds
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
log.warning('MindSpore Not Found!') log.warning('MindSpore Not Found!')
...@@ -432,7 +433,7 @@ class AnalyzeObject: ...@@ -432,7 +433,7 @@ class AnalyzeObject:
if hasattr(network, backbone_key): if hasattr(network, backbone_key):
backbone = getattr(network, backbone_key) backbone = getattr(network, backbone_key)
backbone_name = type(backbone).__name__ backbone_name = type(backbone).__name__
elif network is not None: if backbone_name is None and network is not None:
backbone_name = type(network).__name__ backbone_name = type(network).__name__
return backbone_name return backbone_name
...@@ -498,8 +499,8 @@ class AnalyzeObject: ...@@ -498,8 +499,8 @@ class AnalyzeObject:
log.debug('dataset_batch_num: %d', batch_num) log.debug('dataset_batch_num: %d', batch_num)
log.debug('dataset_batch_size: %d', batch_size) log.debug('dataset_batch_size: %d', batch_size)
dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
if dataset_path: if dataset_path and os.path.isfile(dataset_path):
dataset_path = '/'.join(dataset_path.split('/')[:-1]) dataset_path, _ = os.path.split(dataset_path)
dataset_size = int(batch_num * batch_size) dataset_size = int(batch_num * batch_size)
if dataset_type == 'train': if dataset_type == 'train':
...@@ -516,14 +517,24 @@ class AnalyzeObject: ...@@ -516,14 +517,24 @@ class AnalyzeObject:
Get dataset path of MindDataset object. Get dataset path of MindDataset object.
Args: Args:
output_dataset (Union[MindDataset, Dataset]): See output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset,
mindspore.dataengine.datasets.Dataset. VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
See mindspore.dataengine.datasets.Dataset.
Returns: Returns:
str, dataset path. str, dataset path.
""" """
if isinstance(output_dataset, MindDataset): dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset,
Cifar100Dataset, VOCDataset, CelebADataset)
dataset_file_set = (MindDataset, ManifestDataset)
dataset_files_set = (TFRecordDataset, TextFileDataset)
if isinstance(output_dataset, dataset_file_set):
return output_dataset.dataset_file return output_dataset.dataset_file
if isinstance(output_dataset, dataset_dir_set):
return output_dataset.dataset_dir
if isinstance(output_dataset, dataset_files_set):
return output_dataset.dataset_files[0]
return self.get_dataset_path(output_dataset.input[0]) return self.get_dataset_path(output_dataset.input[0])
@staticmethod @staticmethod
...@@ -544,7 +555,7 @@ class AnalyzeObject: ...@@ -544,7 +555,7 @@ class AnalyzeObject:
dataset_path = AnalyzeObject().get_dataset_path(dataset) dataset_path = AnalyzeObject().get_dataset_path(dataset)
except IndexError: except IndexError:
dataset_path = None dataset_path = None
validate_file_path(dataset_path, allow_empty=True) dataset_path = validate_file_path(dataset_path, allow_empty=True)
return dataset_path return dataset_path
@staticmethod @staticmethod
......
...@@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False): ...@@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False):
""" """
try: try:
if allow_empty and not file_path: if allow_empty and not file_path:
return return file_path
safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
except ValidationError as error: except ValidationError as error:
log.error(str(error)) log.error(str(error))
raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR, raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR,
......
...@@ -323,11 +323,13 @@ class TestAnalyzer(TestCase): ...@@ -323,11 +323,13 @@ class TestAnalyzer(TestCase):
res = self.analyzer.get_dataset_path_wrapped(dataset) res = self.analyzer.get_dataset_path_wrapped(dataset)
assert res == '/path/to/cifar10' assert res == '/path/to/cifar10'
@mock.patch('os.path.isfile')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_dataset_path_wrapped') 'AnalyzeObject.get_dataset_path_wrapped')
def test_analyze_dataset(self, mock_get_path): def test_analyze_dataset(self, mock_get_path, mock_isfile):
"""Test analyze_dataset method.""" """Test analyze_dataset method."""
mock_get_path.return_value = '/path/to/mindinsightset' mock_get_path.return_value = '/path/to/mindinsightset'
mock_isfile.return_value = True
dataset = MindDataset( dataset = MindDataset(
dataset_size=10, dataset_size=10,
dataset_file='/path/to/mindinsightset' dataset_file='/path/to/mindinsightset'
......
...@@ -13,5 +13,5 @@ ...@@ -13,5 +13,5 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Mock mindspore.dataset.engine.""" """Mock mindspore.dataset.engine."""
from .datasets import Dataset, MindDataset from .datasets import *
from .serializer_deserializer import serialize from .serializer_deserializer import serialize
...@@ -38,3 +38,75 @@ class MindDataset(Dataset): ...@@ -38,3 +38,75 @@ class MindDataset(Dataset):
def __init__(self, dataset_size=None, dataset_file=None): def __init__(self, dataset_size=None, dataset_file=None):
super(MindDataset, self).__init__(dataset_size) super(MindDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file self.dataset_file = dataset_file
class ImageFolderDatasetV2(Dataset):
"""Mock the MindSpore ImageFolderDatasetV2 class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(ImageFolderDatasetV2, self).__init__(dataset_size)
self.dataset_file = dataset_file
class MnistDataset(Dataset):
"""Mock the MindSpore MnistDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(MnistDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class Cifar10Dataset(Dataset):
"""Mock the MindSpore Cifar10Dataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(Cifar10Dataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class Cifar100Dataset(Dataset):
"""Mock the MindSpore Cifar100Dataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(Cifar100Dataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class VOCDataset(Dataset):
"""Mock the MindSpore VOCDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(VOCDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class CelebADataset(Dataset):
"""Mock the MindSpore CelebADataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(CelebADataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class ManifestDataset(Dataset):
"""Mock the MindSpore ManifestDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(ManifestDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class TFRecordDataset(Dataset):
"""Mock the MindSpore TFRecordDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(TFRecordDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class TextFileDataset(Dataset):
"""Mock the MindSpore TextFileDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(TextFileDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册