提交 b779f3d7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!12 Fix Pylint Warnings

Merge pull request !12 from 李鸿章/fix_pylint
...@@ -18,13 +18,10 @@ Description: This file is used for some common util. ...@@ -18,13 +18,10 @@ Description: This file is used for some common util.
import os import os
import shutil import shutil
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from flask import Response from flask import Response
from . import constants
from . import globals as gbl
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
...@@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat ...@@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.utils import tools from mindinsight.datavisual.utils import tools
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done
from . import constants
from . import globals as gbl
summaries_metadata = None summaries_metadata = None
mock_data_manager = None mock_data_manager = None
summary_base_dir = constants.SUMMARY_BASE_DIR summary_base_dir = constants.SUMMARY_BASE_DIR
...@@ -55,18 +57,21 @@ def init_summary_logs(): ...@@ -55,18 +57,21 @@ def init_summary_logs():
os.mkdir(summary_base_dir, mode=mode) os.mkdir(summary_base_dir, mode=mode)
global summaries_metadata, mock_data_manager global summaries_metadata, mock_data_manager
log_operations = LogOperations() log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs( summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST,
summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, constants.SUMMARY_DIR_PREFIX) constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager) check_loading_done(mock_data_manager)
summaries_metadata.update(log_operations.create_summary_logs( summaries_metadata.update(
summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST)) log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND,
summaries_metadata.update(log_operations.create_multiple_logs( constants.SUMMARY_DIR_NUM_FIRST))
summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM)) summaries_metadata.update(
summaries_metadata.update(log_operations.create_reservoir_log( log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME,
summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM)) constants.MULTIPLE_LOG_NUM))
summaries_metadata.update(
log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME,
constants.RESERVOIR_STEP_NUM))
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
# Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING.
......
...@@ -20,13 +20,13 @@ Usage: ...@@ -20,13 +20,13 @@ Usage:
""" """
import pytest import pytest
from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
BASE_URL = '/v1/mindinsight/datavisual/image/metadata' BASE_URL = '/v1/mindinsight/datavisual/image/metadata'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from .. import globals as gbl
from .....utils.tools import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/image/single-image' BASE_URL = '/v1/mindinsight/datavisual/image/single-image'
......
...@@ -19,11 +19,12 @@ Usage: ...@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import pytest import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata' BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/plugins' BASE_URL = '/v1/mindinsight/datavisual/plugins'
......
...@@ -19,11 +19,12 @@ Usage: ...@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import pytest import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/single-job' BASE_URL = '/v1/mindinsight/datavisual/single-job'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from .. import globals as gbl
from .....utils.tools import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
...@@ -26,12 +26,12 @@ from unittest import TestCase ...@@ -26,12 +26,12 @@ from unittest import TestCase
import pytest import pytest
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError,
LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \ LineageParamTypeError, LineageParamValueError,
LineageSearchConditionParamError, LineageFileNotFoundError LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
LINEAGE_INFO_RUN1 = { LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
......
...@@ -22,8 +22,7 @@ import tempfile ...@@ -22,8 +22,7 @@ import tempfile
import pytest import pytest
from ....utils import mindspore from ....utils import mindspore
from ....utils.mindspore.dataset.engine.serializer_deserializer import \ from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
SERIALIZED_PIPELINE
sys.modules['mindspore'] = mindspore sys.modules['mindspore'] = mindspore
......
...@@ -21,14 +21,15 @@ Usage: ...@@ -21,14 +21,15 @@ Usage:
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from .conftest import TRAIN_ROUTES
from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from ....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainTask: class TestTrainTask:
"""Test train task api.""" """Test train task api."""
...@@ -36,9 +37,7 @@ class TestTrainTask: ...@@ -36,9 +37,7 @@ class TestTrainTask:
_scalar_log_generator = ScalarsLogGenerator() _scalar_log_generator = ScalarsLogGenerator()
_image_log_generator = ImagesLogGenerator() _image_log_generator = ImagesLogGenerator()
@pytest.mark.parametrize( @pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name'])
"plugin_name",
['no_plugin_name', 'not_exist_plugin_name'])
def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name): def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name):
""" """
Parsing unavailable plugin name to single train task. Parsing unavailable plugin name to single train task.
......
...@@ -21,14 +21,15 @@ Usage: ...@@ -21,14 +21,15 @@ Usage:
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from .conftest import TRAIN_ROUTES
from ....utils.tools import get_url
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainVisual: class TestTrainVisual:
"""Test Train Visual APIs.""" """Test Train Visual APIs."""
...@@ -95,14 +96,7 @@ class TestTrainVisual: ...@@ -95,14 +96,7 @@ class TestTrainVisual:
assert response.status_code == 200 assert response.status_code == 200
response = response.get_json() response = response.get_json()
expected_response = { expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]}
"metadatas": [{
"height": 224,
"step": 1,
"wall_time": 1572058058.1175,
"width": 448
}]
}
assert expected_response == response assert expected_response == response
def test_single_image_with_params_miss(self, client): def test_single_image_with_params_miss(self, client):
...@@ -254,8 +248,10 @@ class TestTrainVisual: ...@@ -254,8 +248,10 @@ class TestTrainVisual:
@patch.object(GraphProcessor, 'get_nodes') @patch.object(GraphProcessor, 'get_nodes')
def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client): def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client):
"""Test getting graph nodes successfully.""" """Test getting graph nodes successfully."""
def mock_get_nodes(name, node_type): def mock_get_nodes(name, node_type):
return dict(name=name, node_type=node_type) return dict(name=name, node_type=node_type)
mock_graph_processor.side_effect = mock_get_nodes mock_graph_processor.side_effect = mock_get_nodes
mock_init = Mock(return_value=None) mock_init = Mock(return_value=None)
...@@ -327,10 +323,7 @@ class TestTrainVisual: ...@@ -327,10 +323,7 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. 'offset' should " \ assert results['error_msg'] == "Invalid parameter value. 'offset' should " \
"be greater than or equal to 0." "be greater than or equal to 0."
@pytest.mark.parametrize( @pytest.mark.parametrize("limit", [-1, 0, 1001])
"limit",
[-1, 0, 1001]
)
@patch.object(GraphProcessor, '__init__') @patch.object(GraphProcessor, '__init__')
def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit): def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit):
"""Test getting graph node names with invalid limit.""" """Test getting graph node names with invalid limit."""
...@@ -348,14 +341,10 @@ class TestTrainVisual: ...@@ -348,14 +341,10 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. " \ assert results['error_msg'] == "Invalid parameter value. " \
"'limit' should in [1, 1000]." "'limit' should in [1, 1000]."
@pytest.mark.parametrize( @pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)])
" offset, limit",
[(0, 100), (1, 1), (0, 1000)]
)
@patch.object(GraphProcessor, '__init__') @patch.object(GraphProcessor, '__init__')
@patch.object(GraphProcessor, 'search_node_names') @patch.object(GraphProcessor, 'search_node_names')
def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit):
offset, limit):
""" """
Parsing unavailable params to get image metadata. Parsing unavailable params to get image metadata.
...@@ -367,8 +356,10 @@ class TestTrainVisual: ...@@ -367,8 +356,10 @@ class TestTrainVisual:
response status code: 200. response status code: 200.
response json: dict, contains search_content, offset, and limit. response json: dict, contains search_content, offset, and limit.
""" """
def mock_search_node_names(search_content, offset, limit): def mock_search_node_names(search_content, offset, limit):
return dict(search_content=search_content, offset=int(offset), limit=int(limit)) return dict(search_content=search_content, offset=int(offset), limit=int(limit))
mock_graph_processor.side_effect = mock_search_node_names mock_graph_processor.side_effect = mock_search_node_names
mock_init = Mock(return_value=None) mock_init = Mock(return_value=None)
...@@ -376,15 +367,12 @@ class TestTrainVisual: ...@@ -376,15 +367,12 @@ class TestTrainVisual:
test_train_id = "aaa" test_train_id = "aaa"
test_search_content = "bbb" test_search_content = "bbb"
params = dict(train_id=test_train_id, search=test_search_content, params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit)
offset=offset, limit=limit)
url = get_url(TRAIN_ROUTES['graph_nodes_names'], params) url = get_url(TRAIN_ROUTES['graph_nodes_names'], params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
results = response.get_json() results = response.get_json()
assert results == dict(search_content=test_search_content, assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit))
offset=int(offset),
limit=int(limit))
def test_graph_search_single_node_with_params_is_wrong(self, client): def test_graph_search_single_node_with_params_is_wrong(self, client):
"""Test searching graph single node with params is wrong.""" """Test searching graph single node with params is wrong."""
...@@ -427,8 +415,10 @@ class TestTrainVisual: ...@@ -427,8 +415,10 @@ class TestTrainVisual:
response status code: 200. response status code: 200.
response json: name. response json: name.
""" """
def mock_search_single_node(name): def mock_search_single_node(name):
return name return name
mock_graph_processor.side_effect = mock_search_single_node mock_graph_processor.side_effect = mock_search_single_node
mock_init = Mock(return_value=None) mock_init = Mock(return_value=None)
......
...@@ -20,9 +20,7 @@ from unittest import TestCase, mock ...@@ -20,9 +20,7 @@ from unittest import TestCase, mock
from flask import Response from flask import Response
from mindinsight.backend.application import APP from mindinsight.backend.application import APP
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError
LineageQuerySummaryDataError
LINEAGE_FILTRATION_BASE = { LINEAGE_FILTRATION_BASE = {
'accuracy': None, 'accuracy': None,
......
...@@ -19,15 +19,16 @@ Usage: ...@@ -19,15 +19,16 @@ Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
from unittest.mock import patch from unittest.mock import patch
from werkzeug.exceptions import MethodNotAllowed, NotFound
from ...backend.datavisual.conftest import TRAIN_ROUTES from werkzeug.exceptions import MethodNotAllowed, NotFound
from ..mock import MockLogger
from ....utils.tools import get_url
from mindinsight.datavisual.processors import scalars_processor from mindinsight.datavisual.processors import scalars_processor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from ...backend.datavisual.conftest import TRAIN_ROUTES
from ..mock import MockLogger
class TestErrorHandler: class TestErrorHandler:
"""Test train visual api.""" """Test train visual api."""
......
...@@ -22,18 +22,19 @@ import datetime ...@@ -22,18 +22,19 @@ import datetime
import os import os
import shutil import shutil
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
import pytest
from ...mock import MockLogger import pytest
from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ...mock import MockLogger
class TestDataLoaderGenerator: class TestDataLoaderGenerator:
"""Test data_loader_generator.""" """Test data_loader_generator."""
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
data_loader_generator.logger = MockLogger data_loader_generator.logger = MockLogger
...@@ -88,8 +89,9 @@ class TestDataLoaderGenerator: ...@@ -88,8 +89,9 @@ class TestDataLoaderGenerator:
mock_data_loader.return_value = True mock_data_loader.return_value = True
loader_dict = generator.generate_loaders(loader_pool=dict()) loader_dict = generator.generate_loaders(loader_pool=dict())
expected_ids = [summary.get('relative_path') expected_ids = [
for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]] summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]
]
assert sorted(loader_dict.keys()) == sorted(expected_ids) assert sorted(loader_dict.keys()) == sorted(expected_ids)
shutil.rmtree(summary_base_dir) shutil.rmtree(summary_base_dir)
......
...@@ -23,12 +23,13 @@ import shutil ...@@ -23,12 +23,13 @@ import shutil
import tempfile import tempfile
import pytest import pytest
from ..mock import MockLogger
from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid
from mindinsight.datavisual.data_transform import data_loader from mindinsight.datavisual.data_transform import data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_loader import DataLoader
from ..mock import MockLogger
class TestDataLoader: class TestDataLoader:
"""Test data_loader.""" """Test data_loader."""
...@@ -37,13 +38,13 @@ class TestDataLoader: ...@@ -37,13 +38,13 @@ class TestDataLoader:
def setup_class(cls): def setup_class(cls):
data_loader.logger = MockLogger data_loader.logger = MockLogger
def setup_method(self, method): def setup_method(self):
self._summary_dir = tempfile.mkdtemp() self._summary_dir = tempfile.mkdtemp()
if os.path.exists(self._summary_dir): if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir) shutil.rmtree(self._summary_dir)
os.mkdir(self._summary_dir) os.mkdir(self._summary_dir)
def teardown_method(self, method): def teardown_method(self):
if os.path.exists(self._summary_dir): if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir) shutil.rmtree(self._summary_dir)
......
...@@ -18,32 +18,29 @@ Function: ...@@ -18,32 +18,29 @@ Function:
Usage: Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import time
import os import os
import shutil import shutil
import tempfile import tempfile
import time
from unittest import mock from unittest import mock
from unittest.mock import Mock from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest import pytest
from ..mock import MockLogger
from ....utils.tools import check_loading_done
from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager, ms_data_loader from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.events_data import EventsData from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
DataLoaderGenerator from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \ from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct
MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \
LoaderStruct
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.tools import check_loading_done
from ..mock import MockLogger
class TestDataManager: class TestDataManager:
"""Test data_manager.""" """Test data_manager."""
...@@ -101,11 +98,17 @@ class TestDataManager: ...@@ -101,11 +98,17 @@ class TestDataManager:
"and loader pool size is '3'." "and loader pool size is '3'."
shutil.rmtree(summary_base_dir) shutil.rmtree(summary_base_dir)
@pytest.mark.parametrize('params', @pytest.mark.parametrize('params', [{
[{'reload_interval': '30'}, 'reload_interval': '30'
{'reload_interval': -1}, }, {
{'reload_interval': 30, 'max_threads_count': '20'}, 'reload_interval': -1
{'reload_interval': 30, 'max_threads_count': 0}]) }, {
'reload_interval': 30,
'max_threads_count': '20'
}, {
'reload_interval': 30,
'max_threads_count': 0
}])
def test_start_load_data_with_invalid_params(self, params): def test_start_load_data_with_invalid_params(self, params):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count.""" """Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
......
...@@ -22,20 +22,24 @@ import threading ...@@ -22,20 +22,24 @@ import threading
from collections import namedtuple from collections import namedtuple
import pytest import pytest
from ..mock import MockLogger
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import events_data from mindinsight.datavisual.data_transform import events_data
from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor
from ..mock import MockLogger
class MockReservoir: class MockReservoir:
"""Use this class to replace reservoir.Reservoir in test.""" """Use this class to replace reservoir.Reservoir in test."""
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'), self._samples = [
_Tensor('wall_time3', 3, 'value3')] _Tensor('wall_time1', 1, 'value1'),
_Tensor('wall_time2', 2, 'value2'),
_Tensor('wall_time3', 3, 'value3')
]
def samples(self): def samples(self):
"""Replace the samples function.""" """Replace the samples function."""
...@@ -63,11 +67,12 @@ class TestEventsData: ...@@ -63,11 +67,12 @@ class TestEventsData:
def setup_method(self): def setup_method(self):
"""Mock original logger, init a EventsData object for use.""" """Mock original logger, init a EventsData object for use."""
self._ev_data = EventsData() self._ev_data = EventsData()
self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)], self._ev_data._tags_by_plugin = {
'plugin_name2': [f'tag{i}' for i in range(20, 30)]} 'plugin_name1': [f'tag{i}' for i in range(10)],
'plugin_name2': [f'tag{i}' for i in range(20, 30)]
}
self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()}) self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()})
self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)}
'new_tag': MockReservoir(500)}
self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)] self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)]
def get_ev_data(self): def get_ev_data(self):
...@@ -102,8 +107,7 @@ class TestEventsData: ...@@ -102,8 +107,7 @@ class TestEventsData:
"""Test add_tensor_event success.""" """Test add_tensor_event success."""
ev_data = self.get_ev_data() ev_data = self.get_ev_data()
t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1')
value='value1')
ev_data.add_tensor_event(t_event) ev_data.add_tensor_event(t_event)
assert 'tag0' not in ev_data._tags assert 'tag0' not in ev_data._tags
...@@ -111,6 +115,5 @@ class TestEventsData: ...@@ -111,6 +115,5 @@ class TestEventsData:
assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1'] assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1']
assert 'tag0' not in ev_data._reservoir_by_tag assert 'tag0' not in ev_data._reservoir_by_tag
assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1'] assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1']
assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step,
t_event.step,
t_event.value) t_event.value)
...@@ -19,16 +19,17 @@ Usage: ...@@ -19,16 +19,17 @@ Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import os import os
import tempfile
import shutil import shutil
import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from ..mock import MockLogger
from mindinsight.datavisual.data_transform import ms_data_loader from mindinsight.datavisual.data_transform import ms_data_loader
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from ..mock import MockLogger
# bytes of 3 scalar events # bytes of 3 scalar events
SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*'
b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00' b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00'
...@@ -74,7 +75,8 @@ class TestMsDataLoader: ...@@ -74,7 +75,8 @@ class TestMsDataLoader:
"we will reload all files in path {}.".format(summary_dir) "we will reload all files in path {}.".format(summary_dir)
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
def test_load_success_with_crc_pass(self, crc_pass): @pytest.mark.usefixtures('crc_pass')
def test_load_success_with_crc_pass(self):
"""Test load success.""" """Test load success."""
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
file1 = os.path.join(summary_dir, 'summary.01') file1 = os.path.join(summary_dir, 'summary.01')
...@@ -88,7 +90,8 @@ class TestMsDataLoader: ...@@ -88,7 +90,8 @@ class TestMsDataLoader:
tensors = ms_loader.get_events_data().tensors(tag[0]) tensors = ms_loader.get_events_data().tensors(tag[0])
assert len(tensors) == 3 assert len(tensors) == 3
def test_load_with_crc_fail(self, crc_fail): @pytest.mark.usefixtures('crc_fail')
def test_load_with_crc_fail(self):
"""Test when crc_fail and will not go to func _event_parse.""" """Test when crc_fail and will not go to func _event_parse."""
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
file2 = os.path.join(summary_dir, 'summary.02') file2 = os.path.join(summary_dir, 'summary.02')
...@@ -100,8 +103,10 @@ class TestMsDataLoader: ...@@ -100,8 +103,10 @@ class TestMsDataLoader:
def test_filter_event_files(self): def test_filter_event_files(self):
"""Test filter_event_files function ok.""" """Test filter_event_files function ok."""
file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', file_list = [
'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567'] 'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786',
'mysummary.123abce', 'summay.4567'
]
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
for file in file_list: for file in file_list:
with open(os.path.join(summary_dir, file), 'w'): with open(os.path.join(summary_dir, file), 'w'):
...@@ -113,6 +118,7 @@ class TestMsDataLoader: ...@@ -113,6 +118,7 @@ class TestMsDataLoader:
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
def write_file(filename, record): def write_file(filename, record):
"""Write bytes strings to file.""" """Write bytes strings to file."""
with open(filename, 'wb') as file: with open(filename, 'wb') as file:
......
...@@ -19,18 +19,11 @@ Usage: ...@@ -19,18 +19,11 @@ Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import os import os
import json
import tempfile import tempfile
from unittest.mock import Mock, patch
from unittest.mock import Mock
from unittest.mock import patch
import pytest import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, compare_result_with_file
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor ...@@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs
from ..mock import MockLogger
class TestGraphProcessor: class TestGraphProcessor:
"""Test Graph Processor api.""" """Test Graph Processor api."""
...@@ -76,8 +73,7 @@ class TestGraphProcessor: ...@@ -76,8 +73,7 @@ class TestGraphProcessor:
self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir) self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager( self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
...@@ -91,27 +87,28 @@ class TestGraphProcessor: ...@@ -91,27 +87,28 @@ class TestGraphProcessor:
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
log_operation = LogOperations() log_operation = LogOperations()
self._temp_path, _, _ = log_operation.generate_log( self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir,
PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image")) dict(steps=self._steps_list, tag="image"))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager( self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5) check_loading_done(self._mock_data_manager, time_limit=5)
def test_get_nodes_with_not_exist_train_id(self, load_graph_record): @pytest.mark.usefixtures('load_graph_record')
def test_get_nodes_with_not_exist_train_id(self):
"""Test getting nodes with not exist train id.""" """Test getting nodes with not exist train id."""
test_train_id = "not_exist_train_id" test_train_id = "not_exist_train_id"
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(test_train_id, self._mock_data_manager) GraphProcessor(test_train_id, self._mock_data_manager)
assert "Can not find the train job in data manager." in exc_info.value.message assert "Can not find the train job in data manager." in exc_info.value.message
@pytest.mark.usefixtures('load_graph_record')
@patch.object(DataManager, 'get_train_job_by_plugin') @patch.object(DataManager, 'get_train_job_by_plugin')
def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record): def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin):
"""Test get nodes with loader is None.""" """Test get nodes with loader is None."""
mock_get_train_job_by_plugin.return_value = None mock_get_train_job_by_plugin.return_value = None
with pytest.raises(exceptions.SummaryLogPathInvalid): with pytest.raises(exceptions.SummaryLogPathInvalid):
...@@ -119,15 +116,12 @@ class TestGraphProcessor: ...@@ -119,15 +116,12 @@ class TestGraphProcessor:
assert mock_get_train_job_by_plugin.called assert mock_get_train_job_by_plugin.called
@pytest.mark.parametrize("name, node_type", [ @pytest.mark.usefixtures('load_graph_record')
("not_exist_name", "name_scope"), @pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")])
("", "polymeric_scope") def test_get_nodes_with_not_exist_name(self, name, node_type):
])
def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type):
"""Test getting nodes with not exist name.""" """Test getting nodes with not exist name."""
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
graph_processor.get_nodes(name, node_type) graph_processor.get_nodes(name, node_type)
if name: if name:
...@@ -135,38 +129,33 @@ class TestGraphProcessor: ...@@ -135,38 +129,33 @@ class TestGraphProcessor:
else: else:
assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message
@pytest.mark.parametrize("name, node_type, result_file", [ @pytest.mark.usefixtures('load_graph_record')
(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), @pytest.mark.parametrize(
('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), "name, node_type, result_file",
('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json') [(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'),
]) ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'),
def test_get_nodes_success(self, load_graph_record, name, node_type, result_file): ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')])
def test_get_nodes_success(self, name, node_type, result_file):
"""Test getting nodes successfully.""" """Test getting nodes successfully."""
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
results = graph_processor.get_nodes(name, node_type) results = graph_processor.get_nodes(name, node_type)
expected_file_path = os.path.join(self.graph_results_dir, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path) compare_result_with_file(results, expected_file_path)
@pytest.mark.parametrize("search_content, result_file", [ @pytest.mark.usefixtures('load_graph_record')
(None, 'test_search_node_names_with_search_content_expected_results1.json'), @pytest.mark.parametrize("search_content, result_file",
('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), [(None, 'test_search_node_names_with_search_content_expected_results1.json'),
('not_exist_search_content', None) ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'),
]) ('not_exist_search_content', None)])
def test_search_node_names_with_search_content(self, load_graph_record, def test_search_node_names_with_search_content(self, search_content, result_file):
search_content,
result_file):
"""Test search node names with search content.""" """Test search node names with search content."""
test_offset = 0 test_offset = 0
test_limit = 1000 test_limit = 1000
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager) results = graph_processor.search_node_names(search_content, test_offset, test_limit)
results = graph_processor.search_node_names(search_content,
test_offset,
test_limit)
if search_content == 'not_exist_search_content': if search_content == 'not_exist_search_content':
expected_results = {'names': []} expected_results = {'names': []}
assert results == expected_results assert results == expected_results
...@@ -174,71 +163,65 @@ class TestGraphProcessor: ...@@ -174,71 +163,65 @@ class TestGraphProcessor:
expected_file_path = os.path.join(self.graph_results_dir, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path) compare_result_with_file(results, expected_file_path)
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("offset", [-100, -1]) @pytest.mark.parametrize("offset", [-100, -1])
def test_search_node_names_with_negative_offset(self, load_graph_record, offset): def test_search_node_names_with_negative_offset(self, offset):
"""Test search node names with negative offset.""" """Test search node names with negative offset."""
test_search_content = "" test_search_content = ""
test_limit = 3 test_limit = 3
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, offset, test_limit) graph_processor.search_node_names(test_search_content, offset, test_limit)
assert "'offset' should be greater than or equal to 0." in exc_info.value.message assert "'offset' should be greater than or equal to 0." in exc_info.value.message
@pytest.mark.parametrize("offset, result_file", [ @pytest.mark.usefixtures('load_graph_record')
(1, 'test_search_node_names_with_offset_expected_results1.json') @pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')])
]) def test_search_node_names_with_offset(self, offset, result_file):
def test_search_node_names_with_offset(self, load_graph_record, offset, result_file):
"""Test search node names with offset.""" """Test search node names with offset."""
test_search_content = "Default/bn1" test_search_content = "Default/bn1"
test_offset = offset test_offset = offset
test_limit = 3 test_limit = 3
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager) results = graph_processor.search_node_names(test_search_content, test_offset, test_limit)
results = graph_processor.search_node_names(test_search_content,
test_offset,
test_limit)
expected_file_path = os.path.join(self.graph_results_dir, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path) compare_result_with_file(results, expected_file_path)
def test_search_node_names_with_wrong_limit(self, load_graph_record): @pytest.mark.usefixtures('load_graph_record')
def test_search_node_names_with_wrong_limit(self):
"""Test search node names with wrong limit.""" """Test search node names with wrong limit."""
test_search_content = "" test_search_content = ""
test_offset = 0 test_offset = 0
test_limit = 0 test_limit = 0
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, test_offset, graph_processor.search_node_names(test_search_content, test_offset, test_limit)
test_limit)
assert "'limit' should in [1, 1000]." in exc_info.value.message assert "'limit' should in [1, 1000]." in exc_info.value.message
@pytest.mark.parametrize("name, result_file", [ @pytest.mark.usefixtures('load_graph_record')
('Default/bn1', 'test_search_single_node_success_expected_results1.json') @pytest.mark.parametrize("name, result_file",
]) [('Default/bn1', 'test_search_single_node_success_expected_results1.json')])
def test_search_single_node_success(self, load_graph_record, name, result_file): def test_search_single_node_success(self, name, result_file):
"""Test searching single node successfully.""" """Test searching single node successfully."""
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
results = graph_processor.search_single_node(name) results = graph_processor.search_single_node(name)
expected_file_path = os.path.join(self.graph_results_dir, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path) compare_result_with_file(results, expected_file_path)
@pytest.mark.usefixtures('load_graph_record')
def test_search_single_node_with_not_exist_name(self, load_graph_record): def test_search_single_node_with_not_exist_name(self):
"""Test searching single node with not exist name.""" """Test searching single node with not exist name."""
test_name = "not_exist_name" test_name = "not_exist_name"
with pytest.raises(exceptions.NodeNotInGraphError): with pytest.raises(exceptions.NodeNotInGraphError):
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
graph_processor.search_single_node(test_name) graph_processor.search_single_node(test_name)
def test_check_graph_status_no_graph(self, load_no_graph_record): @pytest.mark.usefixtures('load_no_graph_record')
def test_check_graph_status_no_graph(self):
"""Test checking graph status no graph.""" """Test checking graph status no graph."""
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(self._train_id, self._mock_data_manager) GraphProcessor(self._train_id, self._mock_data_manager)
......
...@@ -22,9 +22,6 @@ import tempfile ...@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor ...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from ..mock import MockLogger
class TestImagesProcessor: class TestImagesProcessor:
"""Test images processor api.""" """Test images processor api."""
...@@ -101,7 +102,8 @@ class TestImagesProcessor: ...@@ -101,7 +102,8 @@ class TestImagesProcessor:
"""Load image record.""" """Load image record."""
self._init_data_manager(self._cross_steps_list) self._init_data_manager(self._cross_steps_list)
def test_get_metadata_list_with_not_exist_id(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Test getting metadata list with not exist id.""" """Test getting metadata list with not exist id."""
test_train_id = 'not_exist_id' test_train_id = 'not_exist_id'
image_processor = ImageProcessor(self._mock_data_manager) image_processor = ImageProcessor(self._mock_data_manager)
...@@ -111,7 +113,8 @@ class TestImagesProcessor: ...@@ -111,7 +113,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Test get metadata list with not exist tag.""" """Test get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name' test_tag_name = 'not_exist_tag_name'
...@@ -123,7 +126,8 @@ class TestImagesProcessor: ...@@ -123,7 +126,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_success(self):
"""Test getting metadata list success.""" """Test getting metadata list success."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
...@@ -132,7 +136,8 @@ class TestImagesProcessor: ...@@ -132,7 +136,8 @@ class TestImagesProcessor:
assert results == self._images_metadata assert results == self._images_metadata
def test_get_single_image_with_not_exist_id(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_id(self):
"""Test getting single image with not exist id.""" """Test getting single image with not exist id."""
test_train_id = 'not_exist_id' test_train_id = 'not_exist_id'
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
...@@ -145,7 +150,8 @@ class TestImagesProcessor: ...@@ -145,7 +150,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_single_image_with_not_exist_tag(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_tag(self):
"""Test getting single image with not exist tag.""" """Test getting single image with not exist tag."""
test_tag_name = 'not_exist_tag_name' test_tag_name = 'not_exist_tag_name'
test_step = self._steps_list[0] test_step = self._steps_list[0]
...@@ -158,7 +164,8 @@ class TestImagesProcessor: ...@@ -158,7 +164,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_single_image_with_not_exist_step(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_step(self):
"""Test getting single image with not exist step.""" """Test getting single image with not exist step."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
test_step = 10000 test_step = 10000
...@@ -171,7 +178,8 @@ class TestImagesProcessor: ...@@ -171,7 +178,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find the step with given train job id and tag." in exc_info.value.message assert "Can not find the step with given train job id and tag." in exc_info.value.message
def test_get_single_image_success(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_success(self):
"""Test getting single image successfully.""" """Test getting single image successfully."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
test_step_index = 0 test_step_index = 0
...@@ -184,7 +192,8 @@ class TestImagesProcessor: ...@@ -184,7 +192,8 @@ class TestImagesProcessor:
assert recv_image_tensor.any() == expected_image_tensor.any() assert recv_image_tensor.any() == expected_image_tensor.any()
def test_reservoir_add_sample(self, load_more_than_limit_image_record): @pytest.mark.usefixtures('load_more_than_limit_image_record')
def test_reservoir_add_sample(self):
"""Test adding sample in reservoir.""" """Test adding sample in reservoir."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
...@@ -201,7 +210,8 @@ class TestImagesProcessor: ...@@ -201,7 +210,8 @@ class TestImagesProcessor:
cnt += 1 cnt += 1
assert len(self._more_steps_list) - cnt == 10 assert len(self._more_steps_list) - cnt == 10
def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record): @pytest.mark.usefixtures('load_reservoir_remove_sample_image_record')
def test_reservoir_remove_sample(self):
""" """
Test removing sample in reservoir. Test removing sample in reservoir.
......
...@@ -22,9 +22,6 @@ import tempfile ...@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor ...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestScalarsProcessor: class TestScalarsProcessor:
"""Test scalar processor api.""" """Test scalar processor api."""
...@@ -78,7 +79,8 @@ class TestScalarsProcessor: ...@@ -78,7 +79,8 @@ class TestScalarsProcessor:
# wait for loading done # wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5) check_loading_done(self._mock_data_manager, time_limit=5)
def test_get_metadata_list_with_not_exist_id(self, load_scalar_record): @pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Get metadata list with not exist id.""" """Get metadata list with not exist id."""
test_train_id = 'not_exist_id' test_train_id = 'not_exist_id'
scalar_processor = ScalarsProcessor(self._mock_data_manager) scalar_processor = ScalarsProcessor(self._mock_data_manager)
...@@ -88,7 +90,8 @@ class TestScalarsProcessor: ...@@ -88,7 +90,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record): @pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Get metadata list with not exist tag.""" """Get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name' test_tag_name = 'not_exist_tag_name'
...@@ -100,7 +103,8 @@ class TestScalarsProcessor: ...@@ -100,7 +103,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_scalar_record): @pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_success(self):
"""Get metadata list success.""" """Get metadata list success."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
......
...@@ -18,15 +18,11 @@ Function: ...@@ -18,15 +18,11 @@ Function:
Usage: Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import os
import tempfile import tempfile
import time import time
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage ...@@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestTrainTaskManager: class TestTrainTaskManager:
"""Test train task manager.""" """Test train task manager."""
...@@ -83,10 +83,7 @@ class TestTrainTaskManager: ...@@ -83,10 +83,7 @@ class TestTrainTaskManager:
train_id = dir_path.replace(self._root_dir, ".") train_id = dir_path.replace(self._root_dir, ".")
# Pass timestamp to write to the same file. # Pass timestamp to write to the same file.
log_settings = dict( log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time())
steps=self._steps_list,
tag=tmp_tag_name,
time=time.time())
if i % 3 != 0: if i % 3 != 0:
log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings)
self._plugins_id_map['image'].append(train_id) self._plugins_id_map['image'].append(train_id)
...@@ -106,7 +103,8 @@ class TestTrainTaskManager: ...@@ -106,7 +103,8 @@ class TestTrainTaskManager:
check_loading_done(self._mock_data_manager, time_limit=30) check_loading_done(self._mock_data_manager, time_limit=30)
def test_get_single_train_task_with_not_exists_train_id(self, load_data): @pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_not_exists_train_id(self):
"""Test getting single train task with not exists train_id.""" """Test getting single train task with not exists train_id."""
train_task_manager = TrainTaskManager(self._mock_data_manager) train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members(): for plugin_name in PluginNameEnum.list_members():
...@@ -118,7 +116,8 @@ class TestTrainTaskManager: ...@@ -118,7 +116,8 @@ class TestTrainTaskManager:
"the train job in data manager." "the train job in data manager."
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
def test_get_single_train_task_with_params(self, load_data): @pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_params(self):
"""Test getting single train task with params.""" """Test getting single train task with params."""
train_task_manager = TrainTaskManager(self._mock_data_manager) train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members(): for plugin_name in PluginNameEnum.list_members():
...@@ -132,7 +131,8 @@ class TestTrainTaskManager: ...@@ -132,7 +131,8 @@ class TestTrainTaskManager:
else: else:
assert test_train_id not in self._plugins_id_map.get(plugin_name) assert test_train_id not in self._plugins_id_map.get(plugin_name)
def test_get_plugins_with_train_id(self, load_data): @pytest.mark.usefixtures('load_data')
def test_get_plugins_with_train_id(self):
"""Test getting plugins with train id.""" """Test getting plugins with train id."""
train_task_manager = TrainTaskManager(self._mock_data_manager) train_task_manager = TrainTaskManager(self._mock_data_manager)
......
...@@ -16,18 +16,16 @@ ...@@ -16,18 +16,16 @@
import os import os
import shutil import shutil
import unittest import unittest
from unittest import mock, TestCase from unittest import TestCase, mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage
AnalyzeObject from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
from mindinsight.lineagemgr.common.exceptions.exceptions import \ MindInsightException)
LineageLogError, LineageGetModelFileError, MindInsightException
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset, Dataset from mindspore.dataset.engine import Dataset, MindDataset
from mindspore.nn import Optimizer, WithLossCell, TrainOneStepWithLossScaleCell, \ from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell
SoftmaxCrossEntropyWithLogits from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep
from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep
from mindspore.train.summary import SummaryRecord from mindspore.train.summary import SummaryRecord
......
...@@ -15,12 +15,9 @@ ...@@ -15,12 +15,9 @@
"""Test the validate module.""" """Test the validate module."""
from unittest import TestCase from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError
LineageParamValueError, LineageParamTypeError from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.model_parameter import \ from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition
SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.validate import \
validate_search_model_condition
from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import MindInsightException
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
"""The event data in querier test.""" """The event data in querier test."""
import json import json
from ....utils.mindspore.dataset.engine.serializer_deserializer import \ from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
SERIALIZED_PIPELINE
EVENT_TRAIN_DICT_0 = { EVENT_TRAIN_DICT_0 = {
'wall_time': 1581499557.7017336, 'wall_time': 1581499557.7017336,
......
...@@ -18,12 +18,12 @@ from unittest import TestCase, mock ...@@ -18,12 +18,12 @@ from unittest import TestCase, mock
from google.protobuf.json_format import ParseDict from google.protobuf.json_format import ParseDict
import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException,
LineageQuerierParamException, LineageParamTypeError, \ LineageSummaryAnalyzeException,
LineageSummaryAnalyzeException, LineageSummaryParseException LineageSummaryParseException)
from mindinsight.lineagemgr.querier.querier import Querier from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo
LineageInfo
from . import event_data from . import event_data
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
"""Test the query_model module.""" """Test the query_model module."""
from unittest import TestCase from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException,
LineageEventNotExistException, LineageEventFieldNotExistException LineageEventNotExistException)
from mindinsight.lineagemgr.querier.query_model import LineageObj from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data from . import event_data
from .test_querier import create_lineage_info, create_filtration_result from .test_querier import create_filtration_result, create_lineage_info
class TestLineageObj(TestCase): class TestLineageObj(TestCase):
......
...@@ -18,10 +18,11 @@ import os ...@@ -18,10 +18,11 @@ import os
import time import time
from google.protobuf import json_format from google.protobuf import json_format
from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class GraphLogGenerator(LogGenerator): class GraphLogGenerator(LogGenerator):
""" """
......
...@@ -18,10 +18,11 @@ import time ...@@ -18,10 +18,11 @@ import time
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class ImagesLogGenerator(LogGenerator): class ImagesLogGenerator(LogGenerator):
""" """
...@@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator): ...@@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator):
images_metadata.append(image_metadata) images_metadata.append(image_metadata)
images_values.update({step: image_tensor}) images_values.update({step: image_tensor})
values = dict( values = dict(wall_time=wall_time, step=step, image=image_tensor, tag=tag_name)
wall_time=wall_time,
step=step,
image=image_tensor,
tag=tag_name
)
self._write_log_one_step(file_path, values) self._write_log_one_step(file_path, values)
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
import time import time
import numpy as np import numpy as np
from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class ScalarsLogGenerator(LogGenerator): class ScalarsLogGenerator(LogGenerator):
""" """
......
...@@ -19,12 +19,12 @@ import json ...@@ -19,12 +19,12 @@ import json
import os import os
import time import time
from mindinsight.datavisual.common.enums import PluginNameEnum
from .log_generators.graph_log_generator import GraphLogGenerator from .log_generators.graph_log_generator import GraphLogGenerator
from .log_generators.images_log_generator import ImagesLogGenerator from .log_generators.images_log_generator import ImagesLogGenerator
from .log_generators.scalars_log_generator import ScalarsLogGenerator from .log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum
log_generators = { log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(), PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(), PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
...@@ -34,6 +34,7 @@ log_generators = { ...@@ -34,6 +34,7 @@ log_generators = {
class LogOperations: class LogOperations:
"""Log Operations.""" """Log Operations."""
def __init__(self): def __init__(self):
self._step_num = 3 self._step_num = 3
self._tag_num = 2 self._tag_num = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册