提交 7ce9f9bd 编写于 作者: L Li Hongzhang

remove TrainLineage, EvalLineage and related ut/st

上级 ed594114
......@@ -15,19 +15,12 @@
"""
Lineagemgr Module Introduction.
This module provides Python APIs to collect and query the lineage of models.
Users can add the TrainLineage/EvalLineage callback to the MindSpore train/eval callback list to
collect the key parameters and results, such as, the name of the network and optimizer, the
evaluation metric and results.
This module provides Python APIs to query the lineage of models.
The APIs can be used to get the lineage information of the models. For example,
what hyperparameter is used in the model training, which model has the highest
accuracy among all the versions, etc.
"""
from mindinsight.lineagemgr.api.model import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.common.log import logger
try:
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage
except (ModuleNotFoundError, NameError, ImportError):
logger.warning('Not found MindSpore!')
__all__ = ["TrainLineage", "EvalLineage", "get_summary_lineage", "filter_summary_lineage"]
__all__ = ["get_summary_lineage", "filter_summary_lineage"]
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
......@@ -37,16 +37,6 @@ class LineageParamValueError(MindInsightException):
)
class LineageParamMissingError(MindInsightException):
"""The parameter missing error in lineage module."""
def __init__(self, msg):
super(LineageParamMissingError, self).__init__(
error=LineageErrors.PARAM_MISSING_ERROR,
message=LineageErrorMsg.PARAM_MISSING_ERROR.value.format(msg)
)
class LineageParamRunContextError(MindInsightException):
"""The input parameter run_context error in lineage module."""
......@@ -67,15 +57,6 @@ class LineageGetModelFileError(MindInsightException):
)
class LineageSearchModelParamError(MindInsightException):
"""The lineage search model param error."""
def __init__(self, msg):
super(LineageSearchModelParamError, self).__init__(
error=LineageErrors.LINEAGE_PARAM_NOT_SUPPORT_ERROR,
message=LineageErrorMsg.LINEAGE_PARAM_NOT_SUPPORT_ERROR.value.format(msg)
)
class LineageSummaryAnalyzeException(MindInsightException):
"""The summary analyze error in lineage module."""
......
......@@ -14,7 +14,6 @@
# ============================================================================
"""Define schema of model lineage input parameters."""
from marshmallow import Schema, fields, ValidationError, pre_load, validates
from marshmallow.validate import Range
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
LineageErrors
......@@ -28,72 +27,10 @@ from mindinsight.utils.exceptions import MindInsightException
try:
from mindspore.dataset.engine import Dataset
from mindspore.nn import Cell, Optimizer
from mindspore.common.tensor import Tensor
from mindspore.train.callback import _ListCallback
except (ImportError, ModuleNotFoundError):
logger.error('MindSpore Not Found!')
class RunContextArgs(Schema):
"""Define the parameter schema for RunContext."""
optimizer = fields.Function(allow_none=True)
loss_fn = fields.Function(allow_none=True)
net_outputs = fields.Function(allow_none=True)
train_network = fields.Function(allow_none=True)
train_dataset = fields.Function(allow_none=True)
epoch_num = fields.Int(allow_none=True, validate=Range(min=1))
batch_num = fields.Int(allow_none=True, validate=Range(min=0))
cur_step_num = fields.Int(allow_none=True, validate=Range(min=0))
parallel_mode = fields.Str(allow_none=True)
device_number = fields.Int(allow_none=True, validate=Range(min=1))
list_callback = fields.Function(allow_none=True)
@pre_load
def check_optimizer(self, data, **kwargs):
optimizer = data.get("optimizer")
if optimizer and not isinstance(optimizer, Optimizer):
raise ValidationError({'optimizer': [
"Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
]})
return data
@pre_load
def check_train_network(self, data, **kwargs):
train_network = data.get("train_network")
if train_network and not isinstance(train_network, Cell):
raise ValidationError({'train_network': [
"Parameter train_network must be an instance of mindspore.nn.Cell."]})
return data
@pre_load
def check_train_dataset(self, data, **kwargs):
train_dataset = data.get("train_dataset")
if train_dataset and not isinstance(train_dataset, Dataset):
raise ValidationError({'train_dataset': [
"Parameter train_dataset must be an instance of "
"mindspore.dataengine.datasets.Dataset"]})
return data
@pre_load
def check_loss(self, data, **kwargs):
net_outputs = data.get("net_outputs")
if net_outputs and not isinstance(net_outputs, Tensor):
raise ValidationError({'net_outpus': [
"The parameter net_outputs is invalid. It should be a Tensor."
]})
return data
@pre_load
def check_list_callback(self, data, **kwargs):
list_callback = data.get("list_callback")
if list_callback and not isinstance(list_callback, _ListCallback):
raise ValidationError({'list_callback': [
"Parameter list_callback must be an instance of "
"mindspore.train.callback._ListCallback."
]})
return data
class EvalParameter(Schema):
"""Define the parameter schema for Evaluation job."""
......
......@@ -18,18 +18,13 @@ import re
from marshmallow import ValidationError
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamMissingError, \
LineageParamTypeError, LineageParamValueError, LineageDirNotExistError
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, \
LineageParamValueError, LineageDirNotExistError
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
from mindinsight.utils.exceptions import MindInsightException, ParamValueError
try:
from mindspore.nn import Cell
from mindspore.train.summary import SummaryRecord
except (ImportError, ModuleNotFoundError):
log.warning('MindSpore Not Found!')
# Named string regular expression
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
......@@ -144,31 +139,6 @@ def validate_int_params(int_param, param_name):
message=LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value)
def validate_network(network):
"""
Verify if the network is valid.
Args:
network (Cell): See mindspore.nn.Cell.
Raises:
LineageParamMissingError: If the network is None.
MindInsightException: If the network is invalid.
"""
if not network:
error_msg = "The input network for TrainLineage should not be None."
log.error(error_msg)
raise LineageParamMissingError(error_msg)
if not isinstance(network, Cell):
log.error("Invalid network. Network should be an instance"
"of mindspore.nn.Cell.")
raise MindInsightException(
error=LineageErrors.PARAM_TRAIN_NETWORK_ERROR,
message=LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value
)
def validate_file_path(file_path, allow_empty=False):
"""
Verify that the file_path is valid.
......@@ -190,28 +160,6 @@ def validate_file_path(file_path, allow_empty=False):
message=str(error))
def validate_train_run_context(schema, data):
"""
Validate mindspore train run_context data according to schema.
Args:
schema (Schema): data schema.
data (dict): data to check schema.
Raises:
MindInsightException: If the parameters are invalid.
"""
errors = schema().validate(data)
for error_key, error_msg in errors.items():
if error_key in TRAIN_RUN_CONTEXT_ERROR_MAPPING.keys():
error_code = TRAIN_RUN_CONTEXT_ERROR_MAPPING.get(error_key)
if TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key):
error_msg = TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key)
log.error(error_msg)
raise MindInsightException(error=error_code, message=error_msg)
def validate_eval_run_context(schema, data):
"""
Validate mindspore evaluation job run_context data according to schema.
......@@ -257,27 +205,6 @@ def validate_search_model_condition(schema, data):
raise MindInsightException(error=error_code, message=error_msg)
def validate_summary_record(summary_record):
"""
Validate summary_record.
Args:
summary_record (SummaryRecord): SummaryRecord is used to record
the summary value, and summary_record is an instance of SummaryRecord,
see mindspore.train.summary.SummaryRecord
Raises:
MindInsightException: If the parameters are invalid.
"""
if not isinstance(summary_record, SummaryRecord):
log.error("Invalid summary_record. It should be an instance "
"of mindspore.train.summary.SummaryRecord.")
raise MindInsightException(
error=LineageErrors.PARAM_SUMMARY_RECORD_ERROR,
message=LineageErrorMsg.PARAM_SUMMARY_RECORD_ERROR.value
)
def validate_raise_exception(raise_exception):
"""
Validate raise_exception.
......
......@@ -13,131 +13,6 @@
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import socket
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError
from mindinsight.lineagemgr.common.log import logger as log
# Set the Event mark
EVENT_FILE_NAME_MARK = "out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK = "_lineage"
def package_dataset_graph(graph):
"""
Package dataset graph.
Args:
graph (dict): Dataset graph.
Returns:
LineageEvent, the proto message event contains dataset graph.
"""
dataset_graph_event = LineageEvent()
dataset_graph_event.wall_time = time.time()
dataset_graph = dataset_graph_event.dataset_graph
if "children" in graph:
children = graph.pop("children")
if children:
_package_children(children=children, message=dataset_graph)
_package_current_dataset(operation=graph, message=dataset_graph)
return dataset_graph_event
def _package_children(children, message):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for child in children:
if child:
child_graph_message = getattr(message, "children").add()
grandson = child.pop("children")
if grandson:
_package_children(children=grandson, message=child_graph_message)
# package other parameters
_package_current_dataset(operation=child, message=child_graph_message)
def _package_current_dataset(operation, message):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for key, value in operation.items():
if value and key == "operations":
for operator in value:
_package_enhancement_operation(
operator,
message.operations.add()
)
elif value and key == "sampler":
_package_enhancement_operation(
value,
message.sampler
)
else:
_package_parameter(key, value, message.parameter)
def _package_enhancement_operation(operation, message):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for key, value in operation.items():
if isinstance(value, list):
if all(isinstance(ele, int) for ele in value):
message.size.extend(value)
else:
message.weights.extend(value)
else:
_package_parameter(key, value, message.operationParam)
def _package_parameter(key, value, message):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if isinstance(value, str):
message.mapStr[key] = value
elif isinstance(value, bool):
message.mapBool[key] = value
elif isinstance(value, int):
message.mapInt[key] = value
elif isinstance(value, float):
message.mapDouble[key] = value
elif isinstance(value, list) and key != "operations":
if value:
replace_value_list = list(map(lambda x: "" if x is None else x, value))
message.mapStrList[key].strValue.extend(replace_value_list)
elif value is None:
message.mapStr[key] = "None"
else:
error_msg = "Parameter {} is not supported " \
"in event package.".format(key)
log.error(error_msg)
raise LineageParamTypeError(error_msg)
def organize_graph(graph_message):
"""
......@@ -296,76 +171,3 @@ def _organize_parameter(parameter):
parameter_result.update(result_str_list_para)
return parameter_result
def package_user_defined_info(user_dict):
"""
Package user defined info.
Args:
user_dict(dict): User defined info dict.
Returns:
LineageEvent, the proto message event contains user defined info.
"""
user_event = LineageEvent()
user_event.wall_time = time.time()
user_defined_info = user_event.user_defined_info
_package_user_defined_info(user_dict, user_defined_info)
return user_event
def _package_user_defined_info(user_defined_dict, user_defined_message):
"""
Setting attribute in user defined proto message.
Args:
user_defined_dict (dict): User define info dict.
user_defined_message (LineageEvent): Proto message of user defined info.
Raises:
LineageParamValueError: When the value is out of range.
LineageParamTypeError: When given a type not support yet.
"""
for key, value in user_defined_dict.items():
if not isinstance(key, str):
error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)
if isinstance(value, int):
attr_name = "map_int32"
elif isinstance(value, float):
attr_name = "map_double"
elif isinstance(value, str):
attr_name = "map_str"
else:
attr_name = "attr_name"
add_user_defined_info = user_defined_message.user_info.add()
try:
getattr(add_user_defined_info, attr_name)[key] = value
except AttributeError:
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
def get_lineage_file_name():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second = str(int(time.time()))
hostname = socket.gethostname()
file_name = f'{EVENT_FILE_NAME_MARK}summary.{time_second}.{hostname}{LINEAGE_FILE_NAME_MARK}'
return file_name
......@@ -23,25 +23,23 @@ Usage:
import os
import shutil
import time
from unittest import mock, TestCase
from unittest import TestCase, mock
import numpy as np
import pytest
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject
from mindinsight.lineagemgr.common.utils import make_directory
from mindinsight.lineagemgr import get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError
from mindinsight.utils.exceptions import MindInsightException
from mindspore.application.model_zoo.resnet import ResNet
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits, WithLossCell
from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep, _ListCallback
from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep, _ListCallback
from mindspore.train.summary import SummaryRecord
from ...conftest import SUMMARY_DIR, SUMMARY_DIR_2, SUMMARY_DIR_3, BASE_SUMMARY_DIR
from tests.utils.lineage_writer.model_lineage import AnalyzeObject, EvalLineage, TrainLineage
from ...conftest import SUMMARY_DIR, SUMMARY_DIR_2, SUMMARY_DIR_3
from .train_one_step import TrainOneStep
......@@ -78,65 +76,6 @@ class TestModelLineage(TestCase):
"version": "v1"
}
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin(self):
"""Test the begin function in TrainLineage."""
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info."""
user_defined_info = {
"info": "info1",
"version": "v1",
"network": "LeNet"
}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined']
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_lineage_with_log_dir(self):
"""Test TrainLineage with log_dir."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'log_dir')
train_callback = TrainLineage(summary_record=summary_dir)
train_callback.begin(RunContext(self.run_context))
assert summary_dir == train_callback.lineage_log_dir
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......@@ -148,7 +87,7 @@ class TestModelLineage(TestCase):
def test_training_end(self, *args):
"""Test the end function in TrainLineage."""
args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)
train_callback.initial_learning_rate = 0.12
train_callback.end(RunContext(self.run_context))
res = get_summary_lineage(SUMMARY_DIR)
......@@ -175,33 +114,6 @@ class TestModelLineage(TestCase):
eval_run_context['step_num'] = 32
eval_callback.end(RunContext(eval_run_context))
@pytest.mark.scene_eval(3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_eval_only(self):
"""Test record evaluation event only."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'eval_only_dir')
summary_record = SummaryRecord(summary_dir)
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.58}
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['step_num'] = 32
eval_only_callback = EvalLineage(summary_record)
eval_only_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir,
['metric', 'dataset_graph'])
expect_res = {
'summary_dir': summary_dir,
'dataset_graph': {},
'metric': {'accuracy': 0.58}
}
assert res == expect_res
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......@@ -209,7 +121,7 @@ class TestModelLineage(TestCase):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name')
@mock.patch('tests.utils.lineage_writer._summary_record.get_lineage_file_name')
@mock.patch('os.path.getsize')
def test_multiple_trains(self, *args):
"""
......@@ -247,89 +159,6 @@ class TestModelLineage(TestCase):
file_num = os.listdir(SUMMARY_DIR_2)
assert len(file_num) == 8
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size')
def test_train_eval(self, *args):
"""Callback for train once and eval once."""
args[0].return_value = 10
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_eval')
make_directory(summary_dir)
args[1].return_value = os.path.join(
summary_dir,
f'train_out.events.summary.{str(int(time.time()))}.ubuntu_lineage'
)
train_callback = TrainLineage(summary_dir)
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))
args[1].return_value = os.path.join(
summary_dir,
f'eval_out.events.summary.{str(int(time.time())+1)}.ubuntu_lineage'
)
eval_callback = EvalLineage(summary_dir)
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['step_num'] = 32
eval_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir)
assert res.get('hyper_parameters', {}).get('loss_function') \
== 'SoftmaxCrossEntropyWithLogits'
assert res.get('algorithm', {}).get('network') == 'ResNet'
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size')
def test_train_multi_eval(self, *args):
"""Callback for train once and eval twice."""
args[0].return_value = 10
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_multi_eval')
make_directory(summary_dir)
args[1].return_value = os.path.join(
summary_dir,
'train_out.events.summary.1590107366.ubuntu_lineage')
train_callback = TrainLineage(summary_dir, True)
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))
args[1].return_value = os.path.join(
summary_dir,
'eval_out.events.summary.1590107367.ubuntu_lineage')
eval_callback = EvalLineage(summary_dir, True)
eval_run_context = self.run_context
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['metrics'] = {'accuracy': 0.79}
eval_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir)
assert res.get('metric', {}).get('accuracy') == 0.79
args[1].return_value = os.path.join(
summary_dir,
'eval_out.events.summary.1590107368.ubuntu_lineage')
eval_callback = EvalLineage(summary_dir, True)
eval_run_context = self.run_context
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
eval_run_context['metrics'] = {'accuracy': 0.80}
eval_callback.end(RunContext(eval_run_context))
res = get_summary_lineage(summary_dir)
assert res.get('metric', {}).get('accuracy') == 0.80
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......@@ -341,7 +170,7 @@ class TestModelLineage(TestCase):
def test_train_with_customized_network(self, *args):
"""Test train with customized network."""
args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)
run_context_customized = self.run_context
del run_context_customized['optimizer']
del run_context_customized['net_outputs']
......@@ -363,27 +192,6 @@ class TestModelLineage(TestCase):
assert res.get('algorithm', {}).get('network') == 'ResNet'
assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception(self):
"""Test exception when raise_exception is set True."""
summary_record = SummaryRecord(SUMMARY_DIR_3)
full_file_name = summary_record.full_file_name
assert os.path.isfile(full_file_name) is True
assert os.path.isfile(full_file_name + "_lineage") is False
train_callback = TrainLineage(summary_record, True)
eval_callback = EvalLineage(summary_record, False)
with self.assertRaises(LineageParamRunContextError):
train_callback.begin(self.run_context)
eval_callback.end(self.run_context)
file_num = os.listdir(SUMMARY_DIR_3)
assert len(file_num) == 1
assert os.path.isfile(full_file_name + "_lineage") is False
@pytest.mark.scene_exception(1)
@pytest.mark.level0
......@@ -392,53 +200,7 @@ class TestModelLineage(TestCase):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception_init(self):
"""Test exception when error happened during the initialization process."""
if os.path.exists(SUMMARY_DIR_3):
shutil.rmtree(SUMMARY_DIR_3)
summary_record = SummaryRecord(SUMMARY_DIR_3)
train_callback = TrainLineage('fake_summary_record', False)
eval_callback = EvalLineage('fake_summary_record', False)
train_callback.begin(RunContext(self.run_context))
eval_callback.end(RunContext(self.run_context))
file_num = os.listdir(SUMMARY_DIR_3)
full_file_name = summary_record.full_file_name
assert len(file_num) == 1
assert os.path.isfile(full_file_name + "_lineage") is False
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception_create_file(self):
"""Test exception when error happened after creating file."""
if os.path.exists(SUMMARY_DIR_3):
shutil.rmtree(SUMMARY_DIR_3)
summary_record = SummaryRecord(SUMMARY_DIR_3)
eval_callback = EvalLineage(summary_record, False)
full_file_name = summary_record.full_file_name + "_lineage"
eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['step_num'] = 32
eval_run_context['valid_dataset'] = self.run_context['train_dataset']
with open(full_file_name, 'ab'):
with mock.patch('builtins.open') as mock_handler:
mock_handler.return_value.__enter__.return_value.write.side_effect = IOError
eval_callback.end(RunContext(eval_run_context))
assert os.path.isfile(full_file_name) is True
assert os.path.getsize(full_file_name) == 0
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_eval_run_context')
@mock.patch('tests.utils.lineage_writer.model_lineage.validate_eval_run_context')
@mock.patch.object(AnalyzeObject, 'get_file_size', return_value=64)
def test_raise_exception_record_trainlineage(self, *args):
"""Test exception when error happened after recording training infos."""
......@@ -446,32 +208,14 @@ class TestModelLineage(TestCase):
shutil.rmtree(SUMMARY_DIR_3)
args[1].side_effect = MindInsightException(error=LineageErrors.PARAM_RUN_CONTEXT_ERROR,
message="RunContext error.")
summary_record = SummaryRecord(SUMMARY_DIR_3)
train_callback = TrainLineage(summary_record, True)
train_callback = TrainLineage(SUMMARY_DIR_3, True)
train_callback.begin(RunContext(self.run_context))
full_file_name = train_callback.lineage_summary.lineage_log_path
file_size1 = os.path.getsize(full_file_name)
train_callback.end(RunContext(self.run_context))
file_size2 = os.path.getsize(full_file_name)
assert file_size2 > file_size1
eval_callback = EvalLineage(summary_record, False)
eval_callback = EvalLineage(SUMMARY_DIR_3, False)
eval_callback.end(RunContext(self.run_context))
file_size3 = os.path.getsize(full_file_name)
assert file_size3 == file_size2
@pytest.mark.scene_exception(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_raise_exception_non_lineage_file(self):
"""Test exception when lineage summary file cannot be found."""
summary_dir = os.path.join(BASE_SUMMARY_DIR, 'run4')
if os.path.exists(summary_dir):
shutil.rmtree(summary_dir)
summary_record = SummaryRecord(summary_dir, file_suffix='_MS_lineage_none')
full_file_name = summary_record.full_file_name
assert full_file_name.endswith('_lineage_none')
assert os.path.isfile(full_file_name)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unittest for model_lineage.py"""
import os
import shutil
import unittest
from unittest import TestCase, mock
from unittest.mock import MagicMock
from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
MindInsightException)
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import Dataset, MindDataset
from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell
from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep
from mindspore.train.summary import SummaryRecord
@mock.patch('builtins.open')
@mock.patch('os.makedirs')
class TestModelLineage(TestCase):
"""Test TrainLineage and EvalLineage class in model_lineage.py."""
@classmethod
def setUpClass(cls):
cls.lineage_list = ['train_network', 'loss_fn', 'optimizer', 'train_dataset',
'valid_dataset', 'epoch', 'valid_step',
'hybrid_parallel', 'data_parallel_size', 'auto_parallel',
'device_number', 'batch_num', 'summary_log_path',
'model_ckpt']
cls.run_context = {key: None for key in cls.lineage_list}
cls.run_context['net_outputs'] = Tensor()
cls.my_run_context = RunContext
cls.my_train_module = TrainLineage
cls.my_eval_module = EvalLineage
cls.my_analyze_module = AnalyzeObject
cls.my_summary_record = SummaryRecord
cls.summary_log_path = '/path/to/summary_log'
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_summary_record_exception(self, *args):
"""Test SummaryRecord with exception."""
args[0].return_value = None
summary_record = self.my_summary_record(self.summary_log_path)
with self.assertRaises(MindInsightException) as context:
self.my_train_module(summary_record=summary_record, raise_exception=1)
self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
def test_begin(self, *args):
"""Test TrainLineage.begin method."""
args[1].return_value = None
args[2].return_value = Optimizer(Tensor(0.1))
args[3].return_value = None
args[5].serialize.return_value = {}
run_context = {'optimizer': Optimizer(Tensor(0.1)),
'epoch_num': 10}
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path))
train_lineage.begin(self.my_run_context(run_context))
args[4].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
def test_begin_error(self, *args):
"""Test TrainLineage.begin method."""
args[1].return_value = None
args[2].return_value = Optimizer(Tensor(0.1))
args[3].return_value = None
args[4].side_effect = Exception
args[5].serialize.return_value = {}
run_context = {'optimizer': Optimizer(Tensor(0.1)),
'epoch_num': 10}
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaisesRegex(LineageLogError, 'Dataset graph log error'):
train_lineage.begin(self.my_run_context(run_context))
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path))
train_lineage.begin(self.my_run_context(run_context))
args[4].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_begin_exception(self, *args):
"""Test TrainLineage.begin method with exception."""
args[0].return_value = None
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context:
train_lineage.begin(self.run_context)
self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception))
run_context = {key: None for key in self.lineage_list}
run_context['optimizer'] = 1
with self.assertRaises(Exception) as context:
train_lineage.begin(self.my_run_context(run_context))
self.assertTrue('The parameter optimizer is invalid.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_train_end(self, *args):
"""Test TrainLineage.end method."""
args[1].return_value = 2.0
args[2].return_value = True
args[3].return_value = True
args[4].return_value = None
args[5].return_value = None
args[6].return_value = None
args[7].return_value = (None, None)
args[8].return_value = 10
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
train_lineage.end(self.my_run_context(self.run_context))
args[6].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_train_end_exception(self, *args):
"""Test TrainLineage.end method when exception."""
args[0].return_value = True
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context:
train_lineage.end(self.run_context)
self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_train_end_exception_log_error(self, *args):
"""Test TrainLineage.end method with logging errors."""
args[1].return_value = 2.0
args[2].return_value = True
args[3].return_value = True
args[4].return_value = None
args[5].return_value = None
args[6].side_effect = Exception
args[7].return_value = (None, None)
args[8].return_value = 10
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(LineageLogError) as context:
train_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in TrainLineage:' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_train_end_exception_log_error2(self, *args):
"""Test TrainLineage.end method with logging errors."""
args[1].return_value = 2.0
args[2].return_value = True
args[3].return_value = True
args[4].return_value = None
args[5].return_value = None
args[6].side_effect = IOError
args[7].return_value = (None, None)
args[8].return_value = 10
run_context = {key: None for key in self.lineage_list}
run_context['loss_fn'] = MagicMock()
run_context['net_outputs'] = Tensor(0.11)
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(LineageLogError) as context:
train_lineage.end(self.my_run_context(run_context))
self.assertTrue('End error in TrainLineage:' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_exception_train_id_none(self, *args):
"""Test EvalLineage.end method with initialization error."""
args[0].return_value = True
with self.assertRaises(MindInsightException) as context:
self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2)
self.assertTrue('Invalid value for raise_exception.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_eval_run_context')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'LineageSummary.record_evaluation_lineage')
def test_eval_end(self, *args):
"""Test EvalLineage.end method."""
args[1].return_value = True
args[2].return_value = True
args[3].return_value = None
args[4].return_value = '/path/to/lineage/log/dir'
args[0].return_value = None
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path))
eval_lineage.end(self.my_run_context(self.run_context))
args[0].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_end_except_run_context(self, *args):
"""Test EvalLineage.end method when run_context is invalid.."""
args[0].return_value = True
args[1].return_value = '/path/to/lineage/log/dir'
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context:
eval_lineage.end(self.run_context)
self.assertTrue('Invalid EvalLineage run_context.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_eval_run_context')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'LineageSummary.record_evaluation_lineage')
def test_eval_end_except_log_error(self, *args):
"""Test EvalLineage.end method with logging error."""
args[0].side_effect = Exception
args[1].return_value = True
args[2].return_value = True
args[3].return_value = None
args[4].return_value = '/path/to/lineage/log/dir'
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(LineageLogError) as context:
eval_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in EvalLineage' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.make_directory')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_dataset')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_eval_run_context')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'LineageSummary.record_evaluation_lineage')
def test_eval_end_except_log_error2(self, *args):
"""Test EvalLineage.end method with logging error."""
args[0].side_effect = IOError
args[1].return_value = True
args[2].return_value = True
args[3].return_value = None
args[4].return_value = '/path/to/lineage/log/dir'
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(LineageLogError) as context:
eval_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in EvalLineage' in str(context.exception))
def test_epoch_is_zero(self, *args):
"""Test TrainLineage.end method."""
args[0].return_value = None
run_context = self.run_context
run_context['epoch_num'] = 0
with self.assertRaises(MindInsightException):
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
train_lineage.end(self.my_run_context(run_context))
def tearDown(self):
"""Teardown."""
if os.path.exists(self.summary_log_path):
try:
shutil.rmtree(self.summary_log_path)
except IOError:
pass
class TestAnalyzer(TestCase):
"""Test Analyzer class in model_lineage.py."""
def setUp(self):
"""SetUp config."""
self.analyzer = AnalyzeObject()
def test_analyze_optimizer(self):
"""Test analyze_optimizer method."""
optimizer = Optimizer(Tensor(0.12))
res = self.analyzer.analyze_optimizer(optimizer)
assert res == 0.12
def test_get_dataset_path(self):
"""Test get_dataset_path method."""
dataset = MindDataset(
dataset_file='/path/to/mindrecord'
)
res = self.analyzer.get_dataset_path(dataset)
assert res == '/path/to/mindrecord'
def test_get_dataset_path_wrapped(self):
"""Test get_dataset_path_wrapped method."""
dataset = Dataset()
dataset.input.append(
MindDataset(
dataset_size=10,
dataset_file='/path/to/cifar10'
))
res = self.analyzer.get_dataset_path_wrapped(dataset)
assert res == '/path/to/cifar10'
@mock.patch('os.path.isfile')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_dataset_path_wrapped')
def test_analyze_dataset(self, mock_get_path, mock_isfile):
"""Test analyze_dataset method."""
mock_get_path.return_value = '/path/to/mindinsightset'
mock_isfile.return_value = True
dataset = MindDataset(
dataset_size=10,
dataset_file='/path/to/mindinsightset'
)
res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train')
res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid')
# batch_size is mocked as 32.
assert res1 == {'step_num': 10,
'train_dataset_path': '/path/to',
'train_dataset_size': 320,
'epoch': 2}
assert res2 == {'step_num': 5, 'valid_dataset_path': '/path/to',
'valid_dataset_size': 320}
def test_get_dataset_path_dataset(self):
"""Test get_dataset_path method with Dataset."""
dataset = Dataset(
dataset_size=10,
dataset_path='/path/to/cifar10'
)
with self.assertRaises(IndexError):
self.analyzer.get_dataset_path(output_dataset=dataset)
def test_get_dataset_path_mindrecord(self):
"""Test get_dataset_path method with MindDataset."""
dataset = MindDataset(
dataset_file='/path/to/cifar10'
)
dataset_path = self.analyzer.get_dataset_path(output_dataset=dataset)
self.assertEqual(dataset_path, '/path/to/cifar10')
def test_get_file_path(self):
"""Test get_file_path method."""
model_ckpt = ModelCheckpoint(prefix='', directory='/path/to')
summary_step = SummaryStep(MagicMock(full_file_name='/path/to/summary.log'))
list_callback = [model_ckpt, summary_step]
ckpt_file_path, _ = AnalyzeObject.get_file_path(list_callback)
self.assertEqual(ckpt_file_path, '/path/to/test_model.ckpt')
@mock.patch('os.path.getsize')
def test_get_file_size(self, os_get_size_mock):
"""Test get_file_size method."""
os_get_size_mock.return_value = 128
file_size = AnalyzeObject.get_file_size('/file/path')
self.assertEqual(file_size, 128)
@mock.patch('os.path.getsize')
def test_get_file_size_except(self, os_get_size_mock):
"""Test failed to get the size of file."""
os_get_size_mock.side_effect = OSError
analyzer = AnalyzeObject
with self.assertRaises(LineageGetModelFileError) as context:
analyzer.get_file_size('/file/path')
self.assertTrue('Error when get model file size:' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size')
def test_get_model_size(self, get_file_size_mock):
"""Test get_model_size method."""
get_file_size_mock.return_value = 128
analyzer = AnalyzeObject
file_size = analyzer.get_model_size(ckpt_file_path='/file/path')
self.assertEqual(file_size, 128)
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size')
def test_get_model_size_no_ckpt(self, get_file_size_mock):
"""Test get_model_size method without ckpt file."""
get_file_size_mock.return_value = 0
analyzer = AnalyzeObject
file_size = analyzer.get_model_size(ckpt_file_path='')
self.assertEqual(file_size, 0)
@mock.patch('builtins.vars')
def test_get_optimizer_by_network(self, mock_vars):
"""Test get_optimizer_by_network."""
mock_optimizer = Optimizer(Tensor(0.1))
mock_cells = MagicMock()
mock_cells.items.return_value = [{'key': mock_optimizer}]
mock_vars.return_value = {
'_cells': {
'key': mock_optimizer
}
}
res = AnalyzeObject.get_optimizer_by_network(MagicMock())
self.assertEqual(res, mock_optimizer)
@mock.patch('builtins.vars')
def test_get_loss_fn_by_network(self, mock_vars):
"""Test get_loss_fn_by_network."""
mock_cell1 = {'_cells': {'key': SoftmaxCrossEntropyWithLogits(0.2)}}
mock_cell2 = {'_cells': {'opt': Optimizer(Tensor(0.1))}}
mock_cell3 = {'_cells': {'loss': SoftmaxCrossEntropyWithLogits(0.1)}}
mock_vars.side_effect = [mock_cell1, mock_cell2, mock_cell3]
res = AnalyzeObject.get_loss_fn_by_network(MagicMock())
self.assertEqual(res, mock_cell3['_cells']['loss'])
@mock.patch('builtins.vars')
def test_get_backbone_network_with_loss_cell(self, mock_vars):
"""Test get_backbone_network with loss_cell."""
mock_cell = {'_cells': {'key': WithLossCell(MagicMock(),
SoftmaxCrossEntropyWithLogits(0.1))}
}
mock_vars.return_value = mock_cell
res = AnalyzeObject.get_backbone_network(MagicMock())
self.assertEqual(res, 'MagicMock')
@mock.patch('builtins.vars')
def test_get_backbone_network(self, mock_vars):
"""Test get_backbone_network."""
mock_net = TrainOneStepWithLossScaleCell()
mock_net.network = MagicMock()
mock_cell = {
'_cells': {
'key': mock_net
}
}
mock_vars.return_value = mock_cell
res = AnalyzeObject.get_backbone_network(MagicMock())
self.assertEqual(res, 'MagicMock')
if __name__ == '__main__':
unittest.main(verbosity=2)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test class EventWriter."""
import os
from unittest import TestCase
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from mindinsight.lineagemgr.summary.summary_record import LineageSummary
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageSummaryAnalyzer
class TestEventWriter(TestCase):
"""Test write_event_to_file."""
def setUp(self):
"""The setup of test."""
self.log_path = "./test.log"
def test_write_event_to_file(self):
"""Test write event to file."""
run_context_args = {"train_network": "res"}
content = LineageSummary.package_train_message(run_context_args).SerializeToString()
event_writer = EventWriter(self.log_path, True)
event_writer.write_event_to_file(content)
lineage_info = LineageSummaryAnalyzer.get_summary_infos(self.log_path)
self.assertEqual(
lineage_info.train_lineage.train_lineage.algorithm.network,
run_context_args.get("train_network")
)
def tearDown(self):
"""The setup of test."""
if os.path.exists(self.log_path):
try:
os.remove(self.log_path)
except IOError:
pass
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test class SummaryRecord."""
import json
from unittest import TestCase, mock
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from mindinsight.lineagemgr.summary.summary_record import LineageSummary
class TestSummaryRecord(TestCase):
"""Test summary record."""
def setUp(self):
"""The setup of test."""
self.run_context_args = dict()
self.run_context_args["train_network"] = "test_train_network"
self.run_context_args["loss"] = 0.1
self.run_context_args["learning_rate"] = 0.1
self.run_context_args["optimizer"] = "test_optimizer"
self.run_context_args["loss_function"] = "test_loss_function"
self.run_context_args["epoch"] = 1
self.run_context_args["parallel_mode"] = "test_parallel_mode"
self.run_context_args["device_num"] = 1
self.run_context_args["batch_size"] = 1
self.run_context_args["train_dataset_path"] = "test_train_dataset_path"
self.run_context_args["train_dataset_size"] = 1
self.run_context_args["model_path"] = "test_model_path"
self.run_context_args["model_size"] = 1
self.eval_args = dict()
self.eval_args["metrics"] = json.dumps({"acc": "test"})
self.eval_args["valid_dataset_path"] = "test_valid_dataset_path"
self.eval_args["valid_dataset_size"] = 1
self.hard_info_args = dict()
self.hard_info_args["pid"] = 1
self.hard_info_args["process_start_time"] = 921226.0
def test_package_train_message(self):
"""Test package_train_message."""
event = LineageSummary.package_train_message(self.run_context_args)
self.assertEqual(
event.train_lineage.algorithm.network, self.run_context_args.get("train_network"))
self.assertEqual(
event.train_lineage.hyper_parameters.optimizer, self.run_context_args.get("optimizer"))
self.assertEqual(
event.train_lineage.train_dataset.train_dataset_path,
self.run_context_args.get("train_dataset_path")
)
@mock.patch.object(EventWriter, "write_event_to_file")
def test_record_train_lineage(self, write_file):
"""Test record_train_lineage."""
write_file.return_value = True
lineage_summray = LineageSummary(lineage_log_dir="test.log")
lineage_summray.record_train_lineage(self.run_context_args)
def test_package_evaluation_message(self):
"""Test package_evaluation_message."""
event = LineageSummary.package_evaluation_message(self.eval_args)
self.assertEqual(event.evaluation_lineage.metric, self.eval_args.get("metrics"))
@mock.patch.object(EventWriter, "write_event_to_file")
def test_record_eval_lineage(self, write_file):
"""Test record_eval_lineage."""
write_file.return_value = True
lineage_summray = LineageSummary(lineage_log_dir="test.log")
lineage_summray.record_evaluation_lineage(self.eval_args)
......@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Lineage writer module."""
from ._summary_record import LineageSummary
__all__ = ["LineageSummary"]
......@@ -17,7 +17,7 @@ import os
import stat
import struct
from mindinsight.datavisual.utils import crc32
from tests.utils import crc32
class EventWriter:
......@@ -84,6 +84,6 @@ class EventWriter:
Returns:
bytes, crc of content, 4 bytes.
"""
crc_value = crc32.GetMaskCrc32cValue(content, len(content))
crc_value = crc32.get_mask_from_string(content)
return struct.pack("<L", crc_value)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import socket
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError
from mindinsight.lineagemgr.common.log import logger as log
# Set the Event mark
EVENT_FILE_NAME_MARK = "out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK = "_lineage"
def package_dataset_graph(graph):
"""
Package dataset graph.
Args:
graph (dict): Dataset graph.
Returns:
LineageEvent, the proto message event contains dataset graph.
"""
dataset_graph_event = LineageEvent()
dataset_graph_event.wall_time = time.time()
dataset_graph = dataset_graph_event.dataset_graph
if "children" in graph:
children = graph.pop("children")
if children:
_package_children(children=children, message=dataset_graph)
_package_current_dataset(operation=graph, message=dataset_graph)
return dataset_graph_event
def _package_children(children, message):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for child in children:
if child:
child_graph_message = getattr(message, "children").add()
grandson = child.pop("children")
if grandson:
_package_children(children=grandson, message=child_graph_message)
# package other parameters
_package_current_dataset(operation=child, message=child_graph_message)
def _package_current_dataset(operation, message):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for key, value in operation.items():
if value and key == "operations":
for operator in value:
_package_enhancement_operation(
operator,
message.operations.add()
)
elif value and key == "sampler":
_package_enhancement_operation(
value,
message.sampler
)
else:
_package_parameter(key, value, message.parameter)
def _package_enhancement_operation(operation, message):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for key, value in operation.items():
if isinstance(value, list):
if all(isinstance(ele, int) for ele in value):
message.size.extend(value)
else:
message.weights.extend(value)
else:
_package_parameter(key, value, message.operationParam)
def _package_parameter(key, value, message):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if isinstance(value, str):
message.mapStr[key] = value
elif isinstance(value, bool):
message.mapBool[key] = value
elif isinstance(value, int):
message.mapInt[key] = value
elif isinstance(value, float):
message.mapDouble[key] = value
elif isinstance(value, list) and key != "operations":
if value:
replace_value_list = list(map(lambda x: "" if x is None else x, value))
message.mapStrList[key].strValue.extend(replace_value_list)
elif value is None:
message.mapStr[key] = "None"
else:
error_msg = "Parameter {} is not supported " \
"in event package.".format(key)
log.error(error_msg)
raise LineageParamTypeError(error_msg)
def package_user_defined_info(user_dict):
"""
Package user defined info.
Args:
user_dict(dict): User defined info dict.
Returns:
LineageEvent, the proto message event contains user defined info.
"""
user_event = LineageEvent()
user_event.wall_time = time.time()
user_defined_info = user_event.user_defined_info
_package_user_defined_info(user_dict, user_defined_info)
return user_event
def _package_user_defined_info(user_defined_dict, user_defined_message):
"""
Setting attribute in user defined proto message.
Args:
user_defined_dict (dict): User define info dict.
user_defined_message (LineageEvent): Proto message of user defined info.
Raises:
LineageParamValueError: When the value is out of range.
LineageParamTypeError: When given a type not support yet.
"""
for key, value in user_defined_dict.items():
if not isinstance(key, str):
error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)
if isinstance(value, int):
attr_name = "map_int32"
elif isinstance(value, float):
attr_name = "map_double"
elif isinstance(value, str):
attr_name = "map_str"
else:
attr_name = "attr_name"
add_user_defined_info = user_defined_message.user_info.add()
try:
getattr(add_user_defined_info, attr_name)[key] = value
except AttributeError:
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
def get_lineage_file_name():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second = str(int(time.time()))
hostname = socket.gethostname()
file_name = f'{EVENT_FILE_NAME_MARK}summary.{time_second}.{hostname}{LINEAGE_FILE_NAME_MARK}'
return file_name
......@@ -17,8 +17,7 @@ import os
import time
from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent
from mindinsight.lineagemgr.common.validator.validate import validate_file_path
from mindinsight.lineagemgr.summary.event_writer import EventWriter
from ._event_writer import EventWriter
from ._summary_adapter import package_dataset_graph, package_user_defined_info, get_lineage_file_name
......@@ -40,11 +39,10 @@ class LineageSummary:
>>> lineage_summary.record_train_lineage(train_lineage)
"""
def __init__(self,
lineage_log_dir=None,
lineage_log_dir,
override=False):
lineage_log_name = get_lineage_file_name()
self.lineage_log_path = os.path.join(lineage_log_dir, lineage_log_name)
validate_file_path(self.lineage_log_path)
self.event_writer = EventWriter(self.lineage_log_path, override)
@staticmethod
......
......@@ -17,22 +17,20 @@ import json
import os
import numpy as np
from mindinsight.lineagemgr.summary.summary_record import LineageSummary
from mindinsight.utils.exceptions import \
MindInsightException
from mindinsight.lineagemgr.common.validator.validate import validate_train_run_context, \
validate_eval_run_context, validate_file_path, validate_network, \
validate_int_params, validate_summary_record, validate_raise_exception,\
validate_user_defined_info
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \
LineageGetModelFileError, LineageLogError
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, LineageErrors
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
LineageParamRunContextError)
from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.utils import try_except, make_directory
from mindinsight.lineagemgr.common.validator.model_parameter import RunContextArgs, \
EvalParameter
from mindinsight.lineagemgr.collection.model.base import Metadata
from mindinsight.lineagemgr.common.utils import make_directory, try_except
from mindinsight.lineagemgr.common.validator.model_parameter import EvalParameter
from mindinsight.lineagemgr.common.validator.validate import (validate_eval_run_context, validate_file_path,
validate_int_params,
validate_raise_exception,
validate_user_defined_info)
from mindinsight.utils.exceptions import MindInsightException
from ._summary_record import LineageSummary
from .base import Metadata
try:
from mindspore.common.tensor import Tensor
......@@ -91,7 +89,6 @@ class TrainLineage(Callback):
# make directory if not exist
self.lineage_log_dir = make_directory(summary_record)
else:
validate_summary_record(summary_record)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_dir = os.path.dirname(summary_log_path)
......@@ -145,7 +142,6 @@ class TrainLineage(Callback):
log.debug('initial_learning_rate: %s', self.initial_learning_rate)
else:
network = run_context_args.get('train_network')
validate_network(network)
optimizer = AnalyzeObject.get_optimizer_by_network(network)
self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
log.debug('initial_learning_rate: %s', self.initial_learning_rate)
......@@ -183,7 +179,6 @@ class TrainLineage(Callback):
raise LineageParamRunContextError(error_msg)
run_context_args = run_context.original_args()
validate_train_run_context(RunContextArgs, run_context_args)
train_lineage = dict()
train_lineage = AnalyzeObject.get_network_args(
......@@ -277,7 +272,6 @@ class EvalLineage(Callback):
# make directory if not exist
self.lineage_log_dir = make_directory(summary_record)
else:
validate_summary_record(summary_record)
summary_log_path = summary_record.full_file_name
validate_file_path(summary_log_path)
self.lineage_log_dir = os.path.dirname(summary_log_path)
......@@ -639,7 +633,6 @@ class AnalyzeObject:
dict, the lineage metadata.
"""
network = run_context_args.get('train_network')
validate_network(network)
optimizer = run_context_args.get('optimizer')
if not optimizer:
optimizer = AnalyzeObject.get_optimizer_by_network(network)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册