提交 bd65d62e 编写于 作者: L luopengting

fix lineage collection for network, loss_fn and dataset path

上级 faa697ce
......@@ -39,7 +39,8 @@ try:
from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
from mindspore.nn import Cell, Optimizer
from mindspore.nn.loss.loss import _Loss
from mindspore.dataset.engine import Dataset, MindDataset
from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \
VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset
import mindspore.dataset as ds
except (ImportError, ModuleNotFoundError):
log.warning('MindSpore Not Found!')
......@@ -432,7 +433,7 @@ class AnalyzeObject:
if hasattr(network, backbone_key):
backbone = getattr(network, backbone_key)
backbone_name = type(backbone).__name__
elif network is not None:
if backbone_name is None and network is not None:
backbone_name = type(network).__name__
return backbone_name
......@@ -498,8 +499,8 @@ class AnalyzeObject:
log.debug('dataset_batch_num: %d', batch_num)
log.debug('dataset_batch_size: %d', batch_size)
dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
if dataset_path:
dataset_path = '/'.join(dataset_path.split('/')[:-1])
if dataset_path and os.path.isfile(dataset_path):
dataset_path, _ = os.path.split(dataset_path)
dataset_size = int(batch_num * batch_size)
if dataset_type == 'train':
......@@ -516,14 +517,24 @@ class AnalyzeObject:
Get dataset path of MindDataset object.
Args:
output_dataset (Union[MindDataset, Dataset]): See
mindspore.dataengine.datasets.Dataset.
output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset,
VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
See mindspore.dataengine.datasets.Dataset.
Returns:
str, dataset path.
"""
if isinstance(output_dataset, MindDataset):
dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset,
Cifar100Dataset, VOCDataset, CelebADataset)
dataset_file_set = (MindDataset, ManifestDataset)
dataset_files_set = (TFRecordDataset, TextFileDataset)
if isinstance(output_dataset, dataset_file_set):
return output_dataset.dataset_file
if isinstance(output_dataset, dataset_dir_set):
return output_dataset.dataset_dir
if isinstance(output_dataset, dataset_files_set):
return output_dataset.dataset_files[0]
return self.get_dataset_path(output_dataset.input[0])
@staticmethod
......@@ -544,7 +555,7 @@ class AnalyzeObject:
dataset_path = AnalyzeObject().get_dataset_path(dataset)
except IndexError:
dataset_path = None
validate_file_path(dataset_path, allow_empty=True)
dataset_path = validate_file_path(dataset_path, allow_empty=True)
return dataset_path
@staticmethod
......
......@@ -182,8 +182,8 @@ def validate_file_path(file_path, allow_empty=False):
"""
try:
if allow_empty and not file_path:
return
safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
return file_path
return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None)
except ValidationError as error:
log.error(str(error))
raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR,
......
......@@ -323,11 +323,13 @@ class TestAnalyzer(TestCase):
res = self.analyzer.get_dataset_path_wrapped(dataset)
assert res == '/path/to/cifar10'
@mock.patch('os.path.isfile')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_dataset_path_wrapped')
def test_analyze_dataset(self, mock_get_path):
def test_analyze_dataset(self, mock_get_path, mock_isfile):
"""Test analyze_dataset method."""
mock_get_path.return_value = '/path/to/mindinsightset'
mock_isfile.return_value = True
dataset = MindDataset(
dataset_size=10,
dataset_file='/path/to/mindinsightset'
......
......@@ -13,5 +13,5 @@
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset.engine."""
from .datasets import Dataset, MindDataset
from .datasets import *
from .serializer_deserializer import serialize
......@@ -38,3 +38,75 @@ class MindDataset(Dataset):
def __init__(self, dataset_size=None, dataset_file=None):
super(MindDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class ImageFolderDatasetV2(Dataset):
"""Mock the MindSpore ImageFolderDatasetV2 class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(ImageFolderDatasetV2, self).__init__(dataset_size)
self.dataset_file = dataset_file
class MnistDataset(Dataset):
"""Mock the MindSpore MnistDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(MnistDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class Cifar10Dataset(Dataset):
"""Mock the MindSpore Cifar10Dataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(Cifar10Dataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class Cifar100Dataset(Dataset):
"""Mock the MindSpore Cifar100Dataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(Cifar100Dataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class VOCDataset(Dataset):
"""Mock the MindSpore VOCDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(VOCDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class CelebADataset(Dataset):
"""Mock the MindSpore CelebADataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(CelebADataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class ManifestDataset(Dataset):
"""Mock the MindSpore ManifestDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(ManifestDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class TFRecordDataset(Dataset):
"""Mock the MindSpore TFRecordDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(TFRecordDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
class TextFileDataset(Dataset):
"""Mock the MindSpore TextFileDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(TextFileDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册