提交 7ce9f9bd 编写于 作者: L Li Hongzhang

remove TrainLineage, EvalLineage and related ut/st

上级 ed594114
......@@ -15,19 +15,12 @@
"""
Lineagemgr Module Introduction.
This module provides Python APIs to collect and query the lineage of models.
Users can add the TrainLineage/EvalLineage callback to the MindSpore train/eval callback list to
collect the key parameters and results, such as, the name of the network and optimizer, the
evaluation metric and results.
This module provides Python APIs to query the lineage of models.
The APIs can be used to get the lineage information of the models. For example,
what hyperparameter is used in the model training, which model has the highest
accuracy among all the versions, etc.
"""
from mindinsight.lineagemgr.api.model import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.common.log import logger
try:
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage
except (ModuleNotFoundError, NameError, ImportError):
logger.warning('Not found MindSpore!')
__all__ = ["TrainLineage", "EvalLineage", "get_summary_lineage", "filter_summary_lineage"]
__all__ = ["get_summary_lineage", "filter_summary_lineage"]
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
......@@ -37,16 +37,6 @@ class LineageParamValueError(MindInsightException):
)
class LineageParamMissingError(MindInsightException):
"""The parameter missing error in lineage module."""
def __init__(self, msg):
super(LineageParamMissingError, self).__init__(
error=LineageErrors.PARAM_MISSING_ERROR,
message=LineageErrorMsg.PARAM_MISSING_ERROR.value.format(msg)
)
class LineageParamRunContextError(MindInsightException):
"""The input parameter run_context error in lineage module."""
......@@ -67,15 +57,6 @@ class LineageGetModelFileError(MindInsightException):
)
class LineageSearchModelParamError(MindInsightException):
"""The lineage search model param error."""
def __init__(self, msg):
super(LineageSearchModelParamError, self).__init__(
error=LineageErrors.LINEAGE_PARAM_NOT_SUPPORT_ERROR,
message=LineageErrorMsg.LINEAGE_PARAM_NOT_SUPPORT_ERROR.value.format(msg)
)
class LineageSummaryAnalyzeException(MindInsightException):
"""The summary analyze error in lineage module."""
......
......@@ -14,7 +14,6 @@
# ============================================================================
"""Define schema of model lineage input parameters."""
from marshmallow import Schema, fields, ValidationError, pre_load, validates
from marshmallow.validate import Range
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
LineageErrors
......@@ -28,72 +27,10 @@ from mindinsight.utils.exceptions import MindInsightException
try:
from mindspore.dataset.engine import Dataset
from mindspore.nn import Cell, Optimizer
from mindspore.common.tensor import Tensor
from mindspore.train.callback import _ListCallback
except (ImportError, ModuleNotFoundError):
logger.error('MindSpore Not Found!')
class RunContextArgs(Schema):
"""Define the parameter schema for RunContext."""
optimizer = fields.Function(allow_none=True)
loss_fn = fields.Function(allow_none=True)
net_outputs = fields.Function(allow_none=True)
train_network = fields.Function(allow_none=True)
train_dataset = fields.Function(allow_none=True)
epoch_num = fields.Int(allow_none=True, validate=Range(min=1))
batch_num = fields.Int(allow_none=True, validate=Range(min=0))
cur_step_num = fields.Int(allow_none=True, validate=Range(min=0))
parallel_mode = fields.Str(allow_none=True)
device_number = fields.Int(allow_none=True, validate=Range(min=1))
list_callback = fields.Function(allow_none=True)
@pre_load
def check_optimizer(self, data, **kwargs):
optimizer = data.get("optimizer")
if optimizer and not isinstance(optimizer, Optimizer):
raise ValidationError({'optimizer': [
"Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
]})
return data
@pre_load
def check_train_network(self, data, **kwargs):
train_network = data.get("train_network")
if train_network and not isinstance(train_network, Cell):
raise ValidationError({'train_network': [
"Parameter train_network must be an instance of mindspore.nn.Cell."]})
return data
@pre_load
def check_train_dataset(self, data, **kwargs):
train_dataset = data.get("train_dataset")
if train_dataset and not isinstance(train_dataset, Dataset):
raise ValidationError({'train_dataset': [
"Parameter train_dataset must be an instance of "
"mindspore.dataengine.datasets.Dataset"]})
return data
@pre_load
def check_loss(self, data, **kwargs):
net_outputs = data.get("net_outputs")
if net_outputs and not isinstance(net_outputs, Tensor):
raise ValidationError({'net_outpus': [
"The parameter net_outputs is invalid. It should be a Tensor."
]})
return data
@pre_load
def check_list_callback(self, data, **kwargs):
list_callback = data.get("list_callback")
if list_callback and not isinstance(list_callback, _ListCallback):
raise ValidationError({'list_callback': [
"Parameter list_callback must be an instance of "
"mindspore.train.callback._ListCallback."
]})
return data
class EvalParameter(Schema):
"""Define the parameter schema for Evaluation job."""
......
......@@ -18,18 +18,13 @@ import re
from marshmallow import ValidationError
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamMissingError, \
LineageParamTypeError, LineageParamValueError, LineageDirNotExistError
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, \
LineageParamValueError, LineageDirNotExistError
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
from mindinsight.utils.exceptions import MindInsightException, ParamValueError
try:
from mindspore.nn import Cell
from mindspore.train.summary import SummaryRecord
except (ImportError, ModuleNotFoundError):
log.warning('MindSpore Not Found!')
# Named string regular expression
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
......@@ -144,31 +139,6 @@ def validate_int_params(int_param, param_name):
message=LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value)
def validate_network(network):
"""
Verify if the network is valid.
Args:
network (Cell): See mindspore.nn.Cell.
Raises:
LineageParamMissingError: If the network is None.
MindInsightException: If the network is invalid.
"""
if not network:
error_msg = "The input network for TrainLineage should not be None."
log.error(error_msg)
raise LineageParamMissingError(error_msg)
if not isinstance(network, Cell):
log.error("Invalid network. Network should be an instance"
"of mindspore.nn.Cell.")
raise MindInsightException(
error=LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
message=LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value
)
def validate_file_path(file_path, allow_empty=False):
"""
Verify that the file_path is valid.
......@@ -190,28 +160,6 @@ def validate_file_path(file_path, allow_empty=False):
message=str(error))
def validate_train_run_context(schema, data):
"""
Validate mindspore train run_context data according to schema.
Args:
schema (Schema): data schema.
data (dict): data to check schema.
Raises:
MindInsightException: If the parameters are invalid.
"""
errors = schema().validate(data)
for error_key, error_msg in errors.items():
if error_key in TRAIN_RUN_CONTEXT_ERROR_MAPPING.keys():
error_code = TRAIN_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
if TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
error_msg = TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
log.error(error_msg)
raise MindInsightException(error=error_code, message=error_msg)
def validate_eval_run_context(schema, data):
"""
Validate mindspore evaluation job run_context data according to schema.
......@@ -257,27 +205,6 @@ def validate_search_model_condition(schema, data):
raise MindInsightException(error=error_code, message=error_msg)
def validate_summary_record(summary_record):
"""
Validate summary_record.
Args:
summary_record (SummaryRecord): SummaryRecord is used to record
the summary value, and summary_record is an instance of SummaryRecord,
see mindspore.train.summary.SummaryRecord
Raises:
MindInsightException: If the parameters are invalid.
"""
if not isinstance(summary_record, SummaryRecord):
log.error("Invalid summary_record. It should be an instance "
"of mindspore.train.summary.SummaryRecord.")
raise MindInsightException(
error=LineageErrors.PARAM_SUMMARY_RECORD_ERROR,
message=LineageErrorMsg.PARAM_SUMMARY_RECORD_ERROR.value
)
def validate_raise_exception(raise_exception):
"""
Validate raise_exception.
......
......@@ -13,131 +13,6 @@
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import socket
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError
from mindinsight.lineagemgr.common.log import logger as log
# Set the Event mark
EVENT_FILE_NAME_MARK = "out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK = "_lineage"
def package_dataset_graph(graph):
"""
Package dataset graph.
Args:
graph (dict): Dataset graph.
Returns:
LineageEvent, the proto message event contains dataset graph.
"""
dataset_graph_event = LineageEvent()
dataset_graph_event.wall_time = time.time()
dataset_graph = dataset_graph_event.dataset_graph
if "children" in graph:
children = graph.pop("children")
if children:
_package_children(children=children, message=dataset_graph)
_package_current_dataset(operation=graph, message=dataset_graph)
return dataset_graph_event
def _package_children(children, message):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for child in children:
if child:
child_graph_message = getattr(message, "children").add()
grandson = child.pop("children")
if grandson:
_package_children(children=grandson, message=child_graph_message)
# package other parameters
_package_current_dataset(operation=child, message=child_graph_message)
def _package_current_dataset(operation, message):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for key, value in operation.items():
if value and key == "operations":
for operator in value:
_package_enhancement_operation(
operator,
message.operations.add()
)
elif value and key == "sampler":
_package_enhancement_operation(
value,
message.sampler
)
else:
_package_parameter(key, value, message.parameter)
def _package_enhancement_operation(operation, message):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for key, value in operation.items():
if isinstance(value, list):
if all(isinstance(ele, int) for ele in value):
message.size.extend(value)
else:
message.weights.extend(value)
else:
_package_parameter(key, value, message.operationParam)
def _package_parameter(key, value, message):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if isinstance(value, str):
message.mapStr[key] = value
elif isinstance(value, bool):
message.mapBool[key] = value
elif isinstance(value, int):
message.mapInt[key] = value
elif isinstance(value, float):
message.mapDouble[key] = value
elif isinstance(value, list) and key != "operations":
if value:
replace_value_list = list(map(lambda x: "" if x is None else x, value))
message.mapStrList[key].strValue.extend(replace_value_list)
elif value is None:
message.mapStr[key] = "None"
else:
error_msg = "Parameter {} is not supported " \
"in event package.".format(key)
log.error(error_msg)
raise LineageParamTypeError(error_msg)
def organize_graph(graph_message):
"""
......@@ -296,76 +171,3 @@ def _organize_parameter(parameter):
parameter_result.update(result_str_list_para)
return parameter_result
def package_user_defined_info(user_dict):
"""
Package user defined info.
Args:
user_dict(dict): User defined info dict.
Returns:
LineageEvent, the proto message event contains user defined info.
"""
user_event = LineageEvent()
user_event.wall_time = time.time()
user_defined_info = user_event.user_defined_info
_package_user_defined_info(user_dict, user_defined_info)
return user_event
def _package_user_defined_info(user_defined_dict, user_defined_message):
"""
Setting attribute in user defined proto message.
Args:
user_defined_dict (dict): User define info dict.
user_defined_message (LineageEvent): Proto message of user defined info.
Raises:
LineageParamValueError: When the value is out of range.
LineageParamTypeError: When given a type not support yet.
"""
for key, value in user_defined_dict.items():
if not isinstance(key, str):
error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)
if isinstance(value, int):
attr_name = "map_int32"
elif isinstance(value, float):
attr_name = "map_double"
elif isinstance(value, str):
attr_name = "map_str"
else:
attr_name = "attr_name"
add_user_defined_info = user_defined_message.user_info.add()
try:
getattr(add_user_defined_info, attr_name)[key] = value
except AttributeError:
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
def get_lineage_file_name():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second = str(int(time.time()))
hostname = socket.gethostname()
file_name = f'{EVENT_FILE_NAME_MARK}summary.{time_second}.{hostname}{LINEAGE_FILE_NAME_MARK}'
return file_name
......@@ -23,25 +23,23 @@ Usage:
import os
import shutil
import time
from unittest import mock, TestCase
from unittest import TestCase, mock
import numpy as np
import pytest
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject
from mindinsight.lineagemgr.common.utils import make_directory
from mindinsight.lineagemgr import get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError
from mindinsight.utils.exceptions import MindInsightException
from mindspore.application.model_zoo.resnet import ResNet
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits, WithLossCell
from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep, _ListCallback
from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep, _ListCallback
from mindspore.train.summary import SummaryRecord
from ...conftest import SUMMARY_DIR, SUMMARY_DIR_2, SUMMARY_DIR_3, BASE_SUMMARY_DIR
from tests.utils.lineage_writer.model_lineage import AnalyzeObject, EvalLineage, TrainLineage
from ...conftest import SUMMARY_DIR, SUMMARY_DIR_2, SUMMARY_DIR_3
from .train_one_step import TrainOneStep
......@@ -78,65 +76,6 @@ class TestModelLineage(TestCase):
"version": "v1"
}
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin(self):
"""Test the begin function in TrainLineage."""
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info."""
user_defined_info = {
"info": "info1",
"version": "v1",
"network": "LeNet"
}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined']
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_lineage_with_log_dir(self):
"""Test TrainLineage with log_dir."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'log_dir')
train_callback = TrainLineage(summary_record=summary_dir)
train_callback.begin(RunContext(self.run_context))
assert summary_dir == train_callback.lineage_log_dir
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......@@ -148,7 +87,7 @@ class TestModelLineage(TestCase):
def test_training_end(self, *args):
"""Test the end function in TrainLineage."""
args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)
train_callback.initial_learning_rate = 0.12
train_callback.end(RunContext(self.run_context))
res = get_summary_lineage(SUMMARY_DIR)
......@@ -175,33 +114,6 @@ class TestModelLineage(TestCase):
eval_run_context['step_num'] = 32
eval_callback.end(RunContext(eval_run_context))
@pytest.mark.scene_eval(3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_eval_only(self):
"""Test record evaluation event only."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'eval_only_dir')
summary_record = SummaryRecord(summary_dir)
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.58}
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['step_num'] = 32
eval_only_callback = EvalLineage(summary_record)
eval_only_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir,
['metric', 'dataset_graph'])
expect_res = {
'summary_dir': summary_dir,
'dataset_graph': {},
'metric': {'accuracy': 0.58}
}
assert res == expect_res
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......@@ -209,7 +121,7 @@ class TestModelLineage(TestCase):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name')
@mock.patch('tests.utils.lineage_writer._summary_record.get_lineage_file_name')
@mock.patch('os.path.getsize')
def test_multiple_trains(self, *args):
"""
......@@ -247,89 +159,6 @@ class TestModelLineage(TestCase):
file_num = os.listdir(SUMMARY_DIR_2)
assert len(file_num) == 8
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size')
def test_train_eval(self, *args):
"""Callback for train once and eval once."""
args[0].return_value = 10
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_eval')
make_directory(summary_dir)
args[1].return_value = os.path.join(
summary_dir,
f'train_out.events.summary.{str(int(time.time()))}.ubuntu_lineage'
)
train_callback = TrainLineage(summary_dir)
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))
args[1].return_value = os.path.join(
summary_dir,
f'eval_out.events.summary.{str(int(time.time())+1)}.ubuntu_lineage'
)
eval_callback = EvalLineage(summary_dir)
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['step_num'] = 32
eval_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir)
assert res.get('hyper_parameters', {}).get('loss_function') \
== 'SoftmaxCrossEntropyWithLogits'
assert res.get('algorithm', {}).get('network') == 'ResNet'
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size')
def test_train_multi_eval(self, *args):
"""Callback for train once and eval twice."""
args[0].return_value = 10
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_multi_eval')
make_directory(summary_dir)
args[1].return_value = os.path.join(
summary_dir,
'train_out.events.summary.1590107366.ubuntu_lineage')
train_callback = TrainLineage(summary_dir, True)
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))
args[1].return_value = os.path.join(
summary_dir,
'eval_out.events.summary.1590107367.ubuntu_lineage')
eval_callback = EvalLineage(summary_dir, True)
eval_run_context = self.run_context
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['metrics'] = {'accuracy': 0.79}
eval_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir)
assert res.get('metric', {}).get('accuracy') == 0.79
args[1].return_value = os.path.join(
summary_dir,
'eval_out.events.summary.1590107368.ubuntu_lineage')
eval_callback = EvalLineage(summary_dir, True)
eval_run_context = self.run_context
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['metrics'] = {'accuracy': 0.80}
eval_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir)
assert res.get('metric', {}).get('accuracy') == 0.80
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......@@ -341,7 +170,7 @@ class TestModelLineage(TestCase):
def test_train_with_customized_network(self, *args):
"""Test train with customized network."""
args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)
run_context_customized = self.run_context
del run_context_customized['optimizer']
del run_context_customized['net_outputs']
......@@ -363,27 +192,6 @@ class TestModelLineage(TestCase):
assert res.get('algorithm', {}).get('network') == 'ResNet'
assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception(self):
"""Test exception when raise_exception is set True."""
summary_record = SummaryRecord(SUMMARY_DIR_3)
full_file_name = summary_record.full_file_name
assert os.path.isfile(full_file_name) is True
assert os.path.isfile(full_file_name + "_lineage") is False
train_callback = TrainLineage(summary_record, True)
eval_callback = EvalLineage(summary_record, False)
with self.assertRaises(LineageParamRunContextError):
train_callback.begin(self.run_context)
eval_callback.end(self.run_context)
file_num = os.listdir(SUMMARY_DIR_3)
assert len(file_num) == 1
assert os.path.isfile(full_file_name + "_lineage") is False
@pytest.mark.scene_exception(1)
@pytest.mark.level0
......@@ -392,53 +200,7 @@ class TestModelLineage(TestCase):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception_init(self):
"""Test exception when error happened during the initialization process."""
if os.path.exists(SUMMARY_DIR_3):
shutil.rmtree(SUMMARY_DIR_3)
summary_record = SummaryRecord(SUMMARY_DIR_3)
train_callback = TrainLineage('fake_summary_record', False)
eval_callback = EvalLineage('fake_summary_record', False)
train_callback.begin(RunContext(self.run_context))
eval_callback.end(RunContext(self.run_context))
file_num = os.listdir(SUMMARY_DIR_3)
full_file_name = summary_record.full_file_name
assert len(file_num) == 1
assert os.path.isfile(full_file_name + "_lineage") is False
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception_create_file(self):
"""Test exception when error happened after creating file."""
if os.path.exists(SUMMARY_DIR_3):
shutil.rmtree(SUMMARY_DIR_3)
summary_record = SummaryRecord(SUMMARY_DIR_3)
eval_callback = EvalLineage(summary_record, False)
full_file_name = summary_record.full_file_name + "_lineage"
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['step_num'] = 32
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
with open(full_file_name, 'ab'):
with mock.patch('builtins.open') as mock_handler:
mock_handler.return_value.__enter__.return_value.write.side_effect = IOError
eval_callback.end(RunContext(eval_run_context))
assert os.path.isfile(full_file_name) is True
assert os.path.getsize(full_file_name) == 0
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_eval_run_context')
@mock.patch('tests.utils.lineage_writer.model_lineage.validate_eval_run_context')
@mock.patch.object(AnalyzeObject, 'get_file_size', return_value=64)
def test_raise_exception_record_trainlineage(self, *args):
"""Test exception when error happened after recording training infos."""
......@@ -446,32 +208,14 @@ class TestModelLineage(TestCase):
shutil.rmtree(SUMMARY_DIR_3)
args[1].side_effect = MindInsightException(error=LineageErrors.PARAM_RUN_CONTEXT_ERROR,
message="RunContext error.")
summary_record = SummaryRecord(SUMMARY_DIR_3)
train_callback = TrainLineage(summary_record, True)
train_callback = TrainLineage(SUMMARY_DIR_3, True)
train_callback.begin(RunContext(self.run_context))
full_file_name = train_callback.lineage_summary.lineage_log_path
file_size1 = os.path.getsize(full_file_name)
train_callback.end(RunContext(self.run_context))
file_size2 = os.path.getsize(full_file_name)
assert file_size2 > file_size1
eval_callback = EvalLineage(summary_record, False)
eval_callback = EvalLineage(SUMMARY_DIR_3, False)
eval_callback.end(RunContext(self.run_context))
file_size3 = os.path.getsize(full_file_name)
assert file_size3 == file_size2
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception_non_lineage_file(self):
"""Test exception when lineage summary file cannot be found."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'run4')
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
summary_record = SummaryRecord(summary_dir, file_suffix='_MS_lineage_none')
full_file_name = summary_record.full_file_name
assert full_file_name.endswith('_lineage_none')
assert os.path.isfile(full_file_name)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test class EventWriter."""
import os
from unittest import TestCase
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from mindinsight.lineagemgr.summary.summary_record import LineageSummary
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageSummaryAnalyzer
class TestEventWriter(TestCase):
"""Test write_event_to_file."""
def setUp(self):
"""The setup of test."""
self.log_path = "./test.log"
def test_write_event_to_file(self):
"""Test write event to file."""
run_context_args = {"train_network": "res"}
content = LineageSummary.package_train_message(run_context_args).SerializeToString()
event_writer = EventWriter(self.log_path, True)
event_writer.write_event_to_file(content)
lineage_info = LineageSummaryAnalyzer.get_summary_infos(self.log_path)
self.assertEqual(
lineage_info.train_lineage.train_lineage.algorithm.network,
run_context_args.get("train_network")
)
def tearDown(self):
"""The setup of test."""
if os.path.exists(self.log_path):
try:
os.remove(self.log_path)
except IOError:
pass
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test class SummaryRecord."""
import json
from unittest import TestCase, mock
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from mindinsight.lineagemgr.summary.summary_record import LineageSummary
class TestSummaryRecord(TestCase):
"""Test summary record."""
def setUp(self):
"""The setup of test."""
self.run_context_args = dict()
self.run_context_args["train_network"] = "test_train_network"
self.run_context_args["loss"] = 0.1
self.run_context_args["learning_rate"] = 0.1
self.run_context_args["optimizer"] = "test_optimizer"
self.run_context_args["loss_function"] = "test_loss_function"
self.run_context_args["epoch"] = 1
self.run_context_args["parallel_mode"] = "test_parallel_mode"
self.run_context_args["device_num"] = 1
self.run_context_args["batch_size"] = 1
self.run_context_args["train_dataset_path"] = "test_train_dataset_path"
self.run_context_args["train_dataset_size"] = 1
self.run_context_args["model_path"] = "test_model_path"
self.run_context_args["model_size"] = 1
self.eval_args = dict()
self.eval_args["metrics"] = json.dumps({"acc": "test"})
self.eval_args["valid_dataset_path"] = "test_valid_dataset_path"
self.eval_args["valid_dataset_size"] = 1
self.hard_info_args = dict()
self.hard_info_args["pid"] = 1
self.hard_info_args["process_start_time"] = 921226.0
def test_package_train_message(self):
"""Test package_train_message."""
event = LineageSummary.package_train_message(self.run_context_args)
self.assertEqual(
event.train_lineage.algorithm.network, self.run_context_args.get("train_network"))
self.assertEqual(
event.train_lineage.hyper_parameters.optimizer, self.run_context_args.get("optimizer"))
self.assertEqual(
event.train_lineage.train_dataset.train_dataset_path,
self.run_context_args.get("train_dataset_path")
)
@mock.patch.object(EventWriter, "write_event_to_file")
def test_record_train_lineage(self, write_file):
"""Test record_train_lineage."""
write_file.return_value = True
lineage_summray = LineageSummary(lineage_log_dir="test.log")
lineage_summray.record_train_lineage(self.run_context_args)
def test_package_evaluation_message(self):
"""Test package_evaluation_message."""
event = LineageSummary.package_evaluation_message(self.eval_args)
self.assertEqual(event.evaluation_lineage.metric, self.eval_args.get("metrics"))
@mock.patch.object(EventWriter, "write_event_to_file")
def test_record_eval_lineage(self, write_file):
"""Test record_eval_lineage."""
write_file.return_value = True
lineage_summray = LineageSummary(lineage_log_dir="test.log")
lineage_summray.record_evaluation_lineage(self.eval_args)
......@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Lineage writer module."""
from ._summary_record import LineageSummary
__all__ = ["LineageSummary"]
......@@ -17,7 +17,7 @@ import os
import stat
import struct
from mindinsight.datavisual.utils import crc32
from tests.utils import crc32
class EventWriter:
......@@ -84,6 +84,6 @@ class EventWriter:
Returns:
bytes, crc of content, 4 bytes.
"""
crc_value = crc32.GetMaskCrc32cValue(content, len(content))
crc_value = crc32.get_mask_from_string(content)
return struct.pack("<L", crc_value)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import socket
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError
from mindinsight.lineagemgr.common.log import logger as log
# Set the Event mark
EVENT_FILE_NAME_MARK = "out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK = "_lineage"
def package_dataset_graph(graph):
"""
Package dataset graph.
Args:
graph (dict): Dataset graph.
Returns:
LineageEvent, the proto message event contains dataset graph.
"""
dataset_graph_event = LineageEvent()
dataset_graph_event.wall_time = time.time()
dataset_graph = dataset_graph_event.dataset_graph
if "children" in graph:
children = graph.pop("children")
if children:
_package_children(children=children, message=dataset_graph)
_package_current_dataset(operation=graph, message=dataset_graph)
return dataset_graph_event
def _package_children(children, message):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for child in children:
if child:
child_graph_message = getattr(message, "children").add()
grandson = child.pop("children")
if grandson:
_package_children(children=grandson, message=child_graph_message)
# package other parameters
_package_current_dataset(operation=child, message=child_graph_message)
def _package_current_dataset(operation, message):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for key, value in operation.items():
if value and key == "operations":
for operator in value:
_package_enhancement_operation(
operator,
message.operations.add()
)
elif value and key == "sampler":
_package_enhancement_operation(
value,
message.sampler
)
else:
_package_parameter(key, value, message.parameter)
def _package_enhancement_operation(operation, message):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for key, value in operation.items():
if isinstance(value, list):
if all(isinstance(ele, int) for ele in value):
message.size.extend(value)
else:
message.weights.extend(value)
else:
_package_parameter(key, value, message.operationParam)
def _package_parameter(key, value, message):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if isinstance(value, str):
message.mapStr[key] = value
elif isinstance(value, bool):
message.mapBool[key] = value
elif isinstance(value, int):
message.mapInt[key] = value
elif isinstance(value, float):
message.mapDouble[key] = value
elif isinstance(value, list) and key != "operations":
if value:
replace_value_list = list(map(lambda x: "" if x is None else x, value))
message.mapStrList[key].strValue.extend(replace_value_list)
elif value is None:
message.mapStr[key] = "None"
else:
error_msg = "Parameter {} is not supported " \
"in event package.".format(key)
log.error(error_msg)
raise LineageParamTypeError(error_msg)
def package_user_defined_info(user_dict):
"""
Package user defined info.
Args:
user_dict(dict): User defined info dict.
Returns:
LineageEvent, the proto message event contains user defined info.
"""
user_event = LineageEvent()
user_event.wall_time = time.time()
user_defined_info = user_event.user_defined_info
_package_user_defined_info(user_dict, user_defined_info)
return user_event
def _package_user_defined_info(user_defined_dict, user_defined_message):
"""
Setting attribute in user defined proto message.
Args:
user_defined_dict (dict): User define info dict.
user_defined_message (LineageEvent): Proto message of user defined info.
Raises:
LineageParamValueError: When the value is out of range.
LineageParamTypeError: When given a type not support yet.
"""
for key, value in user_defined_dict.items():
if not isinstance(key, str):
error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)
if isinstance(value, int):
attr_name = "map_int32"
elif isinstance(value, float):
attr_name = "map_double"
elif isinstance(value, str):
attr_name = "map_str"
else:
attr_name = "attr_name"
add_user_defined_info = user_defined_message.user_info.add()
try:
getattr(add_user_defined_info, attr_name)[key] = value
except AttributeError:
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
def get_lineage_file_name():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second = str(int(time.time()))
hostname = socket.gethostname()
file_name = f'{EVENT_FILE_NAME_MARK}summary.{time_second}.{hostname}{LINEAGE_FILE_NAME_MARK}'
return file_name
......@@ -17,8 +17,7 @@ import os
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.validator.validate import validate_file_path
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from ._event_writer import EventWriter
from ._summary_adapter import package_dataset_graph, package_user_defined_info, get_lineage_file_name
......@@ -40,11 +39,10 @@ class LineageSummary:
>>> lineage_summary.record_train_lineage(train_lineage)
"""
def __init__(self,
lineage_log_dir=None,
lineage_log_dir,
override=False):
lineage_log_name = get_lineage_file_name()
self.lineage_log_path = os.path.join(lineage_log_dir, lineage_log_name)
validate_file_path(self.lineage_log_path)
self.event_writer = EventWriter(self.lineage_log_path, override)
@staticmethod
......
......@@ -17,22 +17,20 @@ import json
import os
import numpy as np
from mindinsight.lineagemgr.summary.summary_record import LineageSummary
from mindinsight.utils.exceptions import \
MindInsightException
from mindinsight.lineagemgr.common.validator.validate import validate_train_run_context, \
validate_eval_run_context, validate_file_path, validate_network, \
validate_int_params, validate_summary_record, validate_raise_exception,\
validate_user_defined_info
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \
LineageGetModelFileError, LineageLogError
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, LineageErrors
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
LineageParamRunContextError)
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.utils import try_except, make_directory
from mindinsight.lineagemgr.common.validator.model_parameter import RunContextArgs, \
EvalParameter
from mindinsight.lineagemgr.collection.model.base import Metadata
from mindinsight.lineagemgr.common.utils import make_directory, try_except
from mindinsight.lineagemgr.common.validator.model_parameter import EvalParameter
from mindinsight.lineagemgr.common.validator.validate import (validate_eval_run_context, validate_file_path,
validate_int_params,
validate_raise_exception,
validate_user_defined_info)
from mindinsight.utils.exceptions import MindInsightException
from ._summary_record import LineageSummary
from .base import Metadata
try:
from mindspore.common.tensor import Tensor
......@@ -91,7 +89,6 @@ class TrainLineage(Callback):
# make directory if not exist
self.lineage_log_dir = make_directory(summary_record)
else:
validate_summary_record(summary_record)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_dir = os.path.dirname(summary_log_path)
......@@ -145,7 +142,6 @@ class TrainLineage(Callback):
log.debug('initial_learning_rate: %s', self.initial_learning_rate)
else:
network = run_context_args.get('train_network')
validate_network(network)
optimizer = AnalyzeObject.get_optimizer_by_network(network)
self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
log.debug('initial_learning_rate: %s', self.initial_learning_rate)
......@@ -183,7 +179,6 @@ class TrainLineage(Callback):
raise LineageParamRunContextError(error_msg)
run_context_args = run_context.original_args()
validate_train_run_context(RunContextArgs, run_context_args)
train_lineage = dict()
train_lineage = AnalyzeObject.get_network_args(
......@@ -277,7 +272,6 @@ class EvalLineage(Callback):
# make directory if not exist
self.lineage_log_dir = make_directory(summary_record)
else:
validate_summary_record(summary_record)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_dir = os.path.dirname(summary_log_path)
......@@ -639,7 +633,6 @@ class AnalyzeObject:
dict, the lineage metadata.
"""
network = run_context_args.get('train_network')
validate_network(network)
optimizer = run_context_args.get('optimizer')
if not optimizer:
optimizer = AnalyzeObject.get_optimizer_by_network(network)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册