提交 2b670b15 编写于 作者: L luopengting

move log operations to tests.utils, change import method to relative path

上级 e7a0496e
...@@ -21,10 +21,10 @@ from unittest.mock import Mock ...@@ -21,10 +21,10 @@ from unittest.mock import Mock
import pytest import pytest
from flask import Response from flask import Response
from tests.st.func.datavisual import constants from . import constants
from tests.st.func.datavisual.utils.log_operations import LogOperations from . import globals as gbl
from tests.st.func.datavisual.utils.utils import check_loading_done from ....utils.log_operations import LogOperations
from tests.st.func.datavisual.utils import globals as gbl from ....utils.tools import check_loading_done
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
...@@ -55,7 +55,8 @@ def init_summary_logs(): ...@@ -55,7 +55,8 @@ def init_summary_logs():
os.mkdir(summary_base_dir, mode=mode) os.mkdir(summary_base_dir, mode=mode)
global summaries_metadata, mock_data_manager global summaries_metadata, mock_data_manager
log_operations = LogOperations() log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST) summaries_metadata = log_operations.create_summary_logs(
summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager) check_loading_done(mock_data_manager)
...@@ -73,7 +74,7 @@ def init_summary_logs(): ...@@ -73,7 +74,7 @@ def init_summary_logs():
# Maximum number of loads is `MAX_DATA_LOADER_SIZE`. # Maximum number of loads is `MAX_DATA_LOADER_SIZE`.
for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE): for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE):
summaries_metadata.pop("./%s%d" % (constants.SUMMARY_PREFIX, i)) summaries_metadata.pop("./%s%d" % (constants.SUMMARY_DIR_PREFIX, i))
yield yield
finally: finally:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import tempfile import tempfile
SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name
SUMMARY_PREFIX = "summary" SUMMARY_DIR_PREFIX = "summary"
SUMMARY_DIR_NUM_FIRST = 5 SUMMARY_DIR_NUM_FIRST = 5
SUMMARY_DIR_NUM_SECOND = 11 SUMMARY_DIR_NUM_SECOND = 11
......
...@@ -22,8 +22,8 @@ import os ...@@ -22,8 +22,8 @@ import os
import json import json
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes'
......
...@@ -23,8 +23,8 @@ import json ...@@ -23,8 +23,8 @@ import json
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node'
......
...@@ -22,13 +22,14 @@ import os ...@@ -22,13 +22,14 @@ import os
import json import json
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names'
class TestSearchNodes: class TestSearchNodes:
"""Test search nodes restful APIs.""" """Test searching nodes restful APIs."""
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
......
...@@ -20,9 +20,9 @@ Usage: ...@@ -20,9 +20,9 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -20,8 +20,8 @@ Usage: ...@@ -20,8 +20,8 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes from .....utils.tools import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -19,8 +19,8 @@ Usage: ...@@ -19,8 +19,8 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -20,8 +20,8 @@ Usage: ...@@ -20,8 +20,8 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -19,8 +19,8 @@ Usage: ...@@ -19,8 +19,8 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -20,8 +20,8 @@ Usage: ...@@ -20,8 +20,8 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.constants import SUMMARY_DIR_NUM from ..constants import SUMMARY_DIR_NUM
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
BASE_URL = '/v1/mindinsight/datavisual/train-jobs' BASE_URL = '/v1/mindinsight/datavisual/train-jobs'
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for graph."""
import json
import os
import time
from google.protobuf import json_format
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class GraphLogGenerator(LogGenerator):
"""
Log generator for graph.
This is a log generator writing graph. User can use it to generate fake
summary logs about graph.
"""
def generate_log(self, file_path, graph_dict):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
graph_dict (dict): A dict consists of graph node information.
Returns:
dict, generated scalar metadata.
"""
graph_event = self.generate_event(dict(graph=graph_dict))
self._write_log_from_event(file_path, graph_event)
return graph_dict
def generate_event(self, values):
"""
Method for generating graph event.
Args:
values (dict): Graph values. e.g. {'graph': graph_dict}.
Returns:
summary_pb2.Event.
"""
graph_json = {
'wall_time': time.time(),
'graph_def': values.get('graph'),
}
graph_event = json_format.Parse(json.dumps(graph_json), summary_pb2.Event())
return graph_event
if __name__ == "__main__":
graph_log_generator = GraphLogGenerator()
test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time()))
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json")
with open(graph_base_path, 'r') as load_f:
graph = json.load(load_f)
graph_log_generator.generate_log(test_file_name, graph)
...@@ -20,8 +20,8 @@ Usage: ...@@ -20,8 +20,8 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -20,8 +20,8 @@ Usage: ...@@ -20,8 +20,8 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes from .....utils.tools import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
......
...@@ -21,10 +21,10 @@ Usage: ...@@ -21,10 +21,10 @@ Usage:
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES from .conftest import TRAIN_ROUTES
from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from tests.ut.datavisual.utils.utils import get_url from ....utils.tools import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
......
...@@ -21,8 +21,8 @@ Usage: ...@@ -21,8 +21,8 @@ Usage:
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES from .conftest import TRAIN_ROUTES
from tests.ut.datavisual.utils.utils import get_url from ....utils.tools import get_url
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.processors.graph_processor import GraphProcessor
......
...@@ -21,9 +21,9 @@ Usage: ...@@ -21,9 +21,9 @@ Usage:
from unittest.mock import patch from unittest.mock import patch
from werkzeug.exceptions import MethodNotAllowed, NotFound from werkzeug.exceptions import MethodNotAllowed, NotFound
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES from ...backend.datavisual.conftest import TRAIN_ROUTES
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from tests.ut.datavisual.utils.utils import get_url from ....utils.tools import get_url
from mindinsight.datavisual.processors import scalars_processor from mindinsight.datavisual.processors import scalars_processor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
""" """
Function: Function:
Test mindinsight.datavisual.data_transform.log_generators.data_loader_generator Test mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
Usage: Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
...@@ -26,7 +26,7 @@ import tempfile ...@@ -26,7 +26,7 @@ import tempfile
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ...mock import MockLogger
from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
......
...@@ -23,7 +23,7 @@ import shutil ...@@ -23,7 +23,7 @@ import shutil
import tempfile import tempfile
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid
from mindinsight.datavisual.data_transform import data_loader from mindinsight.datavisual.data_transform import data_loader
......
...@@ -27,8 +27,8 @@ from unittest.mock import Mock ...@@ -27,8 +27,8 @@ from unittest.mock import Mock
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from tests.ut.datavisual.utils.utils import check_loading_done from ....utils.tools import check_loading_done
from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager, ms_data_loader from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
......
...@@ -22,7 +22,7 @@ import threading ...@@ -22,7 +22,7 @@ import threading
from collections import namedtuple from collections import namedtuple
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import events_data from mindinsight.datavisual.data_transform import events_data
......
...@@ -24,7 +24,7 @@ import shutil ...@@ -24,7 +24,7 @@ import shutil
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from mindinsight.datavisual.data_transform import ms_data_loader from mindinsight.datavisual.data_transform import ms_data_loader
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
......
...@@ -27,9 +27,9 @@ from unittest.mock import patch ...@@ -27,9 +27,9 @@ from unittest.mock import patch
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations from ....utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
...@@ -70,14 +70,10 @@ class TestGraphProcessor: ...@@ -70,14 +70,10 @@ class TestGraphProcessor:
"""Load graph record.""" """Load graph record."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
graph_base_path = os.path.join(os.path.dirname(__file__), log_operation = LogOperations()
os.pardir, "utils", "log_generators", "graph_base.json") self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._temp_path, self._graph_dict = LogOperations.generate_log(
PluginNameEnum.GRAPH.value, log_dir, dict(graph_base_path=graph_base_path))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager( self._mock_data_manager = data_manager.DataManager(
...@@ -94,7 +90,8 @@ class TestGraphProcessor: ...@@ -94,7 +90,8 @@ class TestGraphProcessor:
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, _, _ = LogOperations.generate_log( log_operation = LogOperations()
self._temp_path, _, _ = log_operation.generate_log(
PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image")) PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image"))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
......
...@@ -22,9 +22,9 @@ import tempfile ...@@ -22,9 +22,9 @@ import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations from ....utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -73,12 +73,11 @@ class TestImagesProcessor: ...@@ -73,12 +73,11 @@ class TestImagesProcessor:
""" """
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, self._images_metadata, self._images_values = LogOperations.generate_log( log_operation = LogOperations()
self._temp_path, self._images_metadata, self._images_values = log_operation.generate_log(
PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
...@@ -178,14 +177,10 @@ class TestImagesProcessor: ...@@ -178,14 +177,10 @@ class TestImagesProcessor:
test_step_index = 0 test_step_index = 0
test_step = self._steps_list[test_step_index] test_step = self._steps_list[test_step_index]
expected_image_tensor = self._images_values.get(test_step)
image_processor = ImageProcessor(self._mock_data_manager) image_processor = ImageProcessor(self._mock_data_manager)
results = image_processor.get_single_image(self._train_id, test_tag_name, test_step) results = image_processor.get_single_image(self._train_id, test_tag_name, test_step)
recv_image_tensor = get_image_tensor_from_bytes(results)
expected_image_tensor = self._images_values.get(test_step)
image_generator = LogOperations.get_log_generator(PluginNameEnum.IMAGE.value)
recv_image_tensor = image_generator.get_image_tensor_from_bytes(results)
assert recv_image_tensor.any() == expected_image_tensor.any() assert recv_image_tensor.any() == expected_image_tensor.any()
......
...@@ -22,9 +22,9 @@ import tempfile ...@@ -22,9 +22,9 @@ import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations from ....utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -65,12 +65,11 @@ class TestScalarsProcessor: ...@@ -65,12 +65,11 @@ class TestScalarsProcessor:
"""Load scalar record.""" """Load scalar record."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, self._scalars_metadata, self._scalars_values = LogOperations.generate_log( log_operation = LogOperations()
self._temp_path, self._scalars_metadata, self._scalars_values = log_operation.generate_log(
PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
......
...@@ -24,9 +24,9 @@ import time ...@@ -24,9 +24,9 @@ import time
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger from ..mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations from ....utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs from ....utils.tools import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -70,17 +70,14 @@ class TestTrainTaskManager: ...@@ -70,17 +70,14 @@ class TestTrainTaskManager:
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
def load_data(self): def load_data(self):
"""Load data.""" """Load data."""
log_operation = LogOperations()
self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []} self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []}
self._events_names = [] self._events_names = []
self._train_id_list = [] self._train_id_list = []
graph_base_path = os.path.join(os.path.dirname(__file__),
os.pardir, "utils", "log_generators", "graph_base.json")
self._root_dir = tempfile.mkdtemp() self._root_dir = tempfile.mkdtemp()
for i in range(self._dir_num): for i in range(self._dir_num):
dir_path = tempfile.mkdtemp(dir=self._root_dir) dir_path = tempfile.mkdtemp(dir=self._root_dir)
tmp_tag_name = self._tag_name + '_' + str(i) tmp_tag_name = self._tag_name + '_' + str(i)
event_name = str(i) + "_name" event_name = str(i) + "_name"
train_id = dir_path.replace(self._root_dir, ".") train_id = dir_path.replace(self._root_dir, ".")
...@@ -89,20 +86,17 @@ class TestTrainTaskManager: ...@@ -89,20 +86,17 @@ class TestTrainTaskManager:
log_settings = dict( log_settings = dict(
steps=self._steps_list, steps=self._steps_list,
tag=tmp_tag_name, tag=tmp_tag_name,
graph_base_path=graph_base_path,
time=time.time()) time=time.time())
if i % 3 != 0: if i % 3 != 0:
LogOperations.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings)
self._plugins_id_map['image'].append(train_id) self._plugins_id_map['image'].append(train_id)
if i % 3 != 1: if i % 3 != 1:
LogOperations.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings)
self._plugins_id_map['scalar'].append(train_id) self._plugins_id_map['scalar'].append(train_id)
if i % 3 != 2: if i % 3 != 2:
LogOperations.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings)
self._plugins_id_map['graph'].append(train_id) self._plugins_id_map['graph'].append(train_id)
self._events_names.append(event_name) self._events_names.append(event_name)
self._train_id_list.append(train_id) self._train_id_list.append(train_id)
self._generated_path.append(self._root_dir) self._generated_path.append(self._root_dir)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mask string to crc32."""
CRC_TABLE_32 = (
0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF,
0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C,
0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57,
0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A,
0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E,
0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD,
0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696,
0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957,
0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3,
0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, 0xDBFC821C, 0x2997011F,
0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395,
0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859,
0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312,
0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE,
0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90,
0x563C5F93, 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C,
0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, 0x1871A4D8,
0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5,
0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E,
0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D,
0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19,
0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8,
0xE52CC12C, 0x1747422F, 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3,
0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540,
0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A,
0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6,
0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76, 0x12551F82,
0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E,
0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351
)
_CRC = 0
_MASK = 0xFFFFFFFF
def _uint32(x):
"""Transform x's type to uint32."""
return x & 0xFFFFFFFF
def _get_crc_checksum(crc, data):
"""Get crc checksum."""
crc ^= _MASK
for d in data:
crc_table_index = (crc ^ d) & 0xFF
crc = (CRC_TABLE_32[crc_table_index] ^ (crc >> 8)) & _MASK
crc ^= _MASK
return crc
def get_mask_from_string(data):
"""
Get masked crc from data.
Args:
data (byte): Byte string of data.
Returns:
uint32, masked crc.
"""
crc = _get_crc_checksum(_CRC, data)
crc = _uint32(crc & _MASK)
crc = _uint32(((crc >> 15) | _uint32(crc << 17)) + 0xA282EAD8)
return crc
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for images."""
import io
import time
import numpy as np
from PIL import Image
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class ImagesLogGenerator(LogGenerator):
"""
Log generator for images.
This is a log generator writing images. User can use it to generate fake
summary logs about images.
"""
def generate_event(self, values):
"""
Method for generating image event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
image (np.array): Pixels tensor.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
image_event = summary_pb2.Event()
image_event.wall_time = values.get('wall_time')
image_event.step = values.get('step')
height, width, channel, image_string = self._get_image_string(values.get('image'))
value = image_event.summary.value.add()
value.tag = values.get('tag')
value.image.height = height
value.image.width = width
value.image.colorspace = channel
value.image.encoded_image = image_string
return image_event
def _get_image_string(self, image_tensor):
"""
Generate image string from tensor.
Args:
image_tensor (np.array): Pixels tensor.
Returns:
int, height.
int, width.
int, channel.
bytes, image_string.
"""
height, width, channel = image_tensor.shape
scaled_height = int(height)
scaled_width = int(width)
image = Image.fromarray(image_tensor)
image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
return height, width, channel, image_string
def _make_image_tensor(self, shape):
"""
Make image tensor according to shape.
Args:
shape (list): Shape of image, consists of height, width, channel.
Returns:
np.array, image tensor.
"""
image = np.prod(shape)
image_tensor = (np.arange(image, dtype=float)).reshape(shape)
image_tensor = image_tensor / np.max(image_tensor) * 255
image_tensor = image_tensor.astype(np.uint8)
return image_tensor
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated image metadata.
dict, generated image tensors.
"""
images_values = dict()
images_metadata = []
for step in steps_list:
wall_time = time.time()
# height, width, channel
image_tensor = self._make_image_tensor([5, 5, 3])
image_metadata = dict()
image_metadata.update({'wall_time': wall_time})
image_metadata.update({'step': step})
image_metadata.update({'height': image_tensor.shape[0]})
image_metadata.update({'width': image_tensor.shape[1]})
images_metadata.append(image_metadata)
images_values.update({step: image_tensor})
values = dict(
wall_time=wall_time,
step=step,
image=image_tensor,
tag=tag_name
)
self._write_log_one_step(file_path, values)
return images_metadata, images_values
def get_image_tensor_from_bytes(self, image_string):
"""Get image tensor from bytes."""
img = Image.open(io.BytesIO(image_string))
image_tensor = np.array(img)
return image_tensor
if __name__ == "__main__":
images_log_generator = ImagesLogGenerator()
test_file_name = '%s.%s.%s' % ('image', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tags = "test_image_tag_name"
images_log_generator.generate_log(test_file_name, test_steps, test_tags)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base log Generator."""
import struct
from abc import abstractmethod
from tests.ut.datavisual.utils import crc32
class LogGenerator:
"""
Base log generator.
This is a base class for log generators. User can use it to generate fake
summary logs.
"""
@abstractmethod
def generate_event(self, values):
"""
Abstract method for generating event.
Args:
values (dict): Values.
Returns:
summary_pb2.Event.
"""
def _write_log_one_step(self, file_path, values):
"""
Write log one step.
Args:
file_path (str): File path to write.
values (dict): Values.
"""
event = self.generate_event(values)
self._write_log_from_event(file_path, event)
@staticmethod
def _write_log_from_event(file_path, event):
"""
Write log by event.
Args:
file_path (str): File path to write.
event (summary_pb2.Event): Event object in proto.
"""
send_msg = event.SerializeToString()
header = struct.pack('<Q', len(send_msg))
header_crc = struct.pack('<I', crc32.get_mask_from_string(header))
footer_crc = struct.pack('<I', crc32.get_mask_from_string(send_msg))
write_event = header + header_crc + send_msg + footer_crc
with open(file_path, "ab") as f:
f.write(write_event)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for scalars."""
import time
import numpy as np
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class ScalarsLogGenerator(LogGenerator):
"""
Log generator for scalars.
This is a log generator writing scalars. User can use it to generate fake
summary logs about scalar.
"""
def generate_event(self, values):
"""
Method for generating scalar event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
value (float): Scalar value.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
scalar_event = summary_pb2.Event()
scalar_event.wall_time = values.get('wall_time')
scalar_event.step = values.get('step')
value = scalar_event.summary.value.add()
value.tag = values.get('tag')
value.scalar_value = values.get('value')
return scalar_event
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated scalar metadata.
None, to be consistent with return value of ImageGenerator.
"""
scalars_metadata = []
for step in steps_list:
scalar_metadata = dict()
wall_time = time.time()
value = np.random.rand()
scalar_metadata.update({'wall_time': wall_time})
scalar_metadata.update({'step': step})
scalar_metadata.update({'value': value})
scalars_metadata.append(scalar_metadata)
scalar_metadata.update({"tag": tag_name})
self._write_log_one_step(file_path, scalar_metadata)
return scalars_metadata, None
if __name__ == "__main__":
scalars_log_generator = ScalarsLogGenerator()
test_file_name = '%s.%s.%s' % ('scalar', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tag = "test_scalar_tag_name"
scalars_log_generator.generate_log(test_file_name, test_steps, test_tag)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""log operations module."""
import json
import os
import time
from tests.ut.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator
from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum
log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
PluginNameEnum.SCALAR.value: ScalarsLogGenerator()
}
class LogOperations:
"""Log Operations class."""
@staticmethod
def generate_log(plugin_name, log_dir, log_settings, valid=True):
"""
Generate log.
Args:
plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'.
log_dir (str): Log path to write log.
log_settings (dict): Info about the log, e.g.:
{
current_time (int): Timestamp in summary file name, not necessary.
graph_base_path (str): Path of graph_bas.json, necessary for `graph`.
steps (list[int]): Steps for `image` and `scalar`, default is [1].
tag (str): Tag name, default is 'default_tag'.
}
valid (bool): If true, summary name will be valid.
Returns:
str, Summary log path.
"""
current_time = log_settings.get('time', int(time.time()))
current_time = int(current_time)
log_generator = log_generators.get(plugin_name)
if valid:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time)))
else:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time)))
if plugin_name == PluginNameEnum.GRAPH.value:
graph_base_path = log_settings.get('graph_base_path')
with open(graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f)
graph_dict = log_generator.generate_log(temp_path, graph_dict)
return temp_path, graph_dict
steps_list = log_settings.get('steps', [1])
tag_name = log_settings.get('tag', 'default_tag')
metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name)
return temp_path, metadata, values
@staticmethod
def get_log_generator(plugin_name):
"""Get log generator."""
return log_generators.get(plugin_name)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Description: This file is used for some common util.
"""
import os
import shutil
import time
from urllib.parse import urlencode
from mindinsight.datavisual.common.enums import DataManagerStatus
def get_url(url, params):
"""
Concatenate the URL and params.
Args:
url (str): A link requested. For example, http://example.com.
params (dict): A dict consists of params. For example, {'offset': 1, 'limit':'100}.
Returns:
str, like http://example.com?offset=1&limit=100
"""
return url + '?' + urlencode(params)
def delete_files_or_dirs(path_list):
"""Delete files or dirs in path_list."""
for path in path_list:
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
def check_loading_done(data_manager, time_limit=15):
"""If loading data for more than `time_limit` seconds, exit."""
start_time = time.time()
while data_manager.status != DataManagerStatus.DONE.value:
time_used = time.time() - start_time
if time_used > time_limit:
break
time.sleep(0.1)
continue
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -18,8 +18,7 @@ import os ...@@ -18,8 +18,7 @@ import os
import time import time
from google.protobuf import json_format from google.protobuf import json_format
from .log_generator import LogGenerator
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
...@@ -74,7 +73,7 @@ class GraphLogGenerator(LogGenerator): ...@@ -74,7 +73,7 @@ class GraphLogGenerator(LogGenerator):
if __name__ == "__main__": if __name__ == "__main__":
graph_log_generator = GraphLogGenerator() graph_log_generator = GraphLogGenerator()
test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time()))
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json") graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators--", "graph_base.json")
with open(graph_base_path, 'r') as load_f: with open(graph_base_path, 'r') as load_f:
graph = json.load(load_f) graph = json.load(load_f)
graph_log_generator.generate_log(test_file_name, graph) graph_log_generator.generate_log(test_file_name, graph)
...@@ -18,7 +18,7 @@ import time ...@@ -18,7 +18,7 @@ import time
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import struct import struct
from abc import abstractmethod from abc import abstractmethod
from tests.st.func.datavisual.utils import crc32 from ...utils import crc32
class LogGenerator: class LogGenerator:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import time import time
import numpy as np import numpy as np
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator from .log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
......
...@@ -19,10 +19,9 @@ import json ...@@ -19,10 +19,9 @@ import json
import os import os
import time import time
from tests.st.func.datavisual.constants import SUMMARY_PREFIX from .log_generators.graph_log_generator import GraphLogGenerator
from tests.st.func.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator from .log_generators.images_log_generator import ImagesLogGenerator
from tests.st.func.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator from .log_generators.scalars_log_generator import ScalarsLogGenerator
from tests.st.func.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
...@@ -39,6 +38,7 @@ class LogOperations: ...@@ -39,6 +38,7 @@ class LogOperations:
self._step_num = 3 self._step_num = 3
self._tag_num = 2 self._tag_num = 2
self._time_count = 0 self._time_count = 0
self._graph_base_path = os.path.join(os.path.dirname(__file__), "log_generators", "graph_base.json")
def _get_steps(self): def _get_steps(self):
"""Get steps.""" """Get steps."""
...@@ -61,9 +61,7 @@ class LogOperations: ...@@ -61,9 +61,7 @@ class LogOperations:
metadata_dict["plugins"].update({plugin_name: list()}) metadata_dict["plugins"].update({plugin_name: list()})
log_generator = log_generators.get(plugin_name) log_generator = log_generators.get(plugin_name)
if plugin_name == PluginNameEnum.GRAPH.value: if plugin_name == PluginNameEnum.GRAPH.value:
graph_base_path = os.path.join(os.path.dirname(__file__), with open(self._graph_base_path, 'r') as load_f:
os.pardir, "utils", "log_generators", "graph_base.json")
with open(graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f) graph_dict = json.load(load_f)
values = log_generator.generate_log(file_path, graph_dict) values = log_generator.generate_log(file_path, graph_dict)
metadata_dict["actual_values"].update({plugin_name: values}) metadata_dict["actual_values"].update({plugin_name: values})
...@@ -82,13 +80,13 @@ class LogOperations: ...@@ -82,13 +80,13 @@ class LogOperations:
self._time_count += 1 self._time_count += 1
return metadata_dict return metadata_dict
def create_summary_logs(self, summary_base_dir, summary_dir_num, start_index=0): def create_summary_logs(self, summary_base_dir, summary_dir_num, dir_prefix, start_index=0):
"""Create summary logs in summary_base_dir.""" """Create summary logs in summary_base_dir."""
summary_metadata = dict() summary_metadata = dict()
steps_list = self._get_steps() steps_list = self._get_steps()
tag_name_list = self._get_tags() tag_name_list = self._get_tags()
for i in range(start_index, summary_dir_num + start_index): for i in range(start_index, summary_dir_num + start_index):
log_dir = os.path.join(summary_base_dir, f'{SUMMARY_PREFIX}{i}') log_dir = os.path.join(summary_base_dir, f'{dir_prefix}{i}')
os.makedirs(log_dir) os.makedirs(log_dir)
train_id = log_dir.replace(summary_base_dir, ".") train_id = log_dir.replace(summary_base_dir, ".")
...@@ -120,3 +118,47 @@ class LogOperations: ...@@ -120,3 +118,47 @@ class LogOperations:
metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list) metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list)
return {train_id: metadata_dict} return {train_id: metadata_dict}
def generate_log(self, plugin_name, log_dir, log_settings=None, valid=True):
"""
Generate log for ut.
Args:
plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'.
log_dir (str): Log path to write log.
log_settings (dict): Info about the log, e.g.:
{
current_time (int): Timestamp in summary file name, not necessary.
graph_base_path (str): Path of graph_bas.json, necessary for `graph`.
steps (list[int]): Steps for `image` and `scalar`, default is [1].
tag (str): Tag name, default is 'default_tag'.
}
valid (bool): If true, summary name will be valid.
Returns:
str, Summary log path.
"""
if log_settings is None:
log_settings = dict()
current_time = log_settings.get('time', int(time.time()))
current_time = int(current_time)
log_generator = log_generators.get(plugin_name)
if valid:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time)))
else:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time)))
if plugin_name == PluginNameEnum.GRAPH.value:
with open(self._graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f)
graph_dict = log_generator.generate_log(temp_path, graph_dict)
return temp_path, graph_dict
steps_list = log_settings.get('steps', [1])
tag_name = log_settings.get('tag', 'default_tag')
metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name)
return temp_path, metadata, values
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册