提交 4a614930 编写于 作者: C chenchao99

fix ut/st about cmetrics

上级 228a448b
...@@ -33,6 +33,96 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \ ...@@ -33,6 +33,96 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \
from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH
LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'metric': {
'accuracy': 0.78
},
'hyper_parameters': {
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14,
'parallel_mode': 'stand_alone',
'device_num': 2,
'batch_size': 32
},
'algorithm': {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
},
'valid_dataset': {
'valid_dataset_size': 10240
},
'model': {
'path': '{"ckpt": "'
+ BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}',
'size': 64
},
'dataset_graph': DATASET_GRAPH
}
LINEAGE_FILTRATION_EXCEPT_RUN = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN2 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
}
@pytest.mark.usefixtures("create_summary_dir") @pytest.mark.usefixtures("create_summary_dir")
class TestModelApi(TestCase): class TestModelApi(TestCase):
"""Test lineage information query interface.""" """Test lineage information query interface."""
...@@ -67,36 +157,7 @@ class TestModelApi(TestCase): ...@@ -67,36 +157,7 @@ class TestModelApi(TestCase):
total_res = get_summary_lineage(SUMMARY_DIR) total_res = get_summary_lineage(SUMMARY_DIR)
partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters']) partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters'])
partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm']) partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm'])
expect_total_res = { expect_total_res = LINEAGE_INFO_RUN1
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'metric': {
'accuracy': 0.78
},
'hyper_parameters': {
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14,
'parallel_mode': 'stand_alone',
'device_num': 2,
'batch_size': 32
},
'algorithm': {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
},
'valid_dataset': {
'valid_dataset_size': 10240
},
'model': {
'path': '{"ckpt": "'
+ BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}',
'size': 64
},
'dataset_graph': DATASET_GRAPH
}
expect_partial_res1 = { expect_partial_res1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'hyper_parameters': { 'hyper_parameters': {
...@@ -139,7 +200,7 @@ class TestModelApi(TestCase): ...@@ -139,7 +200,7 @@ class TestModelApi(TestCase):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_single @pytest.mark.env_single
def test_get_summary_lineage_exception(self): def test_get_summary_lineage_exception_1(self):
"""Test the interface of get_summary_lineage with exception.""" """Test the interface of get_summary_lineage with exception."""
# summary path does not exist # summary path does not exist
self.assertRaisesRegex( self.assertRaisesRegex(
...@@ -183,6 +244,14 @@ class TestModelApi(TestCase): ...@@ -183,6 +244,14 @@ class TestModelApi(TestCase):
keys=None keys=None
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_get_summary_lineage_exception_2(self):
"""Test the interface of get_summary_lineage with exception."""
# keys is invalid # keys is invalid
self.assertRaisesRegex( self.assertRaisesRegex(
LineageParamValueError, LineageParamValueError,
...@@ -250,64 +319,9 @@ class TestModelApi(TestCase): ...@@ -250,64 +319,9 @@ class TestModelApi(TestCase):
"""Test the interface of filter_summary_lineage.""" """Test the interface of filter_summary_lineage."""
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_EXCEPT_RUN,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), LINEAGE_FILTRATION_RUN1,
'loss_function': 'SoftmaxCrossEntropyWithLogits', LINEAGE_FILTRATION_RUN2
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
}
], ],
'count': 3 'count': 3
} }
...@@ -357,46 +371,8 @@ class TestModelApi(TestCase): ...@@ -357,46 +371,8 @@ class TestModelApi(TestCase):
} }
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_RUN2,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), LINEAGE_FILTRATION_RUN1
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
], ],
'count': 2 'count': 2
} }
...@@ -432,46 +408,8 @@ class TestModelApi(TestCase): ...@@ -432,46 +408,8 @@ class TestModelApi(TestCase):
} }
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_RUN2,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), LINEAGE_FILTRATION_RUN1
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
], ],
'count': 2 'count': 2
} }
...@@ -498,44 +436,8 @@ class TestModelApi(TestCase): ...@@ -498,44 +436,8 @@ class TestModelApi(TestCase):
} }
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_EXCEPT_RUN,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), LINEAGE_FILTRATION_RUN1
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
], ],
'count': 2 'count': 2
} }
...@@ -674,6 +576,14 @@ class TestModelApi(TestCase): ...@@ -674,6 +576,14 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_3(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the condition of offset is invalid # the condition of offset is invalid
search_condition = { search_condition = {
'offset': 1.0 'offset': 1.0
...@@ -712,6 +622,14 @@ class TestModelApi(TestCase): ...@@ -712,6 +622,14 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_4(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the sorted_type not supported # the sorted_type not supported
search_condition = { search_condition = {
'sorted_name': 'summary_dir', 'sorted_name': 'summary_dir',
...@@ -753,6 +671,14 @@ class TestModelApi(TestCase): ...@@ -753,6 +671,14 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_5(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the summary dir is invalid in search condition # the summary dir is invalid in search condition
search_condition = { search_condition = {
'summary_dir': { 'summary_dir': {
...@@ -811,7 +737,7 @@ class TestModelApi(TestCase): ...@@ -811,7 +737,7 @@ class TestModelApi(TestCase):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_single @pytest.mark.env_single
def test_filter_summary_lineage_exception_3(self): def test_filter_summary_lineage_exception_6(self):
"""Test the abnormal execution of the filter_summary_lineage interface.""" """Test the abnormal execution of the filter_summary_lineage interface."""
# gt > lt # gt > lt
search_condition1 = { search_condition1 = {
......
...@@ -22,6 +22,8 @@ import tempfile ...@@ -22,6 +22,8 @@ import tempfile
import pytest import pytest
from ....utils import mindspore from ....utils import mindspore
from ....utils.mindspore.dataset.engine.serializer_deserializer import \
SERIALIZED_PIPELINE
sys.modules['mindspore'] = mindspore sys.modules['mindspore'] = mindspore
...@@ -32,52 +34,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run') ...@@ -32,52 +34,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run')
COLLECTION_MODULE = 'TestModelLineage' COLLECTION_MODULE = 'TestModelLineage'
API_MODULE = 'TestModelApi' API_MODULE = 'TestModelApi'
DATASET_GRAPH = { DATASET_GRAPH = SERIALIZED_PIPELINE
'op_type': 'BatchDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'drop_remainder': True,
'batch_size': 10,
'children': [
{
'op_type': 'MapDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'input_columns': [
'label'
],
'output_columns': [
None
],
'operations': [
{
'tensor_op_module': 'minddata.transforms.c_transforms',
'tensor_op_name': 'OneHot',
'num_classes': 10
}
],
'children': [
{
'op_type': 'MnistDataset',
'shard_id': None,
'num_shards': None,
'op_module': 'minddata.dataengine.datasets',
'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData',
'num_parallel_workers': None,
'shuffle': None,
'num_samples': 100,
'sampler': {
'sampler_module': 'minddata.dataengine.samplers',
'sampler_name': 'RandomSampler',
'replacement': True,
'num_samples': 100
},
'children': []
}
]
}
]
}
def get_module_name(nodeid): def get_module_name(nodeid):
"""Get the module name from nodeid.""" """Get the module name from nodeid."""
......
...@@ -24,24 +24,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \ ...@@ -24,24 +24,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageQuerySummaryDataError LineageQuerySummaryDataError
class TestSearchModel(TestCase): LINEAGE_FILTRATION_BASE = {
"""Test the restful api of search_model."""
def setUp(self):
"""Test init."""
APP.response_class = Response
self.app_client = APP.test_client()
self.url = '/v1/mindinsight/models/model_lineage'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage')
def test_search_model_success(self, *args):
"""Test the success of model_success."""
base_dir = '/path/to/test_lineage_summary_dir_base'
args[0].return_value = {
'object': [
{
'summary_dir': base_dir,
'accuracy': None, 'accuracy': None,
'mae': None, 'mae': None,
'mse': None, 'mse': None,
...@@ -57,9 +40,8 @@ class TestSearchModel(TestCase): ...@@ -57,9 +40,8 @@ class TestSearchModel(TestCase):
'batch_size': 32, 'batch_size': 32,
'loss': 0.029999999329447746, 'loss': 0.029999999329447746,
'model_size': 128 'model_size': 128
}, }
{ LINEAGE_FILTRATION_RUN1 = {
'summary_dir': os.path.join(base_dir, 'run1'),
'accuracy': 0.78, 'accuracy': 0.78,
'mae': None, 'mae': None,
'mse': None, 'mse': None,
...@@ -75,6 +57,32 @@ class TestSearchModel(TestCase): ...@@ -75,6 +57,32 @@ class TestSearchModel(TestCase):
'batch_size': 32, 'batch_size': 32,
'loss': 0.029999999329447746, 'loss': 0.029999999329447746,
'model_size': 128 'model_size': 128
}
class TestSearchModel(TestCase):
"""Test the restful api of search_model."""
def setUp(self):
"""Test init."""
APP.response_class = Response
self.app_client = APP.test_client()
self.url = '/v1/mindinsight/models/model_lineage'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage')
def test_search_model_success(self, *args):
"""Test the success of model_success."""
base_dir = '/path/to/test_lineage_summary_dir_base'
args[0].return_value = {
'object': [
{
'summary_dir': base_dir,
**LINEAGE_FILTRATION_BASE
},
{
'summary_dir': os.path.join(base_dir, 'run1'),
**LINEAGE_FILTRATION_RUN1
} }
], ],
'count': 2 'count': 2
...@@ -93,39 +101,11 @@ class TestSearchModel(TestCase): ...@@ -93,39 +101,11 @@ class TestSearchModel(TestCase):
'object': [ 'object': [
{ {
'summary_dir': './', 'summary_dir': './',
'accuracy': None, **LINEAGE_FILTRATION_BASE
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 12,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
}, },
{ {
'summary_dir': './run1', 'summary_dir': './run1',
'accuracy': 0.78, **LINEAGE_FILTRATION_RUN1
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': 64,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
} }
], ],
'count': 2 'count': 2
......
...@@ -62,14 +62,10 @@ class TestModelLineage(TestCase): ...@@ -62,14 +62,10 @@ class TestModelLineage(TestCase):
self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception)) self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph')
'LineageSummary.record_dataset_graph') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network')
'validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_optimizer_by_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
def test_begin(self, *args): def test_begin(self, *args):
"""Test TrainLineage.begin method.""" """Test TrainLineage.begin method."""
...@@ -84,14 +80,10 @@ class TestModelLineage(TestCase): ...@@ -84,14 +80,10 @@ class TestModelLineage(TestCase):
args[4].assert_called() args[4].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph')
'LineageSummary.record_dataset_graph') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network')
'validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_optimizer_by_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
def test_begin_error(self, *args): def test_begin_error(self, *args):
"""Test TrainLineage.begin method.""" """Test TrainLineage.begin method."""
...@@ -124,15 +116,11 @@ class TestModelLineage(TestCase): ...@@ -124,15 +116,11 @@ class TestModelLineage(TestCase):
train_lineage.begin(self.my_run_context(run_context)) train_lineage.begin(self.my_run_context(run_context))
self.assertTrue('The parameter optimizer is invalid.' in str(context.exception)) self.assertTrue('The parameter optimizer is invalid.' in str(context.exception))
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float') @mock.patch('builtins.float')
...@@ -160,15 +148,11 @@ class TestModelLineage(TestCase): ...@@ -160,15 +148,11 @@ class TestModelLineage(TestCase):
train_lineage.end(self.run_context) train_lineage.end(self.run_context)
self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception)) self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception))
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float') @mock.patch('builtins.float')
...@@ -188,15 +172,11 @@ class TestModelLineage(TestCase): ...@@ -188,15 +172,11 @@ class TestModelLineage(TestCase):
train_lineage.end(self.my_run_context(self.run_context)) train_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in TrainLineage:' in str(context.exception)) self.assertTrue('End error in TrainLineage:' in str(context.exception))
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float') @mock.patch('builtins.float')
......
...@@ -20,13 +20,7 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher ...@@ -20,13 +20,7 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.lineagemgr.common.path_parser import SummaryPathParser from mindinsight.lineagemgr.common.path_parser import SummaryPathParser
class TestSummaryPathParser(TestCase): MOCK_SUMMARY_DIRS = [
"""Test the class of SummaryPathParser."""
@mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_summary_dirs(self, *args):
"""Test the function of get_summary_dirs."""
args[0].return_value = [
{ {
'relative_path': './relative_path0' 'relative_path': './relative_path0'
}, },
...@@ -36,25 +30,8 @@ class TestSummaryPathParser(TestCase): ...@@ -36,25 +30,8 @@ class TestSummaryPathParser(TestCase):
{ {
'relative_path': './relative_path1' 'relative_path': './relative_path1'
} }
] ]
MOCK_SUMMARIES = [
expected_result = [
'/path/to/base/relative_path0',
'/path/to/base',
'/path/to/base/relative_path1'
]
base_dir = '/path/to/base'
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual(expected_result, result)
args[0].return_value = []
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual([], result)
@mock.patch.object(SummaryWatcher, 'list_summaries')
def test_get_latest_lineage_summary(self, *args):
"""Test the function of get_latest_lineage_summary."""
args[0].return_value = [
{ {
'file_name': 'file0', 'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970) 'create_time': datetime.fromtimestamp(1582031970)
...@@ -71,7 +48,34 @@ class TestSummaryPathParser(TestCase): ...@@ -71,7 +48,34 @@ class TestSummaryPathParser(TestCase):
'file_name': 'file1_lineage', 'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971) 'create_time': datetime.fromtimestamp(1582031971)
} }
]
class TestSummaryPathParser(TestCase):
"""Test the class of SummaryPathParser."""
@mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_summary_dirs(self, *args):
"""Test the function of get_summary_dirs."""
args[0].return_value = MOCK_SUMMARY_DIRS
expected_result = [
'/path/to/base/relative_path0',
'/path/to/base',
'/path/to/base/relative_path1'
] ]
base_dir = '/path/to/base'
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual(expected_result, result)
args[0].return_value = []
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual([], result)
@mock.patch.object(SummaryWatcher, 'list_summaries')
def test_get_latest_lineage_summary(self, *args):
"""Test the function of get_latest_lineage_summary."""
args[0].return_value = MOCK_SUMMARIES
summary_dir = '/path/to/summary_dir' summary_dir = '/path/to/summary_dir'
result = SummaryPathParser.get_latest_lineage_summary(summary_dir) result = SummaryPathParser.get_latest_lineage_summary(summary_dir)
self.assertEqual('/path/to/summary_dir/file1_lineage', result) self.assertEqual('/path/to/summary_dir/file1_lineage', result)
...@@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase): ...@@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase):
@mock.patch.object(SummaryWatcher, 'list_summary_directories') @mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_latest_lineage_summaries(self, *args): def test_get_latest_lineage_summaries(self, *args):
"""Test the function of get_latest_lineage_summaries.""" """Test the function of get_latest_lineage_summaries."""
args[0].return_value = [ args[0].return_value = MOCK_SUMMARY_DIRS
{ args[1].return_value = MOCK_SUMMARIES
'relative_path': './relative_path0'
},
{
'relative_path': './'
},
{
'relative_path': './relative_path1'
}
]
args[1].return_value = [
{
'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file0_lineage',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file1',
'create_time': datetime.fromtimestamp(1582031971)
},
{
'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971)
}
]
expected_result = [ expected_result = [
'/path/to/base/relative_path0/file1_lineage', '/path/to/base/relative_path0/file1_lineage',
......
...@@ -26,27 +26,23 @@ from mindinsight.utils.exceptions import MindInsightException ...@@ -26,27 +26,23 @@ from mindinsight.utils.exceptions import MindInsightException
class TestValidateSearchModelCondition(TestCase): class TestValidateSearchModelCondition(TestCase):
"""Test the mothod of validate_search_model_condition.""" """Test the mothod of validate_search_model_condition."""
def test_validate_search_model_condition(self): def test_validate_search_model_condition_param_type_error(self):
"""Test the mothod of validate_search_model_condition.""" """Test the mothod of validate_search_model_condition with LineageParamTypeError."""
condition = { condition = {
'summary_dir': 'xxx' 'summary_dir': 'xxx'
} }
self.assertRaisesRegex( self._assert_raise_of_lineage_param_type_error(
LineageParamTypeError,
'The search_condition element summary_dir should be dict.', 'The search_condition element summary_dir should be dict.',
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_param_value_error(self):
"""Test the mothod of validate_search_model_condition with LineageParamValueError."""
condition = { condition = {
'xxx': 'xxx' 'xxx': 'xxx'
} }
self.assertRaisesRegex( self._assert_raise_of_lineage_param_value_error(
LineageParamValueError,
'The search attribute not supported.', 'The search attribute not supported.',
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -55,22 +51,38 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -55,22 +51,38 @@ class TestValidateSearchModelCondition(TestCase):
'xxx': 'xxx' 'xxx': 'xxx'
} }
} }
self.assertRaisesRegex( self._assert_raise_of_lineage_param_value_error(
LineageParamValueError,
"The compare condition should be in", "The compare condition should be in",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
condition = {
1: {
"ge": 8.0
}
}
self._assert_raise_of_lineage_param_value_error(
"The search attribute not supported.",
condition
)
condition = {
'metric_': {
"ge": 8.0
}
}
self._assert_raise_of_lineage_param_value_error(
"The search attribute not supported.",
condition
)
def test_validate_search_model_condition_mindinsight_exception_1(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
"offset": 100001 "offset": 100001
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"Invalid input offset. 0 <= offset <= 100000", "Invalid input offset. 0 <= offset <= 100000",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -80,11 +92,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -80,11 +92,9 @@ class TestValidateSearchModelCondition(TestCase):
}, },
'limit': 10 'limit': 10
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter summary_dir is invalid. It should be a dict and "
"The parameter summary_dir is invalid. It should be a dict and the value should be a string", "the value should be a string",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -93,11 +103,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -93,11 +103,9 @@ class TestValidateSearchModelCondition(TestCase):
'in': 1.0 'in': 1.0
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -106,24 +114,22 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -106,24 +114,22 @@ class TestValidateSearchModelCondition(TestCase):
'lt': True 'lt': True
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_mindinsight_exception_2(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
'learning_rate': { 'learning_rate': {
'gt': [1.0] 'gt': [1.0]
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -132,11 +138,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -132,11 +138,9 @@ class TestValidateSearchModelCondition(TestCase):
'ge': 1 'ge': 1
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter loss_function is invalid. It should be a dict and "
"The parameter loss_function is invalid. It should be a dict and the value should be a string", "the value should be a string",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -145,12 +149,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -145,12 +149,9 @@ class TestValidateSearchModelCondition(TestCase):
'in': 2 'in': 2
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter train_dataset_count is invalid. It should be a dict " "The parameter train_dataset_count is invalid. It should be a dict "
"and the value should be a integer between 0", "and the value should be a integer between 0",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -162,14 +163,14 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -162,14 +163,14 @@ class TestValidateSearchModelCondition(TestCase):
'eq': 'xxx' 'eq': 'xxx'
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter network is invalid. It should be a dict and "
"The parameter network is invalid. It should be a dict and the value should be a string", "the value should be a string",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_mindinsight_exception_3(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
'batch_size': { 'batch_size': {
'lt': 2, 'lt': 2,
...@@ -179,11 +180,8 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -179,11 +180,8 @@ class TestValidateSearchModelCondition(TestCase):
'eq': 222 'eq': 222
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter batch_size is invalid. It should be a non-negative integer.", "The parameter batch_size is invalid. It should be a non-negative integer.",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -192,12 +190,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -192,12 +190,9 @@ class TestValidateSearchModelCondition(TestCase):
'lt': -2 'lt': -2
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter test_dataset_count is invalid. It should be a dict " "The parameter test_dataset_count is invalid. It should be a dict "
"and the value should be a integer between 0", "and the value should be a integer between 0",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -206,11 +201,8 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -206,11 +201,8 @@ class TestValidateSearchModelCondition(TestCase):
'lt': False 'lt': False
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter epoch is invalid. It should be a positive integer.", "The parameter epoch is invalid. It should be a positive integer.",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -219,65 +211,79 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -219,65 +211,79 @@ class TestValidateSearchModelCondition(TestCase):
"ge": "" "ge": ""
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_mindinsight_exception_4(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
"train_dataset_count": { "train_dataset_count": {
"ge": 8.0 "ge": 8.0
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter train_dataset_count is invalid. It should be a dict " "The parameter train_dataset_count is invalid. It should be a dict "
"and the value should be a integer between 0", "and the value should be a integer between 0",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
condition = { condition = {
1: { 'metric_attribute': {
"ge": 8.0 'ge': 'xxx'
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
LineageParamValueError, "The parameter metric_attribute is invalid. "
"The search attribute not supported.", "It should be a dict and the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
condition = { def _assert_raise(self, exception, msg, condition):
'metric_': { """
"ge": 8.0 Assert raise by unittest.
}
}
LineageParamValueError('The search attribute not supported.')
self.assertRaisesRegex(
LineageParamValueError,
"The search attribute not supported.",
validate_search_model_condition,
SearchModelConditionParameter,
condition
)
condition = { Args:
'metric_attribute': { exception (Type): Exception class expected to be raised.
'ge': 'xxx' msg (msg): Expected error message.
} condition (dict): The parameter of search condition.
} """
self.assertRaisesRegex( self.assertRaisesRegex(
MindInsightException, exception,
"The parameter metric_attribute is invalid. " msg,
"It should be a dict and the value should be a float or a integer",
validate_search_model_condition, validate_search_model_condition,
SearchModelConditionParameter, SearchModelConditionParameter,
condition condition
) )
def _assert_raise_of_mindinsight_exception(self, msg, condition):
"""
Assert raise of MindinsightException by unittest.
Args:
msg (msg): Expected error message.
condition (dict): The parameter of search condition.
"""
self._assert_raise(MindInsightException, msg, condition)
def _assert_raise_of_lineage_param_value_error(self, msg, condition):
"""
Assert raise of LineageParamValueError by unittest.
Args:
msg (msg): Expected error message.
condition (dict): The parameter of search condition.
"""
self._assert_raise(LineageParamValueError, msg, condition)
def _assert_raise_of_lineage_param_type_error(self, msg, condition):
"""
Assert raise of LineageParamTypeError by unittest.
Args:
msg (msg): Expected error message.
condition (dict): The parameter of search condition.
"""
self._assert_raise(LineageParamTypeError, msg, condition)
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
"""The event data in querier test.""" """The event data in querier test."""
import json import json
from ....utils.mindspore.dataset.engine.serializer_deserializer import \
SERIALIZED_PIPELINE
EVENT_TRAIN_DICT_0 = { EVENT_TRAIN_DICT_0 = {
'wall_time': 1581499557.7017336, 'wall_time': 1581499557.7017336,
'train_lineage': { 'train_lineage': {
...@@ -373,49 +376,4 @@ EVENT_DATASET_DICT_0 = { ...@@ -373,49 +376,4 @@ EVENT_DATASET_DICT_0 = {
} }
} }
DATASET_DICT_0 = { DATASET_DICT_0 = SERIALIZED_PIPELINE
'op_type': 'BatchDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'drop_remainder': True,
'batch_size': 10,
'children': [
{
'op_type': 'MapDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'input_columns': [
'label'
],
'output_columns': [
None
],
'operations': [
{
'tensor_op_module': 'minddata.transforms.c_transforms',
'tensor_op_name': 'OneHot',
'num_classes': 10
}
],
'children': [
{
'op_type': 'MnistDataset',
'shard_id': None,
'num_shards': None,
'op_module': 'minddata.dataengine.datasets',
'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData',
'num_parallel_workers': None,
'shuffle': None,
'num_samples': 100,
'sampler': {
'sampler_module': 'minddata.dataengine.samplers',
'sampler_name': 'RandomSampler',
'replacement': True,
'num_samples': 100
},
'children': []
}
]
}
]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册