提交 13030f75 编写于 作者: K kouzhenzhong

lineage: decouple train/eval lineage with summary_writer

上级 98516a84
......@@ -29,7 +29,7 @@ from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, L
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \
LineageGetModelFileError, LineageLogError
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.utils import try_except
from mindinsight.lineagemgr.common.utils import try_except, make_directory
from mindinsight.lineagemgr.common.validator.model_parameter import RunContextArgs, \
EvalParameter
from mindinsight.lineagemgr.collection.model.base import Metadata
......@@ -50,9 +50,12 @@ class TrainLineage(Callback):
Collect lineage of a training job.
Args:
summary_record (SummaryRecord): SummaryRecord is used to record
the summary value, and summary_record is an instance of SummaryRecord,
see mindspore.train.summary.SummaryRecord.
summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
is used to record the summary value(see mindspore.train.summary.SummaryRecord),
or a log dir(as a `str`) to be passed to `LineageSummary` to create
a lineage summary recorder. It should be noted that instead of making
use of summary_record to record lineage info directly, we obtain
log dir from it then create a new summary file to write lineage info.
raise_exception (bool): Whether to raise exception when error occurs in
TrainLineage. If True, raise exception. If False, catch exception
and continue. Default: False.
......@@ -74,18 +77,25 @@ class TrainLineage(Callback):
>>> lineagemgr = TrainLineage(summary_record=summary_writer)
>>> model.train(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
"""
def __init__(self, summary_record, raise_exception=False, user_defined_info=None):
def __init__(self,
summary_record,
raise_exception=False,
user_defined_info=None):
super(TrainLineage, self).__init__()
try:
validate_raise_exception(raise_exception)
self.raise_exception = raise_exception
validate_summary_record(summary_record)
self.summary_record = summary_record
if isinstance(summary_record, str):
# make directory if not exist
self.lineage_log_dir = make_directory(summary_record)
else:
validate_summary_record(summary_record)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_dir = os.path.dirname(summary_log_path)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_path = summary_log_path + '_lineage'
self.lineage_summary = LineageSummary(self.lineage_log_dir)
self.initial_learning_rate = None
......@@ -113,8 +123,7 @@ class TrainLineage(Callback):
log.info('Initialize training lineage collection...')
if self.user_defined_info:
lineage_summary = LineageSummary(summary_log_path=self.lineage_log_path)
lineage_summary.record_user_defined_info(self.user_defined_info)
self.lineage_summary.record_user_defined_info(self.user_defined_info)
if not isinstance(run_context, RunContext):
error_msg = f'Invalid TrainLineage run_context.'
......@@ -147,8 +156,7 @@ class TrainLineage(Callback):
dataset_graph_dict = json.loads(dataset_graph_json_str)
log.info('Logging dataset graph...')
try:
lineage_summary = LineageSummary(self.lineage_log_path)
lineage_summary.record_dataset_graph(dataset_graph=dataset_graph_dict)
self.lineage_summary.record_dataset_graph(dataset_graph=dataset_graph_dict)
except Exception as error:
error_msg = f'Dataset graph log error in TrainLineage begin: {error}'
log.error(error_msg)
......@@ -210,8 +218,7 @@ class TrainLineage(Callback):
log.info('Logging lineage information...')
try:
lineage_summary = LineageSummary(self.lineage_log_path)
lineage_summary.record_train_lineage(train_lineage)
self.lineage_summary.record_train_lineage(train_lineage)
except IOError as error:
error_msg = f'End error in TrainLineage: {error}'
log.error(error_msg)
......@@ -228,10 +235,12 @@ class EvalLineage(Callback):
"""
Collect lineage of an evaluation job.
Args:
summary_record (SummaryRecord): SummaryRecord is used to record
the summary value, and summary_record is an instance of SummaryRecord,
see mindspore.train.summary.SummaryRecord.
summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
is used to record the summary value(see mindspore.train.summary.SummaryRecord),
or a log dir(as a `str`) to be passed to `LineageSummary` to create
a lineage summary recorder. It should be noted that instead of making
use of summary_record to record lineage info directly, we obtain
log dir from it then create a new summary file to write lineage info.
raise_exception (bool): Whether to raise exception when error occurs in
EvalLineage. If True, raise exception. If False, catch exception
and continue. Default: False.
......@@ -253,18 +262,25 @@ class EvalLineage(Callback):
>>> lineagemgr = EvalLineage(summary_record=summary_writer)
>>> model.eval(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
"""
def __init__(self, summary_record, raise_exception=False, user_defined_info=None):
def __init__(self,
summary_record,
raise_exception=False,
user_defined_info=None):
super(EvalLineage, self).__init__()
try:
validate_raise_exception(raise_exception)
self.raise_exception = raise_exception
validate_summary_record(summary_record)
self.summary_record = summary_record
if isinstance(summary_record, str):
# make directory if not exist
self.lineage_log_dir = make_directory(summary_record)
else:
validate_summary_record(summary_record)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_dir = os.path.dirname(summary_log_path)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_path = summary_log_path + '_lineage'
self.lineage_summary = LineageSummary(self.lineage_log_dir)
self.user_defined_info = user_defined_info
if user_defined_info:
......@@ -289,8 +305,7 @@ class EvalLineage(Callback):
LineageLogError: If recording lineage information fails.
"""
if self.user_defined_info:
lineage_summary = LineageSummary(summary_log_path=self.lineage_log_path)
lineage_summary.record_user_defined_info(self.user_defined_info)
self.lineage_summary.record_user_defined_info(self.user_defined_info)
if not isinstance(run_context, RunContext):
error_msg = f'Invalid EvalLineage run_context.'
......@@ -312,8 +327,7 @@ class EvalLineage(Callback):
log.info('Logging evaluation job lineage...')
try:
lineage_summary = LineageSummary(self.lineage_log_path)
lineage_summary.record_evaluation_lineage(eval_lineage)
self.lineage_summary.record_evaluation_lineage(eval_lineage)
except IOError as error:
error_msg = f'End error in EvalLineage: {error}'
log.error(error_msg)
......
......@@ -13,14 +13,15 @@
# limitations under the License.
# ============================================================================
"""Lineage utils."""
from functools import wraps
import os
import re
from functools import wraps
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \
LineageGetModelFileError, LineageLogError, LineageParamValueError, LineageDirNotExistError, \
LineageParamSummaryPathError
LineageGetModelFileError, LineageLogError, LineageParamValueError, LineageParamTypeError, \
LineageDirNotExistError, LineageParamSummaryPathError
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.validator.validate import validate_path
from mindinsight.utils.exceptions import MindInsightException
......@@ -76,3 +77,29 @@ def get_timestamp(filename):
"""Get timestamp from filename."""
timestamp = int(re.search(SummaryWatcher().SUMMARY_FILENAME_REGEX, filename)[1])
return timestamp
def make_directory(path):
"""Make directory."""
real_path = None
if path is None or not isinstance(path, str) or not path.strip():
log.error("Invalid input path: %r.", path)
raise LineageParamTypeError("Invalid path type")
# convert relative path to abs path
path = os.path.realpath(path)
log.debug("The abs path is %r", path)
# check path exist and its write permissions]
if os.path.exists(path):
real_path = path
else:
# All exceptions need to be caught because create directory maybe have some limit(permissions)
log.debug("The directory(%s) doesn't exist, will create it", path)
try:
os.makedirs(path, exist_ok=True)
real_path = path
except PermissionError as e:
log.error("No write permission on the directory(%r), error = %r", path, e)
raise LineageParamTypeError("No write permission on the directory.")
return real_path
......@@ -14,7 +14,7 @@
# ============================================================================
"""Validate the parameters."""
import os
import re
from marshmallow import ValidationError
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
......@@ -31,6 +31,9 @@ try:
except (ImportError, ModuleNotFoundError):
log.warning('MindSpore Not Found!')
# Named string regular expression
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
TRAIN_RUN_CONTEXT_ERROR_MAPPING = {
'optimizer': LineageErrors.PARAM_OPTIMIZER_ERROR,
'loss_fn': LineageErrors.PARAM_LOSS_FN_ERROR,
......@@ -511,3 +514,27 @@ def validate_added_info(added_info: dict):
raise LineageParamValueError("'remark' must be str.")
# length of remark should be in [0, 128].
validate_range("length of remark", len(value), min_value=0, max_value=128)
def validate_str_by_regular(target, reg=None, flag=re.ASCII):
"""
Validate string by given regular.
Args:
target: target string.
reg: pattern.
flag: pattern mode.
Raises:
LineageParamValueError, if string not match given pattern.
Returns:
bool, if target matches pattern, return True.
"""
if reg is None:
reg = _name_re
if re.match(reg, target, flag) is None:
raise LineageParamValueError("'{}' is illegal, it should be match "
"regular'{}' by flags'{}'".format(target, reg, flag))
return True
......@@ -13,12 +13,18 @@
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import socket
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError
from mindinsight.lineagemgr.common.log import logger as log
# Set the Event mark
EVENT_FILE_NAME_MARK = "out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK = "_lineage"
def package_dataset_graph(graph):
"""
......@@ -345,3 +351,21 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
def get_lineage_file_name():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second = str(int(time.time()))
hostname = socket.gethostname()
file_name = f'{EVENT_FILE_NAME_MARK}summary.{time_second}.{hostname}{LINEAGE_FILE_NAME_MARK}'
return file_name
......@@ -13,11 +13,13 @@
# limitations under the License.
# ============================================================================
"""Record message to summary log."""
import os
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.validator.validate import validate_file_path
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from ._summary_adapter import package_dataset_graph, package_user_defined_info
from ._summary_adapter import package_dataset_graph, package_user_defined_info, get_lineage_file_name
class LineageSummary:
......@@ -26,20 +28,24 @@ class LineageSummary:
Recording train lineage and evaluation lineage to summary log.
Args:
summary_log_path (str): Summary log path.
lineage_log_dir (str): lineage log dir.
override (bool): If override the summary log exist.
Raises:
IOError: Write to summary log failed or file_path is a dir.
IOError: Write to summary log failed.
Examples:
>>> summary_log_path = "./test.log"
>>> train_lineage = {"train_network": "ResNet"}
>>> lineage_summary = LineageSummary(summary_log_path=summary_log_path)
>>> lineage_summary = LineageSummary(lineage_log_dir="./")
>>> lineage_summary.record_train_lineage(train_lineage)
"""
def __init__(self, summary_log_path=None, override=False):
self.event_writer = EventWriter(summary_log_path, override)
def __init__(self,
lineage_log_dir=None,
override=False):
lineage_log_name = get_lineage_file_name()
self.lineage_log_path = os.path.join(lineage_log_dir, lineage_log_name)
validate_file_path(self.lineage_log_path)
self.event_writer = EventWriter(self.lineage_log_path, override)
@staticmethod
def package_train_message(run_context_args):
......
......@@ -73,6 +73,7 @@ class TestModelLineage(TestCase):
]
cls.run_context['list_callback'] = _ListCallback(callback)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -84,9 +85,10 @@ class TestModelLineage(TestCase):
train_callback = TrainLineage(self.summary_record, True)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = self.summary_record.full_file_name + '_lineage'
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -103,9 +105,28 @@ class TestModelLineage(TestCase):
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = self.summary_record.full_file_name + '_lineage'
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_lineage_with_log_dir(self):
"""Test TrainLineage with log_dir."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'log_dir')
train_callback = TrainLineage(summary_record=summary_dir)
train_callback.begin(RunContext(self.run_context))
assert summary_dir == train_callback.lineage_log_dir
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -127,6 +148,7 @@ class TestModelLineage(TestCase):
res = get_summary_lineage(SUMMARY_DIR)
assert res.get('hyper_parameters', {}).get('epoch') == 14
@pytest.mark.scene_eval(3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -142,6 +164,7 @@ class TestModelLineage(TestCase):
eval_run_context['step_num'] = 32
eval_callback.end(RunContext(eval_run_context))
@pytest.mark.scene_eval(3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -168,6 +191,7 @@ class TestModelLineage(TestCase):
assert res == expect_res
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -177,31 +201,30 @@ class TestModelLineage(TestCase):
@mock.patch('os.path.getsize')
def test_multiple_trains(self, *args):
"""
Callback TrainLineage and EvalLineage for mutltiple times.
Callback TrainLineage and EvalLineage for multiple times.
Write TrainLineage and EvalLineage in different files under same directory.
EvalLineage log file end with '_lineage'.
"""
args[0].return_value = 10
for i in range(2):
summary_record = SummaryRecord(SUMMARY_DIR_2,
create_time=int(time.time()) + i)
eval_record = SummaryRecord(SUMMARY_DIR_2,
file_prefix='events.eval.',
create_time=int(time.time() + 10) + i,
file_suffix='.MS_lineage')
summary_record = SummaryRecord(SUMMARY_DIR_2, create_time=int(time.time()))
eval_record = SummaryRecord(SUMMARY_DIR_2, create_time=int(time.time()) + 1)
train_callback = TrainLineage(summary_record, True)
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))
time.sleep(1)
eval_callback = EvalLineage(eval_record, True)
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78 + i + 1}
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['step_num'] = 32
eval_callback.end(RunContext(eval_run_context))
time.sleep(1)
file_num = os.listdir(SUMMARY_DIR_2)
assert len(file_num) == 8
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -234,6 +257,7 @@ class TestModelLineage(TestCase):
assert res.get('algorithm', {}).get('network') == 'ResNet'
assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -246,7 +270,6 @@ class TestModelLineage(TestCase):
full_file_name = summary_record.full_file_name
assert os.path.isfile(full_file_name) is True
assert os.path.isfile(full_file_name + "_lineage") is False
train_callback = TrainLineage(summary_record, True)
eval_callback = EvalLineage(summary_record, False)
with self.assertRaises(LineageParamRunContextError):
......@@ -256,6 +279,7 @@ class TestModelLineage(TestCase):
assert len(file_num) == 1
assert os.path.isfile(full_file_name + "_lineage") is False
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -276,6 +300,7 @@ class TestModelLineage(TestCase):
assert len(file_num) == 1
assert os.path.isfile(full_file_name + "_lineage") is False
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -300,6 +325,7 @@ class TestModelLineage(TestCase):
assert os.path.isfile(full_file_name) is True
assert os.path.getsize(full_file_name) == 0
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -317,7 +343,7 @@ class TestModelLineage(TestCase):
summary_record = SummaryRecord(SUMMARY_DIR_3)
train_callback = TrainLineage(summary_record, True)
train_callback.begin(RunContext(self.run_context))
full_file_name = summary_record.full_file_name + "_lineage"
full_file_name = train_callback.lineage_summary.lineage_log_path
file_size1 = os.path.getsize(full_file_name)
train_callback.end(RunContext(self.run_context))
file_size2 = os.path.getsize(full_file_name)
......@@ -327,6 +353,7 @@ class TestModelLineage(TestCase):
file_size3 = os.path.getsize(full_file_name)
assert file_size3 == file_size2
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -338,7 +365,7 @@ class TestModelLineage(TestCase):
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'run4')
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
summary_record = SummaryRecord(summary_dir, file_suffix='_MS_lineage')
summary_record = SummaryRecord(summary_dir, file_suffix='_MS_lineage_none')
full_file_name = summary_record.full_file_name
assert full_file_name.endswith('_lineage')
assert full_file_name.endswith('_lineage_none')
assert os.path.isfile(full_file_name)
......@@ -53,6 +53,19 @@ def pytest_collection_modifyitems(items):
if module_item is not None:
module_item.append(item)
ordered_items = split_items.get(COLLECTION_MODULE)
item_scenes = []
for item in ordered_items:
scenes = [
marker for marker in item.own_markers
if marker.name.startswith('scene')
]
if scenes:
scene_mark = scenes[0].args[0]
else:
scene_mark = 0
item_scenes.append((item, scene_mark))
sorted_item_scenes = sorted(item_scenes, key=lambda x: x[1])
ordered_items = [item_scene[0] for item_scene in sorted_item_scenes]
ordered_items.extend(split_items.get(API_MODULE))
items[:] = ordered_items
......
......@@ -205,6 +205,7 @@ class TestModelLineage(TestCase):
self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2)
self.assertTrue('Invalid value for raise_exception.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
......@@ -216,20 +217,24 @@ class TestModelLineage(TestCase):
args[1].return_value = True
args[2].return_value = True
args[3].return_value = None
args[4].return_value = '/path/to/lineage/log/dir'
args[0].return_value = None
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path))
eval_lineage.end(self.my_run_context(self.run_context))
args[0].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_end_except_run_context(self, *args):
"""Test EvalLineage.end method when run_context is invalid.."""
args[0].return_value = True
args[1].return_value = '/path/to/lineage/log/dir'
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context:
eval_lineage.end(self.run_context)
self.assertTrue('Invalid EvalLineage run_context.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
......@@ -242,11 +247,13 @@ class TestModelLineage(TestCase):
args[1].return_value = True
args[2].return_value = True
args[3].return_value = None
args[4].return_value = '/path/to/lineage/log/dir'
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(LineageLogError) as context:
eval_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in EvalLineage' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
......@@ -259,6 +266,7 @@ class TestModelLineage(TestCase):
args[1].return_value = True
args[2].return_value = True
args[3].return_value = None
args[4].return_value = '/path/to/lineage/log/dir'
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(LineageLogError) as context:
eval_lineage.end(self.my_run_context(self.run_context))
......
......@@ -64,7 +64,7 @@ class TestSummaryRecord(TestCase):
def test_record_train_lineage(self, write_file):
"""Test record_train_lineage."""
write_file.return_value = True
lineage_summray = LineageSummary(summary_log_path="test.log")
lineage_summray = LineageSummary(lineage_log_dir="test.log")
lineage_summray.record_train_lineage(self.run_context_args)
def test_package_evaluation_message(self):
......@@ -76,5 +76,5 @@ class TestSummaryRecord(TestCase):
def test_record_eval_lineage(self, write_file):
"""Test record_eval_lineage."""
write_file.return_value = True
lineage_summray = LineageSummary(summary_log_path="test.log")
lineage_summray = LineageSummary(lineage_log_dir="test.log")
lineage_summray.record_evaluation_lineage(self.eval_args)
......@@ -15,6 +15,7 @@
"""MindSpore Mock Interface"""
import os
import time
import socket
class SummaryRecord:
......@@ -22,13 +23,15 @@ class SummaryRecord:
def __init__(self,
log_dir: str,
file_prefix: str = "events.",
file_suffix: str = ".MS",
file_prefix: str = "events",
file_suffix: str = "_MS",
create_time=int(time.time())):
self.log_dir = log_dir
self.prefix = file_prefix
self.suffix = file_suffix
file_name = file_prefix + 'summary.' + str(create_time) + file_suffix
hostname = socket.gethostname()
file_name = f'{file_prefix}.out.events.summary.{str(create_time)}.{hostname}{file_suffix}'
self.full_file_name = os.path.join(log_dir, file_name)
permissions = os.R_OK | os.W_OK | os.X_OK
mode = permissions << 6
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册