Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
eff90369
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eff90369
编写于
6月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1866 GNN data supports sampling API
Merge pull request !1866 from heleiwang/gnn_june
上级
2aaf1f31
3ece8dd0
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
867 addition
and
227 deletion
+867
-227
example/graph_to_mindrecord/README.md
example/graph_to_mindrecord/README.md
+0
-3
example/graph_to_mindrecord/citeseer/mr_api.py
example/graph_to_mindrecord/citeseer/mr_api.py
+68
-49
example/graph_to_mindrecord/cora/mr_api.py
example/graph_to_mindrecord/cora/mr_api.py
+44
-52
example/graph_to_mindrecord/write_citeseer.sh
example/graph_to_mindrecord/write_citeseer.sh
+1
-1
example/graph_to_mindrecord/write_cora.sh
example/graph_to_mindrecord/write_cora.sh
+1
-1
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+35
-4
mindspore/ccsrc/dataset/engine/gnn/graph.cc
mindspore/ccsrc/dataset/engine/gnn/graph.cc
+193
-57
mindspore/ccsrc/dataset/engine/gnn/graph.h
mindspore/ccsrc/dataset/engine/gnn/graph.h
+57
-31
mindspore/ccsrc/dataset/engine/gnn/local_node.cc
mindspore/ccsrc/dataset/engine/gnn/local_node.cc
+40
-11
mindspore/ccsrc/dataset/engine/gnn/local_node.h
mindspore/ccsrc/dataset/engine/gnn/local_node.h
+14
-3
mindspore/ccsrc/dataset/engine/gnn/node.h
mindspore/ccsrc/dataset/engine/gnn/node.h
+10
-3
mindspore/dataset/engine/graphdata.py
mindspore/dataset/engine/graphdata.py
+103
-3
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+85
-0
tests/ut/cpp/dataset/gnn_graph_test.cc
tests/ut/cpp/dataset/gnn_graph_test.cc
+113
-9
tests/ut/data/mindrecord/testGraphData/testdata
tests/ut/data/mindrecord/testGraphData/testdata
+0
-0
tests/ut/data/mindrecord/testGraphData/testdata.db
tests/ut/data/mindrecord/testGraphData/testdata.db
+0
-0
tests/ut/python/dataset/test_graphdata.py
tests/ut/python/dataset/test_graphdata.py
+103
-0
未找到文件。
example/graph_to_mindrecord/README.md
浏览文件 @
eff90369
...
@@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t
...
@@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t
1.
Download and prepare the Cora dataset as required.
1.
Download and prepare the Cora dataset as required.
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
2.
Edit write_cora.sh and modify the parameters
2.
Edit write_cora.sh and modify the parameters
```
```
--mindrecord_file: output MindRecord file.
--mindrecord_file: output MindRecord file.
...
...
example/graph_to_mindrecord/citeseer/mr_api.py
浏览文件 @
eff90369
...
@@ -15,29 +15,26 @@
...
@@ -15,29 +15,26 @@
"""
"""
User-defined API for MindRecord GNN writer.
User-defined API for MindRecord GNN writer.
"""
"""
import
csv
import
os
import
os
import
pickle
as
pkl
import
numpy
as
np
import
numpy
as
np
import
scipy.sparse
as
sp
import
scipy.sparse
as
sp
# parse args from command line parameter 'graph_api_args'
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
# args delimiter is ':'
args
=
os
.
environ
[
'graph_api_args'
].
split
(
':'
)
args
=
os
.
environ
[
'graph_api_args'
].
split
(
':'
)
CITESEER_CONTENT_FILE
=
args
[
0
]
CITESEER_PATH
=
args
[
0
]
CITESEER_CITES_FILE
=
args
[
1
]
dataset_str
=
'citeseer'
CITESEER_MINDRECRD_LABEL_FILE
=
CITESEER_CONTENT_FILE
+
"_label_mindrecord"
CITESEER_MINDRECRD_ID_MAP_FILE
=
CITESEER_CONTENT_FILE
+
"_id_mindrecord"
node_id_map
=
{}
# profile: (num_features, feature_data_types, feature_shapes)
# profile: (num_features, feature_data_types, feature_shapes)
node_profile
=
(
2
,
[
"float32"
,
"int
64
"
],
[[
-
1
],
[
-
1
]])
node_profile
=
(
2
,
[
"float32"
,
"int
32
"
],
[[
-
1
],
[
-
1
]])
edge_profile
=
(
0
,
[],
[])
edge_profile
=
(
0
,
[],
[])
node_ids
=
[]
def
_normalize_citeseer_features
(
features
):
def
_normalize_citeseer_features
(
features
):
features
=
np
.
array
(
features
)
row_sum
=
np
.
array
(
features
.
sum
(
1
))
row_sum
=
np
.
array
(
features
.
sum
(
1
))
r_inv
=
np
.
power
(
row_sum
*
1.0
,
-
1
).
flatten
()
r_inv
=
np
.
power
(
row_sum
*
1.0
,
-
1
).
flatten
()
r_inv
[
np
.
isinf
(
r_inv
)]
=
0.
r_inv
[
np
.
isinf
(
r_inv
)]
=
0.
...
@@ -46,6 +43,14 @@ def _normalize_citeseer_features(features):
...
@@ -46,6 +43,14 @@ def _normalize_citeseer_features(features):
return
features
return
features
def
_parse_index_file
(
filename
):
"""Parse index file."""
index
=
[]
for
line
in
open
(
filename
):
index
.
append
(
int
(
line
.
strip
()))
return
index
def
yield_nodes
(
task_id
=
0
):
def
yield_nodes
(
task_id
=
0
):
"""
"""
Generate node data
Generate node data
...
@@ -54,29 +59,46 @@ def yield_nodes(task_id=0):
...
@@ -54,29 +59,46 @@ def yield_nodes(task_id=0):
data (dict): data row which is dict.
data (dict): data row which is dict.
"""
"""
print
(
"Node task is {}"
.
format
(
task_id
))
print
(
"Node task is {}"
.
format
(
task_id
))
label_types
=
{}
names
=
[
'x'
,
'y'
,
'tx'
,
'ty'
,
'allx'
,
'ally'
]
label_size
=
0
objects
=
[]
node_num
=
0
for
name
in
names
:
with
open
(
CITESEER_CONTENT_FILE
)
as
content_file
:
with
open
(
"{}/ind.{}.{}"
.
format
(
CITESEER_PATH
,
dataset_str
,
name
),
'rb'
)
as
f
:
content_reader
=
csv
.
reader
(
content_file
,
delimiter
=
'
\t
'
)
objects
.
append
(
pkl
.
load
(
f
,
encoding
=
'latin1'
))
line_count
=
0
x
,
y
,
tx
,
ty
,
allx
,
ally
=
tuple
(
objects
)
for
row
in
content_reader
:
test_idx_reorder
=
_parse_index_file
(
if
not
row
[
-
1
]
in
label_types
:
"{}/ind.{}.test.index"
.
format
(
CITESEER_PATH
,
dataset_str
))
label_types
[
row
[
-
1
]]
=
label_size
test_idx_range
=
np
.
sort
(
test_idx_reorder
)
label_size
+=
1
if
not
row
[
0
]
in
node_id_map
:
tx
=
_normalize_citeseer_features
(
tx
)
node_id_map
[
row
[
0
]]
=
node_num
allx
=
_normalize_citeseer_features
(
allx
)
node_num
+=
1
raw_features
=
[[
int
(
x
)
for
x
in
row
[
1
:
-
1
]]]
# Fix citeseer dataset (there are some isolated nodes in the graph)
node
=
{
'id'
:
node_id_map
[
row
[
0
]],
'type'
:
0
,
'feature_1'
:
_normalize_citeseer_features
(
raw_features
),
# Find isolated nodes, add them as zero-vecs into the right position
'feature_2'
:
[
label_types
[
row
[
-
1
]]]}
test_idx_range_full
=
range
(
min
(
test_idx_reorder
),
max
(
test_idx_reorder
)
+
1
)
yield
node
tx_extended
=
sp
.
lil_matrix
((
len
(
test_idx_range_full
),
x
.
shape
[
1
]))
line_count
+=
1
tx_extended
[
test_idx_range
-
min
(
test_idx_range
),
:]
=
tx
tx
=
tx_extended
ty_extended
=
np
.
zeros
((
len
(
test_idx_range_full
),
y
.
shape
[
1
]))
ty_extended
[
test_idx_range
-
min
(
test_idx_range
),
:]
=
ty
ty
=
ty_extended
features
=
sp
.
vstack
((
allx
,
tx
)).
tolil
()
features
[
test_idx_reorder
,
:]
=
features
[
test_idx_range
,
:]
features
=
features
.
A
labels
=
np
.
vstack
((
ally
,
ty
))
labels
[
test_idx_reorder
,
:]
=
labels
[
test_idx_range
,
:]
line_count
=
0
for
i
,
label
in
enumerate
(
labels
):
if
not
1
in
label
.
tolist
():
continue
node
=
{
'id'
:
i
,
'type'
:
0
,
'feature_1'
:
features
[
i
].
tolist
(),
'feature_2'
:
label
.
tolist
().
index
(
1
)}
line_count
+=
1
node_ids
.
append
(
i
)
yield
node
print
(
'Processed {} lines for nodes.'
.
format
(
line_count
))
print
(
'Processed {} lines for nodes.'
.
format
(
line_count
))
# print('label types {}.'.format(label_types))
with
open
(
CITESEER_MINDRECRD_LABEL_FILE
,
'w'
)
as
f
:
for
k
in
label_types
:
print
(
k
+
','
+
str
(
label_types
[
k
]),
file
=
f
)
def
yield_edges
(
task_id
=
0
):
def
yield_edges
(
task_id
=
0
):
...
@@ -87,23 +109,20 @@ def yield_edges(task_id=0):
...
@@ -87,23 +109,20 @@ def yield_edges(task_id=0):
data (dict): data row which is dict.
data (dict): data row which is dict.
"""
"""
print
(
"Edge task is {}"
.
format
(
task_id
))
print
(
"Edge task is {}"
.
format
(
task_id
))
# print(map_string_int)
with
open
(
"{}/ind.{}.graph"
.
format
(
CITESEER_PATH
,
dataset_str
),
'rb'
)
as
f
:
with
open
(
CITESEER_CITES_FILE
)
as
cites_file
:
graph
=
pkl
.
load
(
f
,
encoding
=
'latin1'
)
cites_reader
=
csv
.
reader
(
cites_file
,
delimiter
=
'
\t
'
)
line_count
=
0
line_count
=
0
for
row
in
cites_reader
:
for
i
in
graph
:
if
not
row
[
0
]
in
node_id_map
:
for
dst_id
in
graph
[
i
]:
print
(
'Source node {} does not exist.'
.
format
(
row
[
0
]))
if
not
i
in
node_ids
:
continue
print
(
'Source node {} does not exist.'
.
format
(
i
))
if
not
row
[
1
]
in
node_id_map
:
continue
print
(
'Destination node {} does not exist.'
.
format
(
row
[
1
]))
if
not
dst_id
in
node_ids
:
continue
print
(
'Destination node {} does not exist.'
.
format
(
line_count
+=
1
dst_id
))
edge
=
{
'id'
:
line_count
,
continue
'src_id'
:
node_id_map
[
row
[
0
]],
'dst_id'
:
node_id_map
[
row
[
1
]],
'type'
:
0
}
edge
=
{
'id'
:
line_count
,
yield
edge
'src_id'
:
i
,
'dst_id'
:
dst_id
,
'type'
:
0
}
line_count
+=
1
with
open
(
CITESEER_MINDRECRD_ID_MAP_FILE
,
'w'
)
as
f
:
yield
edge
for
k
in
node_id_map
:
print
(
k
+
','
+
str
(
node_id_map
[
k
]),
file
=
f
)
print
(
'Processed {} lines for edges.'
.
format
(
line_count
))
print
(
'Processed {} lines for edges.'
.
format
(
line_count
))
example/graph_to_mindrecord/cora/mr_api.py
浏览文件 @
eff90369
...
@@ -15,29 +15,24 @@
...
@@ -15,29 +15,24 @@
"""
"""
User-defined API for MindRecord GNN writer.
User-defined API for MindRecord GNN writer.
"""
"""
import
csv
import
os
import
os
import
pickle
as
pkl
import
numpy
as
np
import
numpy
as
np
import
scipy.sparse
as
sp
import
scipy.sparse
as
sp
# parse args from command line parameter 'graph_api_args'
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
# args delimiter is ':'
args
=
os
.
environ
[
'graph_api_args'
].
split
(
':'
)
args
=
os
.
environ
[
'graph_api_args'
].
split
(
':'
)
CORA_CONTENT_FILE
=
args
[
0
]
CORA_PATH
=
args
[
0
]
CORA_CITES_FILE
=
args
[
1
]
dataset_str
=
'cora'
CORA_MINDRECRD_LABEL_FILE
=
CORA_CONTENT_FILE
+
"_label_mindrecord"
CORA_CONTENT_ID_MAP_FILE
=
CORA_CONTENT_FILE
+
"_id_mindrecord"
node_id_map
=
{}
# profile: (num_features, feature_data_types, feature_shapes)
# profile: (num_features, feature_data_types, feature_shapes)
node_profile
=
(
2
,
[
"float32"
,
"int
64
"
],
[[
-
1
],
[
-
1
]])
node_profile
=
(
2
,
[
"float32"
,
"int
32
"
],
[[
-
1
],
[
-
1
]])
edge_profile
=
(
0
,
[],
[])
edge_profile
=
(
0
,
[],
[])
def
_normalize_cora_features
(
features
):
def
_normalize_cora_features
(
features
):
features
=
np
.
array
(
features
)
row_sum
=
np
.
array
(
features
.
sum
(
1
))
row_sum
=
np
.
array
(
features
.
sum
(
1
))
r_inv
=
np
.
power
(
row_sum
*
1.0
,
-
1
).
flatten
()
r_inv
=
np
.
power
(
row_sum
*
1.0
,
-
1
).
flatten
()
r_inv
[
np
.
isinf
(
r_inv
)]
=
0.
r_inv
[
np
.
isinf
(
r_inv
)]
=
0.
...
@@ -46,6 +41,14 @@ def _normalize_cora_features(features):
...
@@ -46,6 +41,14 @@ def _normalize_cora_features(features):
return
features
return
features
def
_parse_index_file
(
filename
):
"""Parse index file."""
index
=
[]
for
line
in
open
(
filename
):
index
.
append
(
int
(
line
.
strip
()))
return
index
def
yield_nodes
(
task_id
=
0
):
def
yield_nodes
(
task_id
=
0
):
"""
"""
Generate node data
Generate node data
...
@@ -54,32 +57,32 @@ def yield_nodes(task_id=0):
...
@@ -54,32 +57,32 @@ def yield_nodes(task_id=0):
data (dict): data row which is dict.
data (dict): data row which is dict.
"""
"""
print
(
"Node task is {}"
.
format
(
task_id
))
print
(
"Node task is {}"
.
format
(
task_id
))
label_types
=
{}
label_size
=
0
names
=
[
'tx'
,
'ty'
,
'allx'
,
'ally'
]
node_num
=
0
objects
=
[]
with
open
(
CORA_CONTENT_FILE
)
as
content_file
:
for
name
in
names
:
content_reader
=
csv
.
reader
(
content_file
,
delimiter
=
','
)
with
open
(
"{}/ind.{}.{}"
.
format
(
CORA_PATH
,
dataset_str
,
name
),
'rb'
)
as
f
:
line_count
=
0
objects
.
append
(
pkl
.
load
(
f
,
encoding
=
'latin1'
))
for
row
in
content_reader
:
tx
,
ty
,
allx
,
ally
=
tuple
(
objects
)
if
line_count
==
0
:
test_idx_reorder
=
_parse_index_file
(
line_count
+=
1
"{}/ind.{}.test.index"
.
format
(
CORA_PATH
,
dataset_str
))
continue
test_idx_range
=
np
.
sort
(
test_idx_reorder
)
if
not
row
[
0
]
in
node_id_map
:
node_id_map
[
row
[
0
]]
=
node_num
features
=
sp
.
vstack
((
allx
,
tx
)).
tolil
()
node_num
+=
1
features
[
test_idx_reorder
,
:]
=
features
[
test_idx_range
,
:]
if
not
row
[
-
1
]
in
label_types
:
features
=
_normalize_cora_features
(
features
)
label_types
[
row
[
-
1
]]
=
label_size
features
=
features
.
A
label_size
+=
1
raw_features
=
[[
int
(
x
)
for
x
in
row
[
1
:
-
1
]]]
labels
=
np
.
vstack
((
ally
,
ty
))
node
=
{
'id'
:
node_id_map
[
row
[
0
]],
'type'
:
0
,
'feature_1'
:
_normalize_cora_features
(
raw_features
),
labels
[
test_idx_reorder
,
:]
=
labels
[
test_idx_range
,
:]
'feature_2'
:
[
label_types
[
row
[
-
1
]]]}
yield
node
line_count
=
0
line_count
+=
1
for
i
,
label
in
enumerate
(
labels
):
node
=
{
'id'
:
i
,
'type'
:
0
,
'feature_1'
:
features
[
i
].
tolist
(),
'feature_2'
:
label
.
tolist
().
index
(
1
)}
line_count
+=
1
yield
node
print
(
'Processed {} lines for nodes.'
.
format
(
line_count
))
print
(
'Processed {} lines for nodes.'
.
format
(
line_count
))
print
(
'label types {}.'
.
format
(
label_types
))
with
open
(
CORA_MINDRECRD_LABEL_FILE
,
'w'
)
as
f
:
for
k
in
label_types
:
print
(
k
+
','
+
str
(
label_types
[
k
]),
file
=
f
)
def
yield_edges
(
task_id
=
0
):
def
yield_edges
(
task_id
=
0
):
...
@@ -90,24 +93,13 @@ def yield_edges(task_id=0):
...
@@ -90,24 +93,13 @@ def yield_edges(task_id=0):
data (dict): data row which is dict.
data (dict): data row which is dict.
"""
"""
print
(
"Edge task is {}"
.
format
(
task_id
))
print
(
"Edge task is {}"
.
format
(
task_id
))
with
open
(
CORA_CITES_FILE
)
as
cites_file
:
with
open
(
"{}/ind.{}.graph"
.
format
(
CORA_PATH
,
dataset_str
),
'rb'
)
as
f
:
cites_reader
=
csv
.
reader
(
cites_file
,
delimiter
=
',
'
)
graph
=
pkl
.
load
(
f
,
encoding
=
'latin1
'
)
line_count
=
0
line_count
=
0
for
row
in
cites_reader
:
for
i
in
graph
:
if
line_count
==
0
:
for
dst_id
in
graph
[
i
]:
edge
=
{
'id'
:
line_count
,
'src_id'
:
i
,
'dst_id'
:
dst_id
,
'type'
:
0
}
line_count
+=
1
line_count
+=
1
continue
yield
edge
if
not
row
[
0
]
in
node_id_map
:
print
(
'Source node {} does not exist.'
.
format
(
row
[
0
]))
continue
if
not
row
[
1
]
in
node_id_map
:
print
(
'Destination node {} does not exist.'
.
format
(
row
[
1
]))
continue
edge
=
{
'id'
:
line_count
,
'src_id'
:
node_id_map
[
row
[
0
]],
'dst_id'
:
node_id_map
[
row
[
1
]],
'type'
:
0
}
yield
edge
line_count
+=
1
print
(
'Processed {} lines for edges.'
.
format
(
line_count
))
print
(
'Processed {} lines for edges.'
.
format
(
line_count
))
with
open
(
CORA_CONTENT_ID_MAP_FILE
,
'w'
)
as
f
:
for
k
in
node_id_map
:
print
(
k
+
','
+
str
(
node_id_map
[
k
]),
file
=
f
)
example/graph_to_mindrecord/write_citeseer.sh
浏览文件 @
eff90369
...
@@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \
...
@@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \
--mindrecord_partitions
1
\
--mindrecord_partitions
1
\
--mindrecord_header_size_by_bit
18
\
--mindrecord_header_size_by_bit
18
\
--mindrecord_page_size_by_bit
20
\
--mindrecord_page_size_by_bit
20
\
--graph_api_args
"
$SRC_PATH
/citeseer.content:
$SRC_PATH
/citeseer.cites
"
--graph_api_args
"
$SRC_PATH
"
example/graph_to_mindrecord/write_cora.sh
浏览文件 @
eff90369
...
@@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \
...
@@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \
--mindrecord_partitions
1
\
--mindrecord_partitions
1
\
--mindrecord_header_size_by_bit
18
\
--mindrecord_header_size_by_bit
18
\
--mindrecord_page_size_by_bit
20
\
--mindrecord_page_size_by_bit
20
\
--graph_api_args
"
$SRC_PATH
/cora_content.csv:
$SRC_PATH
/cora_cites.csv
"
--graph_api_args
"
$SRC_PATH
"
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
eff90369
...
@@ -527,10 +527,22 @@ void bindGraphData(py::module *m) {
...
@@ -527,10 +527,22 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR
(
g_out
->
Init
());
THROW_IF_ERROR
(
g_out
->
Init
());
return
g_out
;
return
g_out
;
}))
}))
.
def
(
"get_nodes"
,
.
def
(
"get_
all_
nodes"
,
[](
gnn
::
Graph
&
g
,
gnn
::
NodeType
node_type
,
gnn
::
NodeIdType
node_num
)
{
[](
gnn
::
Graph
&
g
,
gnn
::
NodeType
node_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetNodes
(
node_type
,
node_num
,
&
out
));
THROW_IF_ERROR
(
g
.
GetAllNodes
(
node_type
,
&
out
));
return
out
;
})
.
def
(
"get_all_edges"
,
[](
gnn
::
Graph
&
g
,
gnn
::
EdgeType
edge_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetAllEdges
(
edge_type
,
&
out
));
return
out
;
})
.
def
(
"get_nodes_from_edges"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
edge_list
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetNodesFromEdges
(
edge_list
,
&
out
));
return
out
;
return
out
;
})
})
.
def
(
"get_all_neighbors"
,
.
def
(
"get_all_neighbors"
,
...
@@ -539,12 +551,31 @@ void bindGraphData(py::module *m) {
...
@@ -539,12 +551,31 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR
(
g
.
GetAllNeighbors
(
node_list
,
neighbor_type
,
&
out
));
THROW_IF_ERROR
(
g
.
GetAllNeighbors
(
node_list
,
neighbor_type
,
&
out
));
return
out
;
return
out
;
})
})
.
def
(
"get_sampled_neighbors"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
std
::
vector
<
gnn
::
NodeIdType
>
neighbor_nums
,
std
::
vector
<
gnn
::
NodeType
>
neighbor_types
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetSampledNeighbors
(
node_list
,
neighbor_nums
,
neighbor_types
,
&
out
));
return
out
;
})
.
def
(
"get_neg_sampled_neighbors"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
gnn
::
NodeIdType
neighbor_num
,
gnn
::
NodeType
neg_neighbor_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetNegSampledNeighbors
(
node_list
,
neighbor_num
,
neg_neighbor_type
,
&
out
));
return
out
;
})
.
def
(
"get_node_feature"
,
.
def
(
"get_node_feature"
,
[](
gnn
::
Graph
&
g
,
std
::
shared_ptr
<
Tensor
>
node_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
[](
gnn
::
Graph
&
g
,
std
::
shared_ptr
<
Tensor
>
node_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
TensorRow
out
;
TensorRow
out
;
THROW_IF_ERROR
(
g
.
GetNodeFeature
(
node_list
,
feature_types
,
&
out
));
THROW_IF_ERROR
(
g
.
GetNodeFeature
(
node_list
,
feature_types
,
&
out
));
return
out
;
return
out
;
});
})
.
def
(
"graph_info"
,
[](
gnn
::
Graph
&
g
)
{
py
::
dict
out
;
THROW_IF_ERROR
(
g
.
GraphInfo
(
&
out
));
return
out
;
});
}
}
// This is where we externalize the C logic as python modules
// This is where we externalize the C logic as python modules
...
...
mindspore/ccsrc/dataset/engine/gnn/graph.cc
浏览文件 @
eff90369
...
@@ -17,29 +17,30 @@
...
@@ -17,29 +17,30 @@
#include <algorithm>
#include <algorithm>
#include <functional>
#include <functional>
#include <iterator>
#include <numeric>
#include <numeric>
#include <utility>
#include <utility>
#include "dataset/core/tensor_shape.h"
#include "dataset/core/tensor_shape.h"
#include "dataset/util/random.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
dataset
{
namespace
dataset
{
namespace
gnn
{
namespace
gnn
{
Graph
::
Graph
(
std
::
string
dataset_file
,
int32_t
num_workers
)
:
dataset_file_
(
dataset_file
),
num_workers_
(
num_workers
)
{
Graph
::
Graph
(
std
::
string
dataset_file
,
int32_t
num_workers
)
:
dataset_file_
(
dataset_file
),
num_workers_
(
num_workers
),
rnd_
(
GetRandomDevice
())
{
rnd_
.
seed
(
GetSeed
());
MS_LOG
(
INFO
)
<<
"num_workers:"
<<
num_workers
;
MS_LOG
(
INFO
)
<<
"num_workers:"
<<
num_workers
;
}
}
Status
Graph
::
Get
Nodes
(
NodeType
node_type
,
NodeIdType
node_num
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
Status
Graph
::
Get
AllNodes
(
NodeType
node_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
auto
itr
=
node_type_map_
.
find
(
node_type
);
auto
itr
=
node_type_map_
.
find
(
node_type
);
if
(
itr
==
node_type_map_
.
end
())
{
if
(
itr
==
node_type_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid node type:"
+
std
::
to_string
(
node_type
);
std
::
string
err_msg
=
"Invalid node type:"
+
std
::
to_string
(
node_type
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
}
else
{
if
(
node_num
==
-
1
)
{
RETURN_IF_NOT_OK
(
CreateTensorByVector
<
NodeIdType
>
({
itr
->
second
},
DataType
(
DataType
::
DE_INT32
),
out
));
RETURN_IF_NOT_OK
(
CreateTensorByVector
<
NodeIdType
>
({
itr
->
second
},
DataType
(
DataType
::
DE_INT32
),
out
));
}
else
{
}
}
}
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -59,9 +60,9 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data
...
@@ -59,9 +60,9 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
tensor
,
TensorImpl
::
kFlexible
,
TensorShape
({
static_cast
<
dsize_t
>
(
m
),
static_cast
<
dsize_t
>
(
n
)}),
type
,
nullptr
));
&
tensor
,
TensorImpl
::
kFlexible
,
TensorShape
({
static_cast
<
dsize_t
>
(
m
),
static_cast
<
dsize_t
>
(
n
)}),
type
,
nullptr
));
T
*
ptr
=
reinterpret_cast
<
T
*>
(
tensor
->
GetMutableBuffer
());
T
*
ptr
=
reinterpret_cast
<
T
*>
(
tensor
->
GetMutableBuffer
());
for
(
auto
id_m
:
data
)
{
for
(
const
auto
&
id_m
:
data
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
id_m
.
size
()
==
n
,
"Each member of the vector has a different size"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
id_m
.
size
()
==
n
,
"Each member of the vector has a different size"
);
for
(
auto
id_n
:
id_m
)
{
for
(
const
auto
&
id_n
:
id_m
)
{
*
ptr
=
id_n
;
*
ptr
=
id_n
;
ptr
++
;
ptr
++
;
}
}
...
@@ -89,7 +90,38 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz
...
@@ -89,7 +90,38 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Graph
::
GetEdges
(
EdgeType
edge_type
,
EdgeIdType
edge_num
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
return
Status
::
OK
();
}
Status
Graph
::
GetAllEdges
(
EdgeType
edge_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
auto
itr
=
edge_type_map_
.
find
(
edge_type
);
if
(
itr
==
edge_type_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid edge type:"
+
std
::
to_string
(
edge_type
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
RETURN_IF_NOT_OK
(
CreateTensorByVector
<
EdgeIdType
>
({
itr
->
second
},
DataType
(
DataType
::
DE_INT32
),
out
));
}
return
Status
::
OK
();
}
Status
Graph
::
GetNodesFromEdges
(
const
std
::
vector
<
EdgeIdType
>
&
edge_list
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
if
(
edge_list
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"Input edge_list is empty"
);
}
std
::
vector
<
std
::
vector
<
NodeIdType
>>
node_list
;
node_list
.
reserve
(
edge_list
.
size
());
for
(
const
auto
&
edge_id
:
edge_list
)
{
auto
itr
=
edge_id_map_
.
find
(
edge_id
);
if
(
itr
==
edge_id_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid edge id:"
+
std
::
to_string
(
edge_id
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>
nodes
;
RETURN_IF_NOT_OK
(
itr
->
second
->
GetNode
(
&
nodes
));
node_list
.
push_back
({
nodes
.
first
->
id
(),
nodes
.
second
->
id
()});
}
}
RETURN_IF_NOT_OK
(
CreateTensorByVector
<
NodeIdType
>
(
node_list
,
DataType
(
DataType
::
DE_INT32
),
out
));
return
Status
::
OK
();
}
Status
Graph
::
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
Status
Graph
::
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
std
::
shared_ptr
<
Tensor
>
*
out
)
{
...
@@ -105,14 +137,10 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
...
@@ -105,14 +137,10 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
size_t
max_neighbor_num
=
0
;
size_t
max_neighbor_num
=
0
;
neighbors
.
resize
(
node_list
.
size
());
neighbors
.
resize
(
node_list
.
size
());
for
(
size_t
i
=
0
;
i
<
node_list
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
node_list
.
size
();
++
i
)
{
auto
itr
=
node_id_map_
.
find
(
node_list
[
i
]);
std
::
shared_ptr
<
Node
>
node
;
if
(
itr
!=
node_id_map_
.
end
())
{
RETURN_IF_NOT_OK
(
GetNodeByNodeId
(
node_list
[
i
],
&
node
));
RETURN_IF_NOT_OK
(
itr
->
second
->
GetNeighbors
(
neighbor_type
,
-
1
,
&
neighbors
[
i
]));
RETURN_IF_NOT_OK
(
node
->
GetAllNeighbors
(
neighbor_type
,
&
neighbors
[
i
]));
max_neighbor_num
=
max_neighbor_num
>
neighbors
[
i
].
size
()
?
max_neighbor_num
:
neighbors
[
i
].
size
();
max_neighbor_num
=
max_neighbor_num
>
neighbors
[
i
].
size
()
?
max_neighbor_num
:
neighbors
[
i
].
size
();
}
else
{
std
::
string
err_msg
=
"Invalid node id:"
+
std
::
to_string
(
node_list
[
i
]);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
}
RETURN_IF_NOT_OK
(
ComplementVector
<
NodeIdType
>
(
&
neighbors
,
max_neighbor_num
,
kDefaultNodeId
));
RETURN_IF_NOT_OK
(
ComplementVector
<
NodeIdType
>
(
&
neighbors
,
max_neighbor_num
,
kDefaultNodeId
));
...
@@ -121,13 +149,94 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
...
@@ -121,13 +149,94 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Graph
::
GetSampledNeighbor
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
Status
Graph
::
GetSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
!
node_list
.
empty
(),
"Input node_list is empty."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
neighbor_nums
.
size
()
==
neighbor_types
.
size
(),
"The sizes of neighbor_nums and neighbor_types are inconsistent."
);
std
::
vector
<
std
::
vector
<
NodeIdType
>>
neighbors_vec
(
node_list
.
size
());
for
(
size_t
node_idx
=
0
;
node_idx
<
node_list
.
size
();
++
node_idx
)
{
neighbors_vec
[
node_idx
].
emplace_back
(
node_list
[
node_idx
]);
std
::
vector
<
NodeIdType
>
input_list
=
{
node_list
[
node_idx
]};
for
(
size_t
i
=
0
;
i
<
neighbor_nums
.
size
();
++
i
)
{
std
::
vector
<
NodeIdType
>
neighbors
;
neighbors
.
reserve
(
input_list
.
size
()
*
neighbor_nums
[
i
]);
for
(
const
auto
&
node_id
:
input_list
)
{
if
(
node_id
==
kDefaultNodeId
)
{
for
(
int32_t
j
=
0
;
j
<
neighbor_nums
[
i
];
++
j
)
{
neighbors
.
emplace_back
(
kDefaultNodeId
);
}
}
else
{
std
::
shared_ptr
<
Node
>
node
;
RETURN_IF_NOT_OK
(
GetNodeByNodeId
(
node_id
,
&
node
));
std
::
vector
<
NodeIdType
>
out
;
RETURN_IF_NOT_OK
(
node
->
GetSampledNeighbors
(
neighbor_types
[
i
],
neighbor_nums
[
i
],
&
out
));
neighbors
.
insert
(
neighbors
.
end
(),
out
.
begin
(),
out
.
end
());
}
}
neighbors_vec
[
node_idx
].
insert
(
neighbors_vec
[
node_idx
].
end
(),
neighbors
.
begin
(),
neighbors
.
end
());
input_list
=
std
::
move
(
neighbors
);
}
}
RETURN_IF_NOT_OK
(
CreateTensorByVector
<
NodeIdType
>
(
neighbors_vec
,
DataType
(
DataType
::
DE_INT32
),
out
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Graph
::
GetNegSampledNeighbor
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
Status
Graph
::
NegativeSample
(
const
std
::
vector
<
NodeIdType
>
&
data
,
const
std
::
unordered_set
<
NodeIdType
>
&
exclude_data
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_samples
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
!
data
.
empty
(),
"Input data is empty."
);
std
::
vector
<
NodeIdType
>
shuffled_id
(
data
.
size
());
std
::
iota
(
shuffled_id
.
begin
(),
shuffled_id
.
end
(),
0
);
std
::
shuffle
(
shuffled_id
.
begin
(),
shuffled_id
.
end
(),
rnd_
);
for
(
const
auto
&
index
:
shuffled_id
)
{
if
(
exclude_data
.
find
(
data
[
index
])
!=
exclude_data
.
end
())
{
continue
;
}
out_samples
->
emplace_back
(
data
[
index
]);
if
(
out_samples
->
size
()
>=
samples_num
)
{
break
;
}
}
return
Status
::
OK
();
}
Status
Graph
::
GetNegSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
!
node_list
.
empty
(),
"Input node_list is empty."
);
std
::
vector
<
std
::
vector
<
NodeIdType
>>
neighbors_vec
;
neighbors_vec
.
resize
(
node_list
.
size
());
for
(
size_t
node_idx
=
0
;
node_idx
<
node_list
.
size
();
++
node_idx
)
{
std
::
shared_ptr
<
Node
>
node
;
RETURN_IF_NOT_OK
(
GetNodeByNodeId
(
node_list
[
node_idx
],
&
node
));
std
::
vector
<
NodeIdType
>
neighbors
;
RETURN_IF_NOT_OK
(
node
->
GetAllNeighbors
(
neg_neighbor_type
,
&
neighbors
));
std
::
unordered_set
<
NodeIdType
>
exclude_node
;
std
::
transform
(
neighbors
.
begin
(),
neighbors
.
end
(),
std
::
insert_iterator
<
std
::
unordered_set
<
NodeIdType
>>
(
exclude_node
,
exclude_node
.
begin
()),
[](
const
NodeIdType
node
)
{
return
node
;
});
auto
itr
=
node_type_map_
.
find
(
neg_neighbor_type
);
if
(
itr
==
node_type_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid node type:"
+
std
::
to_string
(
neg_neighbor_type
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
neighbors_vec
[
node_idx
].
emplace_back
(
node
->
id
());
if
(
itr
->
second
.
size
()
>
exclude_node
.
size
())
{
while
(
neighbors_vec
[
node_idx
].
size
()
<
samples_num
+
1
)
{
RETURN_IF_NOT_OK
(
NegativeSample
(
itr
->
second
,
exclude_node
,
samples_num
-
neighbors_vec
[
node_idx
].
size
(),
&
neighbors_vec
[
node_idx
]));
}
}
else
{
MS_LOG
(
DEBUG
)
<<
"There are no negative neighbors. node_id:"
<<
node
->
id
()
<<
" neg_neighbor_type:"
<<
neg_neighbor_type
;
// If there are no negative neighbors, they are filled with kDefaultNodeId
for
(
int32_t
i
=
0
;
i
<
samples_num
;
++
i
)
{
neighbors_vec
[
node_idx
].
emplace_back
(
kDefaultNodeId
);
}
}
}
}
RETURN_IF_NOT_OK
(
CreateTensorByVector
<
NodeIdType
>
(
neighbors_vec
,
DataType
(
DataType
::
DE_INT32
),
out
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -154,7 +263,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
...
@@ -154,7 +263,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
}
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
feature_types
.
empty
(),
"Inpude feature_types is empty"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
feature_types
.
empty
(),
"Inpude feature_types is empty"
);
TensorRow
tensors
;
TensorRow
tensors
;
for
(
auto
f_type
:
feature_types
)
{
for
(
const
auto
&
f_type
:
feature_types
)
{
std
::
shared_ptr
<
Feature
>
default_feature
;
std
::
shared_ptr
<
Feature
>
default_feature
;
// If no feature can be obtained, fill in the default value
// If no feature can be obtained, fill in the default value
RETURN_IF_NOT_OK
(
GetNodeDefaultFeature
(
f_type
,
&
default_feature
));
RETURN_IF_NOT_OK
(
GetNodeDefaultFeature
(
f_type
,
&
default_feature
));
...
@@ -169,18 +278,14 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
...
@@ -169,18 +278,14 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
dsize_t
index
=
0
;
dsize_t
index
=
0
;
for
(
auto
node_itr
=
nodes
->
begin
<
NodeIdType
>
();
node_itr
!=
nodes
->
end
<
NodeIdType
>
();
++
node_itr
)
{
for
(
auto
node_itr
=
nodes
->
begin
<
NodeIdType
>
();
node_itr
!=
nodes
->
end
<
NodeIdType
>
();
++
node_itr
)
{
auto
itr
=
node_id_map_
.
find
(
*
node_itr
);
std
::
shared_ptr
<
Feature
>
feature
;
std
::
shared_ptr
<
Feature
>
feature
;
if
(
itr
!=
node_id_map_
.
end
())
{
if
(
*
node_itr
==
kDefaultNodeId
)
{
if
(
!
itr
->
second
->
GetFeatures
(
f_type
,
&
feature
).
IsOk
())
{
feature
=
default_feature
;
feature
=
default_feature
;
}
}
else
{
}
else
{
if
(
*
node_itr
==
kDefaultNodeId
)
{
std
::
shared_ptr
<
Node
>
node
;
RETURN_IF_NOT_OK
(
GetNodeByNodeId
(
*
node_itr
,
&
node
));
if
(
!
node
->
GetFeatures
(
f_type
,
&
feature
).
IsOk
())
{
feature
=
default_feature
;
feature
=
default_feature
;
}
else
{
std
::
string
err_msg
=
"Invalid node id:"
+
std
::
to_string
(
*
node_itr
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
}
}
RETURN_IF_NOT_OK
(
fea_tensor
->
InsertTensor
({
index
},
feature
->
Value
()));
RETURN_IF_NOT_OK
(
fea_tensor
->
InsertTensor
({
index
},
feature
->
Value
()));
...
@@ -209,35 +314,54 @@ Status Graph::Init() {
...
@@ -209,35 +314,54 @@ Status Graph::Init() {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Graph
::
GetMetaInfo
(
std
::
vector
<
NodeMetaInfo
>
*
node_info
,
std
::
vector
<
EdgeMetaInfo
>
*
edge_info
)
{
Status
Graph
::
GetMetaInfo
(
MetaInfo
*
meta_info
)
{
node_info
->
reserve
(
node_type_map_
.
size
());
meta_info
->
node_type
.
resize
(
node_type_map_
.
size
());
for
(
auto
node
:
node_type_map_
)
{
std
::
transform
(
node_type_map_
.
begin
(),
node_type_map_
.
end
(),
meta_info
->
node_type
.
begin
(),
NodeMetaInfo
n_info
;
[](
auto
itr
)
{
return
itr
.
first
;
});
n_info
.
type
=
node
.
first
;
std
::
sort
(
meta_info
->
node_type
.
begin
(),
meta_info
->
node_type
.
end
());
n_info
.
num
=
node
.
second
.
size
();
auto
itr
=
node_feature_map_
.
find
(
node
.
first
);
meta_info
->
edge_type
.
resize
(
edge_type_map_
.
size
());
if
(
itr
!=
node_feature_map_
.
end
())
{
std
::
transform
(
edge_type_map_
.
begin
(),
edge_type_map_
.
end
(),
meta_info
->
edge_type
.
begin
(),
for
(
auto
f_type
:
itr
->
second
)
{
[](
auto
itr
)
{
return
itr
.
first
;
});
n_info
.
feature_type
.
push_back
(
f_type
);
std
::
sort
(
meta_info
->
edge_type
.
begin
(),
meta_info
->
edge_type
.
end
());
}
std
::
sort
(
n_info
.
feature_type
.
begin
(),
n_info
.
feature_type
.
end
());
for
(
const
auto
&
node
:
node_type_map_
)
{
meta_info
->
node_num
[
node
.
first
]
=
node
.
second
.
size
();
}
for
(
const
auto
&
edge
:
edge_type_map_
)
{
meta_info
->
edge_num
[
edge
.
first
]
=
edge
.
second
.
size
();
}
for
(
const
auto
&
node_feature
:
node_feature_map_
)
{
for
(
auto
type
:
node_feature
.
second
)
{
meta_info
->
node_feature_type
.
emplace_back
(
type
);
}
}
node_info
->
push_back
(
n_info
);
}
}
std
::
sort
(
meta_info
->
node_feature_type
.
begin
(),
meta_info
->
node_feature_type
.
end
());
auto
unique_node
=
std
::
unique
(
meta_info
->
node_feature_type
.
begin
(),
meta_info
->
node_feature_type
.
end
());
edge_info
->
reserve
(
edge_type_map_
.
size
());
meta_info
->
node_feature_type
.
erase
(
unique_node
,
meta_info
->
node_feature_type
.
end
());
for
(
auto
edge
:
edge_type_map_
)
{
EdgeMetaInfo
e_info
;
for
(
const
auto
&
edge_feature
:
edge_feature_map_
)
{
e_info
.
type
=
edge
.
first
;
for
(
const
auto
&
type
:
edge_feature
.
second
)
{
e_info
.
num
=
edge
.
second
.
size
();
meta_info
->
edge_feature_type
.
emplace_back
(
type
);
auto
itr
=
edge_feature_map_
.
find
(
edge
.
first
);
if
(
itr
!=
edge_feature_map_
.
end
())
{
for
(
auto
f_type
:
itr
->
second
)
{
e_info
.
feature_type
.
push_back
(
f_type
);
}
}
}
edge_info
->
push_back
(
e_info
);
}
}
std
::
sort
(
meta_info
->
edge_feature_type
.
begin
(),
meta_info
->
edge_feature_type
.
end
());
auto
unique_edge
=
std
::
unique
(
meta_info
->
edge_feature_type
.
begin
(),
meta_info
->
edge_feature_type
.
end
());
meta_info
->
edge_feature_type
.
erase
(
unique_edge
,
meta_info
->
edge_feature_type
.
end
());
return
Status
::
OK
();
}
Status
Graph
::
GraphInfo
(
py
::
dict
*
out
)
{
MetaInfo
meta_info
;
RETURN_IF_NOT_OK
(
GetMetaInfo
(
&
meta_info
));
(
*
out
)[
"node_type"
]
=
py
::
cast
(
meta_info
.
node_type
);
(
*
out
)[
"edge_type"
]
=
py
::
cast
(
meta_info
.
edge_type
);
(
*
out
)[
"node_num"
]
=
py
::
cast
(
meta_info
.
node_num
);
(
*
out
)[
"edge_num"
]
=
py
::
cast
(
meta_info
.
edge_num
);
(
*
out
)[
"node_feature_type"
]
=
py
::
cast
(
meta_info
.
node_feature_type
);
(
*
out
)[
"edge_feature_type"
]
=
py
::
cast
(
meta_info
.
edge_feature_type
);
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -250,6 +374,18 @@ Status Graph::LoadNodeAndEdge() {
...
@@ -250,6 +374,18 @@ Status Graph::LoadNodeAndEdge() {
&
node_feature_map_
,
&
edge_feature_map_
,
&
default_feature_map_
));
&
node_feature_map_
,
&
edge_feature_map_
,
&
default_feature_map_
));
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
Graph
::
GetNodeByNodeId
(
NodeIdType
id
,
std
::
shared_ptr
<
Node
>
*
node
)
{
auto
itr
=
node_id_map_
.
find
(
id
);
if
(
itr
==
node_id_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid node id:"
+
std
::
to_string
(
id
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
*
node
=
itr
->
second
;
}
return
Status
::
OK
();
}
}
// namespace gnn
}
// namespace gnn
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/gnn/graph.h
浏览文件 @
eff90369
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <map>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
...
@@ -33,24 +34,13 @@ namespace mindspore {
...
@@ -33,24 +34,13 @@ namespace mindspore {
namespace
dataset
{
namespace
dataset
{
namespace
gnn
{
namespace
gnn
{
struct
NodeMetaInfo
{
struct
MetaInfo
{
NodeType
type
;
std
::
vector
<
NodeType
>
node_type
;
NodeIdType
num
;
std
::
vector
<
EdgeType
>
edge_type
;
std
::
vector
<
FeatureType
>
feature_type
;
std
::
map
<
NodeType
,
NodeIdType
>
node_num
;
NodeMetaInfo
()
{
std
::
map
<
EdgeType
,
EdgeIdType
>
edge_num
;
type
=
0
;
std
::
vector
<
FeatureType
>
node_feature_type
;
num
=
0
;
std
::
vector
<
FeatureType
>
edge_feature_type
;
}
};
struct
EdgeMetaInfo
{
EdgeType
type
;
EdgeIdType
num
;
std
::
vector
<
FeatureType
>
feature_type
;
EdgeMetaInfo
()
{
type
=
0
;
num
=
0
;
}
};
};
class
Graph
{
class
Graph
{
...
@@ -62,19 +52,23 @@ class Graph {
...
@@ -62,19 +52,23 @@ class Graph {
~
Graph
()
=
default
;
~
Graph
()
=
default
;
// Get
the
nodes from the graph.
// Get
all
nodes from the graph.
// @param NodeType node_type - type of node
// @param NodeType node_type - type of node
// @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status - The error code return
Status
Get
Nodes
(
NodeType
node_type
,
NodeIdType
node_num
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
Get
AllNodes
(
NodeType
node_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// Get
the
edges from the graph.
// Get
all
edges from the graph.
// @param NodeType edge_type - type of edge
// @param NodeType edge_type - type of edge
// @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status - The error code return
Status
GetEdges
(
EdgeType
edge_type
,
EdgeIdType
edge_num
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
GetAllEdges
(
EdgeType
edge_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status
GetNodesFromEdges
(
const
std
::
vector
<
EdgeIdType
>
&
edge_list
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// All neighbors of the acquisition node.
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeType> node_list - List of nodes
...
@@ -86,10 +80,24 @@ class Graph {
...
@@ -86,10 +80,24 @@ class Graph {
Status
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
Status
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
GetSampledNeighbor
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
// Get sampled neighbors.
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// @param std::vector<NodeType> node_list - List of nodes
Status
GetNegSampledNeighbor
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status
GetSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status
GetNegSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
RandomWalk
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
meta_path
,
float
p
,
float
q
,
Status
RandomWalk
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
meta_path
,
float
p
,
float
q
,
NodeIdType
default_node
,
std
::
shared_ptr
<
Tensor
>
*
out
);
NodeIdType
default_node
,
std
::
shared_ptr
<
Tensor
>
*
out
);
...
@@ -112,10 +120,12 @@ class Graph {
...
@@ -112,10 +120,12 @@ class Graph {
TensorRow
*
out
);
TensorRow
*
out
);
// Get meta information of graph
// Get meta information of graph
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node
// @param MetaInfo *meta_info - Returned meta information
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge
// @return Status - The error code return
// @return Status - The error code return
Status
GetMetaInfo
(
std
::
vector
<
NodeMetaInfo
>
*
node_info
,
std
::
vector
<
EdgeMetaInfo
>
*
edge_info
);
Status
GetMetaInfo
(
MetaInfo
*
meta_info
);
// Return meta information to python layer
Status
GraphInfo
(
py
::
dict
*
out
);
Status
Init
();
Status
Init
();
...
@@ -146,8 +156,24 @@ class Graph {
...
@@ -146,8 +156,24 @@ class Graph {
// @return Status - The error code return
// @return Status - The error code return
Status
GetNodeDefaultFeature
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
);
Status
GetNodeDefaultFeature
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
);
// Find node object using node id
// @param NodeIdType id -
// @param std::shared_ptr<Node> *node - Returned node object
// @return Status - The error code return
Status
GetNodeByNodeId
(
NodeIdType
id
,
std
::
shared_ptr
<
Node
>
*
node
);
// Negative sampling
// @param std::vector<NodeIdType> &input_data - The data set to be sampled
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
// @param int32_t samples_num -
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
// @return Status - The error code return
Status
NegativeSample
(
const
std
::
vector
<
NodeIdType
>
&
input_data
,
const
std
::
unordered_set
<
NodeIdType
>
&
exclude_data
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_samples
);
std
::
string
dataset_file_
;
std
::
string
dataset_file_
;
int32_t
num_workers_
;
// The number of worker threads
int32_t
num_workers_
;
// The number of worker threads
std
::
mt19937
rnd_
;
std
::
unordered_map
<
NodeType
,
std
::
vector
<
NodeIdType
>>
node_type_map_
;
std
::
unordered_map
<
NodeType
,
std
::
vector
<
NodeIdType
>>
node_type_map_
;
std
::
unordered_map
<
NodeIdType
,
std
::
shared_ptr
<
Node
>>
node_id_map_
;
std
::
unordered_map
<
NodeIdType
,
std
::
shared_ptr
<
Node
>>
node_id_map_
;
...
...
mindspore/ccsrc/dataset/engine/gnn/local_node.cc
浏览文件 @
eff90369
...
@@ -20,12 +20,13 @@
...
@@ -20,12 +20,13 @@
#include <utility>
#include <utility>
#include "dataset/engine/gnn/edge.h"
#include "dataset/engine/gnn/edge.h"
#include "dataset/util/random.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
dataset
{
namespace
dataset
{
namespace
gnn
{
namespace
gnn
{
LocalNode
::
LocalNode
(
NodeIdType
id
,
NodeType
type
)
:
Node
(
id
,
type
)
{
}
LocalNode
::
LocalNode
(
NodeIdType
id
,
NodeType
type
)
:
Node
(
id
,
type
)
,
rnd_
(
GetRandomDevice
())
{
rnd_
.
seed
(
GetSeed
());
}
Status
LocalNode
::
GetFeatures
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
)
{
Status
LocalNode
::
GetFeatures
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
)
{
auto
itr
=
features_
.
find
(
feature_type
);
auto
itr
=
features_
.
find
(
feature_type
);
...
@@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
...
@@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
}
}
}
}
Status
LocalNode
::
Get
Neighbors
(
NodeType
neighbor_type
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
{
Status
LocalNode
::
Get
AllNeighbors
(
NodeType
neighbor_type
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
{
std
::
vector
<
NodeIdType
>
neighbors
;
std
::
vector
<
NodeIdType
>
neighbors
;
auto
itr
=
neighbor_nodes_
.
find
(
neighbor_type
);
auto
itr
=
neighbor_nodes_
.
find
(
neighbor_type
);
if
(
itr
!=
neighbor_nodes_
.
end
())
{
if
(
itr
!=
neighbor_nodes_
.
end
())
{
if
(
samples_num
==
-
1
)
{
neighbors
.
resize
(
itr
->
second
.
size
()
+
1
);
// Return all neighbors
neighbors
[
0
]
=
id_
;
neighbors
.
resize
(
itr
->
second
.
size
()
+
1
);
std
::
transform
(
itr
->
second
.
begin
(),
itr
->
second
.
end
(),
neighbors
.
begin
()
+
1
,
neighbors
[
0
]
=
id_
;
[](
const
std
::
shared_ptr
<
Node
>
node
)
{
return
node
->
id
();
});
std
::
transform
(
itr
->
second
.
begin
(),
itr
->
second
.
end
(),
neighbors
.
begin
()
+
1
,
[](
const
std
::
shared_ptr
<
Node
>
node
)
{
return
node
->
id
();
});
}
else
{
}
}
else
{
}
else
{
neighbors
.
push_back
(
id_
);
MS_LOG
(
DEBUG
)
<<
"No neighbors. node_id:"
<<
id_
<<
" neighbor_type:"
<<
neighbor_type
;
MS_LOG
(
DEBUG
)
<<
"No neighbors. node_id:"
<<
id_
<<
" neighbor_type:"
<<
neighbor_type
;
neighbors
.
emplace_back
(
id_
);
}
*
out_neighbors
=
std
::
move
(
neighbors
);
return
Status
::
OK
();
}
Status
LocalNode
::
GetSampledNeighbors
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>
&
neighbors
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out
)
{
std
::
vector
<
NodeIdType
>
shuffled_id
(
neighbors
.
size
());
std
::
iota
(
shuffled_id
.
begin
(),
shuffled_id
.
end
(),
0
);
std
::
shuffle
(
shuffled_id
.
begin
(),
shuffled_id
.
end
(),
rnd_
);
int32_t
num
=
std
::
min
(
samples_num
,
static_cast
<
int32_t
>
(
neighbors
.
size
()));
for
(
int32_t
i
=
0
;
i
<
num
;
++
i
)
{
out
->
emplace_back
(
neighbors
[
shuffled_id
[
i
]]
->
id
());
}
return
Status
::
OK
();
}
Status
LocalNode
::
GetSampledNeighbors
(
NodeType
neighbor_type
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
{
std
::
vector
<
NodeIdType
>
neighbors
;
neighbors
.
reserve
(
samples_num
);
auto
itr
=
neighbor_nodes_
.
find
(
neighbor_type
);
if
(
itr
!=
neighbor_nodes_
.
end
())
{
while
(
neighbors
.
size
()
<
samples_num
)
{
RETURN_IF_NOT_OK
(
GetSampledNeighbors
(
itr
->
second
,
samples_num
-
neighbors
.
size
(),
&
neighbors
));
}
}
else
{
MS_LOG
(
DEBUG
)
<<
"There are no neighbors. node_id:"
<<
id_
<<
" neighbor_type:"
<<
neighbor_type
;
// If there are no neighbors, they are filled with kDefaultNodeId
for
(
int32_t
i
=
0
;
i
<
samples_num
;
++
i
)
{
neighbors
.
emplace_back
(
kDefaultNodeId
);
}
}
}
*
out_neighbors
=
std
::
move
(
neighbors
);
*
out_neighbors
=
std
::
move
(
neighbors
);
return
Status
::
OK
();
return
Status
::
OK
();
...
...
mindspore/ccsrc/dataset/engine/gnn/local_node.h
浏览文件 @
eff90369
...
@@ -43,12 +43,19 @@ class LocalNode : public Node {
...
@@ -43,12 +43,19 @@ class LocalNode : public Node {
// @return Status - The error code return
// @return Status - The error code return
Status
GetFeatures
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
)
override
;
Status
GetFeatures
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
)
override
;
// Get the neighbors of a node
// Get the
all
neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status - The error code return
Status
GetNeighbors
(
NodeType
neighbor_type
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
override
;
Status
GetAllNeighbors
(
NodeType
neighbor_type
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
override
;
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
Status
GetSampledNeighbors
(
NodeType
neighbor_type
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
override
;
// Add neighbor of node
// Add neighbor of node
// @param std::shared_ptr<Node> node -
// @param std::shared_ptr<Node> node -
...
@@ -61,6 +68,10 @@ class LocalNode : public Node {
...
@@ -61,6 +68,10 @@ class LocalNode : public Node {
Status
UpdateFeature
(
const
std
::
shared_ptr
<
Feature
>
&
feature
)
override
;
Status
UpdateFeature
(
const
std
::
shared_ptr
<
Feature
>
&
feature
)
override
;
private:
private:
Status
GetSampledNeighbors
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>
&
neighbors
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out
);
std
::
mt19937
rnd_
;
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Feature
>>
features_
;
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Feature
>>
features_
;
std
::
unordered_map
<
NodeType
,
std
::
vector
<
std
::
shared_ptr
<
Node
>>>
neighbor_nodes_
;
std
::
unordered_map
<
NodeType
,
std
::
vector
<
std
::
shared_ptr
<
Node
>>>
neighbor_nodes_
;
};
};
...
...
mindspore/ccsrc/dataset/engine/gnn/node.h
浏览文件 @
eff90369
...
@@ -52,12 +52,19 @@ class Node {
...
@@ -52,12 +52,19 @@ class Node {
// @return Status - The error code return
// @return Status - The error code return
virtual
Status
GetFeatures
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
)
=
0
;
virtual
Status
GetFeatures
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Feature
>
*
out_feature
)
=
0
;
// Get the neighbors of a node
// Get the
all
neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status - The error code return
virtual
Status
GetNeighbors
(
NodeType
neighbor_type
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
=
0
;
virtual
Status
GetAllNeighbors
(
NodeType
neighbor_type
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
=
0
;
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
virtual
Status
GetSampledNeighbors
(
NodeType
neighbor_type
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_neighbors
)
=
0
;
// Add neighbor of node
// Add neighbor of node
// @param std::shared_ptr<Node> node -
// @param std::shared_ptr<Node> node -
...
...
mindspore/dataset/engine/graphdata.py
浏览文件 @
eff90369
...
@@ -20,8 +20,9 @@ import numpy as np
...
@@ -20,8 +20,9 @@ import numpy as np
from
mindspore._c_dataengine
import
Graph
from
mindspore._c_dataengine
import
Graph
from
mindspore._c_dataengine
import
Tensor
from
mindspore._c_dataengine
import
Tensor
from
.validators
import
check_gnn_graphdata
,
check_gnn_get_all_nodes
,
check_gnn_get_all_neighbors
,
\
from
.validators
import
check_gnn_graphdata
,
check_gnn_get_all_nodes
,
check_gnn_get_all_edges
,
\
check_gnn_get_node_feature
check_gnn_get_nodes_from_edges
,
check_gnn_get_all_neighbors
,
check_gnn_get_sampled_neighbors
,
\
check_gnn_get_neg_sampled_neighbors
,
check_gnn_get_node_feature
class
GraphData
:
class
GraphData
:
...
@@ -60,7 +61,44 @@ class GraphData:
...
@@ -60,7 +61,44 @@ class GraphData:
Raises:
Raises:
TypeError: If `node_type` is not integer.
TypeError: If `node_type` is not integer.
"""
"""
return
self
.
_graph
.
get_nodes
(
node_type
,
-
1
).
as_array
()
return
self
.
_graph
.
get_all_nodes
(
node_type
).
as_array
()
@
check_gnn_get_all_edges
def
get_all_edges
(
self
,
edge_type
):
"""
Get all edges in the graph.
Args:
edge_type (int): Specify the type of edge.
Returns:
numpy.ndarray: array of edges.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.get_all_edges(0)
Raises:
TypeError: If `edge_type` is not integer.
"""
return
self
.
_graph
.
get_all_edges
(
edge_type
).
as_array
()
@
check_gnn_get_nodes_from_edges
def
get_nodes_from_edges
(
self
,
edge_list
):
"""
Get nodes from the edges.
Args:
edge_list (list or numpy.ndarray): The given list of edges.
Returns:
numpy.ndarray: array of nodes.
Raises:
TypeError: If `edge_list` is not list or ndarray.
"""
return
self
.
_graph
.
get_nodes_from_edges
(
edge_list
).
as_array
()
@
check_gnn_get_all_neighbors
@
check_gnn_get_all_neighbors
def
get_all_neighbors
(
self
,
node_list
,
neighbor_type
):
def
get_all_neighbors
(
self
,
node_list
,
neighbor_type
):
...
@@ -86,6 +124,58 @@ class GraphData:
...
@@ -86,6 +124,58 @@ class GraphData:
"""
"""
return
self
.
_graph
.
get_all_neighbors
(
node_list
,
neighbor_type
).
as_array
()
return
self
.
_graph
.
get_all_neighbors
(
node_list
,
neighbor_type
).
as_array
()
@
check_gnn_get_sampled_neighbors
def
get_sampled_neighbors
(
self
,
node_list
,
neighbor_nums
,
neighbor_types
):
"""
Get sampled neighbor information, maximum support 6-hop sampling.
Args:
node_list (list or numpy.ndarray): The given list of nodes.
neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop.
neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop.
Returns:
numpy.ndarray: array of nodes.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.get_all_nodes(0)
>>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0])
Raises:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray.
"""
return
self
.
_graph
.
get_sampled_neighbors
(
node_list
,
neighbor_nums
,
neighbor_types
).
as_array
()
@
check_gnn_get_neg_sampled_neighbors
def
get_neg_sampled_neighbors
(
self
,
node_list
,
neg_neighbor_num
,
neg_neighbor_type
):
"""
Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`.
Args:
node_list (list or numpy.ndarray): The given list of nodes.
neg_neighbor_num (int): Number of neighbors sampled.
neg_neighbor_type (int): Specify the type of negative neighbor.
Returns:
numpy.ndarray: array of nodes.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.get_all_nodes(0)
>>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0)
Raises:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer.
"""
return
self
.
_graph
.
get_neg_sampled_neighbors
(
node_list
,
neg_neighbor_num
,
neg_neighbor_type
).
as_array
()
@
check_gnn_get_node_feature
@
check_gnn_get_node_feature
def
get_node_feature
(
self
,
node_list
,
feature_types
):
def
get_node_feature
(
self
,
node_list
,
feature_types
):
"""
"""
...
@@ -111,3 +201,13 @@ class GraphData:
...
@@ -111,3 +201,13 @@ class GraphData:
if
isinstance
(
node_list
,
list
):
if
isinstance
(
node_list
,
list
):
node_list
=
np
.
array
(
node_list
,
dtype
=
np
.
int32
)
node_list
=
np
.
array
(
node_list
,
dtype
=
np
.
int32
)
return
[
t
.
as_array
()
for
t
in
self
.
_graph
.
get_node_feature
(
Tensor
(
node_list
),
feature_types
)]
return
[
t
.
as_array
()
for
t
in
self
.
_graph
.
get_node_feature
(
Tensor
(
node_list
),
feature_types
)]
def
graph_info
(
self
):
"""
Get the meta information of the graph, including the number of nodes, the type of nodes,
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.
Returns:
Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type.
"""
return
self
.
_graph
.
graph_info
()
mindspore/dataset/engine/validators.py
浏览文件 @
eff90369
...
@@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method):
...
@@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method):
return
new_method
return
new_method
def
check_gnn_get_all_edges
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check node_type; required argument
check_type
(
param_dict
.
get
(
"edge_type"
),
'edge_type'
,
int
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_gnn_get_nodes_from_edges
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check edge_list; required argument
check_gnn_list_or_ndarray
(
param_dict
.
get
(
"edge_list"
),
'edge_list'
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_gnn_get_all_neighbors
(
method
):
def
check_gnn_get_all_neighbors
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
...
@@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method):
...
@@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method):
return
new_method
return
new_method
def
check_gnn_get_sampled_neighbors
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check node_list; required argument
check_gnn_list_or_ndarray
(
param_dict
.
get
(
"node_list"
),
'node_list'
)
# check neighbor_nums; required argument
neighbor_nums
=
param_dict
.
get
(
"neighbor_nums"
)
check_gnn_list_or_ndarray
(
neighbor_nums
,
'neighbor_nums'
)
if
len
(
neighbor_nums
)
>
6
:
raise
ValueError
(
"Wrong number of input members for {0}, should be less than or equal to 6, got {1}"
.
format
(
'neighbor_nums'
,
len
(
neighbor_nums
)))
# check neighbor_types; required argument
neighbor_types
=
param_dict
.
get
(
"neighbor_types"
)
check_gnn_list_or_ndarray
(
neighbor_types
,
'neighbor_types'
)
if
len
(
neighbor_nums
)
>
6
:
raise
ValueError
(
"Wrong number of input members for {0}, should be less than or equal to 6, got {1}"
.
format
(
'neighbor_types'
,
len
(
neighbor_types
)))
if
len
(
neighbor_nums
)
!=
len
(
neighbor_types
):
raise
ValueError
(
"The number of members of neighbor_nums and neighbor_types is inconsistent"
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_gnn_get_neg_sampled_neighbors
(
method
):
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check node_list; required argument
check_gnn_list_or_ndarray
(
param_dict
.
get
(
"node_list"
),
'node_list'
)
# check neg_neighbor_num; required argument
check_type
(
param_dict
.
get
(
"neg_neighbor_num"
),
'neg_neighbor_num'
,
int
)
# check neg_neighbor_type; required argument
check_type
(
param_dict
.
get
(
"neg_neighbor_type"
),
'neg_neighbor_type'
,
int
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_aligned_list
(
param
,
param_name
,
membor_type
):
def
check_aligned_list
(
param
,
param_name
,
membor_type
):
"""Check whether the structure of each member of the list is the same."""
"""Check whether the structure of each member of the list is the same."""
...
...
tests/ut/cpp/dataset/gnn_graph_test.cc
浏览文件 @
eff90369
...
@@ -13,8 +13,10 @@
...
@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include <algorithm>
#include <string>
#include <string>
#include <memory>
#include <memory>
#include <unordered_set>
#include "common/common.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
@@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
...
@@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
&
default_feature_map
)
&
default_feature_map
)
.
IsOk
());
.
IsOk
());
EXPECT_EQ
(
n_id_map
.
size
(),
20
);
EXPECT_EQ
(
n_id_map
.
size
(),
20
);
EXPECT_EQ
(
e_id_map
.
size
(),
2
0
);
EXPECT_EQ
(
e_id_map
.
size
(),
4
0
);
EXPECT_EQ
(
n_type_map
[
2
].
size
(),
10
);
EXPECT_EQ
(
n_type_map
[
2
].
size
(),
10
);
EXPECT_EQ
(
n_type_map
[
1
].
size
(),
10
);
EXPECT_EQ
(
n_type_map
[
1
].
size
(),
10
);
}
}
...
@@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
...
@@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
Status
s
=
graph
.
Init
();
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
s
.
IsOk
());
std
::
vector
<
NodeMetaInfo
>
node_info
;
MetaInfo
meta_info
;
std
::
vector
<
EdgeMetaInfo
>
edge_info
;
s
=
graph
.
GetMetaInfo
(
&
meta_info
);
s
=
graph
.
GetMetaInfo
(
&
node_info
,
&
edge_info
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
node_info
.
size
()
==
2
);
EXPECT_TRUE
(
meta_info
.
node_type
.
size
()
==
2
);
std
::
shared_ptr
<
Tensor
>
nodes
;
std
::
shared_ptr
<
Tensor
>
nodes
;
s
=
graph
.
Get
Nodes
(
node_info
[
1
].
type
,
-
1
,
&
nodes
);
s
=
graph
.
Get
AllNodes
(
meta_info
.
node_type
[
0
]
,
&
nodes
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
s
.
IsOk
());
std
::
vector
<
NodeIdType
>
node_list
;
std
::
vector
<
NodeIdType
>
node_list
;
for
(
auto
itr
=
nodes
->
begin
<
NodeIdType
>
();
itr
!=
nodes
->
end
<
NodeIdType
>
();
++
itr
)
{
for
(
auto
itr
=
nodes
->
begin
<
NodeIdType
>
();
itr
!=
nodes
->
end
<
NodeIdType
>
();
++
itr
)
{
...
@@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
...
@@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
}
}
}
}
std
::
shared_ptr
<
Tensor
>
neighbors
;
std
::
shared_ptr
<
Tensor
>
neighbors
;
s
=
graph
.
GetAllNeighbors
(
node_list
,
node_info
[
0
].
type
,
&
neighbors
);
s
=
graph
.
GetAllNeighbors
(
node_list
,
meta_info
.
node_type
[
1
]
,
&
neighbors
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
neighbors
->
shape
().
ToString
()
==
"<10,6>"
);
EXPECT_TRUE
(
neighbors
->
shape
().
ToString
()
==
"<10,6>"
);
TensorRow
features
;
TensorRow
features
;
s
=
graph
.
GetNodeFeature
(
nodes
,
node_info
[
1
].
feature_type
,
&
features
);
s
=
graph
.
GetNodeFeature
(
nodes
,
meta_info
.
node_
feature_type
,
&
features
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
features
.
size
()
==
3
);
EXPECT_TRUE
(
features
.
size
()
==
4
);
EXPECT_TRUE
(
features
[
0
]
->
shape
().
ToString
()
==
"<10,5>"
);
EXPECT_TRUE
(
features
[
0
]
->
shape
().
ToString
()
==
"<10,5>"
);
EXPECT_TRUE
(
features
[
0
]
->
ToString
()
==
EXPECT_TRUE
(
features
[
0
]
->
ToString
()
==
"Tensor (shape: <10,5>, Type: int32)
\n
"
"Tensor (shape: <10,5>, Type: int32)
\n
"
...
@@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
...
@@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
EXPECT_TRUE
(
features
[
2
]
->
shape
().
ToString
()
==
"<10>"
);
EXPECT_TRUE
(
features
[
2
]
->
shape
().
ToString
()
==
"<10>"
);
EXPECT_TRUE
(
features
[
2
]
->
ToString
()
==
"Tensor (shape: <10>, Type: int32)
\n
[1,2,3,1,4,3,5,3,5,4]"
);
EXPECT_TRUE
(
features
[
2
]
->
ToString
()
==
"Tensor (shape: <10>, Type: int32)
\n
[1,2,3,1,4,3,5,3,5,4]"
);
}
}
TEST_F
(
MindDataTestGNNGraph
,
TestGetSampledNeighbors
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/testdata"
;
Graph
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
MetaInfo
meta_info
;
s
=
graph
.
GetMetaInfo
(
&
meta_info
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
meta_info
.
node_type
.
size
()
==
2
);
std
::
shared_ptr
<
Tensor
>
edges
;
s
=
graph
.
GetAllEdges
(
meta_info
.
edge_type
[
0
],
&
edges
);
EXPECT_TRUE
(
s
.
IsOk
());
std
::
vector
<
EdgeIdType
>
edge_list
;
edge_list
.
resize
(
edges
->
Size
());
std
::
transform
(
edges
->
begin
<
EdgeIdType
>
(),
edges
->
end
<
EdgeIdType
>
(),
edge_list
.
begin
(),
[](
const
EdgeIdType
edge
)
{
return
edge
;
});
std
::
shared_ptr
<
Tensor
>
nodes
;
s
=
graph
.
GetNodesFromEdges
(
edge_list
,
&
nodes
);
EXPECT_TRUE
(
s
.
IsOk
());
std
::
unordered_set
<
NodeIdType
>
node_set
;
std
::
vector
<
NodeIdType
>
node_list
;
int
index
=
0
;
for
(
auto
itr
=
nodes
->
begin
<
NodeIdType
>
();
itr
!=
nodes
->
end
<
NodeIdType
>
();
++
itr
)
{
index
++
;
if
(
index
%
2
==
0
)
{
continue
;
}
node_set
.
emplace
(
*
itr
);
if
(
node_set
.
size
()
>=
5
)
{
break
;
}
}
node_list
.
resize
(
node_set
.
size
());
std
::
transform
(
node_set
.
begin
(),
node_set
.
end
(),
node_list
.
begin
(),
[](
const
NodeIdType
node
)
{
return
node
;
});
std
::
shared_ptr
<
Tensor
>
neighbors
;
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
10
},
{
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
neighbors
->
shape
().
ToString
()
==
"<5,11>"
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
2
,
3
},
{
meta_info
.
node_type
[
1
],
meta_info
.
node_type
[
0
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
neighbors
->
shape
().
ToString
()
==
"<5,9>"
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
2
,
3
,
4
},
{
meta_info
.
node_type
[
1
],
meta_info
.
node_type
[
0
],
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
neighbors
->
shape
().
ToString
()
==
"<5,33>"
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
({},
{
10
},
{
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Input node_list is empty."
)
!=
std
::
string
::
npos
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
2
,
3
,
4
},
{
meta_info
.
node_type
[
1
],
meta_info
.
node_type
[
0
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"The sizes of neighbor_nums and neighbor_types are inconsistent."
)
!=
std
::
string
::
npos
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
({
301
},
{
10
},
{
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid node id:301"
)
!=
std
::
string
::
npos
);
}
TEST_F
(
MindDataTestGNNGraph
,
TestGetNegSampledNeighbors
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/testdata"
;
Graph
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
MetaInfo
meta_info
;
s
=
graph
.
GetMetaInfo
(
&
meta_info
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
meta_info
.
node_type
.
size
()
==
2
);
std
::
shared_ptr
<
Tensor
>
nodes
;
s
=
graph
.
GetAllNodes
(
meta_info
.
node_type
[
0
],
&
nodes
);
EXPECT_TRUE
(
s
.
IsOk
());
std
::
vector
<
NodeIdType
>
node_list
;
for
(
auto
itr
=
nodes
->
begin
<
NodeIdType
>
();
itr
!=
nodes
->
end
<
NodeIdType
>
();
++
itr
)
{
node_list
.
push_back
(
*
itr
);
if
(
node_list
.
size
()
>=
10
)
{
break
;
}
}
std
::
shared_ptr
<
Tensor
>
neg_neighbors
;
s
=
graph
.
GetNegSampledNeighbors
(
node_list
,
3
,
meta_info
.
node_type
[
1
],
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
neg_neighbors
->
shape
().
ToString
()
==
"<10,4>"
);
neg_neighbors
.
reset
();
s
=
graph
.
GetNegSampledNeighbors
({},
3
,
meta_info
.
node_type
[
1
],
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Input node_list is empty."
)
!=
std
::
string
::
npos
);
neg_neighbors
.
reset
();
s
=
graph
.
GetNegSampledNeighbors
(
node_list
,
3
,
3
,
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid node type:3"
)
!=
std
::
string
::
npos
);
}
tests/ut/data/mindrecord/testGraphData/testdata
浏览文件 @
eff90369
无法预览此类型文件
tests/ut/data/mindrecord/testGraphData/testdata.db
浏览文件 @
eff90369
无法预览此类型文件
tests/ut/python/dataset/test_graphdata.py
浏览文件 @
eff90369
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
random
import
pytest
import
pytest
import
numpy
as
np
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
...
@@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check():
...
@@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check():
g
.
get_node_feature
(
input_list
,
[
1
,
"a"
])
g
.
get_node_feature
(
input_list
,
[
1
,
"a"
])
def
test_graphdata_getsampledneighbors
():
g
=
ds
.
GraphData
(
DATASET_FILE
,
1
)
edges
=
g
.
get_all_edges
(
0
)
nodes
=
g
.
get_nodes_from_edges
(
edges
)
assert
len
(
nodes
)
==
40
neighbor
=
g
.
get_sampled_neighbors
(
np
.
unique
(
nodes
[
0
:
21
,
0
]),
[
2
,
3
],
[
2
,
1
])
assert
neighbor
.
shape
==
(
10
,
9
)
def
test_graphdata_getnegsampledneighbors
():
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
nodes
=
g
.
get_all_nodes
(
1
)
assert
len
(
nodes
)
==
10
neighbor
=
g
.
get_neg_sampled_neighbors
(
nodes
,
5
,
2
)
assert
neighbor
.
shape
==
(
10
,
6
)
def
test_graphdata_graphinfo
():
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
graph_info
=
g
.
graph_info
()
assert
graph_info
[
'node_type'
]
==
[
1
,
2
]
assert
graph_info
[
'edge_type'
]
==
[
0
]
assert
graph_info
[
'node_num'
]
==
{
1
:
10
,
2
:
10
}
assert
graph_info
[
'edge_num'
]
==
{
0
:
40
}
assert
graph_info
[
'node_feature_type'
]
==
[
1
,
2
,
3
,
4
]
assert
graph_info
[
'edge_feature_type'
]
==
[]
class
RandomBatchedSampler
(
ds
.
Sampler
):
# RandomBatchedSampler generate random sequence without replacement in a batched manner
def
__init__
(
self
,
index_range
,
num_edges_per_sample
):
super
().
__init__
()
self
.
index_range
=
index_range
self
.
num_edges_per_sample
=
num_edges_per_sample
def
__iter__
(
self
):
indices
=
[
i
+
1
for
i
in
range
(
self
.
index_range
)]
# Reset random seed here if necessary
# random.seed(0)
random
.
shuffle
(
indices
)
for
i
in
range
(
0
,
self
.
index_range
,
self
.
num_edges_per_sample
):
# Drop reminder
if
i
+
self
.
num_edges_per_sample
<=
self
.
index_range
:
yield
indices
[
i
:
i
+
self
.
num_edges_per_sample
]
class
GNNGraphDataset
():
def
__init__
(
self
,
g
,
batch_num
):
self
.
g
=
g
self
.
batch_num
=
batch_num
def
__len__
(
self
):
# Total sample size of GNN dataset
# In this case, the size should be total_num_edges/num_edges_per_sample
return
self
.
g
.
graph_info
()[
'edge_num'
][
0
]
//
self
.
batch_num
def
__getitem__
(
self
,
index
):
# index will be a list of indices yielded from RandomBatchedSampler
# Fetch edges/nodes/samples/features based on indices
nodes
=
self
.
g
.
get_nodes_from_edges
(
index
.
astype
(
np
.
int32
))
nodes
=
nodes
[:,
0
]
neg_nodes
=
self
.
g
.
get_neg_sampled_neighbors
(
node_list
=
nodes
,
neg_neighbor_num
=
3
,
neg_neighbor_type
=
1
)
nodes_neighbors
=
self
.
g
.
get_sampled_neighbors
(
node_list
=
nodes
,
neighbor_nums
=
[
2
,
2
],
neighbor_types
=
[
2
,
1
])
neg_nodes_neighbors
=
self
.
g
.
get_sampled_neighbors
(
node_list
=
neg_nodes
[:,
1
:].
reshape
(
-
1
),
neighbor_nums
=
[
2
,
2
],
neighbor_types
=
[
2
,
2
])
nodes_neighbors_features
=
self
.
g
.
get_node_feature
(
node_list
=
nodes_neighbors
,
feature_types
=
[
2
,
3
])
neg_neighbors_features
=
self
.
g
.
get_node_feature
(
node_list
=
neg_nodes_neighbors
,
feature_types
=
[
2
,
3
])
return
nodes_neighbors
,
neg_nodes_neighbors
,
nodes_neighbors_features
[
0
],
neg_neighbors_features
[
1
]
def
test_graphdata_generatordataset
():
g
=
ds
.
GraphData
(
DATASET_FILE
)
batch_num
=
2
edge_num
=
g
.
graph_info
()[
'edge_num'
][
0
]
out_column_names
=
[
"neighbors"
,
"neg_neighbors"
,
"neighbors_features"
,
"neg_neighbors_features"
]
dataset
=
ds
.
GeneratorDataset
(
source
=
GNNGraphDataset
(
g
,
batch_num
),
column_names
=
out_column_names
,
sampler
=
RandomBatchedSampler
(
edge_num
,
batch_num
),
num_parallel_workers
=
4
)
dataset
=
dataset
.
repeat
(
2
)
itr
=
dataset
.
create_dict_iterator
()
i
=
0
for
data
in
itr
:
assert
data
[
'neighbors'
].
shape
==
(
2
,
7
)
assert
data
[
'neg_neighbors'
].
shape
==
(
6
,
7
)
assert
data
[
'neighbors_features'
].
shape
==
(
2
,
7
)
assert
data
[
'neg_neighbors_features'
].
shape
==
(
6
,
7
)
i
+=
1
assert
i
==
40
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_graphdata_getfullneighbor
()
test_graphdata_getfullneighbor
()
logger
.
info
(
'test_graphdata_getfullneighbor Ended.
\n
'
)
logger
.
info
(
'test_graphdata_getfullneighbor Ended.
\n
'
)
test_graphdata_getnodefeature_input_check
()
test_graphdata_getnodefeature_input_check
()
logger
.
info
(
'test_graphdata_getnodefeature_input_check Ended.
\n
'
)
logger
.
info
(
'test_graphdata_getnodefeature_input_check Ended.
\n
'
)
test_graphdata_getsampledneighbors
()
logger
.
info
(
'test_graphdata_getsampledneighbors Ended.
\n
'
)
test_graphdata_getnegsampledneighbors
()
logger
.
info
(
'test_graphdata_getnegsampledneighbors Ended.
\n
'
)
test_graphdata_graphinfo
()
logger
.
info
(
'test_graphdata_graphinfo Ended.
\n
'
)
test_graphdata_generatordataset
()
logger
.
info
(
'test_graphdata_generatordataset Ended.
\n
'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录