Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
8ea4d18c
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8ea4d18c
编写于
3月 30, 2020
作者:
O
ougongchang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extract the common function methods and reduced cyclomatic complexity of functions
上级
b91233a9
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
86 addition
and
95 deletion
+86
-95
mindinsight/datavisual/data_transform/graph/graph.py
mindinsight/datavisual/data_transform/graph/graph.py
+48
-51
mindinsight/datavisual/data_transform/graph/msgraph.py
mindinsight/datavisual/data_transform/graph/msgraph.py
+4
-4
mindinsight/datavisual/data_transform/graph/node.py
mindinsight/datavisual/data_transform/graph/node.py
+1
-1
tests/st/func/datavisual/graph/test_query_nodes_restful_api.py
.../st/func/datavisual/graph/test_query_nodes_restful_api.py
+4
-9
tests/st/func/datavisual/graph/test_query_single_nodes_restful_api.py
...c/datavisual/graph/test_query_single_nodes_restful_api.py
+3
-9
tests/st/func/datavisual/graph/test_search_nodes_restful_api.py
...st/func/datavisual/graph/test_search_nodes_restful_api.py
+4
-9
tests/ut/datavisual/processors/test_graph_processor.py
tests/ut/datavisual/processors/test_graph_processor.py
+10
-11
tests/utils/tools.py
tests/utils/tools.py
+12
-1
未找到文件。
mindinsight/datavisual/data_transform/graph/graph.py
浏览文件 @
8ea4d18c
...
@@ -18,19 +18,21 @@ This file is used to define the basic graph.
...
@@ -18,19 +18,21 @@ This file is used to define the basic graph.
import
copy
import
copy
import
time
import
time
from
enum
import
Enum
from
mindinsight.datavisual.common.log
import
logger
from
mindinsight.datavisual.common.log
import
logger
from
mindinsight.datavisual.common
import
exceptions
from
mindinsight.datavisual.common
import
exceptions
from
.node
import
NodeTypeEnum
from
.node
import
NodeTypeEnum
from
.node
import
Node
from
.node
import
Node
class
EdgeTypeEnum
:
class
EdgeTypeEnum
(
Enum
)
:
"""Node edge type enum."""
"""Node edge type enum."""
control
=
'control'
CONTROL
=
'control'
data
=
'data'
DATA
=
'data'
class
DataTypeEnum
:
class
DataTypeEnum
(
Enum
)
:
"""Data type enum."""
"""Data type enum."""
DT_TENSOR
=
13
DT_TENSOR
=
13
...
@@ -292,70 +294,65 @@ class Graph:
...
@@ -292,70 +294,65 @@ class Graph:
output_attr
[
'scope'
]
=
NodeTypeEnum
.
POLYMERIC_SCOPE
.
value
output_attr
[
'scope'
]
=
NodeTypeEnum
.
POLYMERIC_SCOPE
.
value
node
.
update_output
({
dst_name
:
output_attr
})
node
.
update_output
({
dst_name
:
output_attr
})
def
_
calc
_polymeric_input_output
(
self
):
def
_
update
_polymeric_input_output
(
self
):
"""Calc polymeric input and output after build polymeric node."""
"""Calc polymeric input and output after build polymeric node."""
for
name
,
node
in
self
.
_normal_nodes
.
items
():
for
node
in
self
.
_normal_nodes
.
values
():
polymeric_input
=
{}
polymeric_input
=
self
.
_calc_polymeric_attr
(
node
,
'input'
)
for
src_name
in
node
.
input
:
src_node
=
self
.
_polymeric_nodes
.
get
(
src_name
)
if
node
.
node_type
==
NodeTypeEnum
.
POLYMERIC_SCOPE
.
value
:
src_name
=
src_name
if
not
src_node
else
src_node
.
polymeric_scope_name
output_name
=
self
.
_calc_dummy_node_name
(
name
,
src_name
)
polymeric_input
.
update
({
output_name
:
{
'edge_type'
:
EdgeTypeEnum
.
data
}})
continue
if
not
src_node
:
continue
if
not
node
.
name_scope
and
src_node
.
name_scope
:
# if current node is in first layer, and the src node is not in
# the first layer, the src node will not be the polymeric input of current node.
continue
if
node
.
name_scope
==
src_node
.
name_scope
\
or
node
.
name_scope
.
startswith
(
src_node
.
name_scope
):
polymeric_input
.
update
(
{
src_node
.
polymeric_scope_name
:
{
'edge_type'
:
EdgeTypeEnum
.
data
}})
node
.
update_polymeric_input
(
polymeric_input
)
node
.
update_polymeric_input
(
polymeric_input
)
polymeric_output
=
{}
polymeric_output
=
self
.
_calc_polymeric_attr
(
node
,
'output'
)
for
dst_name
in
node
.
output
:
dst_node
=
self
.
_polymeric_nodes
.
get
(
dst_name
)
if
node
.
node_type
==
NodeTypeEnum
.
POLYMERIC_SCOPE
.
value
:
dst_name
=
dst_name
if
not
dst_node
else
dst_node
.
polymeric_scope_name
output_name
=
self
.
_calc_dummy_node_name
(
name
,
dst_name
)
polymeric_output
.
update
({
output_name
:
{
'edge_type'
:
EdgeTypeEnum
.
data
}})
continue
if
not
dst_node
:
continue
if
not
node
.
name_scope
and
dst_node
.
name_scope
:
continue
if
node
.
name_scope
==
dst_node
.
name_scope
\
or
node
.
name_scope
.
startswith
(
dst_node
.
name_scope
):
polymeric_output
.
update
(
{
dst_node
.
polymeric_scope_name
:
{
'edge_type'
:
EdgeTypeEnum
.
data
}})
node
.
update_polymeric_output
(
polymeric_output
)
node
.
update_polymeric_output
(
polymeric_output
)
for
name
,
node
in
self
.
_polymeric_nodes
.
items
():
for
name
,
node
in
self
.
_polymeric_nodes
.
items
():
polymeric_input
=
{}
polymeric_input
=
{}
for
src_name
in
node
.
input
:
for
src_name
in
node
.
input
:
output_name
=
self
.
_calc_dummy_node_name
(
name
,
src_name
)
output_name
=
self
.
_calc_dummy_node_name
(
name
,
src_name
)
polymeric_input
.
update
({
output_name
:
{
'edge_type'
:
EdgeTypeEnum
.
data
}})
polymeric_input
.
update
({
output_name
:
{
'edge_type'
:
EdgeTypeEnum
.
DATA
.
value
}})
node
.
update_polymeric_input
(
polymeric_input
)
node
.
update_polymeric_input
(
polymeric_input
)
polymeric_output
=
{}
polymeric_output
=
{}
for
dst_name
in
node
.
output
:
for
dst_name
in
node
.
output
:
polymeric_output
=
{}
polymeric_output
=
{}
output_name
=
self
.
_calc_dummy_node_name
(
name
,
dst_name
)
output_name
=
self
.
_calc_dummy_node_name
(
name
,
dst_name
)
polymeric_output
.
update
({
output_name
:
{
'edge_type'
:
EdgeTypeEnum
.
data
}})
polymeric_output
.
update
({
output_name
:
{
'edge_type'
:
EdgeTypeEnum
.
DATA
.
value
}})
node
.
update_polymeric_output
(
polymeric_output
)
node
.
update_polymeric_output
(
polymeric_output
)
def
_calc_polymeric_attr
(
self
,
node
,
attr
):
"""
Calc polymeric input or polymeric output after build polymeric node.
Args:
node (Node): Computes the polymeric input for a given node.
attr (str): The polymeric attr, optional value is `input` or `output`.
Returns:
dict, return polymeric input or polymeric output of the given node.
"""
polymeric_attr
=
{}
for
node_name
in
getattr
(
node
,
attr
):
polymeric_node
=
self
.
_polymeric_nodes
.
get
(
node_name
)
if
node
.
node_type
==
NodeTypeEnum
.
POLYMERIC_SCOPE
.
value
:
node_name
=
node_name
if
not
polymeric_node
else
polymeric_node
.
polymeric_scope_name
dummy_node_name
=
self
.
_calc_dummy_node_name
(
node
.
name
,
node_name
)
polymeric_attr
.
update
({
dummy_node_name
:
{
'edge_type'
:
EdgeTypeEnum
.
DATA
.
value
}})
continue
if
not
polymeric_node
:
continue
if
not
node
.
name_scope
and
polymeric_node
.
name_scope
:
# If current node is in top-level layer, and the polymeric_node node is not in
# the top-level layer, the polymeric node will not be the polymeric input
# or polymeric output of current node.
continue
if
node
.
name_scope
==
polymeric_node
.
name_scope
\
or
node
.
name_scope
.
startswith
(
polymeric_node
.
name_scope
+
'/'
):
polymeric_attr
.
update
(
{
polymeric_node
.
polymeric_scope_name
:
{
'edge_type'
:
EdgeTypeEnum
.
DATA
.
value
}})
return
polymeric_attr
def
_calc_dummy_node_name
(
self
,
current_node_name
,
other_node_name
):
def
_calc_dummy_node_name
(
self
,
current_node_name
,
other_node_name
):
"""
"""
Calc dummy node name.
Calc dummy node name.
...
...
mindinsight/datavisual/data_transform/graph/msgraph.py
浏览文件 @
8ea4d18c
...
@@ -39,7 +39,7 @@ class MSGraph(Graph):
...
@@ -39,7 +39,7 @@ class MSGraph(Graph):
self
.
_build_leaf_nodes
(
graph_proto
)
self
.
_build_leaf_nodes
(
graph_proto
)
self
.
_build_polymeric_nodes
()
self
.
_build_polymeric_nodes
()
self
.
_build_name_scope_nodes
()
self
.
_build_name_scope_nodes
()
self
.
_
calc
_polymeric_input_output
()
self
.
_
update
_polymeric_input_output
()
logger
.
info
(
"Build graph end, normal node count: %s, polymeric node "
logger
.
info
(
"Build graph end, normal node count: %s, polymeric node "
"count: %s."
,
len
(
self
.
_normal_nodes
),
len
(
self
.
_polymeric_nodes
))
"count: %s."
,
len
(
self
.
_normal_nodes
),
len
(
self
.
_polymeric_nodes
))
...
@@ -90,9 +90,9 @@ class MSGraph(Graph):
...
@@ -90,9 +90,9 @@ class MSGraph(Graph):
node_name
=
leaf_node_id_map_name
[
node_def
.
name
]
node_name
=
leaf_node_id_map_name
[
node_def
.
name
]
node
=
self
.
_leaf_nodes
[
node_name
]
node
=
self
.
_leaf_nodes
[
node_name
]
for
input_def
in
node_def
.
input
:
for
input_def
in
node_def
.
input
:
edge_type
=
EdgeTypeEnum
.
data
edge_type
=
EdgeTypeEnum
.
DATA
.
value
if
input_def
.
type
==
"CONTROL_EDGE"
:
if
input_def
.
type
==
"CONTROL_EDGE"
:
edge_type
=
EdgeTypeEnum
.
control
edge_type
=
EdgeTypeEnum
.
CONTROL
.
value
if
const_nodes_map
.
get
(
input_def
.
name
):
if
const_nodes_map
.
get
(
input_def
.
name
):
const_node
=
copy
.
deepcopy
(
const_nodes_map
[
input_def
.
name
])
const_node
=
copy
.
deepcopy
(
const_nodes_map
[
input_def
.
name
])
...
@@ -218,7 +218,7 @@ class MSGraph(Graph):
...
@@ -218,7 +218,7 @@ class MSGraph(Graph):
node
=
Node
(
name
=
const
.
key
,
node_id
=
const_node_id
)
node
=
Node
(
name
=
const
.
key
,
node_id
=
const_node_id
)
node
.
node_type
=
NodeTypeEnum
.
CONST
.
value
node
.
node_type
=
NodeTypeEnum
.
CONST
.
value
node
.
update_attr
({
const
.
key
:
str
(
const
.
value
)})
node
.
update_attr
({
const
.
key
:
str
(
const
.
value
)})
if
const
.
value
.
dtype
==
DataTypeEnum
.
DT_TENSOR
:
if
const
.
value
.
dtype
==
DataTypeEnum
.
DT_TENSOR
.
value
:
shape
=
[]
shape
=
[]
for
dim
in
const
.
value
.
tensor_val
.
dims
:
for
dim
in
const
.
value
.
tensor_val
.
dims
:
shape
.
append
(
dim
)
shape
.
append
(
dim
)
...
...
mindinsight/datavisual/data_transform/graph/node.py
浏览文件 @
8ea4d18c
...
@@ -172,7 +172,7 @@ class Node:
...
@@ -172,7 +172,7 @@ class Node:
Args:
Args:
polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name:
polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name:
{'edge_type': EdgeTypeEnum.
data
}}).
{'edge_type': EdgeTypeEnum.
DATA.value
}}).
"""
"""
self
.
_polymeric_output
.
update
(
polymeric_output
)
self
.
_polymeric_output
.
update
(
polymeric_output
)
...
...
tests/st/func/datavisual/graph/test_query_nodes_restful_api.py
浏览文件 @
8ea4d18c
...
@@ -19,11 +19,11 @@ Usage:
...
@@ -19,11 +19,11 @@ Usage:
pytest tests/st/func/datavisual
pytest tests/st/func/datavisual
"""
"""
import
os
import
os
import
json
import
pytest
import
pytest
from
..
import
globals
as
gbl
from
..
import
globals
as
gbl
from
.....utils.tools
import
get_url
from
.....utils.tools
import
get_url
,
compare_result_with_file
BASE_URL
=
'/v1/mindinsight/datavisual/graphs/nodes'
BASE_URL
=
'/v1/mindinsight/datavisual/graphs/nodes'
...
@@ -33,12 +33,6 @@ class TestQueryNodes:
...
@@ -33,12 +33,6 @@ class TestQueryNodes:
graph_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'graph_results'
)
graph_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'graph_results'
)
def
compare_result_with_file
(
self
,
result
,
filename
):
"""Compare result with file which contain the expected results."""
with
open
(
os
.
path
.
join
(
self
.
graph_results_dir
,
filename
),
'r'
)
as
fp
:
expected_results
=
json
.
load
(
fp
)
assert
result
==
expected_results
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
env_single
@
pytest
.
mark
.
env_single
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
platform_x86_cpu
...
@@ -65,4 +59,5 @@ class TestQueryNodes:
...
@@ -65,4 +59,5 @@ class TestQueryNodes:
url
=
get_url
(
BASE_URL
,
params
)
url
=
get_url
(
BASE_URL
,
params
)
response
=
client
.
get
(
url
)
response
=
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
self
.
compare_result_with_file
(
response
.
get_json
(),
result_file
)
file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
response
.
get_json
(),
file_path
)
tests/st/func/datavisual/graph/test_query_single_nodes_restful_api.py
浏览文件 @
8ea4d18c
...
@@ -19,12 +19,11 @@ Usage:
...
@@ -19,12 +19,11 @@ Usage:
pytest tests/st/func/datavisual
pytest tests/st/func/datavisual
"""
"""
import
os
import
os
import
json
import
pytest
import
pytest
from
..
import
globals
as
gbl
from
..
import
globals
as
gbl
from
.....utils.tools
import
get_url
from
.....utils.tools
import
get_url
,
compare_result_with_file
BASE_URL
=
'/v1/mindinsight/datavisual/graphs/single-node'
BASE_URL
=
'/v1/mindinsight/datavisual/graphs/single-node'
...
@@ -34,12 +33,6 @@ class TestQuerySingleNode:
...
@@ -34,12 +33,6 @@ class TestQuerySingleNode:
graph_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'graph_results'
)
graph_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'graph_results'
)
def
compare_result_with_file
(
self
,
result
,
filename
):
"""Compare result with file which contain the expected results."""
with
open
(
os
.
path
.
join
(
self
.
graph_results_dir
,
filename
),
'r'
)
as
fp
:
expected_results
=
json
.
load
(
fp
)
assert
result
==
expected_results
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
env_single
@
pytest
.
mark
.
env_single
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
platform_x86_cpu
...
@@ -59,4 +52,5 @@ class TestQuerySingleNode:
...
@@ -59,4 +52,5 @@ class TestQuerySingleNode:
url
=
get_url
(
BASE_URL
,
params
)
url
=
get_url
(
BASE_URL
,
params
)
response
=
client
.
get
(
url
)
response
=
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
self
.
compare_result_with_file
(
response
.
get_json
(),
result_file
)
file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
response
.
get_json
(),
file_path
)
tests/st/func/datavisual/graph/test_search_nodes_restful_api.py
浏览文件 @
8ea4d18c
...
@@ -19,11 +19,11 @@ Usage:
...
@@ -19,11 +19,11 @@ Usage:
pytest tests/st/func/datavisual
pytest tests/st/func/datavisual
"""
"""
import
os
import
os
import
json
import
pytest
import
pytest
from
..
import
globals
as
gbl
from
..
import
globals
as
gbl
from
.....utils.tools
import
get_url
from
.....utils.tools
import
get_url
,
compare_result_with_file
BASE_URL
=
'/v1/mindinsight/datavisual/graphs/nodes/names'
BASE_URL
=
'/v1/mindinsight/datavisual/graphs/nodes/names'
...
@@ -33,12 +33,6 @@ class TestSearchNodes:
...
@@ -33,12 +33,6 @@ class TestSearchNodes:
graph_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'graph_results'
)
graph_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'graph_results'
)
def
compare_result_with_file
(
self
,
result
,
filename
):
"""Compare result with file which contain the expected results."""
with
open
(
os
.
path
.
join
(
self
.
graph_results_dir
,
filename
),
'r'
)
as
fp
:
expected_results
=
json
.
load
(
fp
)
assert
result
==
expected_results
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
env_single
@
pytest
.
mark
.
env_single
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
platform_x86_cpu
...
@@ -59,4 +53,5 @@ class TestSearchNodes:
...
@@ -59,4 +53,5 @@ class TestSearchNodes:
url
=
get_url
(
BASE_URL
,
params
)
url
=
get_url
(
BASE_URL
,
params
)
response
=
client
.
get
(
url
)
response
=
client
.
get
(
url
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
self
.
compare_result_with_file
(
response
.
get_json
(),
result_file
)
file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
response
.
get_json
(),
file_path
)
tests/ut/datavisual/processors/test_graph_processor.py
浏览文件 @
8ea4d18c
...
@@ -29,7 +29,7 @@ import pytest
...
@@ -29,7 +29,7 @@ import pytest
from
..mock
import
MockLogger
from
..mock
import
MockLogger
from
....utils.log_operations
import
LogOperations
from
....utils.log_operations
import
LogOperations
from
....utils.tools
import
check_loading_done
,
delete_files_or_dirs
from
....utils.tools
import
check_loading_done
,
delete_files_or_dirs
,
compare_result_with_file
from
mindinsight.datavisual.common
import
exceptions
from
mindinsight.datavisual.common
import
exceptions
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
...
@@ -103,12 +103,6 @@ class TestGraphProcessor:
...
@@ -103,12 +103,6 @@ class TestGraphProcessor:
# wait for loading done
# wait for loading done
check_loading_done
(
self
.
_mock_data_manager
,
time_limit
=
5
)
check_loading_done
(
self
.
_mock_data_manager
,
time_limit
=
5
)
def
compare_result_with_file
(
self
,
result
,
filename
):
"""Compare result with file which contain the expected results."""
with
open
(
os
.
path
.
join
(
self
.
graph_results_dir
,
filename
),
'r'
)
as
fp
:
expected_results
=
json
.
load
(
fp
)
assert
result
==
expected_results
def
test_get_nodes_with_not_exist_train_id
(
self
,
load_graph_record
):
def
test_get_nodes_with_not_exist_train_id
(
self
,
load_graph_record
):
"""Test getting nodes with not exist train id."""
"""Test getting nodes with not exist train id."""
test_train_id
=
"not_exist_train_id"
test_train_id
=
"not_exist_train_id"
...
@@ -152,7 +146,9 @@ class TestGraphProcessor:
...
@@ -152,7 +146,9 @@ class TestGraphProcessor:
graph_processor
=
GraphProcessor
(
self
.
_train_id
,
graph_processor
=
GraphProcessor
(
self
.
_train_id
,
self
.
_mock_data_manager
)
self
.
_mock_data_manager
)
results
=
graph_processor
.
get_nodes
(
name
,
node_type
)
results
=
graph_processor
.
get_nodes
(
name
,
node_type
)
self
.
compare_result_with_file
(
results
,
result_file
)
expected_file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
results
,
expected_file_path
)
@
pytest
.
mark
.
parametrize
(
"search_content, result_file"
,
[
@
pytest
.
mark
.
parametrize
(
"search_content, result_file"
,
[
(
None
,
'test_search_node_names_with_search_content_expected_results1.json'
),
(
None
,
'test_search_node_names_with_search_content_expected_results1.json'
),
...
@@ -175,7 +171,8 @@ class TestGraphProcessor:
...
@@ -175,7 +171,8 @@ class TestGraphProcessor:
expected_results
=
{
'names'
:
[]}
expected_results
=
{
'names'
:
[]}
assert
results
==
expected_results
assert
results
==
expected_results
else
:
else
:
self
.
compare_result_with_file
(
results
,
result_file
)
expected_file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
results
,
expected_file_path
)
@
pytest
.
mark
.
parametrize
(
"offset"
,
[
-
100
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"offset"
,
[
-
100
,
-
1
])
def
test_search_node_names_with_negative_offset
(
self
,
load_graph_record
,
offset
):
def
test_search_node_names_with_negative_offset
(
self
,
load_graph_record
,
offset
):
...
@@ -203,7 +200,8 @@ class TestGraphProcessor:
...
@@ -203,7 +200,8 @@ class TestGraphProcessor:
results
=
graph_processor
.
search_node_names
(
test_search_content
,
results
=
graph_processor
.
search_node_names
(
test_search_content
,
test_offset
,
test_offset
,
test_limit
)
test_limit
)
self
.
compare_result_with_file
(
results
,
result_file
)
expected_file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
results
,
expected_file_path
)
def
test_search_node_names_with_wrong_limit
(
self
,
load_graph_record
):
def
test_search_node_names_with_wrong_limit
(
self
,
load_graph_record
):
"""Test search node names with wrong limit."""
"""Test search node names with wrong limit."""
...
@@ -227,7 +225,8 @@ class TestGraphProcessor:
...
@@ -227,7 +225,8 @@ class TestGraphProcessor:
graph_processor
=
GraphProcessor
(
self
.
_train_id
,
graph_processor
=
GraphProcessor
(
self
.
_train_id
,
self
.
_mock_data_manager
)
self
.
_mock_data_manager
)
results
=
graph_processor
.
search_single_node
(
name
)
results
=
graph_processor
.
search_single_node
(
name
)
self
.
compare_result_with_file
(
results
,
result_file
)
expected_file_path
=
os
.
path
.
join
(
self
.
graph_results_dir
,
result_file
)
compare_result_with_file
(
results
,
expected_file_path
)
def
test_search_single_node_with_not_exist_name
(
self
,
load_graph_record
):
def
test_search_single_node_with_not_exist_name
(
self
,
load_graph_record
):
...
...
tests/utils/tools.py
浏览文件 @
8ea4d18c
# Copyright 20
19
Huawei Technologies Co., Ltd
# Copyright 20
20
Huawei Technologies Co., Ltd
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -19,9 +19,13 @@ import io
...
@@ -19,9 +19,13 @@ import io
import
os
import
os
import
shutil
import
shutil
import
time
import
time
import
json
from
urllib.parse
import
urlencode
from
urllib.parse
import
urlencode
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
mindinsight.datavisual.common.enums
import
DataManagerStatus
from
mindinsight.datavisual.common.enums
import
DataManagerStatus
...
@@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string):
...
@@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string):
image_tensor
=
np
.
array
(
img
)
image_tensor
=
np
.
array
(
img
)
return
image_tensor
return
image_tensor
def
compare_result_with_file
(
result
,
expected_file_path
):
"""Compare result with file which contain the expected results."""
with
open
(
expected_file_path
,
'r'
)
as
file
:
expected_results
=
json
.
load
(
file
)
assert
result
==
expected_results
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录