提交 79a61526 编写于 作者: 李鸿章

fix pylint warnings

上级 a82483c0
......@@ -18,13 +18,10 @@ Description: This file is used for some common util.
import os
import shutil
from unittest.mock import Mock
import pytest
from flask import Response
from . import constants
from . import globals as gbl
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager
......@@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.utils import tools
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done
from . import constants
from . import globals as gbl
summaries_metadata = None
mock_data_manager = None
summary_base_dir = constants.SUMMARY_BASE_DIR
......@@ -55,18 +57,21 @@ def init_summary_logs():
os.mkdir(summary_base_dir, mode=mode)
global summaries_metadata, mock_data_manager
log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs(
summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, constants.SUMMARY_DIR_PREFIX)
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST,
constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager)
summaries_metadata.update(log_operations.create_summary_logs(
summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST))
summaries_metadata.update(log_operations.create_multiple_logs(
summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM))
summaries_metadata.update(log_operations.create_reservoir_log(
summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM))
summaries_metadata.update(
log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND,
constants.SUMMARY_DIR_NUM_FIRST))
summaries_metadata.update(
log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME,
constants.MULTIPLE_LOG_NUM))
summaries_metadata.update(
log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME,
constants.RESERVOIR_STEP_NUM))
mock_data_manager.start_load_data(reload_interval=0)
# Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING.
......
......@@ -20,13 +20,13 @@ Usage:
"""
import pytest
from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.conf import settings
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
BASE_URL = '/v1/mindinsight/datavisual/image/metadata'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from .. import globals as gbl
from .....utils.tools import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/image/single-image'
......
......@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual
"""
import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/plugins'
......
......@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual
"""
import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/single-job'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from .. import globals as gbl
from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from .. import globals as gbl
from .....utils.tools import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
......@@ -26,12 +26,12 @@ from unittest import TestCase
import pytest
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \
LineageSearchConditionParamError, LineageFileNotFoundError
from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH
from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError,
LineageParamTypeError, LineageParamValueError,
LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
......
......@@ -22,8 +22,7 @@ import tempfile
import pytest
from ....utils import mindspore
from ....utils.mindspore.dataset.engine.serializer_deserializer import \
SERIALIZED_PIPELINE
from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
sys.modules['mindspore'] = mindspore
......
......@@ -21,14 +21,15 @@ Usage:
from unittest.mock import patch
import pytest
from .conftest import TRAIN_ROUTES
from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from ....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainTask:
"""Test train task api."""
......@@ -36,9 +37,7 @@ class TestTrainTask:
_scalar_log_generator = ScalarsLogGenerator()
_image_log_generator = ImagesLogGenerator()
@pytest.mark.parametrize(
"plugin_name",
['no_plugin_name', 'not_exist_plugin_name'])
@pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name'])
def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name):
"""
Parsing unavailable plugin name to single train task.
......
......@@ -21,14 +21,15 @@ Usage:
from unittest.mock import Mock, patch
import pytest
from .conftest import TRAIN_ROUTES
from ....utils.tools import get_url
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainVisual:
"""Test Train Visual APIs."""
......@@ -95,14 +96,7 @@ class TestTrainVisual:
assert response.status_code == 200
response = response.get_json()
expected_response = {
"metadatas": [{
"height": 224,
"step": 1,
"wall_time": 1572058058.1175,
"width": 448
}]
}
expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]}
assert expected_response == response
def test_single_image_with_params_miss(self, client):
......@@ -254,8 +248,10 @@ class TestTrainVisual:
@patch.object(GraphProcessor, 'get_nodes')
def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client):
"""Test getting graph nodes successfully."""
def mock_get_nodes(name, node_type):
return dict(name=name, node_type=node_type)
mock_graph_processor.side_effect = mock_get_nodes
mock_init = Mock(return_value=None)
......@@ -327,10 +323,7 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. 'offset' should " \
"be greater than or equal to 0."
@pytest.mark.parametrize(
"limit",
[-1, 0, 1001]
)
@pytest.mark.parametrize("limit", [-1, 0, 1001])
@patch.object(GraphProcessor, '__init__')
def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit):
"""Test getting graph node names with invalid limit."""
......@@ -348,14 +341,10 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. " \
"'limit' should in [1, 1000]."
@pytest.mark.parametrize(
" offset, limit",
[(0, 100), (1, 1), (0, 1000)]
)
@pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)])
@patch.object(GraphProcessor, '__init__')
@patch.object(GraphProcessor, 'search_node_names')
def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client,
offset, limit):
def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit):
"""
Parsing unavailable params to get image metadata.
......@@ -367,8 +356,10 @@ class TestTrainVisual:
response status code: 200.
response json: dict, contains search_content, offset, and limit.
"""
def mock_search_node_names(search_content, offset, limit):
return dict(search_content=search_content, offset=int(offset), limit=int(limit))
mock_graph_processor.side_effect = mock_search_node_names
mock_init = Mock(return_value=None)
......@@ -376,15 +367,12 @@ class TestTrainVisual:
test_train_id = "aaa"
test_search_content = "bbb"
params = dict(train_id=test_train_id, search=test_search_content,
offset=offset, limit=limit)
params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit)
url = get_url(TRAIN_ROUTES['graph_nodes_names'], params)
response = client.get(url)
assert response.status_code == 200
results = response.get_json()
assert results == dict(search_content=test_search_content,
offset=int(offset),
limit=int(limit))
assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit))
def test_graph_search_single_node_with_params_is_wrong(self, client):
"""Test searching graph single node with params is wrong."""
......@@ -427,8 +415,10 @@ class TestTrainVisual:
response status code: 200.
response json: name.
"""
def mock_search_single_node(name):
return name
mock_graph_processor.side_effect = mock_search_single_node
mock_init = Mock(return_value=None)
......
......@@ -20,9 +20,7 @@ from unittest import TestCase, mock
from flask import Response
from mindinsight.backend.application import APP
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageQuerySummaryDataError
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError
LINEAGE_FILTRATION_BASE = {
'accuracy': None,
......
......@@ -19,15 +19,16 @@ Usage:
pytest tests/ut/datavisual
"""
from unittest.mock import patch
from werkzeug.exceptions import MethodNotAllowed, NotFound
from ...backend.datavisual.conftest import TRAIN_ROUTES
from ..mock import MockLogger
from ....utils.tools import get_url
from werkzeug.exceptions import MethodNotAllowed, NotFound
from mindinsight.datavisual.processors import scalars_processor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from ...backend.datavisual.conftest import TRAIN_ROUTES
from ..mock import MockLogger
class TestErrorHandler:
"""Test train visual api."""
......
......@@ -22,18 +22,19 @@ import datetime
import os
import shutil
import tempfile
from unittest.mock import patch
import pytest
from ...mock import MockLogger
import pytest
from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator
from mindinsight.utils.exceptions import ParamValueError
from ...mock import MockLogger
class TestDataLoaderGenerator:
"""Test data_loader_generator."""
@classmethod
def setup_class(cls):
data_loader_generator.logger = MockLogger
......@@ -88,8 +89,9 @@ class TestDataLoaderGenerator:
mock_data_loader.return_value = True
loader_dict = generator.generate_loaders(loader_pool=dict())
expected_ids = [summary.get('relative_path')
for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]]
expected_ids = [
summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]
]
assert sorted(loader_dict.keys()) == sorted(expected_ids)
shutil.rmtree(summary_base_dir)
......
......@@ -23,12 +23,13 @@ import shutil
import tempfile
import pytest
from ..mock import MockLogger
from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid
from mindinsight.datavisual.data_transform import data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader
from ..mock import MockLogger
class TestDataLoader:
"""Test data_loader."""
......@@ -37,13 +38,13 @@ class TestDataLoader:
def setup_class(cls):
data_loader.logger = MockLogger
def setup_method(self, method):
def setup_method(self):
self._summary_dir = tempfile.mkdtemp()
if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir)
os.mkdir(self._summary_dir)
def teardown_method(self, method):
def teardown_method(self):
if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir)
......
......@@ -18,32 +18,29 @@ Function:
Usage:
pytest tests/ut/datavisual
"""
import time
import os
import shutil
import tempfile
import time
from unittest import mock
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
from ..mock import MockLogger
from ....utils.tools import check_loading_done
from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \
DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \
MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \
LoaderStruct
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.utils.exceptions import ParamValueError
from ....utils.tools import check_loading_done
from ..mock import MockLogger
class TestDataManager:
"""Test data_manager."""
......@@ -101,11 +98,17 @@ class TestDataManager:
"and loader pool size is '3'."
shutil.rmtree(summary_base_dir)
@pytest.mark.parametrize('params',
[{'reload_interval': '30'},
{'reload_interval': -1},
{'reload_interval': 30, 'max_threads_count': '20'},
{'reload_interval': 30, 'max_threads_count': 0}])
@pytest.mark.parametrize('params', [{
'reload_interval': '30'
}, {
'reload_interval': -1
}, {
'reload_interval': 30,
'max_threads_count': '20'
}, {
'reload_interval': 30,
'max_threads_count': 0
}])
def test_start_load_data_with_invalid_params(self, params):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir = tempfile.mkdtemp()
......
......@@ -22,20 +22,24 @@ import threading
from collections import namedtuple
import pytest
from ..mock import MockLogger
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import events_data
from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor
from ..mock import MockLogger
class MockReservoir:
"""Use this class to replace reservoir.Reservoir in test."""
def __init__(self, size):
self.size = size
self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'),
_Tensor('wall_time3', 3, 'value3')]
self._samples = [
_Tensor('wall_time1', 1, 'value1'),
_Tensor('wall_time2', 2, 'value2'),
_Tensor('wall_time3', 3, 'value3')
]
def samples(self):
"""Replace the samples function."""
......@@ -63,11 +67,12 @@ class TestEventsData:
def setup_method(self):
"""Mock original logger, init a EventsData object for use."""
self._ev_data = EventsData()
self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)],
'plugin_name2': [f'tag{i}' for i in range(20, 30)]}
self._ev_data._tags_by_plugin = {
'plugin_name1': [f'tag{i}' for i in range(10)],
'plugin_name2': [f'tag{i}' for i in range(20, 30)]
}
self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()})
self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500),
'new_tag': MockReservoir(500)}
self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)}
self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)]
def get_ev_data(self):
......@@ -102,8 +107,7 @@ class TestEventsData:
"""Test add_tensor_event success."""
ev_data = self.get_ev_data()
t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1',
value='value1')
t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1')
ev_data.add_tensor_event(t_event)
assert 'tag0' not in ev_data._tags
......@@ -111,6 +115,5 @@ class TestEventsData:
assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1']
assert 'tag0' not in ev_data._reservoir_by_tag
assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1']
assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time,
t_event.step,
assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step,
t_event.value)
......@@ -19,16 +19,17 @@ Usage:
pytest tests/ut/datavisual
"""
import os
import tempfile
import shutil
import tempfile
from unittest.mock import Mock
import pytest
from ..mock import MockLogger
from mindinsight.datavisual.data_transform import ms_data_loader
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from ..mock import MockLogger
# bytes of 3 scalar events
SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*'
b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00'
......@@ -74,7 +75,8 @@ class TestMsDataLoader:
"we will reload all files in path {}.".format(summary_dir)
shutil.rmtree(summary_dir)
def test_load_success_with_crc_pass(self, crc_pass):
@pytest.mark.usefixtures('crc_pass')
def test_load_success_with_crc_pass(self):
"""Test load success."""
summary_dir = tempfile.mkdtemp()
file1 = os.path.join(summary_dir, 'summary.01')
......@@ -88,7 +90,8 @@ class TestMsDataLoader:
tensors = ms_loader.get_events_data().tensors(tag[0])
assert len(tensors) == 3
def test_load_with_crc_fail(self, crc_fail):
@pytest.mark.usefixtures('crc_fail')
def test_load_with_crc_fail(self):
"""Test when crc_fail and will not go to func _event_parse."""
summary_dir = tempfile.mkdtemp()
file2 = os.path.join(summary_dir, 'summary.02')
......@@ -100,8 +103,10 @@ class TestMsDataLoader:
def test_filter_event_files(self):
"""Test filter_event_files function ok."""
file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678',
'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567']
file_list = [
'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786',
'mysummary.123abce', 'summay.4567'
]
summary_dir = tempfile.mkdtemp()
for file in file_list:
with open(os.path.join(summary_dir, file), 'w'):
......@@ -113,6 +118,7 @@ class TestMsDataLoader:
shutil.rmtree(summary_dir)
def write_file(filename, record):
"""Write bytes strings to file."""
with open(filename, 'wb') as file:
......
......@@ -19,18 +19,11 @@ Usage:
pytest tests/ut/datavisual
"""
import os
import json
import tempfile
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, compare_result_with_file
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs
from ..mock import MockLogger
class TestGraphProcessor:
"""Test Graph Processor api."""
......@@ -76,8 +73,7 @@ class TestGraphProcessor:
self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager(
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......@@ -91,27 +87,28 @@ class TestGraphProcessor:
self._train_id = log_dir.replace(summary_base_dir, ".")
log_operation = LogOperations()
self._temp_path, _, _ = log_operation.generate_log(
PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image"))
self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir,
dict(steps=self._steps_list, tag="image"))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager(
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5)
def test_get_nodes_with_not_exist_train_id(self, load_graph_record):
@pytest.mark.usefixtures('load_graph_record')
def test_get_nodes_with_not_exist_train_id(self):
"""Test getting nodes with not exist train id."""
test_train_id = "not_exist_train_id"
with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(test_train_id, self._mock_data_manager)
assert "Can not find the train job in data manager." in exc_info.value.message
@pytest.mark.usefixtures('load_graph_record')
@patch.object(DataManager, 'get_train_job_by_plugin')
def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record):
def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin):
"""Test get nodes with loader is None."""
mock_get_train_job_by_plugin.return_value = None
with pytest.raises(exceptions.SummaryLogPathInvalid):
......@@ -119,15 +116,12 @@ class TestGraphProcessor:
assert mock_get_train_job_by_plugin.called
@pytest.mark.parametrize("name, node_type", [
("not_exist_name", "name_scope"),
("", "polymeric_scope")
])
def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type):
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")])
def test_get_nodes_with_not_exist_name(self, name, node_type):
"""Test getting nodes with not exist name."""
with pytest.raises(ParamValueError) as exc_info:
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
graph_processor.get_nodes(name, node_type)
if name:
......@@ -135,38 +129,33 @@ class TestGraphProcessor:
else:
assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message
@pytest.mark.parametrize("name, node_type, result_file", [
(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'),
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize(
"name, node_type, result_file",
[(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'),
('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'),
('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')
])
def test_get_nodes_success(self, load_graph_record, name, node_type, result_file):
('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')])
def test_get_nodes_success(self, name, node_type, result_file):
"""Test getting nodes successfully."""
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.get_nodes(name, node_type)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.parametrize("search_content, result_file", [
(None, 'test_search_node_names_with_search_content_expected_results1.json'),
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("search_content, result_file",
[(None, 'test_search_node_names_with_search_content_expected_results1.json'),
('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'),
('not_exist_search_content', None)
])
def test_search_node_names_with_search_content(self, load_graph_record,
search_content,
result_file):
('not_exist_search_content', None)])
def test_search_node_names_with_search_content(self, search_content, result_file):
"""Test search node names with search content."""
test_offset = 0
test_limit = 1000
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
results = graph_processor.search_node_names(search_content,
test_offset,
test_limit)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.search_node_names(search_content, test_offset, test_limit)
if search_content == 'not_exist_search_content':
expected_results = {'names': []}
assert results == expected_results
......@@ -174,71 +163,65 @@ class TestGraphProcessor:
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("offset", [-100, -1])
def test_search_node_names_with_negative_offset(self, load_graph_record, offset):
def test_search_node_names_with_negative_offset(self, offset):
"""Test search node names with negative offset."""
test_search_content = ""
test_limit = 3
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, offset, test_limit)
assert "'offset' should be greater than or equal to 0." in exc_info.value.message
@pytest.mark.parametrize("offset, result_file", [
(1, 'test_search_node_names_with_offset_expected_results1.json')
])
def test_search_node_names_with_offset(self, load_graph_record, offset, result_file):
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')])
def test_search_node_names_with_offset(self, offset, result_file):
"""Test search node names with offset."""
test_search_content = "Default/bn1"
test_offset = offset
test_limit = 3
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
results = graph_processor.search_node_names(test_search_content,
test_offset,
test_limit)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.search_node_names(test_search_content, test_offset, test_limit)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_node_names_with_wrong_limit(self, load_graph_record):
@pytest.mark.usefixtures('load_graph_record')
def test_search_node_names_with_wrong_limit(self):
"""Test search node names with wrong limit."""
test_search_content = ""
test_offset = 0
test_limit = 0
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, test_offset,
test_limit)
graph_processor.search_node_names(test_search_content, test_offset, test_limit)
assert "'limit' should in [1, 1000]." in exc_info.value.message
@pytest.mark.parametrize("name, result_file", [
('Default/bn1', 'test_search_single_node_success_expected_results1.json')
])
def test_search_single_node_success(self, load_graph_record, name, result_file):
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("name, result_file",
[('Default/bn1', 'test_search_single_node_success_expected_results1.json')])
def test_search_single_node_success(self, name, result_file):
"""Test searching single node successfully."""
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.search_single_node(name)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_single_node_with_not_exist_name(self, load_graph_record):
@pytest.mark.usefixtures('load_graph_record')
def test_search_single_node_with_not_exist_name(self):
"""Test searching single node with not exist name."""
test_name = "not_exist_name"
with pytest.raises(exceptions.NodeNotInGraphError):
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
graph_processor.search_single_node(test_name)
def test_check_graph_status_no_graph(self, load_no_graph_record):
@pytest.mark.usefixtures('load_no_graph_record')
def test_check_graph_status_no_graph(self):
"""Test checking graph status no graph."""
with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(self._train_id, self._mock_data_manager)
......
......@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock
import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from ..mock import MockLogger
class TestImagesProcessor:
"""Test images processor api."""
......@@ -101,7 +102,8 @@ class TestImagesProcessor:
"""Load image record."""
self._init_data_manager(self._cross_steps_list)
def test_get_metadata_list_with_not_exist_id(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Test getting metadata list with not exist id."""
test_train_id = 'not_exist_id'
image_processor = ImageProcessor(self._mock_data_manager)
......@@ -111,7 +113,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Test get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name'
......@@ -123,7 +126,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_success(self):
"""Test getting metadata list success."""
test_tag_name = self._complete_tag_name
......@@ -132,7 +136,8 @@ class TestImagesProcessor:
assert results == self._images_metadata
def test_get_single_image_with_not_exist_id(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_id(self):
"""Test getting single image with not exist id."""
test_train_id = 'not_exist_id'
test_tag_name = self._complete_tag_name
......@@ -145,7 +150,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_single_image_with_not_exist_tag(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_tag(self):
"""Test getting single image with not exist tag."""
test_tag_name = 'not_exist_tag_name'
test_step = self._steps_list[0]
......@@ -158,7 +164,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_single_image_with_not_exist_step(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_step(self):
"""Test getting single image with not exist step."""
test_tag_name = self._complete_tag_name
test_step = 10000
......@@ -171,7 +178,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find the step with given train job id and tag." in exc_info.value.message
def test_get_single_image_success(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_success(self):
"""Test getting single image successfully."""
test_tag_name = self._complete_tag_name
test_step_index = 0
......@@ -184,7 +192,8 @@ class TestImagesProcessor:
assert recv_image_tensor.any() == expected_image_tensor.any()
def test_reservoir_add_sample(self, load_more_than_limit_image_record):
@pytest.mark.usefixtures('load_more_than_limit_image_record')
def test_reservoir_add_sample(self):
"""Test adding sample in reservoir."""
test_tag_name = self._complete_tag_name
......@@ -201,7 +210,8 @@ class TestImagesProcessor:
cnt += 1
assert len(self._more_steps_list) - cnt == 10
def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record):
@pytest.mark.usefixtures('load_reservoir_remove_sample_image_record')
def test_reservoir_remove_sample(self):
"""
Test removing sample in reservoir.
......
......@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock
import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestScalarsProcessor:
"""Test scalar processor api."""
......@@ -78,7 +79,8 @@ class TestScalarsProcessor:
# wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5)
def test_get_metadata_list_with_not_exist_id(self, load_scalar_record):
@pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Get metadata list with not exist id."""
test_train_id = 'not_exist_id'
scalar_processor = ScalarsProcessor(self._mock_data_manager)
......@@ -88,7 +90,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record):
@pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name'
......@@ -100,7 +103,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_scalar_record):
@pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_success(self):
"""Get metadata list success."""
test_tag_name = self._complete_tag_name
......
......@@ -18,15 +18,11 @@ Function:
Usage:
pytest tests/ut/datavisual
"""
import os
import tempfile
import time
from unittest.mock import Mock
import pytest
from ..mock import MockLogger
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestTrainTaskManager:
"""Test train task manager."""
......@@ -83,10 +83,7 @@ class TestTrainTaskManager:
train_id = dir_path.replace(self._root_dir, ".")
# Pass timestamp to write to the same file.
log_settings = dict(
steps=self._steps_list,
tag=tmp_tag_name,
time=time.time())
log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time())
if i % 3 != 0:
log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings)
self._plugins_id_map['image'].append(train_id)
......@@ -106,7 +103,8 @@ class TestTrainTaskManager:
check_loading_done(self._mock_data_manager, time_limit=30)
def test_get_single_train_task_with_not_exists_train_id(self, load_data):
@pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_not_exists_train_id(self):
"""Test getting single train task with not exists train_id."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members():
......@@ -118,7 +116,8 @@ class TestTrainTaskManager:
"the train job in data manager."
assert exc_info.value.error_code == '50540002'
def test_get_single_train_task_with_params(self, load_data):
@pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_params(self):
"""Test getting single train task with params."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members():
......@@ -132,7 +131,8 @@ class TestTrainTaskManager:
else:
assert test_train_id not in self._plugins_id_map.get(plugin_name)
def test_get_plugins_with_train_id(self, load_data):
@pytest.mark.usefixtures('load_data')
def test_get_plugins_with_train_id(self):
"""Test getting plugins with train id."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
......
......@@ -16,18 +16,16 @@
import os
import shutil
import unittest
from unittest import mock, TestCase
from unittest import TestCase, mock
from unittest.mock import MagicMock
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageLogError, LineageGetModelFileError, MindInsightException
from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
MindInsightException)
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset, Dataset
from mindspore.nn import Optimizer, WithLossCell, TrainOneStepWithLossScaleCell, \
SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep
from mindspore.dataset.engine import Dataset, MindDataset
from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell
from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep
from mindspore.train.summary import SummaryRecord
......
......@@ -15,12 +15,9 @@
"""Test the validate module."""
from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageParamValueError, LineageParamTypeError
from mindinsight.lineagemgr.common.validator.model_parameter import \
SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.validate import \
validate_search_model_condition
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError
from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition
from mindinsight.utils.exceptions import MindInsightException
......
......@@ -15,8 +15,7 @@
"""The event data in querier test."""
import json
from ....utils.mindspore.dataset.engine.serializer_deserializer import \
SERIALIZED_PIPELINE
from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
EVENT_TRAIN_DICT_0 = {
'wall_time': 1581499557.7017336,
......
......@@ -18,12 +18,12 @@ from unittest import TestCase, mock
from google.protobuf.json_format import ParseDict
import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageQuerierParamException, LineageParamTypeError, \
LineageSummaryAnalyzeException, LineageSummaryParseException
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException,
LineageSummaryAnalyzeException,
LineageSummaryParseException)
from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \
LineageInfo
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo
from . import event_data
......
......@@ -15,11 +15,12 @@
"""Test the query_model module."""
from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageEventNotExistException, LineageEventFieldNotExistException
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException,
LineageEventNotExistException)
from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data
from .test_querier import create_lineage_info, create_filtration_result
from .test_querier import create_filtration_result, create_lineage_info
class TestLineageObj(TestCase):
......
......@@ -18,10 +18,11 @@ import os
import time
from google.protobuf import json_format
from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class GraphLogGenerator(LogGenerator):
"""
......
......@@ -18,10 +18,11 @@ import time
import numpy as np
from PIL import Image
from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class ImagesLogGenerator(LogGenerator):
"""
......@@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator):
images_metadata.append(image_metadata)
images_values.update({step: image_tensor})
values = dict(
wall_time=wall_time,
step=step,
image=image_tensor,
tag=tag_name
)
values = dict(wall_time=wall_time, step=step, image=image_tensor, tag=tag_name)
self._write_log_one_step(file_path, values)
......
......@@ -16,10 +16,11 @@
import time
import numpy as np
from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class ScalarsLogGenerator(LogGenerator):
"""
......
......@@ -19,12 +19,12 @@ import json
import os
import time
from mindinsight.datavisual.common.enums import PluginNameEnum
from .log_generators.graph_log_generator import GraphLogGenerator
from .log_generators.images_log_generator import ImagesLogGenerator
from .log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum
log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
......@@ -34,6 +34,7 @@ log_generators = {
class LogOperations:
"""Log Operations."""
def __init__(self):
self._step_num = 3
self._tag_num = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册