From bd65d62e55a0f785a3aa1bb9e9e1355ca2e7e4e8 Mon Sep 17 00:00:00 2001 From: luopengting Date: Sun, 24 May 2020 13:00:52 +0800 Subject: [PATCH] fix lineage collection for network, loss_fn and dataset path --- .../collection/model/model_lineage.py | 27 ++++--- .../lineagemgr/common/validator/validate.py | 4 +- .../collection/model/test_model_lineage.py | 4 +- .../mindspore/dataset/engine/__init__.py | 2 +- .../mindspore/dataset/engine/datasets.py | 72 +++++++++++++++++++ 5 files changed, 97 insertions(+), 12 deletions(-) diff --git a/mindinsight/lineagemgr/collection/model/model_lineage.py b/mindinsight/lineagemgr/collection/model/model_lineage.py index 966a59b..6b4547c 100644 --- a/mindinsight/lineagemgr/collection/model/model_lineage.py +++ b/mindinsight/lineagemgr/collection/model/model_lineage.py @@ -39,7 +39,8 @@ try: from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep from mindspore.nn import Cell, Optimizer from mindspore.nn.loss.loss import _Loss - from mindspore.dataset.engine import Dataset, MindDataset + from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \ + VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset import mindspore.dataset as ds except (ImportError, ModuleNotFoundError): log.warning('MindSpore Not Found!') @@ -432,7 +433,7 @@ class AnalyzeObject: if hasattr(network, backbone_key): backbone = getattr(network, backbone_key) backbone_name = type(backbone).__name__ - elif network is not None: + if backbone_name is None and network is not None: backbone_name = type(network).__name__ return backbone_name @@ -498,8 +499,8 @@ class AnalyzeObject: log.debug('dataset_batch_num: %d', batch_num) log.debug('dataset_batch_size: %d', batch_size) dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) - if dataset_path: - dataset_path = '/'.join(dataset_path.split('/')[:-1]) + if dataset_path and os.path.isfile(dataset_path): + dataset_path, _ = os.path.split(dataset_path) dataset_size = int(batch_num * batch_size) if dataset_type == 'train': @@ -516,14 +517,24 @@ class AnalyzeObject: Get dataset path of MindDataset object. Args: - output_dataset (Union[MindDataset, Dataset]): See - mindspore.dataengine.datasets.Dataset. + output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, + VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]): + See mindspore.dataengine.datasets.Dataset. Returns: str, dataset path. """ - if isinstance(output_dataset, MindDataset): + dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, + Cifar100Dataset, VOCDataset, CelebADataset) + dataset_file_set = (MindDataset, ManifestDataset) + dataset_files_set = (TFRecordDataset, TextFileDataset) + + if isinstance(output_dataset, dataset_file_set): return output_dataset.dataset_file + if isinstance(output_dataset, dataset_dir_set): + return output_dataset.dataset_dir + if isinstance(output_dataset, dataset_files_set): + return output_dataset.dataset_files[0] return self.get_dataset_path(output_dataset.input[0]) @staticmethod @@ -544,7 +555,7 @@ class AnalyzeObject: dataset_path = AnalyzeObject().get_dataset_path(dataset) except IndexError: dataset_path = None - validate_file_path(dataset_path, allow_empty=True) + dataset_path = validate_file_path(dataset_path, allow_empty=True) return dataset_path @staticmethod diff --git a/mindinsight/lineagemgr/common/validator/validate.py b/mindinsight/lineagemgr/common/validator/validate.py index 2125de2..91c71b2 100644 --- a/mindinsight/lineagemgr/common/validator/validate.py +++ b/mindinsight/lineagemgr/common/validator/validate.py @@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False): """ try: if allow_empty and not file_path: - return - safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) + return file_path + return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) except ValidationError as error: log.error(str(error)) raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR, diff --git a/tests/ut/lineagemgr/collection/model/test_model_lineage.py b/tests/ut/lineagemgr/collection/model/test_model_lineage.py index 81a61ea..d4c506a 100644 --- a/tests/ut/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/ut/lineagemgr/collection/model/test_model_lineage.py @@ -323,11 +323,13 @@ class TestAnalyzer(TestCase): res = self.analyzer.get_dataset_path_wrapped(dataset) assert res == '/path/to/cifar10' + @mock.patch('os.path.isfile') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' 'AnalyzeObject.get_dataset_path_wrapped') - def test_analyze_dataset(self, mock_get_path): + def test_analyze_dataset(self, mock_get_path, mock_isfile): """Test analyze_dataset method.""" mock_get_path.return_value = '/path/to/mindinsightset' + mock_isfile.return_value = True dataset = MindDataset( dataset_size=10, dataset_file='/path/to/mindinsightset' diff --git a/tests/utils/mindspore/dataset/engine/__init__.py b/tests/utils/mindspore/dataset/engine/__init__.py index bc14017..0ac981b 100644 --- a/tests/utils/mindspore/dataset/engine/__init__.py +++ b/tests/utils/mindspore/dataset/engine/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. # ============================================================================ """Mock mindspore.dataset.engine.""" -from .datasets import Dataset, MindDataset +from .datasets import * from .serializer_deserializer import serialize diff --git a/tests/utils/mindspore/dataset/engine/datasets.py b/tests/utils/mindspore/dataset/engine/datasets.py index 8fc2896..0b9f85c 100644 --- a/tests/utils/mindspore/dataset/engine/datasets.py +++ b/tests/utils/mindspore/dataset/engine/datasets.py @@ -38,3 +38,75 @@ class MindDataset(Dataset): def __init__(self, dataset_size=None, dataset_file=None): super(MindDataset, self).__init__(dataset_size) self.dataset_file = dataset_file + + +class ImageFolderDatasetV2(Dataset): + """Mock the MindSpore ImageFolderDatasetV2 class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(ImageFolderDatasetV2, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class MnistDataset(Dataset): + """Mock the MindSpore MnistDataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(MnistDataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class Cifar10Dataset(Dataset): + """Mock the MindSpore Cifar10Dataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(Cifar10Dataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class Cifar100Dataset(Dataset): + """Mock the MindSpore Cifar100Dataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(Cifar100Dataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class VOCDataset(Dataset): + """Mock the MindSpore VOCDataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(VOCDataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class CelebADataset(Dataset): + """Mock the MindSpore CelebADataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(CelebADataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class ManifestDataset(Dataset): + """Mock the MindSpore ManifestDataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(ManifestDataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class TFRecordDataset(Dataset): + """Mock the MindSpore TFRecordDataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(TFRecordDataset, self).__init__(dataset_size) + self.dataset_file = dataset_file + + +class TextFileDataset(Dataset): + """Mock the MindSpore TextFileDataset class.""" + + def __init__(self, dataset_size=None, dataset_file=None): + super(TextFileDataset, self).__init__(dataset_size) + self.dataset_file = dataset_file -- GitLab