diff --git a/tests/st/func/datavisual/conftest.py b/tests/st/func/datavisual/conftest.py index 1cfc2bace865b2de7ee899f63cb093f4369314a7..bc0b2292d52b69b1c14bfc453c5b92f8545734a5 100644 --- a/tests/st/func/datavisual/conftest.py +++ b/tests/st/func/datavisual/conftest.py @@ -18,13 +18,10 @@ Description: This file is used for some common util. import os import shutil from unittest.mock import Mock + import pytest from flask import Response -from . import constants -from . import globals as gbl -from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done from mindinsight.conf import settings from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform.data_manager import DataManager @@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.utils import tools +from ....utils.log_operations import LogOperations +from ....utils.tools import check_loading_done +from . import constants +from . import globals as gbl + summaries_metadata = None mock_data_manager = None summary_base_dir = constants.SUMMARY_BASE_DIR @@ -55,18 +57,21 @@ def init_summary_logs(): os.mkdir(summary_base_dir, mode=mode) global summaries_metadata, mock_data_manager log_operations = LogOperations() - summaries_metadata = log_operations.create_summary_logs( - summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, constants.SUMMARY_DIR_PREFIX) + summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, + constants.SUMMARY_DIR_PREFIX) mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) mock_data_manager.start_load_data(reload_interval=0) check_loading_done(mock_data_manager) - summaries_metadata.update(log_operations.create_summary_logs( - summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST)) - summaries_metadata.update(log_operations.create_multiple_logs( - summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM)) - summaries_metadata.update(log_operations.create_reservoir_log( - summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM)) + summaries_metadata.update( + log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, + constants.SUMMARY_DIR_NUM_FIRST)) + summaries_metadata.update( + log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME, + constants.MULTIPLE_LOG_NUM)) + summaries_metadata.update( + log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME, + constants.RESERVOIR_STEP_NUM)) mock_data_manager.start_load_data(reload_interval=0) # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. diff --git a/tests/st/func/datavisual/image/test_metadata_restful_api.py b/tests/st/func/datavisual/image/test_metadata_restful_api.py index 9fa0a115c031fa4d1d1ec2a0ff524207876e3779..03fa85aeea1ea411bb3f10175d987bc46a8770de 100644 --- a/tests/st/func/datavisual/image/test_metadata_restful_api.py +++ b/tests/st/func/datavisual/image/test_metadata_restful_api.py @@ -20,13 +20,13 @@ Usage: """ import pytest -from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID -from .. import globals as gbl -from .....utils.tools import get_url - from mindinsight.conf import settings from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_url +from .. import globals as gbl +from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID + BASE_URL = '/v1/mindinsight/datavisual/image/metadata' diff --git a/tests/st/func/datavisual/image/test_single_image_restful_api.py b/tests/st/func/datavisual/image/test_single_image_restful_api.py index c40843b05512d277379efdca6b5910a520c571a6..573c5dce268a7a457c15a2fc15698e12a0fbf776 100644 --- a/tests/st/func/datavisual/image/test_single_image_restful_api.py +++ b/tests/st/func/datavisual/image/test_single_image_restful_api.py @@ -20,11 +20,11 @@ Usage: """ import pytest -from .. import globals as gbl -from .....utils.tools import get_url, get_image_tensor_from_bytes - from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_image_tensor_from_bytes, get_url +from .. import globals as gbl + BASE_URL = '/v1/mindinsight/datavisual/image/single-image' diff --git a/tests/st/func/datavisual/scalar/test_metadata_restful_api.py b/tests/st/func/datavisual/scalar/test_metadata_restful_api.py index 0103008e58c13b7677769253594db64cdeb05e7d..c95e26e583633807a3c08f56486f0bcbdf0bd3f9 100644 --- a/tests/st/func/datavisual/scalar/test_metadata_restful_api.py +++ b/tests/st/func/datavisual/scalar/test_metadata_restful_api.py @@ -19,11 +19,12 @@ Usage: pytest tests/st/func/datavisual """ import pytest -from .. import globals as gbl -from .....utils.tools import get_url from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_url +from .. import globals as gbl + BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata' diff --git a/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py b/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py index 825bf35cc051b162ac94a9f27d5bd0b3f576694f..95bb9a256d4d185059eaa560f53e2bd3fea163b9 100644 --- a/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py +++ b/tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py @@ -20,11 +20,11 @@ Usage: """ import pytest -from .. import globals as gbl -from .....utils.tools import get_url - from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_url +from .. import globals as gbl + BASE_URL = '/v1/mindinsight/datavisual/plugins' diff --git a/tests/st/func/datavisual/taskmanager/test_query_single_train_task_restful_api.py b/tests/st/func/datavisual/taskmanager/test_query_single_train_task_restful_api.py index 1f4fcf26c82a395640479fa54059d2a73f42b902..e15d49f6abb2eb15149934fb912048c9fc91e64e 100644 --- a/tests/st/func/datavisual/taskmanager/test_query_single_train_task_restful_api.py +++ b/tests/st/func/datavisual/taskmanager/test_query_single_train_task_restful_api.py @@ -19,11 +19,12 @@ Usage: pytest tests/st/func/datavisual """ import pytest -from .. import globals as gbl -from .....utils.tools import get_url from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_url +from .. import globals as gbl + BASE_URL = '/v1/mindinsight/datavisual/single-job' diff --git a/tests/st/func/datavisual/workflow/test_image_metadata.py b/tests/st/func/datavisual/workflow/test_image_metadata.py index 1a72be7144e750eafe8f6fa9ab6c4bf10b68a4dc..5ba9ca169ca49a057423b62cd627bd0dab84e502 100644 --- a/tests/st/func/datavisual/workflow/test_image_metadata.py +++ b/tests/st/func/datavisual/workflow/test_image_metadata.py @@ -20,11 +20,11 @@ Usage: """ import pytest -from .. import globals as gbl -from .....utils.tools import get_url - from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_url +from .. import globals as gbl + TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' diff --git a/tests/st/func/datavisual/workflow/test_single_image.py b/tests/st/func/datavisual/workflow/test_single_image.py index 163347b4aea5e7c2fdc9522fc8eacce1b4af0f73..4177fffc7d70f47e408e709cafa3c94f5b11573d 100644 --- a/tests/st/func/datavisual/workflow/test_single_image.py +++ b/tests/st/func/datavisual/workflow/test_single_image.py @@ -20,11 +20,11 @@ Usage: """ import pytest -from .. import globals as gbl -from .....utils.tools import get_url, get_image_tensor_from_bytes - from mindinsight.datavisual.common.enums import PluginNameEnum +from .....utils.tools import get_image_tensor_from_bytes, get_url +from .. import globals as gbl + TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' diff --git a/tests/st/func/lineagemgr/api/test_model_api.py b/tests/st/func/lineagemgr/api/test_model_api.py index 4ff8a3769515a6a05a7ccbee4cc74e1cd7426f92..46d5027bd3a525d69af4bcf4e2698dc2f73d9776 100644 --- a/tests/st/func/lineagemgr/api/test_model_api.py +++ b/tests/st/func/lineagemgr/api/test_model_api.py @@ -26,12 +26,12 @@ from unittest import TestCase import pytest -from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage -from mindinsight.lineagemgr.common.exceptions.exceptions import \ - LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \ - LineageSearchConditionParamError, LineageFileNotFoundError -from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH +from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage +from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError, + LineageParamTypeError, LineageParamValueError, + LineageSearchConditionParamError) +from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 LINEAGE_INFO_RUN1 = { 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), diff --git a/tests/st/func/lineagemgr/conftest.py b/tests/st/func/lineagemgr/conftest.py index 801847e1ea53bf223a2707e817b63e5b2e89256d..89ed7b67b88e47e416db9b2799c06f97270d0f4c 100644 --- a/tests/st/func/lineagemgr/conftest.py +++ b/tests/st/func/lineagemgr/conftest.py @@ -22,8 +22,7 @@ import tempfile import pytest from ....utils import mindspore -from ....utils.mindspore.dataset.engine.serializer_deserializer import \ - SERIALIZED_PIPELINE +from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE sys.modules['mindspore'] = mindspore diff --git a/tests/ut/backend/datavisual/test_task_manager_api.py b/tests/ut/backend/datavisual/test_task_manager_api.py index b88f033c348544db008015082e954a2ca3da9d3d..47475fdf27088549b72a702d0530259775b5f25c 100644 --- a/tests/ut/backend/datavisual/test_task_manager_api.py +++ b/tests/ut/backend/datavisual/test_task_manager_api.py @@ -21,14 +21,15 @@ Usage: from unittest.mock import patch import pytest -from .conftest import TRAIN_ROUTES -from ....utils.log_generators.images_log_generator import ImagesLogGenerator -from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator -from ....utils.tools import get_url from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager +from ....utils.log_generators.images_log_generator import ImagesLogGenerator +from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator +from ....utils.tools import get_url +from .conftest import TRAIN_ROUTES + class TestTrainTask: """Test train task api.""" @@ -36,9 +37,7 @@ class TestTrainTask: _scalar_log_generator = ScalarsLogGenerator() _image_log_generator = ImagesLogGenerator() - @pytest.mark.parametrize( - "plugin_name", - ['no_plugin_name', 'not_exist_plugin_name']) + @pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name']) def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name): """ Parsing unavailable plugin name to single train task. diff --git a/tests/ut/backend/datavisual/test_train_visual_api.py b/tests/ut/backend/datavisual/test_train_visual_api.py index 97ddee6bc60af05d3e3a92f3d1e55d048b2ab5a9..831edbf34c2e5175ecad42ba6e8a4fd116628c16 100644 --- a/tests/ut/backend/datavisual/test_train_visual_api.py +++ b/tests/ut/backend/datavisual/test_train_visual_api.py @@ -21,14 +21,15 @@ Usage: from unittest.mock import Mock, patch import pytest -from .conftest import TRAIN_ROUTES -from ....utils.tools import get_url from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor +from ....utils.tools import get_url +from .conftest import TRAIN_ROUTES + class TestTrainVisual: """Test Train Visual APIs.""" @@ -95,14 +96,7 @@ class TestTrainVisual: assert response.status_code == 200 response = response.get_json() - expected_response = { - "metadatas": [{ - "height": 224, - "step": 1, - "wall_time": 1572058058.1175, - "width": 448 - }] - } + expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]} assert expected_response == response def test_single_image_with_params_miss(self, client): @@ -254,8 +248,10 @@ class TestTrainVisual: @patch.object(GraphProcessor, 'get_nodes') def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client): """Test getting graph nodes successfully.""" + def mock_get_nodes(name, node_type): return dict(name=name, node_type=node_type) + mock_graph_processor.side_effect = mock_get_nodes mock_init = Mock(return_value=None) @@ -327,10 +323,7 @@ class TestTrainVisual: assert results['error_msg'] == "Invalid parameter value. 'offset' should " \ "be greater than or equal to 0." - @pytest.mark.parametrize( - "limit", - [-1, 0, 1001] - ) + @pytest.mark.parametrize("limit", [-1, 0, 1001]) @patch.object(GraphProcessor, '__init__') def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit): """Test getting graph node names with invalid limit.""" @@ -348,14 +341,10 @@ class TestTrainVisual: assert results['error_msg'] == "Invalid parameter value. " \ "'limit' should in [1, 1000]." - @pytest.mark.parametrize( - " offset, limit", - [(0, 100), (1, 1), (0, 1000)] - ) + @pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)]) @patch.object(GraphProcessor, '__init__') @patch.object(GraphProcessor, 'search_node_names') - def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, - offset, limit): + def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit): """ Parsing unavailable params to get image metadata. @@ -367,8 +356,10 @@ class TestTrainVisual: response status code: 200. response json: dict, contains search_content, offset, and limit. """ + def mock_search_node_names(search_content, offset, limit): return dict(search_content=search_content, offset=int(offset), limit=int(limit)) + mock_graph_processor.side_effect = mock_search_node_names mock_init = Mock(return_value=None) @@ -376,15 +367,12 @@ class TestTrainVisual: test_train_id = "aaa" test_search_content = "bbb" - params = dict(train_id=test_train_id, search=test_search_content, - offset=offset, limit=limit) + params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit) url = get_url(TRAIN_ROUTES['graph_nodes_names'], params) response = client.get(url) assert response.status_code == 200 results = response.get_json() - assert results == dict(search_content=test_search_content, - offset=int(offset), - limit=int(limit)) + assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit)) def test_graph_search_single_node_with_params_is_wrong(self, client): """Test searching graph single node with params is wrong.""" @@ -427,8 +415,10 @@ class TestTrainVisual: response status code: 200. response json: name. """ + def mock_search_single_node(name): return name + mock_graph_processor.side_effect = mock_search_single_node mock_init = Mock(return_value=None) diff --git a/tests/ut/backend/lineagemgr/test_lineage_api.py b/tests/ut/backend/lineagemgr/test_lineage_api.py index 71df15aa36a08ff882a52d2e30ae1bf9dbf5f6fd..9d68b397a1fa4519d949cd8a61756d08c16505d6 100644 --- a/tests/ut/backend/lineagemgr/test_lineage_api.py +++ b/tests/ut/backend/lineagemgr/test_lineage_api.py @@ -20,9 +20,7 @@ from unittest import TestCase, mock from flask import Response from mindinsight.backend.application import APP -from mindinsight.lineagemgr.common.exceptions.exceptions import \ - LineageQuerySummaryDataError - +from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError LINEAGE_FILTRATION_BASE = { 'accuracy': None, diff --git a/tests/ut/datavisual/common/test_error_handler.py b/tests/ut/datavisual/common/test_error_handler.py index e68c5bbc6721a5ef393bdd04f567f863f9c93e3b..cbe68909bb189fe0a258e97d984aa9dbaf1bfe9b 100644 --- a/tests/ut/datavisual/common/test_error_handler.py +++ b/tests/ut/datavisual/common/test_error_handler.py @@ -19,15 +19,16 @@ Usage: pytest tests/ut/datavisual """ from unittest.mock import patch -from werkzeug.exceptions import MethodNotAllowed, NotFound -from ...backend.datavisual.conftest import TRAIN_ROUTES -from ..mock import MockLogger -from ....utils.tools import get_url +from werkzeug.exceptions import MethodNotAllowed, NotFound from mindinsight.datavisual.processors import scalars_processor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor +from ....utils.tools import get_url +from ...backend.datavisual.conftest import TRAIN_ROUTES +from ..mock import MockLogger + class TestErrorHandler: """Test train visual api.""" diff --git a/tests/ut/datavisual/data_transform/loader_generators/test_data_loader_generator.py b/tests/ut/datavisual/data_transform/loader_generators/test_data_loader_generator.py index 3a3bfe56ca6b03c479820bfe8fce065e895fc07f..1e4ab3c266e125e4c90c8deb0a68d8aac4cab069 100644 --- a/tests/ut/datavisual/data_transform/loader_generators/test_data_loader_generator.py +++ b/tests/ut/datavisual/data_transform/loader_generators/test_data_loader_generator.py @@ -22,18 +22,19 @@ import datetime import os import shutil import tempfile - from unittest.mock import patch -import pytest -from ...mock import MockLogger +import pytest from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator from mindinsight.utils.exceptions import ParamValueError +from ...mock import MockLogger + class TestDataLoaderGenerator: """Test data_loader_generator.""" + @classmethod def setup_class(cls): data_loader_generator.logger = MockLogger @@ -88,8 +89,9 @@ class TestDataLoaderGenerator: mock_data_loader.return_value = True loader_dict = generator.generate_loaders(loader_pool=dict()) - expected_ids = [summary.get('relative_path') - for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]] + expected_ids = [ + summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:] + ] assert sorted(loader_dict.keys()) == sorted(expected_ids) shutil.rmtree(summary_base_dir) diff --git a/tests/ut/datavisual/data_transform/test_data_loader.py b/tests/ut/datavisual/data_transform/test_data_loader.py index 0812a061b0fa44821e80ee5d755cae358923e5e4..5ab023d3b4d08b6b253b4636f203eeb741daa602 100644 --- a/tests/ut/datavisual/data_transform/test_data_loader.py +++ b/tests/ut/datavisual/data_transform/test_data_loader.py @@ -23,12 +23,13 @@ import shutil import tempfile import pytest -from ..mock import MockLogger from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid from mindinsight.datavisual.data_transform import data_loader from mindinsight.datavisual.data_transform.data_loader import DataLoader +from ..mock import MockLogger + class TestDataLoader: """Test data_loader.""" @@ -37,13 +38,13 @@ class TestDataLoader: def setup_class(cls): data_loader.logger = MockLogger - def setup_method(self, method): + def setup_method(self): self._summary_dir = tempfile.mkdtemp() if os.path.exists(self._summary_dir): shutil.rmtree(self._summary_dir) os.mkdir(self._summary_dir) - def teardown_method(self, method): + def teardown_method(self): if os.path.exists(self._summary_dir): shutil.rmtree(self._summary_dir) diff --git a/tests/ut/datavisual/data_transform/test_data_manager.py b/tests/ut/datavisual/data_transform/test_data_manager.py index 65405b031e21c5bd245aad8207268e084e2403c4..0ed02a3df1b8cf4a0b2c91b5685b09fa0b4ef158 100644 --- a/tests/ut/datavisual/data_transform/test_data_manager.py +++ b/tests/ut/datavisual/data_transform/test_data_manager.py @@ -18,32 +18,29 @@ Function: Usage: pytest tests/ut/datavisual """ -import time import os import shutil import tempfile +import time from unittest import mock -from unittest.mock import Mock -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest -from ..mock import MockLogger -from ....utils.tools import check_loading_done from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum from mindinsight.datavisual.data_transform import data_manager, ms_data_loader from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.events_data import EventsData -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \ - DataLoaderGenerator -from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \ - MAX_DATA_LOADER_SIZE -from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \ - LoaderStruct +from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator +from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE +from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.utils.exceptions import ParamValueError +from ....utils.tools import check_loading_done +from ..mock import MockLogger + class TestDataManager: """Test data_manager.""" @@ -101,11 +98,17 @@ class TestDataManager: "and loader pool size is '3'." shutil.rmtree(summary_base_dir) - @pytest.mark.parametrize('params', - [{'reload_interval': '30'}, - {'reload_interval': -1}, - {'reload_interval': 30, 'max_threads_count': '20'}, - {'reload_interval': 30, 'max_threads_count': 0}]) + @pytest.mark.parametrize('params', [{ + 'reload_interval': '30' + }, { + 'reload_interval': -1 + }, { + 'reload_interval': 30, + 'max_threads_count': '20' + }, { + 'reload_interval': 30, + 'max_threads_count': 0 + }]) def test_start_load_data_with_invalid_params(self, params): """Test start_load_data with invalid reload_interval or invalid max_threads_count.""" summary_base_dir = tempfile.mkdtemp() diff --git a/tests/ut/datavisual/data_transform/test_events_data.py b/tests/ut/datavisual/data_transform/test_events_data.py index e984bede0f4745ca9c92ddf064fdf739296d5bd0..cd3167dadf8086c1b16de887fd20fff35f518e1e 100644 --- a/tests/ut/datavisual/data_transform/test_events_data.py +++ b/tests/ut/datavisual/data_transform/test_events_data.py @@ -22,20 +22,24 @@ import threading from collections import namedtuple import pytest -from ..mock import MockLogger from mindinsight.conf import settings from mindinsight.datavisual.data_transform import events_data from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor +from ..mock import MockLogger + class MockReservoir: """Use this class to replace reservoir.Reservoir in test.""" def __init__(self, size): self.size = size - self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'), - _Tensor('wall_time3', 3, 'value3')] + self._samples = [ + _Tensor('wall_time1', 1, 'value1'), + _Tensor('wall_time2', 2, 'value2'), + _Tensor('wall_time3', 3, 'value3') + ] def samples(self): """Replace the samples function.""" @@ -63,11 +67,12 @@ class TestEventsData: def setup_method(self): """Mock original logger, init a EventsData object for use.""" self._ev_data = EventsData() - self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)], - 'plugin_name2': [f'tag{i}' for i in range(20, 30)]} + self._ev_data._tags_by_plugin = { + 'plugin_name1': [f'tag{i}' for i in range(10)], + 'plugin_name2': [f'tag{i}' for i in range(20, 30)] + } self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()}) - self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), - 'new_tag': MockReservoir(500)} + self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)} self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)] def get_ev_data(self): @@ -102,8 +107,7 @@ class TestEventsData: """Test add_tensor_event success.""" ev_data = self.get_ev_data() - t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', - value='value1') + t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1') ev_data.add_tensor_event(t_event) assert 'tag0' not in ev_data._tags @@ -111,6 +115,5 @@ class TestEventsData: assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1'] assert 'tag0' not in ev_data._reservoir_by_tag assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1'] - assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, - t_event.step, + assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step, t_event.value) diff --git a/tests/ut/datavisual/data_transform/test_ms_data_loader.py b/tests/ut/datavisual/data_transform/test_ms_data_loader.py index 796bde61d5480b85e3d6f2d21adbf68637d20bf7..8b63f239d9ceebba84bc47346d51bce72c9ada88 100644 --- a/tests/ut/datavisual/data_transform/test_ms_data_loader.py +++ b/tests/ut/datavisual/data_transform/test_ms_data_loader.py @@ -19,16 +19,17 @@ Usage: pytest tests/ut/datavisual """ import os -import tempfile import shutil +import tempfile from unittest.mock import Mock import pytest -from ..mock import MockLogger from mindinsight.datavisual.data_transform import ms_data_loader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader +from ..mock import MockLogger + # bytes of 3 scalar events SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00' @@ -74,7 +75,8 @@ class TestMsDataLoader: "we will reload all files in path {}.".format(summary_dir) shutil.rmtree(summary_dir) - def test_load_success_with_crc_pass(self, crc_pass): + @pytest.mark.usefixtures('crc_pass') + def test_load_success_with_crc_pass(self): """Test load success.""" summary_dir = tempfile.mkdtemp() file1 = os.path.join(summary_dir, 'summary.01') @@ -88,7 +90,8 @@ class TestMsDataLoader: tensors = ms_loader.get_events_data().tensors(tag[0]) assert len(tensors) == 3 - def test_load_with_crc_fail(self, crc_fail): + @pytest.mark.usefixtures('crc_fail') + def test_load_with_crc_fail(self): """Test when crc_fail and will not go to func _event_parse.""" summary_dir = tempfile.mkdtemp() file2 = os.path.join(summary_dir, 'summary.02') @@ -100,8 +103,10 @@ class TestMsDataLoader: def test_filter_event_files(self): """Test filter_event_files function ok.""" - file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', - 'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567'] + file_list = [ + 'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786', + 'mysummary.123abce', 'summay.4567' + ] summary_dir = tempfile.mkdtemp() for file in file_list: with open(os.path.join(summary_dir, file), 'w'): @@ -113,6 +118,7 @@ class TestMsDataLoader: shutil.rmtree(summary_dir) + def write_file(filename, record): """Write bytes strings to file.""" with open(filename, 'wb') as file: diff --git a/tests/ut/datavisual/processors/test_graph_processor.py b/tests/ut/datavisual/processors/test_graph_processor.py index e2118ea62cdfcefbec47c73ba32b43efa06147f5..1a263571ac2a59f420e0b34429db67261dca42a0 100644 --- a/tests/ut/datavisual/processors/test_graph_processor.py +++ b/tests/ut/datavisual/processors/test_graph_processor.py @@ -19,18 +19,11 @@ Usage: pytest tests/ut/datavisual """ import os -import json import tempfile - -from unittest.mock import Mock -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest -from ..mock import MockLogger -from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs, compare_result_with_file - from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.data_transform import data_manager @@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError +from ....utils.log_operations import LogOperations +from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs +from ..mock import MockLogger + class TestGraphProcessor: """Test Graph Processor api.""" @@ -76,8 +73,7 @@ class TestGraphProcessor: self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir) self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager( - [DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done @@ -91,27 +87,28 @@ class TestGraphProcessor: self._train_id = log_dir.replace(summary_base_dir, ".") log_operation = LogOperations() - self._temp_path, _, _ = log_operation.generate_log( - PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image")) + self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir, + dict(steps=self._steps_list, tag="image")) self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager( - [DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done check_loading_done(self._mock_data_manager, time_limit=5) - def test_get_nodes_with_not_exist_train_id(self, load_graph_record): + @pytest.mark.usefixtures('load_graph_record') + def test_get_nodes_with_not_exist_train_id(self): """Test getting nodes with not exist train id.""" test_train_id = "not_exist_train_id" with pytest.raises(ParamValueError) as exc_info: GraphProcessor(test_train_id, self._mock_data_manager) assert "Can not find the train job in data manager." in exc_info.value.message + @pytest.mark.usefixtures('load_graph_record') @patch.object(DataManager, 'get_train_job_by_plugin') - def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record): + def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin): """Test get nodes with loader is None.""" mock_get_train_job_by_plugin.return_value = None with pytest.raises(exceptions.SummaryLogPathInvalid): @@ -119,15 +116,12 @@ class TestGraphProcessor: assert mock_get_train_job_by_plugin.called - @pytest.mark.parametrize("name, node_type", [ - ("not_exist_name", "name_scope"), - ("", "polymeric_scope") - ]) - def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type): + @pytest.mark.usefixtures('load_graph_record') + @pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")]) + def test_get_nodes_with_not_exist_name(self, name, node_type): """Test getting nodes with not exist name.""" with pytest.raises(ParamValueError) as exc_info: - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) graph_processor.get_nodes(name, node_type) if name: @@ -135,38 +129,33 @@ class TestGraphProcessor: else: assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message - @pytest.mark.parametrize("name, node_type, result_file", [ - (None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), - ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), - ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json') - ]) - def test_get_nodes_success(self, load_graph_record, name, node_type, result_file): + @pytest.mark.usefixtures('load_graph_record') + @pytest.mark.parametrize( + "name, node_type, result_file", + [(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), + ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), + ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')]) + def test_get_nodes_success(self, name, node_type, result_file): """Test getting nodes successfully.""" - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) results = graph_processor.get_nodes(name, node_type) expected_file_path = os.path.join(self.graph_results_dir, result_file) compare_result_with_file(results, expected_file_path) - @pytest.mark.parametrize("search_content, result_file", [ - (None, 'test_search_node_names_with_search_content_expected_results1.json'), - ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), - ('not_exist_search_content', None) - ]) - def test_search_node_names_with_search_content(self, load_graph_record, - search_content, - result_file): + @pytest.mark.usefixtures('load_graph_record') + @pytest.mark.parametrize("search_content, result_file", + [(None, 'test_search_node_names_with_search_content_expected_results1.json'), + ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), + ('not_exist_search_content', None)]) + def test_search_node_names_with_search_content(self, search_content, result_file): """Test search node names with search content.""" test_offset = 0 test_limit = 1000 - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) - results = graph_processor.search_node_names(search_content, - test_offset, - test_limit) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) + results = graph_processor.search_node_names(search_content, test_offset, test_limit) if search_content == 'not_exist_search_content': expected_results = {'names': []} assert results == expected_results @@ -174,71 +163,65 @@ class TestGraphProcessor: expected_file_path = os.path.join(self.graph_results_dir, result_file) compare_result_with_file(results, expected_file_path) + @pytest.mark.usefixtures('load_graph_record') @pytest.mark.parametrize("offset", [-100, -1]) - def test_search_node_names_with_negative_offset(self, load_graph_record, offset): + def test_search_node_names_with_negative_offset(self, offset): """Test search node names with negative offset.""" test_search_content = "" test_limit = 3 - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) with pytest.raises(ParamValueError) as exc_info: graph_processor.search_node_names(test_search_content, offset, test_limit) assert "'offset' should be greater than or equal to 0." in exc_info.value.message - @pytest.mark.parametrize("offset, result_file", [ - (1, 'test_search_node_names_with_offset_expected_results1.json') - ]) - def test_search_node_names_with_offset(self, load_graph_record, offset, result_file): + @pytest.mark.usefixtures('load_graph_record') + @pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')]) + def test_search_node_names_with_offset(self, offset, result_file): """Test search node names with offset.""" test_search_content = "Default/bn1" test_offset = offset test_limit = 3 - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) - results = graph_processor.search_node_names(test_search_content, - test_offset, - test_limit) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) + results = graph_processor.search_node_names(test_search_content, test_offset, test_limit) expected_file_path = os.path.join(self.graph_results_dir, result_file) compare_result_with_file(results, expected_file_path) - def test_search_node_names_with_wrong_limit(self, load_graph_record): + @pytest.mark.usefixtures('load_graph_record') + def test_search_node_names_with_wrong_limit(self): """Test search node names with wrong limit.""" test_search_content = "" test_offset = 0 test_limit = 0 - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) with pytest.raises(ParamValueError) as exc_info: - graph_processor.search_node_names(test_search_content, test_offset, - test_limit) + graph_processor.search_node_names(test_search_content, test_offset, test_limit) assert "'limit' should in [1, 1000]." in exc_info.value.message - @pytest.mark.parametrize("name, result_file", [ - ('Default/bn1', 'test_search_single_node_success_expected_results1.json') - ]) - def test_search_single_node_success(self, load_graph_record, name, result_file): + @pytest.mark.usefixtures('load_graph_record') + @pytest.mark.parametrize("name, result_file", + [('Default/bn1', 'test_search_single_node_success_expected_results1.json')]) + def test_search_single_node_success(self, name, result_file): """Test searching single node successfully.""" - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) results = graph_processor.search_single_node(name) expected_file_path = os.path.join(self.graph_results_dir, result_file) compare_result_with_file(results, expected_file_path) - - def test_search_single_node_with_not_exist_name(self, load_graph_record): + @pytest.mark.usefixtures('load_graph_record') + def test_search_single_node_with_not_exist_name(self): """Test searching single node with not exist name.""" test_name = "not_exist_name" with pytest.raises(exceptions.NodeNotInGraphError): - graph_processor = GraphProcessor(self._train_id, - self._mock_data_manager) + graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) graph_processor.search_single_node(test_name) - def test_check_graph_status_no_graph(self, load_no_graph_record): + @pytest.mark.usefixtures('load_no_graph_record') + def test_check_graph_status_no_graph(self): """Test checking graph status no graph.""" with pytest.raises(ParamValueError) as exc_info: GraphProcessor(self._train_id, self._mock_data_manager) diff --git a/tests/ut/datavisual/processors/test_images_processor.py b/tests/ut/datavisual/processors/test_images_processor.py index 6252ddfd08995d41bd9258a035dd200037912c99..defa711e23676c0377e5619cd8564cd0ef6294f7 100644 --- a/tests/ut/datavisual/processors/test_images_processor.py +++ b/tests/ut/datavisual/processors/test_images_processor.py @@ -22,9 +22,6 @@ import tempfile from unittest.mock import Mock import pytest -from ..mock import MockLogger -from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.data_transform import data_manager @@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError +from ....utils.log_operations import LogOperations +from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes +from ..mock import MockLogger + class TestImagesProcessor: """Test images processor api.""" @@ -101,7 +102,8 @@ class TestImagesProcessor: """Load image record.""" self._init_data_manager(self._cross_steps_list) - def test_get_metadata_list_with_not_exist_id(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_metadata_list_with_not_exist_id(self): """Test getting metadata list with not exist id.""" test_train_id = 'not_exist_id' image_processor = ImageProcessor(self._mock_data_manager) @@ -111,7 +113,8 @@ class TestImagesProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find any data in loader pool about the train job." in exc_info.value.message - def test_get_metadata_list_with_not_exist_tag(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_metadata_list_with_not_exist_tag(self): """Test get metadata list with not exist tag.""" test_tag_name = 'not_exist_tag_name' @@ -123,7 +126,8 @@ class TestImagesProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find any data in this train job by given tag." in exc_info.value.message - def test_get_metadata_list_success(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_metadata_list_success(self): """Test getting metadata list success.""" test_tag_name = self._complete_tag_name @@ -132,7 +136,8 @@ class TestImagesProcessor: assert results == self._images_metadata - def test_get_single_image_with_not_exist_id(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_single_image_with_not_exist_id(self): """Test getting single image with not exist id.""" test_train_id = 'not_exist_id' test_tag_name = self._complete_tag_name @@ -145,7 +150,8 @@ class TestImagesProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find any data in loader pool about the train job." in exc_info.value.message - def test_get_single_image_with_not_exist_tag(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_single_image_with_not_exist_tag(self): """Test getting single image with not exist tag.""" test_tag_name = 'not_exist_tag_name' test_step = self._steps_list[0] @@ -158,7 +164,8 @@ class TestImagesProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find any data in this train job by given tag." in exc_info.value.message - def test_get_single_image_with_not_exist_step(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_single_image_with_not_exist_step(self): """Test getting single image with not exist step.""" test_tag_name = self._complete_tag_name test_step = 10000 @@ -171,7 +178,8 @@ class TestImagesProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find the step with given train job id and tag." in exc_info.value.message - def test_get_single_image_success(self, load_image_record): + @pytest.mark.usefixtures('load_image_record') + def test_get_single_image_success(self): """Test getting single image successfully.""" test_tag_name = self._complete_tag_name test_step_index = 0 @@ -184,7 +192,8 @@ class TestImagesProcessor: assert recv_image_tensor.any() == expected_image_tensor.any() - def test_reservoir_add_sample(self, load_more_than_limit_image_record): + @pytest.mark.usefixtures('load_more_than_limit_image_record') + def test_reservoir_add_sample(self): """Test adding sample in reservoir.""" test_tag_name = self._complete_tag_name @@ -201,7 +210,8 @@ class TestImagesProcessor: cnt += 1 assert len(self._more_steps_list) - cnt == 10 - def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record): + @pytest.mark.usefixtures('load_reservoir_remove_sample_image_record') + def test_reservoir_remove_sample(self): """ Test removing sample in reservoir. diff --git a/tests/ut/datavisual/processors/test_scalars_processor.py b/tests/ut/datavisual/processors/test_scalars_processor.py index 6976cd4709318082f806f158f8ecfa9a2ee6be0a..87c54ee3f0e5de8b3503b77cd2b78649ebe59451 100644 --- a/tests/ut/datavisual/processors/test_scalars_processor.py +++ b/tests/ut/datavisual/processors/test_scalars_processor.py @@ -22,9 +22,6 @@ import tempfile from unittest.mock import Mock import pytest -from ..mock import MockLogger -from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.data_transform import data_manager @@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError +from ....utils.log_operations import LogOperations +from ....utils.tools import check_loading_done, delete_files_or_dirs +from ..mock import MockLogger + class TestScalarsProcessor: """Test scalar processor api.""" @@ -78,7 +79,8 @@ class TestScalarsProcessor: # wait for loading done check_loading_done(self._mock_data_manager, time_limit=5) - def test_get_metadata_list_with_not_exist_id(self, load_scalar_record): + @pytest.mark.usefixtures('load_scalar_record') + def test_get_metadata_list_with_not_exist_id(self): """Get metadata list with not exist id.""" test_train_id = 'not_exist_id' scalar_processor = ScalarsProcessor(self._mock_data_manager) @@ -88,7 +90,8 @@ class TestScalarsProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find any data in loader pool about the train job." in exc_info.value.message - def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record): + @pytest.mark.usefixtures('load_scalar_record') + def test_get_metadata_list_with_not_exist_tag(self): """Get metadata list with not exist tag.""" test_tag_name = 'not_exist_tag_name' @@ -100,7 +103,8 @@ class TestScalarsProcessor: assert exc_info.value.error_code == '50540002' assert "Can not find any data in this train job by given tag." in exc_info.value.message - def test_get_metadata_list_success(self, load_scalar_record): + @pytest.mark.usefixtures('load_scalar_record') + def test_get_metadata_list_success(self): """Get metadata list success.""" test_tag_name = self._complete_tag_name diff --git a/tests/ut/datavisual/processors/test_train_task_manager.py b/tests/ut/datavisual/processors/test_train_task_manager.py index f2c2e8d06a7a03fc48f535b1ad5b2788257ff7c7..ab6f38d8688059ef56102a801068211eb14e2aa5 100644 --- a/tests/ut/datavisual/processors/test_train_task_manager.py +++ b/tests/ut/datavisual/processors/test_train_task_manager.py @@ -18,15 +18,11 @@ Function: Usage: pytest tests/ut/datavisual """ -import os import tempfile import time from unittest.mock import Mock import pytest -from ..mock import MockLogger -from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.data_transform import data_manager @@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError +from ....utils.log_operations import LogOperations +from ....utils.tools import check_loading_done, delete_files_or_dirs +from ..mock import MockLogger + class TestTrainTaskManager: """Test train task manager.""" @@ -83,10 +83,7 @@ class TestTrainTaskManager: train_id = dir_path.replace(self._root_dir, ".") # Pass timestamp to write to the same file. - log_settings = dict( - steps=self._steps_list, - tag=tmp_tag_name, - time=time.time()) + log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time()) if i % 3 != 0: log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) self._plugins_id_map['image'].append(train_id) @@ -106,7 +103,8 @@ class TestTrainTaskManager: check_loading_done(self._mock_data_manager, time_limit=30) - def test_get_single_train_task_with_not_exists_train_id(self, load_data): + @pytest.mark.usefixtures('load_data') + def test_get_single_train_task_with_not_exists_train_id(self): """Test getting single train task with not exists train_id.""" train_task_manager = TrainTaskManager(self._mock_data_manager) for plugin_name in PluginNameEnum.list_members(): @@ -118,7 +116,8 @@ class TestTrainTaskManager: "the train job in data manager." assert exc_info.value.error_code == '50540002' - def test_get_single_train_task_with_params(self, load_data): + @pytest.mark.usefixtures('load_data') + def test_get_single_train_task_with_params(self): """Test getting single train task with params.""" train_task_manager = TrainTaskManager(self._mock_data_manager) for plugin_name in PluginNameEnum.list_members(): @@ -132,7 +131,8 @@ class TestTrainTaskManager: else: assert test_train_id not in self._plugins_id_map.get(plugin_name) - def test_get_plugins_with_train_id(self, load_data): + @pytest.mark.usefixtures('load_data') + def test_get_plugins_with_train_id(self): """Test getting plugins with train id.""" train_task_manager = TrainTaskManager(self._mock_data_manager) diff --git a/tests/ut/lineagemgr/collection/model/test_model_lineage.py b/tests/ut/lineagemgr/collection/model/test_model_lineage.py index b28039ed66714d82e9eb68e5939b187e2bc4c33f..df367705f81cf47a7fd76a23ca284706d35a1a59 100644 --- a/tests/ut/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/ut/lineagemgr/collection/model/test_model_lineage.py @@ -16,18 +16,16 @@ import os import shutil import unittest -from unittest import mock, TestCase +from unittest import TestCase, mock from unittest.mock import MagicMock -from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ - AnalyzeObject -from mindinsight.lineagemgr.common.exceptions.exceptions import \ - LineageLogError, LineageGetModelFileError, MindInsightException +from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage +from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError, + MindInsightException) from mindspore.common.tensor import Tensor -from mindspore.dataset.engine import MindDataset, Dataset -from mindspore.nn import Optimizer, WithLossCell, TrainOneStepWithLossScaleCell, \ - SoftmaxCrossEntropyWithLogits -from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep +from mindspore.dataset.engine import Dataset, MindDataset +from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell +from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep from mindspore.train.summary import SummaryRecord diff --git a/tests/ut/lineagemgr/common/validator/test_validate.py b/tests/ut/lineagemgr/common/validator/test_validate.py index c7b2722afda3e29e876dce6f2861646fdac1b5e0..ed5c76ccc2e86b9936262df4fe8d7a17d75d8554 100644 --- a/tests/ut/lineagemgr/common/validator/test_validate.py +++ b/tests/ut/lineagemgr/common/validator/test_validate.py @@ -15,12 +15,9 @@ """Test the validate module.""" from unittest import TestCase -from mindinsight.lineagemgr.common.exceptions.exceptions import \ - LineageParamValueError, LineageParamTypeError -from mindinsight.lineagemgr.common.validator.model_parameter import \ - SearchModelConditionParameter -from mindinsight.lineagemgr.common.validator.validate import \ - validate_search_model_condition +from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError +from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter +from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition from mindinsight.utils.exceptions import MindInsightException diff --git a/tests/ut/lineagemgr/querier/event_data.py b/tests/ut/lineagemgr/querier/event_data.py index e9d6804412aba1ca40671bd7af805bffece1b880..38a971b5ee5ad25c632f204bf2091a77aa136ce5 100644 --- a/tests/ut/lineagemgr/querier/event_data.py +++ b/tests/ut/lineagemgr/querier/event_data.py @@ -15,8 +15,7 @@ """The event data in querier test.""" import json -from ....utils.mindspore.dataset.engine.serializer_deserializer import \ - SERIALIZED_PIPELINE +from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE EVENT_TRAIN_DICT_0 = { 'wall_time': 1581499557.7017336, diff --git a/tests/ut/lineagemgr/querier/test_querier.py b/tests/ut/lineagemgr/querier/test_querier.py index 5ade13cbdff9a72701a046fdaafc2b9c1a92ab20..5e5d86c460c6b06b03cefc53ac9336be16ea0b06 100644 --- a/tests/ut/lineagemgr/querier/test_querier.py +++ b/tests/ut/lineagemgr/querier/test_querier.py @@ -18,12 +18,12 @@ from unittest import TestCase, mock from google.protobuf.json_format import ParseDict import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 -from mindinsight.lineagemgr.common.exceptions.exceptions import \ - LineageQuerierParamException, LineageParamTypeError, \ - LineageSummaryAnalyzeException, LineageSummaryParseException +from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException, + LineageSummaryAnalyzeException, + LineageSummaryParseException) from mindinsight.lineagemgr.querier.querier import Querier -from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ - LineageInfo +from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo + from . import event_data diff --git a/tests/ut/lineagemgr/querier/test_query_model.py b/tests/ut/lineagemgr/querier/test_query_model.py index 12d566f7daa2a6684b03977f8f650c263e83133e..bcf94ef24a3c3afb22f9d63c73f3ec6568968765 100644 --- a/tests/ut/lineagemgr/querier/test_query_model.py +++ b/tests/ut/lineagemgr/querier/test_query_model.py @@ -15,11 +15,12 @@ """Test the query_model module.""" from unittest import TestCase -from mindinsight.lineagemgr.common.exceptions.exceptions import \ - LineageEventNotExistException, LineageEventFieldNotExistException +from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException, + LineageEventNotExistException) from mindinsight.lineagemgr.querier.query_model import LineageObj + from . import event_data -from .test_querier import create_lineage_info, create_filtration_result +from .test_querier import create_filtration_result, create_lineage_info class TestLineageObj(TestCase): diff --git a/tests/utils/log_generators/graph_log_generator.py b/tests/utils/log_generators/graph_log_generator.py index 5e5632351ca9de9a2d689b10575902fa46237d40..2efc789e31d7cf00ce0197badfa1c27a161bdfb9 100644 --- a/tests/utils/log_generators/graph_log_generator.py +++ b/tests/utils/log_generators/graph_log_generator.py @@ -18,10 +18,11 @@ import os import time from google.protobuf import json_format -from .log_generator import LogGenerator from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 +from .log_generator import LogGenerator + class GraphLogGenerator(LogGenerator): """ diff --git a/tests/utils/log_generators/images_log_generator.py b/tests/utils/log_generators/images_log_generator.py index 98f55def38c95c4547252e12138529b915ee2664..77c44472ba68463272c18ee0497b4d2a6bb379fb 100644 --- a/tests/utils/log_generators/images_log_generator.py +++ b/tests/utils/log_generators/images_log_generator.py @@ -18,10 +18,11 @@ import time import numpy as np from PIL import Image -from .log_generator import LogGenerator from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 +from .log_generator import LogGenerator + class ImagesLogGenerator(LogGenerator): """ @@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator): images_metadata.append(image_metadata) images_values.update({step: image_tensor}) - values = dict( - wall_time=wall_time, - step=step, - image=image_tensor, - tag=tag_name - ) + values = dict(wall_time=wall_time, step=step, image=image_tensor, tag=tag_name) self._write_log_one_step(file_path, values) diff --git a/tests/utils/log_generators/scalars_log_generator.py b/tests/utils/log_generators/scalars_log_generator.py index e49d520e9d4aad27dde44045260767383f7c2138..1e6aeeeabb55548632f62f15769cf36c30ee7f0b 100644 --- a/tests/utils/log_generators/scalars_log_generator.py +++ b/tests/utils/log_generators/scalars_log_generator.py @@ -16,10 +16,11 @@ import time import numpy as np -from .log_generator import LogGenerator from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 +from .log_generator import LogGenerator + class ScalarsLogGenerator(LogGenerator): """ diff --git a/tests/utils/log_operations.py b/tests/utils/log_operations.py index 1e8b01a47c7f2afbb43589cea99a2a743e4e5c41..bc4e9e3850417f7927b196fdf4f77bf1b2e1a41e 100644 --- a/tests/utils/log_operations.py +++ b/tests/utils/log_operations.py @@ -19,12 +19,12 @@ import json import os import time +from mindinsight.datavisual.common.enums import PluginNameEnum + from .log_generators.graph_log_generator import GraphLogGenerator from .log_generators.images_log_generator import ImagesLogGenerator from .log_generators.scalars_log_generator import ScalarsLogGenerator -from mindinsight.datavisual.common.enums import PluginNameEnum - log_generators = { PluginNameEnum.GRAPH.value: GraphLogGenerator(), PluginNameEnum.IMAGE.value: ImagesLogGenerator(), @@ -34,6 +34,7 @@ log_generators = { class LogOperations: """Log Operations.""" + def __init__(self): self._step_num = 3 self._tag_num = 2