提交 2ddb2b9c 编写于 作者: L luopengting

merge model_lineage and dataset_lineage, modify/add st and ut for lineage api

上级 15a7ad78
...@@ -27,52 +27,20 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError ...@@ -27,52 +27,20 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError
BLUEPRINT = Blueprint("lineage", __name__, url_prefix=settings.URL_PREFIX.rstrip("/")) BLUEPRINT = Blueprint("lineage", __name__, url_prefix=settings.URL_PREFIX.rstrip("/"))
@BLUEPRINT.route("/models/model_lineage", methods=["POST"]) @BLUEPRINT.route("/lineagemgr/lineages", methods=["POST"])
def search_model(): def get_lineage():
""" """
Get model lineage info. Get lineage.
Get model info by summary base dir return a model lineage information list of dict
contains model's all kinds of param and count of summary log.
Returns:
str, the model lineage information.
Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.
Examples:
>>> POST http://xxxx/v1/mindinsight/models/model_lineage
"""
search_condition = request.stream.read()
try:
search_condition = json.loads(search_condition if search_condition else "{}")
except Exception:
raise ParamValueError("Json data parse failed.")
model_lineage_info = _get_lineage_info(
lineage_type="model",
search_condition=search_condition
)
return jsonify(model_lineage_info)
@BLUEPRINT.route("/datasets/dataset_lineage", methods=["POST"])
def get_datasets_lineage():
"""
Get dataset lineage.
Returns: Returns:
str, the dataset lineage information. str, the lineage information.
Raises: Raises:
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails. ParamValueError: If parsing json data search_condition fails.
Examples: Examples:
>>> POST http://xxxx/v1/minddata/datasets/dataset_lineage >>> POST http://xxxx/v1/mindinsight/lineagemgr/lineages
""" """
search_condition = request.stream.read() search_condition = request.stream.read()
try: try:
...@@ -80,20 +48,16 @@ def get_datasets_lineage(): ...@@ -80,20 +48,16 @@ def get_datasets_lineage():
except Exception: except Exception:
raise ParamValueError("Json data parse failed.") raise ParamValueError("Json data parse failed.")
dataset_lineage_info = _get_lineage_info( lineage_info = _get_lineage_info(search_condition=search_condition)
lineage_type="dataset",
search_condition=search_condition
)
return jsonify(dataset_lineage_info) return jsonify(lineage_info)
def _get_lineage_info(lineage_type, search_condition): def _get_lineage_info(search_condition):
""" """
Get lineage info for dataset or model. Get lineage info for dataset or model.
Args: Args:
lineage_type (str): Lineage type, 'dataset' or 'model'.
search_condition (dict): Search condition. search_condition (dict): Search condition.
Returns: Returns:
...@@ -102,10 +66,6 @@ def _get_lineage_info(lineage_type, search_condition): ...@@ -102,10 +66,6 @@ def _get_lineage_info(lineage_type, search_condition):
Raises: Raises:
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.
""" """
if 'lineage_type' in search_condition:
raise ParamValueError("Lineage type does not need to be assigned in a specific interface.")
if lineage_type == 'dataset':
search_condition.update({'lineage_type': 'dataset'})
summary_base_dir = str(settings.SUMMARY_BASE_DIR) summary_base_dir = str(settings.SUMMARY_BASE_DIR)
try: try:
lineage_info = filter_summary_lineage( lineage_info = filter_summary_lineage(
......
...@@ -262,8 +262,6 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition): ...@@ -262,8 +262,6 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition):
return search_condition return search_condition
summary_dir_condition = search_condition.get("summary_dir") summary_dir_condition = search_condition.get("summary_dir")
if not set(summary_dir_condition.keys()).issubset(['in', 'eq']):
raise LineageParamValueError("Invalid operation of summary dir.")
if 'in' in summary_dir_condition: if 'in' in summary_dir_condition:
summary_paths = [] summary_paths = []
......
...@@ -193,7 +193,7 @@ class LineageErrorMsg(Enum): ...@@ -193,7 +193,7 @@ class LineageErrorMsg(Enum):
"It should be a string." "It should be a string."
LINEAGE_PARAM_LINEAGE_TYPE_ERROR = "The parameter lineage_type is invalid. " \ LINEAGE_PARAM_LINEAGE_TYPE_ERROR = "The parameter lineage_type is invalid. " \
"It should be None, 'dataset' or 'model'." "It should be 'dataset' or 'model'."
SUMMARY_ANALYZE_ERROR = "Failed to analyze summary log. {}" SUMMARY_ANALYZE_ERROR = "Failed to analyze summary log. {}"
SUMMARY_VERIFICATION_ERROR = "Verification failed in summary analysis. {}" SUMMARY_VERIFICATION_ERROR = "Verification failed in summary analysis. {}"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""Define schema of model lineage input parameters.""" """Define schema of model lineage input parameters."""
from marshmallow import Schema, fields, ValidationError, pre_load, validates from marshmallow import Schema, fields, ValidationError, pre_load, validates
from marshmallow.validate import Range, OneOf from marshmallow.validate import Range
from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \ from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
LineageErrors LineageErrors
...@@ -129,10 +129,7 @@ class SearchModelConditionParameter(Schema): ...@@ -129,10 +129,7 @@ class SearchModelConditionParameter(Schema):
offset = fields.Int(validate=lambda n: 0 <= n <= 100000) offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
sorted_name = fields.Str() sorted_name = fields.Str()
sorted_type = fields.Str(allow_none=True) sorted_type = fields.Str(allow_none=True)
lineage_type = fields.Str( lineage_type = fields.Dict()
validate=OneOf(enum_to_list(LineageType)),
allow_none=True
)
@staticmethod @staticmethod
def check_dict_value_type(data, value_type): def check_dict_value_type(data, value_type):
...@@ -174,53 +171,79 @@ class SearchModelConditionParameter(Schema): ...@@ -174,53 +171,79 @@ class SearchModelConditionParameter(Schema):
@validates("loss_function") @validates("loss_function")
def check_loss_function(self, data): def check_loss_function(self, data):
"""Check loss function."""
SearchModelConditionParameter.check_dict_value_type(data, str) SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("train_dataset_path") @validates("train_dataset_path")
def check_train_dataset_path(self, data): def check_train_dataset_path(self, data):
"""Check train dataset path."""
SearchModelConditionParameter.check_dict_value_type(data, str) SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("train_dataset_count") @validates("train_dataset_count")
def check_train_dataset_count(self, data): def check_train_dataset_count(self, data):
"""Check train dataset count."""
SearchModelConditionParameter.check_dict_value_type(data, int) SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("test_dataset_path") @validates("test_dataset_path")
def check_test_dataset_path(self, data): def check_test_dataset_path(self, data):
"""Check test dataset path."""
SearchModelConditionParameter.check_dict_value_type(data, str) SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("test_dataset_count") @validates("test_dataset_count")
def check_test_dataset_count(self, data): def check_test_dataset_count(self, data):
"""Check test dataset count."""
SearchModelConditionParameter.check_dict_value_type(data, int) SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("network") @validates("network")
def check_network(self, data): def check_network(self, data):
"""Check network."""
SearchModelConditionParameter.check_dict_value_type(data, str) SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("optimizer") @validates("optimizer")
def check_optimizer(self, data): def check_optimizer(self, data):
"""Check optimizer."""
SearchModelConditionParameter.check_dict_value_type(data, str) SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("epoch") @validates("epoch")
def check_epoch(self, data): def check_epoch(self, data):
"""Check epoch."""
SearchModelConditionParameter.check_dict_value_type(data, int) SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("batch_size") @validates("batch_size")
def check_batch_size(self, data): def check_batch_size(self, data):
"""Check batch size."""
SearchModelConditionParameter.check_dict_value_type(data, int) SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("model_size") @validates("model_size")
def check_model_size(self, data): def check_model_size(self, data):
"""Check model size."""
SearchModelConditionParameter.check_dict_value_type(data, int) SearchModelConditionParameter.check_dict_value_type(data, int)
@validates("summary_dir") @validates("summary_dir")
def check_summary_dir(self, data): def check_summary_dir(self, data):
"""Check summary dir."""
SearchModelConditionParameter.check_dict_value_type(data, str) SearchModelConditionParameter.check_dict_value_type(data, str)
@validates("lineage_type")
def check_lineage_type(self, data):
"""Check lineage type."""
SearchModelConditionParameter.check_dict_value_type(data, str)
recv_types = []
for key, value in data.items():
if key == "in":
recv_types = value
else:
recv_types.append(value)
lineage_types = enum_to_list(LineageType)
if not set(recv_types).issubset(lineage_types):
raise ValidationError("Given lineage type should be one of %s." % lineage_types)
@pre_load @pre_load
def check_comparision(self, data, **kwargs): def check_comparision(self, data, **kwargs):
"""Check comparision for all parameters in schema.""" """Check comparision for all parameters in schema."""
for attr, condition in data.items(): for attr, condition in data.items():
if attr in ["limit", "offset", "sorted_name", "sorted_type", "lineage_type"]: if attr in ["limit", "offset", "sorted_name", "sorted_type"]:
continue continue
if not isinstance(attr, str): if not isinstance(attr, str):
...@@ -233,6 +256,13 @@ class SearchModelConditionParameter(Schema): ...@@ -233,6 +256,13 @@ class SearchModelConditionParameter(Schema):
raise LineageParamTypeError("The search_condition element {} should be dict." raise LineageParamTypeError("The search_condition element {} should be dict."
.format(attr)) .format(attr))
if attr in ["summary_dir", "lineage_type"]:
if not set(condition.keys()).issubset(['in', 'eq']):
raise LineageParamValueError("Invalid operation of %s." % attr)
if len(condition.keys()) > 1:
raise LineageParamValueError("More than one operation of %s." % attr)
continue
for key in condition.keys(): for key in condition.keys():
if key not in ["eq", "lt", "gt", "le", "ge", "in"]: if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
raise LineageParamValueError("The compare condition should be in " raise LineageParamValueError("The compare condition should be in "
......
...@@ -23,6 +23,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \ ...@@ -23,6 +23,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageEventNotExistException, LineageQuerierParamException, \ LineageEventNotExistException, LineageQuerierParamException, \
LineageSummaryParseException, LineageEventFieldNotExistException LineageSummaryParseException, LineageEventFieldNotExistException
from mindinsight.lineagemgr.common.log import logger from mindinsight.lineagemgr.common.log import logger
from mindinsight.lineagemgr.common.utils import enum_to_list
from mindinsight.lineagemgr.querier.query_model import LineageObj, FIELD_MAPPING from mindinsight.lineagemgr.querier.query_model import LineageObj, FIELD_MAPPING
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \
LineageSummaryAnalyzer LineageSummaryAnalyzer
...@@ -318,18 +319,46 @@ class Querier: ...@@ -318,18 +319,46 @@ class Querier:
customized[label]["required"] = True customized[label]["required"] = True
customized[label]["type"] = type(value).__name__ customized[label]["type"] = type(value).__name__
search_type = condition.get(ConditionParam.LINEAGE_TYPE.value) lineage_types = condition.get(ConditionParam.LINEAGE_TYPE.value)
lineage_types = self._get_lineage_types(lineage_types)
object_items = []
for item in offset_results:
lineage_object = dict()
if LineageType.MODEL.value in lineage_types:
lineage_object.update(item.to_model_lineage_dict())
if LineageType.DATASET.value in lineage_types:
lineage_object.update(item.to_dataset_lineage_dict())
object_items.append(lineage_object)
lineage_info = { lineage_info = {
'customized': customized, 'customized': customized,
'object': [ 'object': object_items,
item.to_dataset_lineage_dict() if search_type == LineageType.DATASET.value
else item.to_filtration_dict() for item in offset_results
],
'count': len(results) 'count': len(results)
} }
return lineage_info return lineage_info
def _get_lineage_types(self, lineage_type_param):
"""
Get lineage types.
Args:
lineage_type_param (dict): A dict contains "in" or "eq".
Returns:
list, lineage type.
"""
# lineage_type_param is None or an empty dict
if not lineage_type_param:
return enum_to_list(LineageType)
if lineage_type_param.get("in") is not None:
return lineage_type_param.get("in")
return [lineage_type_param.get("eq")]
def _is_valid_field(self, field_name): def _is_valid_field(self, field_name):
""" """
Check if field name is valid. Check if field name is valid.
......
...@@ -38,6 +38,7 @@ FIELD_MAPPING = { ...@@ -38,6 +38,7 @@ FIELD_MAPPING = {
"loss": Field('algorithm', 'loss'), "loss": Field('algorithm', 'loss'),
"model_size": Field('model', 'size'), "model_size": Field('model', 'size'),
"dataset_mark": Field('dataset_mark', None), "dataset_mark": Field('dataset_mark', None),
"lineage_type": Field(None, None)
} }
...@@ -75,6 +76,7 @@ class LineageObj: ...@@ -75,6 +76,7 @@ class LineageObj:
_name_dataset_graph = 'dataset_graph' _name_dataset_graph = 'dataset_graph'
_name_dataset_mark = 'dataset_mark' _name_dataset_mark = 'dataset_mark'
_name_user_defined = 'user_defined' _name_user_defined = 'user_defined'
_name_model_lineage = 'model_lineage'
def __init__(self, summary_dir, **kwargs): def __init__(self, summary_dir, **kwargs):
self._lineage_info = { self._lineage_info = {
...@@ -227,15 +229,6 @@ class LineageObj: ...@@ -227,15 +229,6 @@ class LineageObj:
result[key] = getattr(self, key) result[key] = getattr(self, key)
return result return result
def to_filtration_dict(self):
"""
Returns the lineage information required by filtering interface.
Returns:
dict, the lineage information required by filtering interface.
"""
return self._filtration_result
def to_dataset_lineage_dict(self): def to_dataset_lineage_dict(self):
""" """
Returns the dataset part lineage information. Returns the dataset part lineage information.
...@@ -250,6 +243,22 @@ class LineageObj: ...@@ -250,6 +243,22 @@ class LineageObj:
return dataset_lineage return dataset_lineage
def to_model_lineage_dict(self):
"""
Returns the model part lineage information.
Returns:
dict, the model lineage information.
"""
filtration_result = dict(self._filtration_result)
filtration_result.pop(self._name_dataset_graph)
model_lineage = dict()
model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)})
model_lineage.update({self._name_model_lineage: filtration_result})
return model_lineage
def get_value_by_key(self, key): def get_value_by_key(self, key):
""" """
Get the value based on the key in `FIELD_MAPPING` or Get the value based on the key in `FIELD_MAPPING` or
......
...@@ -20,7 +20,6 @@ Usage: ...@@ -20,7 +20,6 @@ Usage:
The query module test should be run after lineagemgr/collection/model/test_model_lineage.py The query module test should be run after lineagemgr/collection/model/test_model_lineage.py
pytest lineagemgr pytest lineagemgr
""" """
import os import os
from unittest import TestCase from unittest import TestCase
...@@ -66,64 +65,70 @@ LINEAGE_INFO_RUN1 = { ...@@ -66,64 +65,70 @@ LINEAGE_INFO_RUN1 = {
} }
LINEAGE_FILTRATION_EXCEPT_RUN = { LINEAGE_FILTRATION_EXCEPT_RUN = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'model_lineage': {
'train_dataset_path': None, 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_count': 1024, 'train_dataset_path': None,
'user_defined': {}, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': None, 'test_dataset_count': None,
'network': 'ResNet', 'user_defined': {},
'optimizer': 'Momentum', 'network': 'ResNet',
'learning_rate': 0.11999999731779099, 'optimizer': 'Momentum',
'epoch': 10, 'learning_rate': 0.11999999731779099,
'batch_size': 32, 'epoch': 10,
'loss': 0.029999999329447746, 'batch_size': 32,
'model_size': 64, 'loss': 0.029999999329447746,
'metric': {}, 'model_size': 64,
'dataset_graph': DATASET_GRAPH, 'metric': {},
'dataset_mark': 2 'dataset_mark': 2
},
'dataset_graph': DATASET_GRAPH
} }
LINEAGE_FILTRATION_RUN1 = { LINEAGE_FILTRATION_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'model_lineage': {
'train_dataset_path': None, 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_count': 731, 'train_dataset_path': None,
'test_dataset_path': None, 'train_dataset_count': 731,
'user_defined': {}, 'test_dataset_path': None,
'test_dataset_count': 10240, 'test_dataset_count': 10240,
'network': 'ResNet', 'user_defined': {},
'optimizer': 'Momentum', 'network': 'ResNet',
'learning_rate': 0.11999999731779099, 'optimizer': 'Momentum',
'epoch': 14, 'learning_rate': 0.11999999731779099,
'batch_size': 32, 'epoch': 14,
'loss': None, 'batch_size': 32,
'model_size': 64, 'loss': None,
'metric': { 'model_size': 64,
'accuracy': 0.78 'metric': {
'accuracy': 0.78
},
'dataset_mark': 2
}, },
'dataset_graph': DATASET_GRAPH, 'dataset_graph': DATASET_GRAPH
'dataset_mark': 2
} }
LINEAGE_FILTRATION_RUN2 = { LINEAGE_FILTRATION_RUN2 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None, 'model_lineage': {
'train_dataset_path': None, 'loss_function': None,
'train_dataset_count': None, 'train_dataset_path': None,
'user_defined': {}, 'train_dataset_count': None,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 10240, 'test_dataset_count': 10240,
'network': None, 'user_defined': {},
'optimizer': None, 'network': None,
'learning_rate': None, 'optimizer': None,
'epoch': None, 'learning_rate': None,
'batch_size': None, 'epoch': None,
'loss': None, 'batch_size': None,
'model_size': None, 'loss': None,
'metric': { 'model_size': None,
'accuracy': 2.7800000000000002 'metric': {
'accuracy': 2.7800000000000002
},
'dataset_mark': 3
}, },
'dataset_graph': {}, 'dataset_graph': {}
'dataset_mark': 3
} }
...@@ -150,6 +155,14 @@ class TestModelApi(TestCase): ...@@ -150,6 +155,14 @@ class TestModelApi(TestCase):
cls.empty_dir = os.path.join(BASE_SUMMARY_DIR, 'empty_dir') cls.empty_dir = os.path.join(BASE_SUMMARY_DIR, 'empty_dir')
os.makedirs(cls.empty_dir) os.makedirs(cls.empty_dir)
def generate_lineage_object(self, lineage):
lineage = dict(lineage)
lineage_object = dict()
lineage_object.update({'summary_dir': lineage.pop('summary_dir')})
lineage_object.update({'dataset_graph': lineage.pop('dataset_graph')})
lineage_object.update({'model_lineage': lineage})
return lineage_object
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
...@@ -337,7 +350,7 @@ class TestModelApi(TestCase): ...@@ -337,7 +350,7 @@ class TestModelApi(TestCase):
res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition) res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res assert expect_result == res
expect_result = { expect_result = {
...@@ -347,7 +360,7 @@ class TestModelApi(TestCase): ...@@ -347,7 +360,7 @@ class TestModelApi(TestCase):
res = filter_summary_lineage(self.dir_with_empty_lineage) res = filter_summary_lineage(self.dir_with_empty_lineage)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res assert expect_result == res
@pytest.mark.level0 @pytest.mark.level0
...@@ -385,7 +398,7 @@ class TestModelApi(TestCase): ...@@ -385,7 +398,7 @@ class TestModelApi(TestCase):
partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition) partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res assert expect_result == partial_res
@pytest.mark.level0 @pytest.mark.level0
...@@ -423,7 +436,7 @@ class TestModelApi(TestCase): ...@@ -423,7 +436,7 @@ class TestModelApi(TestCase):
partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition) partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res assert expect_result == partial_res
@pytest.mark.level0 @pytest.mark.level0
...@@ -439,7 +452,6 @@ class TestModelApi(TestCase): ...@@ -439,7 +452,6 @@ class TestModelApi(TestCase):
'ge': 30 'ge': 30
}, },
'sorted_name': 'metric/accuracy', 'sorted_name': 'metric/accuracy',
'lineage_type': None
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0, 'customized': event_data.CUSTOMIZED__0,
...@@ -452,14 +464,16 @@ class TestModelApi(TestCase): ...@@ -452,14 +464,16 @@ class TestModelApi(TestCase):
partial_res1 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition1) partial_res1 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition1)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res1.get('object')): for idx, res_object in enumerate(partial_res1.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res1 assert expect_result == partial_res1
search_condition2 = { search_condition2 = {
'batch_size': { 'batch_size': {
'lt': 30 'lt': 30
}, },
'lineage_type': 'model' 'lineage_type': {
'eq': 'model'
},
} }
expect_result = { expect_result = {
'customized': {}, 'customized': {},
...@@ -469,7 +483,7 @@ class TestModelApi(TestCase): ...@@ -469,7 +483,7 @@ class TestModelApi(TestCase):
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res2.get('object')): for idx, res_object in enumerate(partial_res2.get('object')):
expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res2 assert expect_result == partial_res2
@pytest.mark.level0 @pytest.mark.level0
...@@ -485,7 +499,9 @@ class TestModelApi(TestCase): ...@@ -485,7 +499,9 @@ class TestModelApi(TestCase):
'summary_dir': { 'summary_dir': {
'in': [summary_dir] 'in': [summary_dir]
}, },
'lineage_type': 'dataset' 'lineage_type': {
'eq': 'dataset'
},
} }
expect_result = { expect_result = {
'customized': {}, 'customized': {},
...@@ -705,15 +721,29 @@ class TestModelApi(TestCase): ...@@ -705,15 +721,29 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
# the condition type not supported in summary dir
search_condition = { search_condition = {
'summary_dir': { 'lineage_type': {
'lt': '/xxx' 'in': [
'xxx'
]
} }
} }
self.assertRaisesRegex( self.assertRaisesRegex(
LineageParamSummaryPathError, LineageSearchConditionParamError,
'Invalid operation of summary dir.', "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.",
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
)
search_condition = {
'lineage_type': {
'eq': None
}
}
self.assertRaisesRegex(
LineageSearchConditionParamError,
"The parameter lineage_type is invalid. It should be 'dataset' or 'model'.",
filter_summary_lineage, filter_summary_lineage,
BASE_SUMMARY_DIR, BASE_SUMMARY_DIR,
search_condition search_condition
...@@ -779,3 +809,42 @@ class TestModelApi(TestCase): ...@@ -779,3 +809,42 @@ class TestModelApi(TestCase):
} }
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
assert expect_result == partial_res2 assert expect_result == partial_res2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_7(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
condition_keys = ["summary_dir", "lineage_type"]
for condition_key in condition_keys:
# the condition type not supported in summary_dir and lineage_type
search_condition = {
condition_key: {
'lt': '/xxx'
}
}
self.assertRaisesRegex(
LineageSearchConditionParamError,
f'Invalid operation of {condition_key}.',
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
)
# more than one operation in summary_dir and lineage_type
search_condition = {
condition_key: {
'in': ['/xxx', '/yyy'],
'eq': '/zzz',
}
}
self.assertRaisesRegex(
LineageSearchConditionParamError,
f'More than one operation of {condition_key}.',
filter_summary_lineage,
BASE_SUMMARY_DIR,
search_condition
)
...@@ -67,7 +67,7 @@ class TestSearchModel(TestCase): ...@@ -67,7 +67,7 @@ class TestSearchModel(TestCase):
"""Test init.""" """Test init."""
APP.response_class = Response APP.response_class = Response
self.app_client = APP.test_client() self.app_client = APP.test_client()
self.url = '/v1/mindinsight/models/model_lineage' self.url = '/v1/mindinsight/lineagemgr/lineages'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings') @mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage') @mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage')
...@@ -78,11 +78,11 @@ class TestSearchModel(TestCase): ...@@ -78,11 +78,11 @@ class TestSearchModel(TestCase):
'object': [ 'object': [
{ {
'summary_dir': base_dir, 'summary_dir': base_dir,
**LINEAGE_FILTRATION_BASE 'model_lineage': LINEAGE_FILTRATION_BASE
}, },
{ {
'summary_dir': os.path.join(base_dir, 'run1'), 'summary_dir': os.path.join(base_dir, 'run1'),
**LINEAGE_FILTRATION_RUN1 'model_lineage': LINEAGE_FILTRATION_RUN1
} }
], ],
'count': 2 'count': 2
...@@ -101,11 +101,11 @@ class TestSearchModel(TestCase): ...@@ -101,11 +101,11 @@ class TestSearchModel(TestCase):
'object': [ 'object': [
{ {
'summary_dir': './', 'summary_dir': './',
**LINEAGE_FILTRATION_BASE 'model_lineage': LINEAGE_FILTRATION_BASE
}, },
{ {
'summary_dir': './run1', 'summary_dir': './run1',
**LINEAGE_FILTRATION_RUN1 'model_lineage': LINEAGE_FILTRATION_RUN1
} }
], ],
'count': 2 'count': 2
......
...@@ -131,18 +131,6 @@ class TestModel(TestCase): ...@@ -131,18 +131,6 @@ class TestModel(TestCase):
self.assertDictEqual( self.assertDictEqual(
result, search_condition result, search_condition
) )
search_condition = {
'summary_dir': {
'gt': 3
}
}
self.assertRaisesRegex(
LineageParamValueError,
'Invalid operation of summary dir',
_convert_relative_path_to_abspath,
summary_base_dir,
search_condition
)
class TestFilterAPI(TestCase): class TestFilterAPI(TestCase):
......
...@@ -82,22 +82,24 @@ def create_filtration_result(summary_dir, train_event_dict, ...@@ -82,22 +82,24 @@ def create_filtration_result(summary_dir, train_event_dict,
""" """
filtration_result = { filtration_result = {
"summary_dir": summary_dir, "summary_dir": summary_dir,
"loss_function": train_event_dict['train_lineage']['hyper_parameters']['loss_function'], "model_lineage": {
"train_dataset_path": train_event_dict['train_lineage']['train_dataset']['train_dataset_path'], "loss_function": train_event_dict['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_count": train_event_dict['train_lineage']['train_dataset']['train_dataset_size'], "train_dataset_path": train_event_dict['train_lineage']['train_dataset']['train_dataset_path'],
"test_dataset_path": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_path'], "train_dataset_count": train_event_dict['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_count": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_size'], "test_dataset_path": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"network": train_event_dict['train_lineage']['algorithm']['network'], "test_dataset_count": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"optimizer": train_event_dict['train_lineage']['hyper_parameters']['optimizer'], "network": train_event_dict['train_lineage']['algorithm']['network'],
"learning_rate": train_event_dict['train_lineage']['hyper_parameters']['learning_rate'], "optimizer": train_event_dict['train_lineage']['hyper_parameters']['optimizer'],
"epoch": train_event_dict['train_lineage']['hyper_parameters']['epoch'], "learning_rate": train_event_dict['train_lineage']['hyper_parameters']['learning_rate'],
"batch_size": train_event_dict['train_lineage']['hyper_parameters']['batch_size'], "epoch": train_event_dict['train_lineage']['hyper_parameters']['epoch'],
"loss": train_event_dict['train_lineage']['algorithm']['loss'], "batch_size": train_event_dict['train_lineage']['hyper_parameters']['batch_size'],
"model_size": train_event_dict['train_lineage']['model']['size'], "loss": train_event_dict['train_lineage']['algorithm']['loss'],
"metric": metric_dict, "model_size": train_event_dict['train_lineage']['model']['size'],
"metric": metric_dict,
"dataset_mark": '2',
"user_defined": {}
},
"dataset_graph": dataset_dict, "dataset_graph": dataset_dict,
"dataset_mark": '2',
"user_defined": {}
} }
return filtration_result return filtration_result
...@@ -192,47 +194,50 @@ LINEAGE_FILTRATION_4 = create_filtration_result( ...@@ -192,47 +194,50 @@ LINEAGE_FILTRATION_4 = create_filtration_result(
) )
LINEAGE_FILTRATION_5 = { LINEAGE_FILTRATION_5 = {
"summary_dir": '/path/to/summary5', "summary_dir": '/path/to/summary5',
"loss_function": "model_lineage": {
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], "loss_function":
"train_dataset_path": None, event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_count": "train_dataset_path": None,
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], "train_dataset_count":
"test_dataset_path": None, event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_count": None, "test_dataset_path": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], "test_dataset_count": None,
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"learning_rate": "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], "learning_rate":
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"metric": {}, "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"dataset_graph": event_data.DATASET_DICT_0, "metric": {},
"dataset_mark": '2', "dataset_mark": '2',
"user_defined": {} "user_defined": {}
},
"dataset_graph": event_data.DATASET_DICT_0
} }
LINEAGE_FILTRATION_6 = { LINEAGE_FILTRATION_6 = {
"summary_dir": '/path/to/summary6', "summary_dir": '/path/to/summary6',
"loss_function": None, "model_lineage": {
"train_dataset_path": None, "loss_function": None,
"train_dataset_count": None, "train_dataset_path": None,
"test_dataset_path": "train_dataset_count": None,
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], "test_dataset_path":
"test_dataset_count": event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], "test_dataset_count":
"network": None, event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"optimizer": None, "network": None,
"learning_rate": None, "optimizer": None,
"epoch": None, "learning_rate": None,
"batch_size": None, "epoch": None,
"loss": None, "batch_size": None,
"model_size": None, "loss": None,
"metric": event_data.METRIC_5, "model_size": None,
"dataset_graph": event_data.DATASET_DICT_0, "metric": event_data.METRIC_5,
"dataset_mark": '2', "dataset_mark": '2',
"user_defined": {} "user_defined": {}
},
"dataset_graph": event_data.DATASET_DICT_0
} }
......
...@@ -108,8 +108,8 @@ class TestLineageObj(TestCase): ...@@ -108,8 +108,8 @@ class TestLineageObj(TestCase):
result = self.lineage_obj.get_summary_info(filter_keys) result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result) self.assertDictEqual(expected_result, result)
def test_to_filtration_dict(self): def test_to_model_lineage_dict(self):
"""Test the function of to_filtration_dict.""" """Test the function of to_model_lineage_dict."""
expected_result = create_filtration_result( expected_result = create_filtration_result(
self.summary_dir, self.summary_dir,
event_data.EVENT_TRAIN_DICT_0, event_data.EVENT_TRAIN_DICT_0,
...@@ -117,8 +117,18 @@ class TestLineageObj(TestCase): ...@@ -117,8 +117,18 @@ class TestLineageObj(TestCase):
event_data.METRIC_0, event_data.METRIC_0,
event_data.DATASET_DICT_0 event_data.DATASET_DICT_0
) )
expected_result['dataset_mark'] = None expected_result['model_lineage']['dataset_mark'] = None
result = self.lineage_obj.to_filtration_dict() expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result)
def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict."""
expected_result = {
"summary_dir": self.summary_dir,
"dataset_graph": event_data.DATASET_DICT_0
}
result = self.lineage_obj.to_dataset_lineage_dict()
self.assertDictEqual(expected_result, result) self.assertDictEqual(expected_result, result)
def test_get_value_by_key(self): def test_get_value_by_key(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册