提交 2ddb2b9c 编写于 作者: L luopengting

merge model_lineage and dataset_lineage, modify/add st and ut for lineage api

上级 15a7ad78
......@@ -27,52 +27,20 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError
BLUEPRINT = Blueprint("lineage", __name__, url_prefix=settings.URL_PREFIX.rstrip("/"))
@BLUEPRINT.route("/models/model_lineage", methods=["POST"])
def search_model():
@BLUEPRINT.route("/lineagemgr/lineages", methods=["POST"])
def get_lineage():
"""
Get model lineage info.
Get model info by summary base dir return a model lineage information list of dict
contains model's all kinds of param and count of summary log.
Returns:
str, the model lineage information.
Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.
Examples:
>>> POST http://xxxx/v1/mindinsight/models/model_lineage
"""
search_condition = request.stream.read()
try:
search_condition = json.loads(search_condition if search_condition else "{}")
except Exception:
raise ParamValueError("Json data parse failed.")
model_lineage_info = _get_lineage_info(
lineage_type="model",
search_condition=search_condition
)
return jsonify(model_lineage_info)
@BLUEPRINT.route("/datasets/dataset_lineage", methods=["POST"])
def get_datasets_lineage():
"""
Get dataset lineage.
Get lineage.
Returns:
str, the dataset lineage information.
str, the lineage information.
Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.
Examples:
>>> POST http://xxxx/v1/minddata/datasets/dataset_lineage
>>> POST http://xxxx/v1/mindinsight/lineagemgr/lineages
"""
search_condition = request.stream.read()
try:
......@@ -80,20 +48,16 @@ def get_datasets_lineage():
except Exception:
raise ParamValueError("Json data parse failed.")
dataset_lineage_info = _get_lineage_info(
lineage_type="dataset",
search_condition=search_condition
)
lineage_info = _get_lineage_info(search_condition=search_condition)
return jsonify(dataset_lineage_info)
return jsonify(lineage_info)
def _get_lineage_info(lineage_type, search_condition):
def _get_lineage_info(search_condition):
"""
Get lineage info for dataset or model.
Args:
lineage_type (str): Lineage type, 'dataset' or 'model'.
search_condition (dict): Search condition.
Returns:
......@@ -102,10 +66,6 @@ def _get_lineage_info(lineage_type, search_condition):
Raises:
MindInsightException: If method fails to be called.
"""
if 'lineage_type' in search_condition:
raise ParamValueError("Lineage type does not need to be assigned in a specific interface.")
if lineage_type == 'dataset':
search_condition.update({'lineage_type': 'dataset'})
summary_base_dir = str(settings.SUMMARY_BASE_DIR)
try:
lineage_info = filter_summary_lineage(
......
......@@ -262,8 +262,6 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition):
return search_condition
summary_dir_condition = search_condition.get("summary_dir")
if not set(summary_dir_condition.keys()).issubset(['in', 'eq']):
raise LineageParamValueError("Invalid operation of summary dir.")
if 'in' in summary_dir_condition:
summary_paths = []
......
......@@ -193,7 +193,7 @@ class LineageErrorMsg(Enum):
"It should be a string."
LINEAGE_PARAM_LINEAGE_TYPE_ERROR = "The parameter lineage_type is invalid. " \
"It should be None, 'dataset' or 'model'."
"It should be 'dataset' or 'model'."
SUMMARY_ANALYZE_ERROR = "Failed to analyze summary log. {}"
SUMMARY_VERIFICATION_ERROR = "Verification failed in summary analysis. {}"
......
......@@ -14,7 +14,7 @@
# ============================================================================
"""Define schema of model lineage input parameters."""
from marshmallow import Schema, fields, ValidationError, pre_load, validates
from marshmallow.validate import Range, OneOf
from marshmallow.validate import Range
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
LineageErrors
......@@ -129,10 +129,7 @@ class SearchModelConditionParameter(Schema):
offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
sorted_name = fields.Str()
sorted_type = fields.Str(allow_none=True)
lineage_type = fields.Str(
validate=OneOf(enum_to_list(LineageType)),
allow_none=True
)
lineage_type = fields.Dict()
@staticmethod
def check_dict_value_type(data, value_type):
......@@ -174,53 +171,79 @@ class SearchModelConditionParameter(Schema):
@validates("loss_function")
def check_loss_function(self, data):
"""Check loss function."""
SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("train_dataset_path")
def check_train_dataset_path(self, data):
"""Check train dataset path."""
SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("train_dataset_count")
def check_train_dataset_count(self, data):
"""Check train dataset count."""
SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("test_dataset_path")
def check_test_dataset_path(self, data):
"""Check test dataset path."""
SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("test_dataset_count")
def check_test_dataset_count(self, data):
"""Check test dataset count."""
SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("network")
def check_network(self, data):
"""Check network."""
SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("optimizer")
def check_optimizer(self, data):
"""Check optimizer."""
SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("epoch")
def check_epoch(self, data):
"""Check epoch."""
SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("batch_size")
def check_batch_size(self, data):
"""Check batch size."""
SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("model_size")
def check_model_size(self, data):
"""Check model size."""
SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("summary_dir")
def check_summary_dir(self, data):
"""Check summary dir."""
SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("lineage_type")
def check_lineage_type(self, data):
"""Check lineage type."""
SearchModelConditionParameter.check_dict_value_type(data, str)
recv_types = []
for key, value in data.items():
if key == "in":
recv_types = value
else:
recv_types.append(value)
lineage_types = enum_to_list(LineageType)
if not set(recv_types).issubset(lineage_types):
raise ValidationError("Given lineage type should be one of %s." % lineage_types)
@pre_load
def check_comparision(self, data, **kwargs):
"""Check comparision for all parameters in schema."""
for attr, condition in data.items():
if attr in ["limit", "offset", "sorted_name", "sorted_type", "lineage_type"]:
if attr in ["limit", "offset", "sorted_name", "sorted_type"]:
continue
if not isinstance(attr, str):
......@@ -233,6 +256,13 @@ class SearchModelConditionParameter(Schema):
raise LineageParamTypeError("The search_condition element {} should be dict."
.format(attr))
if attr in ["summary_dir", "lineage_type"]:
if not set(condition.keys()).issubset(['in', 'eq']):
raise LineageParamValueError("Invalid operation of %s." % attr)
if len(condition.keys()) > 1:
raise LineageParamValueError("More than one operation of %s." % attr)
continue
for key in condition.keys():
if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
raise LineageParamValueError("The compare condition should be in "
......
......@@ -23,6 +23,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageEventNotExistException, LineageQuerierParamException, \
LineageSummaryParseException, LineageEventFieldNotExistException
from mindinsight.lineagemgr.common.log import logger
from mindinsight.lineagemgr.common.utils import enum_to_list
from mindinsight.lineagemgr.querier.query_model import LineageObj, FIELD_MAPPING
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \
LineageSummaryAnalyzer
......@@ -318,18 +319,46 @@ class Querier:
customized[label]["required"] = True
customized[label]["type"] = type(value).__name__
search_type = condition.get(ConditionParam.LINEAGE_TYPE.value)
lineage_types = condition.get(ConditionParam.LINEAGE_TYPE.value)
lineage_types = self._get_lineage_types(lineage_types)
object_items = []
for item in offset_results:
lineage_object = dict()
if LineageType.MODEL.value in lineage_types:
lineage_object.update(item.to_model_lineage_dict())
if LineageType.DATASET.value in lineage_types:
lineage_object.update(item.to_dataset_lineage_dict())
object_items.append(lineage_object)
lineage_info = {
'customized': customized,
'object': [
item.to_dataset_lineage_dict() if search_type == LineageType.DATASET.value
else item.to_filtration_dict() for item in offset_results
],
'object': object_items,
'count': len(results)
}
return lineage_info
def _get_lineage_types(self, lineage_type_param):
"""
Get lineage types.
Args:
lineage_type_param (dict): A dict contains "in" or "eq".
Returns:
list, lineage type.
"""
# lineage_type_param is None or an empty dict
if not lineage_type_param:
return enum_to_list(LineageType)
if lineage_type_param.get("in") is not None:
return lineage_type_param.get("in")
return [lineage_type_param.get("eq")]
def _is_valid_field(self, field_name):
"""
Check if field name is valid.
......
......@@ -38,6 +38,7 @@ FIELD_MAPPING = {
"loss": Field('algorithm', 'loss'),
"model_size": Field('model', 'size'),
"dataset_mark": Field('dataset_mark', None),
"lineage_type": Field(None, None)
}
......@@ -75,6 +76,7 @@ class LineageObj:
_name_dataset_graph = 'dataset_graph'
_name_dataset_mark = 'dataset_mark'
_name_user_defined = 'user_defined'
_name_model_lineage = 'model_lineage'
def __init__(self, summary_dir, **kwargs):
self._lineage_info = {
......@@ -227,15 +229,6 @@ class LineageObj:
result[key] = getattr(self, key)
return result
def to_filtration_dict(self):
"""
Returns the lineage information required by filtering interface.
Returns:
dict, the lineage information required by filtering interface.
"""
return self._filtration_result
def to_dataset_lineage_dict(self):
"""
Returns the dataset part lineage information.
......@@ -250,6 +243,22 @@ class LineageObj:
return dataset_lineage
def to_model_lineage_dict(self):
"""
Returns the model part lineage information.
Returns:
dict, the model lineage information.
"""
filtration_result = dict(self._filtration_result)
filtration_result.pop(self._name_dataset_graph)
model_lineage = dict()
model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)})
model_lineage.update({self._name_model_lineage: filtration_result})
return model_lineage
def get_value_by_key(self, key):
"""
Get the value based on the key in `FIELD_MAPPING` or
......
......@@ -20,7 +20,6 @@ Usage:
The query module test should be run after lineagemgr/collection/model/test_model_lineage.py
pytest lineagemgr
"""
import os
from unittest import TestCase
......@@ -66,64 +65,70 @@ LINEAGE_INFO_RUN1 = {
}
LINEAGE_FILTRATION_EXCEPT_RUN = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'user_defined': {},
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
'model_lineage': {
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'user_defined': {},
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_mark': 2
},
'dataset_graph': DATASET_GRAPH
}
LINEAGE_FILTRATION_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'user_defined': {},
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
'model_lineage': {
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'user_defined': {},
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_mark': 2
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
'dataset_graph': DATASET_GRAPH
}
LINEAGE_FILTRATION_RUN2 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'user_defined': {},
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
'model_lineage': {
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'user_defined': {},
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_mark': 3
},
'dataset_graph': {},
'dataset_mark': 3
'dataset_graph': {}
}
......@@ -150,6 +155,14 @@ class TestModelApi(TestCase):
cls.empty_dir = os.path.join(BASE_SUMMARY_DIR, 'empty_dir')
os.makedirs(cls.empty_dir)
def generate_lineage_object(self, lineage):
lineage = dict(lineage)
lineage_object = dict()
lineage_object.update({'summary_dir': lineage.pop('summary_dir')})
lineage_object.update({'dataset_graph': lineage.pop('dataset_graph')})
lineage_object.update({'model_lineage': lineage})
return lineage_object
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......@@ -337,7 +350,7 @@ class TestModelApi(TestCase):
res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark')
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
expect_result = {
......@@ -347,7 +360,7 @@ class TestModelApi(TestCase):
res = filter_summary_lineage(self.dir_with_empty_lineage)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark')
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
@pytest.mark.level0
......@@ -385,7 +398,7 @@ class TestModelApi(TestCase):
partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark')
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res
@pytest.mark.level0
......@@ -423,7 +436,7 @@ class TestModelApi(TestCase):
partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark')
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res
@pytest.mark.level0
......@@ -439,7 +452,6 @@ class TestModelApi(TestCase):
'ge': 30
},
'sorted_name': 'metric/accuracy',
'lineage_type': None
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
......@@ -452,14 +464,16 @@ class TestModelApi(TestCase):
partial_res1 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition1)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res1.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark')
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res1
search_condition2 = {
'batch_size': {
'lt': 30
},
'lineage_type': 'model'
'lineage_type': {
'eq': 'model'
},
}
expect_result = {
'customized': {},
......@@ -469,7 +483,7 @@ class TestModelApi(TestCase):
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res2.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark')
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res2
@pytest.mark.level0
......@@ -485,7 +499,9 @@ class TestModelApi(TestCase):
'summary_dir': {
'in': [summary_dir]
},
'lineage_type': 'dataset'
'lineage_type': {
'eq': 'dataset'
},
}
expect_result = {
'customized': {},
......@@ -705,15 +721,29 @@ class TestModelApi(TestCase):
search_condition
)
# the condition type not supported in summary dir
search_condition = {
'summary_dir': {
'lt': '/xxx'
'lineage_type': {
'in': [
'xxx'
]
}
}
self.assertRaisesRegex(
LineageParamSummaryPathError,
'Invalid operation of summary dir.',
LineageSearchConditionParamError,
"The parameter lineage_type is invalid. It should be 'dataset' or 'model'.",
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
)
search_condition = {
'lineage_type': {
'eq': None
}
}
self.assertRaisesRegex(
LineageSearchConditionParamError,
"The parameter lineage_type is invalid. It should be 'dataset' or 'model'.",
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
......@@ -779,3 +809,42 @@ class TestModelApi(TestCase):
}
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
assert expect_result == partial_res2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_7(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
condition_keys = ["summary_dir", "lineage_type"]
for condition_key in condition_keys:
# the condition type not supported in summary_dir and lineage_type
search_condition = {
condition_key: {
'lt': '/xxx'
}
}
self.assertRaisesRegex(
LineageSearchConditionParamError,
f'Invalid operation of {condition_key}.',
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
)
# more than one operation in summary_dir and lineage_type
search_condition = {
condition_key: {
'in': ['/xxx', '/yyy'],
'eq': '/zzz',
}
}
self.assertRaisesRegex(
LineageSearchConditionParamError,
f'More than one operation of {condition_key}.',
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
)
......@@ -67,7 +67,7 @@ class TestSearchModel(TestCase):
"""Test init."""
APP.response_class = Response
self.app_client = APP.test_client()
self.url = '/v1/mindinsight/models/model_lineage'
self.url = '/v1/mindinsight/lineagemgr/lineages'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage')
......@@ -78,11 +78,11 @@ class TestSearchModel(TestCase):
'object': [
{
'summary_dir': base_dir,
**LINEAGE_FILTRATION_BASE
'model_lineage': LINEAGE_FILTRATION_BASE
},
{
'summary_dir': os.path.join(base_dir, 'run1'),
**LINEAGE_FILTRATION_RUN1
'model_lineage': LINEAGE_FILTRATION_RUN1
}
],
'count': 2
......@@ -101,11 +101,11 @@ class TestSearchModel(TestCase):
'object': [
{
'summary_dir': './',
**LINEAGE_FILTRATION_BASE
'model_lineage': LINEAGE_FILTRATION_BASE
},
{
'summary_dir': './run1',
**LINEAGE_FILTRATION_RUN1
'model_lineage': LINEAGE_FILTRATION_RUN1
}
],
'count': 2
......
......@@ -131,18 +131,6 @@ class TestModel(TestCase):
self.assertDictEqual(
result, search_condition
)
search_condition = {
'summary_dir': {
'gt': 3
}
}
self.assertRaisesRegex(
LineageParamValueError,
'Invalid operation of summary dir',
_convert_relative_path_to_abspath,
summary_base_dir,
search_condition
)
class TestFilterAPI(TestCase):
......
......@@ -82,22 +82,24 @@ def create_filtration_result(summary_dir, train_event_dict,
"""
filtration_result = {
"summary_dir": summary_dir,
"loss_function": train_event_dict['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": train_event_dict['train_lineage']['train_dataset']['train_dataset_path'],
"train_dataset_count": train_event_dict['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": train_event_dict['train_lineage']['algorithm']['network'],
"optimizer": train_event_dict['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate": train_event_dict['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": train_event_dict['train_lineage']['hyper_parameters']['epoch'],
"batch_size": train_event_dict['train_lineage']['hyper_parameters']['batch_size'],
"loss": train_event_dict['train_lineage']['algorithm']['loss'],
"model_size": train_event_dict['train_lineage']['model']['size'],
"metric": metric_dict,
"model_lineage": {
"loss_function": train_event_dict['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": train_event_dict['train_lineage']['train_dataset']['train_dataset_path'],
"train_dataset_count": train_event_dict['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": train_event_dict['train_lineage']['algorithm']['network'],
"optimizer": train_event_dict['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate": train_event_dict['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": train_event_dict['train_lineage']['hyper_parameters']['epoch'],
"batch_size": train_event_dict['train_lineage']['hyper_parameters']['batch_size'],
"loss": train_event_dict['train_lineage']['algorithm']['loss'],
"model_size": train_event_dict['train_lineage']['model']['size'],
"metric": metric_dict,
"dataset_mark": '2',
"user_defined": {}
},
"dataset_graph": dataset_dict,
"dataset_mark": '2',
"user_defined": {}
}
return filtration_result
......@@ -192,47 +194,50 @@ LINEAGE_FILTRATION_4 = create_filtration_result(
)
LINEAGE_FILTRATION_5 = {
"summary_dir": '/path/to/summary5',
"loss_function":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": None,
"train_dataset_count":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": None,
"test_dataset_count": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"metric": {},
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2',
"user_defined": {}
"model_lineage": {
"loss_function":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": None,
"train_dataset_count":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": None,
"test_dataset_count": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"metric": {},
"dataset_mark": '2',
"user_defined": {}
},
"dataset_graph": event_data.DATASET_DICT_0
}
LINEAGE_FILTRATION_6 = {
"summary_dir": '/path/to/summary6',
"loss_function": None,
"train_dataset_path": None,
"train_dataset_count": None,
"test_dataset_path":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": None,
"optimizer": None,
"learning_rate": None,
"epoch": None,
"batch_size": None,
"loss": None,
"model_size": None,
"metric": event_data.METRIC_5,
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2',
"user_defined": {}
"model_lineage": {
"loss_function": None,
"train_dataset_path": None,
"train_dataset_count": None,
"test_dataset_path":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": None,
"optimizer": None,
"learning_rate": None,
"epoch": None,
"batch_size": None,
"loss": None,
"model_size": None,
"metric": event_data.METRIC_5,
"dataset_mark": '2',
"user_defined": {}
},
"dataset_graph": event_data.DATASET_DICT_0
}
......
......@@ -108,8 +108,8 @@ class TestLineageObj(TestCase):
result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result)
def test_to_filtration_dict(self):
"""Test the function of to_filtration_dict."""
def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict."""
expected_result = create_filtration_result(
self.summary_dir,
event_data.EVENT_TRAIN_DICT_0,
......@@ -117,8 +117,18 @@ class TestLineageObj(TestCase):
event_data.METRIC_0,
event_data.DATASET_DICT_0
)
expected_result['dataset_mark'] = None
result = self.lineage_obj.to_filtration_dict()
expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result)
def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict."""
expected_result = {
"summary_dir": self.summary_dir,
"dataset_graph": event_data.DATASET_DICT_0
}
result = self.lineage_obj.to_dataset_lineage_dict()
self.assertDictEqual(expected_result, result)
def test_get_value_by_key(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册