Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0d52888f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0d52888f
编写于
6月 23, 2020
作者:
H
heleiwang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix misspell and check parameters
上级
5b14292f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
73 addition
and
22 deletion
+73
-22
mindspore/ccsrc/dataset/engine/gnn/graph.cc
mindspore/ccsrc/dataset/engine/gnn/graph.cc
+29
-0
mindspore/ccsrc/dataset/engine/gnn/graph.h
mindspore/ccsrc/dataset/engine/gnn/graph.h
+2
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+21
-21
tests/ut/cpp/dataset/gnn_graph_test.cc
tests/ut/cpp/dataset/gnn_graph_test.cc
+21
-1
未找到文件。
mindspore/ccsrc/dataset/engine/gnn/graph.cc
浏览文件 @
0d52888f
...
...
@@ -149,14 +149,37 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
return
Status
::
OK
();
}
Status
Graph
::
CheckSamplesNum
(
NodeIdType
samples_num
)
{
NodeIdType
all_nodes_number
=
std
::
accumulate
(
node_type_map_
.
begin
(),
node_type_map_
.
end
(),
0
,
[](
NodeIdType
t1
,
const
auto
&
t2
)
->
NodeIdType
{
return
t1
+
t2
.
second
.
size
();
});
if
((
samples_num
<
1
)
||
(
samples_num
>
all_nodes_number
))
{
std
::
string
err_msg
=
"Wrong samples number, should be between 1 and "
+
std
::
to_string
(
all_nodes_number
)
+
", got "
+
std
::
to_string
(
samples_num
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
return
Status
::
OK
();
}
Status
Graph
::
GetSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
!
node_list
.
empty
(),
"Input node_list is empty."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
neighbor_nums
.
size
()
==
neighbor_types
.
size
(),
"The sizes of neighbor_nums and neighbor_types are inconsistent."
);
for
(
const
auto
&
num
:
neighbor_nums
)
{
RETURN_IF_NOT_OK
(
CheckSamplesNum
(
num
));
}
for
(
const
auto
&
type
:
neighbor_types
)
{
if
(
node_type_map_
.
find
(
type
)
==
node_type_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid neighbor type:"
+
std
::
to_string
(
type
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
std
::
vector
<
std
::
vector
<
NodeIdType
>>
neighbors_vec
(
node_list
.
size
());
for
(
size_t
node_idx
=
0
;
node_idx
<
node_list
.
size
();
++
node_idx
)
{
std
::
shared_ptr
<
Node
>
input_node
;
RETURN_IF_NOT_OK
(
GetNodeByNodeId
(
node_list
[
node_idx
],
&
input_node
));
neighbors_vec
[
node_idx
].
emplace_back
(
node_list
[
node_idx
]);
std
::
vector
<
NodeIdType
>
input_list
=
{
node_list
[
node_idx
]};
for
(
size_t
i
=
0
;
i
<
neighbor_nums
.
size
();
++
i
)
{
...
...
@@ -204,6 +227,12 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno
Status
Graph
::
GetNegSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
!
node_list
.
empty
(),
"Input node_list is empty."
);
RETURN_IF_NOT_OK
(
CheckSamplesNum
(
samples_num
));
if
(
node_type_map_
.
find
(
neg_neighbor_type
)
==
node_type_map_
.
end
())
{
std
::
string
err_msg
=
"Invalid neighbor type:"
+
std
::
to_string
(
neg_neighbor_type
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
std
::
vector
<
std
::
vector
<
NodeIdType
>>
neighbors_vec
;
neighbors_vec
.
resize
(
node_list
.
size
());
for
(
size_t
node_idx
=
0
;
node_idx
<
node_list
.
size
();
++
node_idx
)
{
...
...
mindspore/ccsrc/dataset/engine/gnn/graph.h
浏览文件 @
0d52888f
...
...
@@ -226,6 +226,8 @@ class Graph {
Status
NegativeSample
(
const
std
::
vector
<
NodeIdType
>
&
input_data
,
const
std
::
unordered_set
<
NodeIdType
>
&
exclude_data
,
int32_t
samples_num
,
std
::
vector
<
NodeIdType
>
*
out_samples
);
Status
CheckSamplesNum
(
NodeIdType
samples_num
);
std
::
string
dataset_file_
;
int32_t
num_workers_
;
// The number of worker threads
std
::
mt19937
rnd_
;
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
0d52888f
...
...
@@ -1110,10 +1110,10 @@ def check_gnn_list_or_ndarray(param, param_name):
for
m
in
param
:
if
not
isinstance
(
m
,
int
):
raise
TypeError
(
"Each memb
o
r in {0} should be of type int. Got {1}."
.
format
(
param_name
,
type
(
m
)))
"Each memb
e
r in {0} should be of type int. Got {1}."
.
format
(
param_name
,
type
(
m
)))
elif
isinstance
(
param
,
np
.
ndarray
):
if
not
param
.
dtype
==
np
.
int32
:
raise
TypeError
(
"Each memb
o
r in {0} should be of type int32. Got {1}."
.
format
(
raise
TypeError
(
"Each memb
e
r in {0} should be of type int32. Got {1}."
.
format
(
param_name
,
param
.
dtype
))
else
:
raise
TypeError
(
"Wrong input type for {0}, should be list or numpy.ndarray, got {1}"
.
format
(
...
...
@@ -1196,15 +1196,15 @@ def check_gnn_get_sampled_neighbors(method):
# check neighbor_nums; required argument
neighbor_nums
=
param_dict
.
get
(
"neighbor_nums"
)
check_gnn_list_or_ndarray
(
neighbor_nums
,
'neighbor_nums'
)
if
len
(
neighbor_nums
)
>
6
:
raise
ValueError
(
"Wrong number of input members for {0}, should be
less than or equal to
6, got {1}"
.
format
(
if
not
neighbor_nums
or
len
(
neighbor_nums
)
>
6
:
raise
ValueError
(
"Wrong number of input members for {0}, should be
between 1 and
6, got {1}"
.
format
(
'neighbor_nums'
,
len
(
neighbor_nums
)))
# check neighbor_types; required argument
neighbor_types
=
param_dict
.
get
(
"neighbor_types"
)
check_gnn_list_or_ndarray
(
neighbor_types
,
'neighbor_types'
)
if
len
(
neighbor_num
s
)
>
6
:
raise
ValueError
(
"Wrong number of input members for {0}, should be
less than or equal to
6, got {1}"
.
format
(
if
not
neighbor_types
or
len
(
neighbor_type
s
)
>
6
:
raise
ValueError
(
"Wrong number of input members for {0}, should be
between 1 and
6, got {1}"
.
format
(
'neighbor_types'
,
len
(
neighbor_types
)))
if
len
(
neighbor_nums
)
!=
len
(
neighbor_types
):
...
...
@@ -1256,7 +1256,7 @@ def check_gnn_random_walk(method):
return
new_method
def
check_aligned_list
(
param
,
param_name
,
memb
o
r_type
):
def
check_aligned_list
(
param
,
param_name
,
memb
e
r_type
):
"""Check whether the structure of each member of the list is the same."""
if
not
isinstance
(
param
,
list
):
...
...
@@ -1264,27 +1264,27 @@ def check_aligned_list(param, param_name, membor_type):
if
not
param
:
raise
TypeError
(
"Parameter {0} or its members are empty"
.
format
(
param_name
))
memb
o
r_have_list
=
None
memb
e
r_have_list
=
None
list_len
=
None
for
memb
o
r
in
param
:
if
isinstance
(
memb
o
r
,
list
):
check_aligned_list
(
memb
or
,
param_name
,
membo
r_type
)
if
memb
o
r_have_list
not
in
(
None
,
True
):
for
memb
e
r
in
param
:
if
isinstance
(
memb
e
r
,
list
):
check_aligned_list
(
memb
er
,
param_name
,
membe
r_type
)
if
memb
e
r_have_list
not
in
(
None
,
True
):
raise
TypeError
(
"The type of each member of the parameter {0} is inconsistent"
.
format
(
param_name
))
if
list_len
is
not
None
and
len
(
memb
o
r
)
!=
list_len
:
if
list_len
is
not
None
and
len
(
memb
e
r
)
!=
list_len
:
raise
TypeError
(
"The size of each member of parameter {0} is inconsistent"
.
format
(
param_name
))
memb
o
r_have_list
=
True
list_len
=
len
(
memb
o
r
)
memb
e
r_have_list
=
True
list_len
=
len
(
memb
e
r
)
else
:
if
not
isinstance
(
memb
or
,
membo
r_type
):
raise
TypeError
(
"Each memb
o
r in {0} should be of type int. Got {1}."
.
format
(
param_name
,
type
(
memb
o
r
)))
if
memb
o
r_have_list
not
in
(
None
,
False
):
if
not
isinstance
(
memb
er
,
membe
r_type
):
raise
TypeError
(
"Each memb
e
r in {0} should be of type int. Got {1}."
.
format
(
param_name
,
type
(
memb
e
r
)))
if
memb
e
r_have_list
not
in
(
None
,
False
):
raise
TypeError
(
"The type of each member of the parameter {0} is inconsistent"
.
format
(
param_name
))
memb
o
r_have_list
=
False
memb
e
r_have_list
=
False
def
check_gnn_get_node_feature
(
method
):
...
...
@@ -1300,7 +1300,7 @@ def check_gnn_get_node_feature(method):
check_aligned_list
(
node_list
,
'node_list'
,
int
)
elif
isinstance
(
node_list
,
np
.
ndarray
):
if
not
node_list
.
dtype
==
np
.
int32
:
raise
TypeError
(
"Each memb
o
r in {0} should be of type int32. Got {1}."
.
format
(
raise
TypeError
(
"Each memb
e
r in {0} should be of type int32. Got {1}."
.
format
(
node_list
,
node_list
.
dtype
))
else
:
raise
TypeError
(
"Wrong input type for {0}, should be list or numpy.ndarray, got {1}"
.
format
(
...
...
tests/ut/cpp/dataset/gnn_graph_test.cc
浏览文件 @
0d52888f
...
...
@@ -158,6 +158,18 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
s
=
graph
.
GetSampledNeighbors
({},
{
10
},
{
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Input node_list is empty."
)
!=
std
::
string
::
npos
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
({
-
1
,
1
},
{
10
},
{
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid node id"
)
!=
std
::
string
::
npos
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
2
,
50
},
{
meta_info
.
node_type
[
0
],
meta_info
.
node_type
[
1
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Wrong samples number"
)
!=
std
::
string
::
npos
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
2
},
{
5
},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid neighbor type"
)
!=
std
::
string
::
npos
);
neighbors
.
reset
();
s
=
graph
.
GetSampledNeighbors
(
node_list
,
{
2
,
3
,
4
},
{
meta_info
.
node_type
[
1
],
meta_info
.
node_type
[
0
]},
&
neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"The sizes of neighbor_nums and neighbor_types are inconsistent."
)
!=
...
...
@@ -198,9 +210,17 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
s
=
graph
.
GetNegSampledNeighbors
({},
3
,
meta_info
.
node_type
[
1
],
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Input node_list is empty."
)
!=
std
::
string
::
npos
);
neg_neighbors
.
reset
();
s
=
graph
.
GetNegSampledNeighbors
({
-
1
,
1
},
3
,
meta_info
.
node_type
[
1
],
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid node id"
)
!=
std
::
string
::
npos
);
neg_neighbors
.
reset
();
s
=
graph
.
GetNegSampledNeighbors
(
node_list
,
50
,
meta_info
.
node_type
[
1
],
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Wrong samples number"
)
!=
std
::
string
::
npos
);
neg_neighbors
.
reset
();
s
=
graph
.
GetNegSampledNeighbors
(
node_list
,
3
,
3
,
&
neg_neighbors
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid n
ode type:3
"
)
!=
std
::
string
::
npos
);
EXPECT_TRUE
(
s
.
ToString
().
find
(
"Invalid n
eighbor type
"
)
!=
std
::
string
::
npos
);
}
TEST_F
(
MindDataTestGNNGraph
,
TestRandomWalk
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录