Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
f83eadc9
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f83eadc9
编写于
6月 01, 2020
作者:
L
luopengting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance float cmp in tests.lineagemgr, fix probabilistic failure in st
上级
56e27233
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
112 addition
and
117 deletion
+112
-117
tests/st/func/lineagemgr/api/test_model_api.py
tests/st/func/lineagemgr/api/test_model_api.py
+22
-20
tests/st/func/lineagemgr/cache/test_lineage_cache.py
tests/st/func/lineagemgr/cache/test_lineage_cache.py
+4
-5
tests/st/func/lineagemgr/collection/model/test_model_lineage.py
...st/func/lineagemgr/collection/model/test_model_lineage.py
+9
-29
tests/ut/lineagemgr/querier/event_data.py
tests/ut/lineagemgr/querier/event_data.py
+2
-1
tests/ut/lineagemgr/querier/test_querier.py
tests/ut/lineagemgr/querier/test_querier.py
+16
-30
tests/ut/lineagemgr/querier/test_query_model.py
tests/ut/lineagemgr/querier/test_query_model.py
+37
-28
tests/utils/tools.py
tests/utils/tools.py
+22
-4
未找到文件。
tests/st/func/lineagemgr/api/test_model_api.py
浏览文件 @
f83eadc9
...
...
@@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF
LineageSearchConditionParamError
)
from
..conftest
import
BASE_SUMMARY_DIR
,
DATASET_GRAPH
,
SUMMARY_DIR
,
SUMMARY_DIR_2
from
.....ut.lineagemgr.querier
import
event_data
from
.....utils.tools
import
assert_equal_lineages
LINEAGE_INFO_RUN1
=
{
'summary_dir'
:
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'run1'
),
...
...
@@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = {
},
'hyper_parameters'
:
{
'optimizer'
:
'Momentum'
,
'learning_rate'
:
0.1
1999999731779099
,
'learning_rate'
:
0.1
2
,
'loss_function'
:
'SoftmaxCrossEntropyWithLogits'
,
'epoch'
:
14
,
'parallel_mode'
:
'stand_alone'
,
...
...
@@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = {
'user_defined'
:
{},
'network'
:
'ResNet'
,
'optimizer'
:
'Momentum'
,
'learning_rate'
:
0.1
1999999731779099
,
'learning_rate'
:
0.1
2
,
'epoch'
:
10
,
'batch_size'
:
32
,
'device_num'
:
2
,
'loss'
:
0.0
29999999329447746
,
'loss'
:
0.0
3
,
'model_size'
:
64
,
'metric'
:
{},
'dataset_mark'
:
2
...
...
@@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count'
:
1024
,
'test_dataset_path'
:
None
,
'test_dataset_count'
:
1024
,
'user_defined'
:
{
'info'
:
'info1'
,
'version'
:
'v1'
},
'user_defined'
:
{
'info'
:
'info1'
,
'version'
:
'v1'
,
'eval_version'
:
'version2'
},
'network'
:
'ResNet'
,
'optimizer'
:
'Momentum'
,
'learning_rate'
:
0.1
1999999731779099
,
'learning_rate'
:
0.1
2
,
'epoch'
:
14
,
'batch_size'
:
32
,
'device_num'
:
2
,
...
...
@@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = {
'user_defined'
:
{},
'network'
:
"ResNet"
,
'optimizer'
:
"Momentum"
,
'learning_rate'
:
0.1
1999999731779099
,
'learning_rate'
:
0.1
2
,
'epoch'
:
10
,
'batch_size'
:
32
,
'device_num'
:
2
,
'loss'
:
0.0
29999999329447746
,
'loss'
:
0.0
3
,
'model_size'
:
10
,
'metric'
:
{
'accuracy'
:
2.78
00000000000002
'accuracy'
:
2.78
},
'dataset_mark'
:
3
},
...
...
@@ -173,7 +178,7 @@ class TestModelApi(TestCase):
'summary_dir'
:
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'run1'
),
'hyper_parameters'
:
{
'optimizer'
:
'Momentum'
,
'learning_rate'
:
0.1
1999999731779099
,
'learning_rate'
:
0.1
2
,
'loss_function'
:
'SoftmaxCrossEntropyWithLogits'
,
'epoch'
:
14
,
'parallel_mode'
:
'stand_alone'
,
...
...
@@ -190,9 +195,9 @@ class TestModelApi(TestCase):
'network'
:
'ResNet'
}
}
assert
expect_total_res
==
total_res
assert
expect_partial_res1
==
partial_res1
assert
expect_partial_res2
==
partial_res2
assert
_equal_lineages
(
expect_total_res
,
total_res
,
self
.
assertDictEqual
)
assert
_equal_lineages
(
expect_partial_res1
,
partial_res1
,
self
.
assertDictEqual
)
assert
_equal_lineages
(
expect_partial_res2
,
partial_res2
,
self
.
assertDictEqual
)
# the lineage summary file is empty
result
=
get_summary_lineage
(
self
.
dir_with_empty_lineage
)
...
...
@@ -345,7 +350,7 @@ class TestModelApi(TestCase):
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
res
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
res
assert
_equal_lineages
(
expect_result
,
res
,
self
.
assertDictEqual
)
expect_result
=
{
'customized'
:
{},
...
...
@@ -356,7 +361,7 @@ class TestModelApi(TestCase):
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
res
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
res
assert
_equal_lineages
(
expect_result
,
res
,
self
.
assertDictEqual
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -394,7 +399,7 @@ class TestModelApi(TestCase):
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
partial_res
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
partial_res
assert
_equal_lineages
(
expect_result
,
partial_res
,
self
.
assertDictEqual
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -432,7 +437,7 @@ class TestModelApi(TestCase):
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
partial_res
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
partial_res
assert
_equal_lineages
(
expect_result
,
partial_res
,
self
.
assertDictEqual
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -461,7 +466,7 @@ class TestModelApi(TestCase):
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
partial_res1
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
partial_res1
assert
_equal_lineages
(
expect_result
,
partial_res1
,
self
.
assertDictEqual
)
search_condition2
=
{
'batch_size'
:
{
...
...
@@ -477,9 +482,6 @@ class TestModelApi(TestCase):
'count'
:
0
}
partial_res2
=
filter_summary_lineage
(
BASE_SUMMARY_DIR
,
search_condition2
)
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
partial_res2
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
partial_res2
@
pytest
.
mark
.
level0
...
...
tests/st/func/lineagemgr/cache/test_lineage_cache.py
浏览文件 @
f83eadc9
...
...
@@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU
LINEAGE_FILTRATION_RUN1
,
LINEAGE_FILTRATION_RUN2
from
..conftest
import
BASE_SUMMARY_DIR
from
.....ut.lineagemgr.querier
import
event_data
from
.....utils.tools
import
check_loading_done
from
.....utils.tools
import
check_loading_done
,
assert_equal_lineages
@
pytest
.
mark
.
usefixtures
(
"create_summary_dir"
)
...
...
@@ -58,8 +58,7 @@ class TestModelApi(TestCase):
"""Test the interface of get_summary_lineage."""
total_res
=
general_get_summary_lineage
(
data_manager
=
self
.
_data_manger
,
summary_dir
=
"./run1"
)
expect_total_res
=
LINEAGE_INFO_RUN1
assert
expect_total_res
==
total_res
assert_equal_lineages
(
expect_total_res
,
total_res
,
self
.
assertDictEqual
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -86,7 +85,7 @@ class TestModelApi(TestCase):
expect_objects
=
expect_result
.
get
(
'object'
)
for
idx
,
res_object
in
enumerate
(
res
.
get
(
'object'
)):
expect_objects
[
idx
][
'model_lineage'
][
'dataset_mark'
]
=
res_object
[
'model_lineage'
].
get
(
'dataset_mark'
)
assert
expect_result
==
res
assert
_equal_lineages
(
expect_result
,
res
,
self
.
assertDictEqual
)
expect_result
=
{
'customized'
:
{},
...
...
@@ -100,4 +99,4 @@ class TestModelApi(TestCase):
}
}
res
=
general_filter_summary_lineage
(
data_manager
=
self
.
_data_manger
,
search_condition
=
search_condition
)
assert
expect_result
==
res
assert
_equal_lineages
(
expect_result
,
res
,
self
.
assertDictEqual
)
tests/st/func/lineagemgr/collection/model/test_model_lineage.py
浏览文件 @
f83eadc9
...
...
@@ -73,6 +73,10 @@ class TestModelLineage(TestCase):
TrainLineage
(
cls
.
summary_record
)
]
cls
.
run_context
[
'list_callback'
]
=
_ListCallback
(
callback
)
cls
.
user_defined_info
=
{
"info"
:
"info1"
,
"version"
:
"v1"
}
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
...
...
@@ -83,7 +87,7 @@ class TestModelLineage(TestCase):
@
pytest
.
mark
.
env_single
def
test_train_begin
(
self
):
"""Test the begin function in TrainLineage."""
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
)
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
,
self
.
user_defined_info
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
assert
train_callback
.
initial_learning_rate
==
0.12
lineage_log_path
=
train_callback
.
lineage_summary
.
lineage_log_path
...
...
@@ -98,30 +102,6 @@ class TestModelLineage(TestCase):
@
pytest
.
mark
.
env_single
def
test_train_begin_with_user_defined_info
(
self
):
"""Test TrainLineage with nested user defined info."""
user_defined_info
=
{
"info"
:
{
"version"
:
"v1"
}}
train_callback
=
TrainLineage
(
self
.
summary_record
,
False
,
user_defined_info
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
assert
train_callback
.
initial_learning_rate
==
0.12
lineage_log_path
=
train_callback
.
lineage_summary
.
lineage_log_path
assert
os
.
path
.
isfile
(
lineage_log_path
)
is
True
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_train_begin_with_user_defined_key_in_lineage
(
self
):
"""Test TrainLineage with nested user defined info."""
expected_res
=
{
"info"
:
"info1"
,
"version"
:
"v1"
}
user_defined_info
=
{
"info"
:
"info1"
,
"version"
:
"v1"
,
...
...
@@ -137,7 +117,7 @@ class TestModelLineage(TestCase):
lineage_log_path
=
train_callback
.
lineage_summary
.
lineage_log_path
assert
os
.
path
.
isfile
(
lineage_log_path
)
is
True
res
=
filter_summary_lineage
(
os
.
path
.
dirname
(
lineage_log_path
))
assert
expected_res
==
res
[
'object'
][
0
][
'model_lineage'
][
'user_defined'
]
assert
self
.
user_defined_info
==
res
[
'object'
][
0
][
'model_lineage'
][
'user_defined'
]
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
...
...
@@ -168,7 +148,7 @@ class TestModelLineage(TestCase):
def
test_training_end
(
self
,
*
args
):
"""Test the end function in TrainLineage."""
args
[
0
].
return_value
=
64
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
)
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
,
self
.
user_defined_info
)
train_callback
.
initial_learning_rate
=
0.12
train_callback
.
end
(
RunContext
(
self
.
run_context
))
res
=
get_summary_lineage
(
SUMMARY_DIR
)
...
...
@@ -188,7 +168,7 @@ class TestModelLineage(TestCase):
@
pytest
.
mark
.
env_single
def
test_eval_end
(
self
):
"""Test the end function in EvalLineage."""
eval_callback
=
EvalLineage
(
self
.
summary_record
,
True
)
eval_callback
=
EvalLineage
(
self
.
summary_record
,
True
,
{
'eval_version'
:
'version2'
}
)
eval_run_context
=
self
.
run_context
eval_run_context
[
'metrics'
]
=
{
'accuracy'
:
0.78
}
eval_run_context
[
'valid_dataset'
]
=
self
.
run_context
[
'train_dataset'
]
...
...
@@ -361,7 +341,7 @@ class TestModelLineage(TestCase):
def
test_train_with_customized_network
(
self
,
*
args
):
"""Test train with customized network."""
args
[
0
].
return_value
=
64
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
)
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
,
self
.
user_defined_info
)
run_context_customized
=
self
.
run_context
del
run_context_customized
[
'optimizer'
]
del
run_context_customized
[
'net_outputs'
]
...
...
tests/ut/lineagemgr/querier/event_data.py
浏览文件 @
f83eadc9
...
...
@@ -195,7 +195,8 @@ CUSTOMIZED__0 = {
CUSTOMIZED__1
=
{
**
CUSTOMIZED__0
,
'user_defined/info'
:
{
'label'
:
'user_defined/info'
,
'required'
:
False
,
'type'
:
'str'
},
'user_defined/version'
:
{
'label'
:
'user_defined/version'
,
'required'
:
False
,
'type'
:
'str'
}
'user_defined/version'
:
{
'label'
:
'user_defined/version'
,
'required'
:
False
,
'type'
:
'str'
},
'user_defined/eval_version'
:
{
'label'
:
'user_defined/eval_version'
,
'required'
:
False
,
'type'
:
'str'
}
}
CUSTOMIZED_0
=
{
...
...
tests/ut/lineagemgr/querier/test_querier.py
浏览文件 @
f83eadc9
...
...
@@ -27,7 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from
mindinsight.lineagemgr.summary.lineage_summary_analyzer
import
LineageInfo
from
.
import
event_data
from
....utils.tools
import
deal_float_for_dict
from
....utils.tools
import
assert_equal_lineages
def
create_lineage_info
(
train_event_dict
,
eval_event_dict
,
dataset_event_dict
):
...
...
@@ -282,31 +282,17 @@ class TestQuerier(TestCase):
lineage_objects
=
LineageOrganizer
(
summary_base_dir
=
summary_base_dir
).
super_lineage_objs
self
.
multi_querier
=
Querier
(
lineage_objects
)
def
_deal_float_for_list
(
self
,
list1
,
list2
):
index
=
0
for
_
in
list1
:
deal_float_for_dict
(
list1
[
index
],
list2
[
index
])
index
+=
1
def
_assert_list_equal
(
self
,
list1
,
list2
):
self
.
_deal_float_for_list
(
list1
,
list2
)
self
.
assertListEqual
(
list1
,
list2
)
def
_assert_lineages_equal
(
self
,
lineages1
,
lineages2
):
self
.
_deal_float_for_list
(
lineages1
[
'object'
],
lineages2
[
'object'
])
self
.
assertDictEqual
(
lineages1
,
lineages2
)
def
test_get_summary_lineage_success_1
(
self
):
"""Test the success of get_summary_lineage."""
expected_result
=
[
LINEAGE_INFO_0
]
result
=
self
.
single_querier
.
get_summary_lineage
()
self
.
_assert_list_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertListEqual
)
def
test_get_summary_lineage_success_2
(
self
):
"""Test the success of get_summary_lineage."""
expected_result
=
[
LINEAGE_INFO_0
]
result
=
self
.
single_querier
.
get_summary_lineage
()
self
.
_assert_list_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertListEqual
)
def
test_get_summary_lineage_success_3
(
self
):
"""Test the success of get_summary_lineage."""
...
...
@@ -320,7 +306,7 @@ class TestQuerier(TestCase):
result
=
self
.
single_querier
.
get_summary_lineage
(
filter_keys
=
[
'model'
,
'algorithm'
]
)
self
.
_assert_list_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertListEqual
)
def
test_get_summary_lineage_success_4
(
self
):
"""Test the success of get_summary_lineage."""
...
...
@@ -367,7 +353,7 @@ class TestQuerier(TestCase):
}
]
result
=
self
.
multi_querier
.
get_summary_lineage
()
self
.
_assert_list_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertListEqual
)
def
test_get_summary_lineage_success_5
(
self
):
"""Test the success of get_summary_lineage."""
...
...
@@ -375,7 +361,7 @@ class TestQuerier(TestCase):
result
=
self
.
multi_querier
.
get_summary_lineage
(
summary_dir
=
'/path/to/summary1'
)
self
.
_assert_list_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertListEqual
)
def
test_get_summary_lineage_success_6
(
self
):
"""Test the success of get_summary_lineage."""
...
...
@@ -394,7 +380,7 @@ class TestQuerier(TestCase):
result
=
self
.
multi_querier
.
get_summary_lineage
(
summary_dir
=
'/path/to/summary0'
,
filter_keys
=
filter_keys
)
self
.
_assert_list_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertListEqual
)
def
test_get_summary_lineage_fail
(
self
):
"""Test the function of get_summary_lineage with exception."""
...
...
@@ -437,7 +423,7 @@ class TestQuerier(TestCase):
'count'
:
2
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_2
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -462,7 +448,7 @@ class TestQuerier(TestCase):
'count'
:
2
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_3
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -479,7 +465,7 @@ class TestQuerier(TestCase):
'count'
:
7
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_4
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -497,7 +483,7 @@ class TestQuerier(TestCase):
'count'
:
7
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
()
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_5
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -512,7 +498,7 @@ class TestQuerier(TestCase):
'count'
:
1
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_6
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -534,7 +520,7 @@ class TestQuerier(TestCase):
'count'
:
7
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_7
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -556,7 +542,7 @@ class TestQuerier(TestCase):
'count'
:
7
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_8
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -572,7 +558,7 @@ class TestQuerier(TestCase):
'count'
:
1
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_success_9
(
self
):
"""Test the success of filter_summary_lineage."""
...
...
@@ -586,7 +572,7 @@ class TestQuerier(TestCase):
'count'
:
7
,
}
result
=
self
.
multi_querier
.
filter_summary_lineage
(
condition
=
condition
)
self
.
_assert_lineages_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_filter_summary_lineage_fail
(
self
):
"""Test the function of filter_summary_lineage with exception."""
...
...
tests/ut/lineagemgr/querier/test_query_model.py
浏览文件 @
f83eadc9
...
...
@@ -21,7 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj
from
.
import
event_data
from
.test_querier
import
create_filtration_result
,
create_lineage_info
from
....utils.tools
import
deal_float_for_dict
from
....utils.tools
import
assert_equal_lineages
class
TestLineageObj
(
TestCase
):
...
...
@@ -51,56 +51,65 @@ class TestLineageObj(TestCase):
evaluation_lineage
=
lineage_info
.
eval_lineage
)
def
_assert_dict_equal
(
self
,
dict1
,
dict2
):
deal_float_for_dict
(
dict1
,
dict2
)
self
.
assertDictEqual
(
dict1
,
dict2
)
def
test_property
(
self
):
"""Test the function of getting property."""
self
.
assertEqual
(
self
.
summary_dir
,
self
.
lineage_obj
.
summary_dir
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'algorithm'
],
self
.
lineage_obj
.
algorithm
self
.
lineage_obj
.
algorithm
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'model'
],
self
.
lineage_obj
.
model
self
.
lineage_obj
.
model
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'train_dataset'
],
self
.
lineage_obj
.
train_dataset
self
.
lineage_obj
.
train_dataset
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'hyper_parameters'
],
self
.
lineage_obj
.
hyper_parameters
self
.
lineage_obj
.
hyper_parameters
,
self
.
assertDictEqual
)
assert_equal_lineages
(
event_data
.
METRIC_0
,
self
.
lineage_obj
.
metric
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
event_data
.
METRIC_0
,
self
.
lineage_obj
.
metric
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_EVAL_DICT_0
[
'evaluation_lineage'
][
'valid_dataset'
],
self
.
lineage_obj
.
valid_dataset
self
.
lineage_obj
.
valid_dataset
,
self
.
assertDictEqual
)
def
test_property_eval_not_exist
(
self
):
"""Test the function of getting property with no evaluation event."""
self
.
assertEqual
(
self
.
summary_dir
,
self
.
lineage_obj
.
summary_dir
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'algorithm'
],
self
.
lineage_obj_no_eval
.
algorithm
self
.
lineage_obj_no_eval
.
algorithm
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'model'
],
self
.
lineage_obj_no_eval
.
model
self
.
lineage_obj_no_eval
.
model
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'train_dataset'
],
self
.
lineage_obj_no_eval
.
train_dataset
self
.
lineage_obj_no_eval
.
train_dataset
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
(
assert_equal_lineages
(
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'hyper_parameters'
],
self
.
lineage_obj_no_eval
.
hyper_parameters
self
.
lineage_obj_no_eval
.
hyper_parameters
,
self
.
assertDictEqual
)
self
.
_assert_dict_equal
({},
self
.
lineage_obj_no_eval
.
metric
)
self
.
_assert_dict_equal
({},
self
.
lineage_obj_no_eval
.
valid_dataset
)
assert_equal_lineages
({},
self
.
lineage_obj_no_eval
.
metric
,
self
.
assertDictEqual
)
assert_equal_lineages
({},
self
.
lineage_obj_no_eval
.
valid_dataset
,
self
.
assertDictEqual
)
def
test_get_summary_info
(
self
):
"""Test the function of get_summary_info."""
...
...
@@ -111,7 +120,7 @@ class TestLineageObj(TestCase):
'model'
:
event_data
.
EVENT_TRAIN_DICT_0
[
'train_lineage'
][
'model'
]
}
result
=
self
.
lineage_obj
.
get_summary_info
(
filter_keys
)
self
.
_assert_dict_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_to_model_lineage_dict
(
self
):
"""Test the function of to_model_lineage_dict."""
...
...
@@ -125,7 +134,7 @@ class TestLineageObj(TestCase):
expected_result
[
'model_lineage'
][
'dataset_mark'
]
=
None
expected_result
.
pop
(
'dataset_graph'
)
result
=
self
.
lineage_obj
.
to_model_lineage_dict
()
self
.
_assert_dict_equal
(
expected_result
,
result
)
assert_equal_lineages
(
expected_result
,
result
,
self
.
assertDictEqual
)
def
test_to_dataset_lineage_dict
(
self
):
"""Test the function of to_dataset_lineage_dict."""
...
...
tests/utils/tools.py
浏览文件 @
f83eadc9
...
...
@@ -83,9 +83,9 @@ def compare_result_with_file(result, expected_file_path):
assert
result
==
expected_results
def
deal_float_for_dict
(
res
:
dict
,
expected_res
:
dict
):
def
deal_float_for_dict
(
res
:
dict
,
expected_res
:
dict
,
decimal_num
=
5
):
"""
Deal float rounded to
five
decimals in dict.
Deal float rounded to
specified
decimals in dict.
For example:
res:{
...
...
@@ -125,10 +125,9 @@ def deal_float_for_dict(res: dict, expected_res: dict):
"metric": {"acc": 0.1234562}
}
}
decimal_num (int): decimal rounded digits.
"""
decimal_num
=
5
for
key
in
res
:
value
=
res
[
key
]
expected_value
=
expected_res
[
key
]
...
...
@@ -137,3 +136,22 @@ def deal_float_for_dict(res: dict, expected_res: dict):
elif
isinstance
(
value
,
float
):
res
[
key
]
=
round
(
value
,
decimal_num
)
expected_res
[
key
]
=
round
(
expected_value
,
decimal_num
)
def
_deal_float_for_list
(
list1
,
list2
,
decimal_num
):
"""Deal float for list1 and list2."""
index
=
0
for
_
in
list1
:
deal_float_for_dict
(
list1
[
index
],
list2
[
index
],
decimal_num
)
index
+=
1
def
assert_equal_lineages
(
lineages1
,
lineages2
,
assert_func
,
decimal_num
=
2
):
"""Assert lineages."""
if
isinstance
(
lineages1
,
list
)
and
isinstance
(
lineages2
,
list
):
_deal_float_for_list
(
lineages1
,
lineages2
,
decimal_num
)
elif
lineages1
.
get
(
'object'
)
is
not
None
and
lineages2
.
get
(
'object'
)
is
not
None
:
_deal_float_for_list
(
lineages1
[
'object'
],
lineages2
[
'object'
],
decimal_num
)
else
:
deal_float_for_dict
(
lineages1
,
lineages2
,
decimal_num
)
assert_func
(
lineages1
,
lineages2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录