Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ad9db524
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ad9db524
编写于
5月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1394 Fix comment error and mod parameter check in graphdata
Merge pull request !1394 from heleiwang/fix_comments_error
上级
7241f3a4
f28f883c
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
72 addition
and
15 deletion
+72
-15
mindspore/dataset/engine/graphdata.py
mindspore/dataset/engine/graphdata.py
+6
-4
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+56
-9
tests/ut/python/dataset/test_graphdata.py
tests/ut/python/dataset/test_graphdata.py
+10
-2
未找到文件。
mindspore/dataset/engine/graphdata.py
浏览文件 @
ad9db524
...
...
@@ -20,12 +20,13 @@ import numpy as np
from
mindspore._c_dataengine
import
Graph
from
mindspore._c_dataengine
import
Tensor
from
.validators
import
check_gnn_get_all_nodes
,
check_gnn_get_all_neighbors
,
check_gnn_get_node_feature
from
.validators
import
check_gnn_graphdata
,
check_gnn_get_all_nodes
,
check_gnn_get_all_neighbors
,
\
check_gnn_get_node_feature
class
GraphData
:
"""
Reads th graph dataset used for GNN training from the shared file and database.
Reads th
e
graph dataset used for GNN training from the shared file and database.
Args:
dataset_file (str): One of file names in dataset.
...
...
@@ -33,6 +34,7 @@ class GraphData:
(default=None).
"""
@
check_gnn_graphdata
def
__init__
(
self
,
dataset_file
,
num_parallel_workers
=
None
):
self
.
_dataset_file
=
dataset_file
if
num_parallel_workers
is
None
:
...
...
@@ -45,7 +47,7 @@ class GraphData:
Get all nodes in the graph.
Args:
node_type (int): Specify the t
py
e of node.
node_type (int): Specify the t
yp
e of node.
Returns:
numpy.ndarray: array of nodes.
...
...
@@ -67,7 +69,7 @@ class GraphData:
Args:
node_list (list or numpy.ndarray): The given list of nodes.
neighbor_type (int): Specify the t
py
e of neighbor.
neighbor_type (int): Specify the t
yp
e of neighbor.
Returns:
numpy.ndarray: array of nodes.
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
ad9db524
...
...
@@ -19,6 +19,7 @@ import inspect as ins
import
os
from
functools
import
wraps
from
multiprocessing
import
cpu_count
import
numpy
as
np
from
mindspore._c_expression
import
typing
from
.
import
samplers
from
.
import
datasets
...
...
@@ -1075,14 +1076,48 @@ def check_split(method):
return
new_method
def
check_list_or_ndarray
(
param
,
param_name
):
if
(
not
isinstance
(
param
,
list
))
and
(
not
hasattr
(
param
,
'tolist'
)):
raise
TypeError
(
"Wrong input type for {0}, should be list, got {1}"
.
format
(
def
check_gnn_graphdata
(
method
):
"""check the input arguments of graphdata."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check dataset_file; required argument
dataset_file
=
param_dict
.
get
(
'dataset_file'
)
if
dataset_file
is
None
:
raise
ValueError
(
"dataset_file is not provided."
)
check_dataset_file
(
dataset_file
)
nreq_param_int
=
[
'num_parallel_workers'
]
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_gnn_list_or_ndarray
(
param
,
param_name
):
"""Check if the input parameter is list or numpy.ndarray."""
if
isinstance
(
param
,
list
):
for
m
in
param
:
if
not
isinstance
(
m
,
int
):
raise
TypeError
(
"Each membor in {0} should be of type int. Got {1}."
.
format
(
param_name
,
type
(
m
)))
elif
isinstance
(
param
,
np
.
ndarray
):
if
not
param
.
dtype
==
np
.
int32
:
raise
TypeError
(
"Each membor in {0} should be of type int32. Got {1}."
.
format
(
param_name
,
param
.
dtype
))
else
:
raise
TypeError
(
"Wrong input type for {0}, should be list or numpy.ndarray, got {1}"
.
format
(
param_name
,
type
(
param
)))
def
check_gnn_get_all_nodes
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
...
...
@@ -1103,7 +1138,7 @@ def check_gnn_get_all_neighbors(method):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check node_list; required argument
check_list_or_ndarray
(
param_dict
.
get
(
"node_list"
),
'node_list'
)
check_
gnn_
list_or_ndarray
(
param_dict
.
get
(
"node_list"
),
'node_list'
)
# check neighbor_type; required argument
check_type
(
param_dict
.
get
(
"neighbor_type"
),
'neighbor_type'
,
int
)
...
...
@@ -1113,15 +1148,16 @@ def check_gnn_get_all_neighbors(method):
return
new_method
def
check_aligned_list
(
param
,
param_name
):
def
check_aligned_list
(
param
,
param_name
,
membor_type
):
"""Check whether the structure of each member of the list is the same."""
if
not
isinstance
(
param
,
list
):
raise
TypeError
(
"Parameter {0} is not a list"
.
format
(
param_name
))
membor_have_list
=
None
list_len
=
None
for
membor
in
param
:
if
isinstance
(
membor
,
list
):
check_aligned_list
(
membor
,
param_name
)
check_aligned_list
(
membor
,
param_name
,
membor_type
)
if
membor_have_list
not
in
(
None
,
True
):
raise
TypeError
(
"The type of each member of the parameter {0} is inconsistent"
.
format
(
param_name
))
...
...
@@ -1131,6 +1167,9 @@ def check_aligned_list(param, param_name):
membor_have_list
=
True
list_len
=
len
(
membor
)
else
:
if
not
isinstance
(
membor
,
membor_type
):
raise
TypeError
(
"Each membor in {0} should be of type int. Got {1}."
.
format
(
param_name
,
type
(
membor
)))
if
membor_have_list
not
in
(
None
,
False
):
raise
TypeError
(
"The type of each member of the parameter {0} is inconsistent"
.
format
(
param_name
))
...
...
@@ -1139,18 +1178,26 @@ def check_aligned_list(param, param_name):
def
check_gnn_get_node_feature
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check node_list; required argument
node_list
=
param_dict
.
get
(
"node_list"
)
check_list_or_ndarray
(
node_list
,
'node_list'
)
if
isinstance
(
node_list
,
list
):
check_aligned_list
(
node_list
,
'node_list'
)
check_aligned_list
(
node_list
,
'node_list'
,
int
)
elif
isinstance
(
node_list
,
np
.
ndarray
):
if
not
node_list
.
dtype
==
np
.
int32
:
raise
TypeError
(
"Each membor in {0} should be of type int32. Got {1}."
.
format
(
node_list
,
node_list
.
dtype
))
else
:
raise
TypeError
(
"Wrong input type for {0}, should be list or numpy.ndarray, got {1}"
.
format
(
'node_list'
,
type
(
node_list
)))
# check feature_types; required argument
check_list_or_ndarray
(
param_dict
.
get
(
"feature_types"
),
'feature_types'
)
check_gnn_list_or_ndarray
(
param_dict
.
get
(
"feature_types"
),
'feature_types'
)
return
method
(
*
args
,
**
kwargs
)
...
...
tests/ut/python/dataset/test_graphdata.py
浏览文件 @
ad9db524
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import
pytest
import
numpy
as
np
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
...
...
@@ -23,8 +24,7 @@ def test_graphdata_getfullneighbor():
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
nodes
=
g
.
get_all_nodes
(
1
)
assert
len
(
nodes
)
==
10
nodes_list
=
nodes
.
tolist
()
neighbor
=
g
.
get_all_neighbors
(
nodes_list
,
2
)
neighbor
=
g
.
get_all_neighbors
(
nodes
,
2
)
assert
neighbor
.
shape
==
(
10
,
6
)
row_tensor
=
g
.
get_node_feature
(
neighbor
.
tolist
(),
[
2
,
3
])
assert
row_tensor
[
0
].
shape
==
(
10
,
6
)
...
...
@@ -60,6 +60,14 @@ def test_graphdata_getnodefeature_input_check():
input_list
=
[[
1
,
1
],
[
1
,
1
]]
g
.
get_node_feature
(
input_list
,
1
)
with
pytest
.
raises
(
TypeError
):
input_list
=
[[
1
,
0.1
],
[
1
,
1
]]
g
.
get_node_feature
(
input_list
,
1
)
with
pytest
.
raises
(
TypeError
):
input_list
=
np
.
array
([[
1
,
0.1
],
[
1
,
1
]])
g
.
get_node_feature
(
input_list
,
1
)
with
pytest
.
raises
(
TypeError
):
input_list
=
[[
1
,
1
],
[
1
,
1
]]
g
.
get_node_feature
(
input_list
,
[
"a"
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录