Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
14899a14
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
14899a14
编写于
7月 06, 2020
作者:
T
tony_liu2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gnn random walk pr 1977 comments
add fix to random resize decode crop test case fix pylint issues
上级
9991df86
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
96 addition
and
16 deletion
+96
-16
mindspore/ccsrc/dataset/engine/gnn/graph.cc
mindspore/ccsrc/dataset/engine/gnn/graph.cc
+12
-0
mindspore/ccsrc/dataset/engine/gnn/graph.h
mindspore/ccsrc/dataset/engine/gnn/graph.h
+1
-1
mindspore/dataset/engine/graphdata.py
mindspore/dataset/engine/graphdata.py
+4
-3
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+4
-0
tests/ut/cpp/dataset/gnn_graph_test.cc
tests/ut/cpp/dataset/gnn_graph_test.cc
+27
-1
tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc
tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc
+3
-3
tests/ut/python/dataset/test_graphdata.py
tests/ut/python/dataset/test_graphdata.py
+45
-8
未找到文件。
mindspore/ccsrc/dataset/engine/gnn/graph.cc
浏览文件 @
14899a14
...
@@ -439,6 +439,18 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co
...
@@ -439,6 +439,18 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co
", step_away_param: "
+
std
::
to_string
(
step_away_param
);
", step_away_param: "
+
std
::
to_string
(
step_away_param
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
if
(
default_node
<
-
1
)
{
std
::
string
err_msg
=
"Failed, default_node required to be greater or equal to -1."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
num_walks
<=
0
)
{
std
::
string
err_msg
=
"Failed, num_walks parameter required to be greater than 0"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
num_workers
<=
0
)
{
std
::
string
err_msg
=
"Failed, num_workers parameter required to be greater than 0"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
step_home_param_
=
step_home_param
;
step_home_param_
=
step_home_param
;
step_away_param_
=
step_away_param
;
step_away_param_
=
step_away_param
;
default_node_
=
default_node
;
default_node_
=
default_node
;
...
...
mindspore/ccsrc/dataset/engine/gnn/graph.h
浏览文件 @
14899a14
...
@@ -181,7 +181,7 @@ class Graph {
...
@@ -181,7 +181,7 @@ class Graph {
float
step_away_param_
;
// Inout hyper parameter. Default is 1.0
float
step_away_param_
;
// Inout hyper parameter. Default is 1.0
NodeIdType
default_node_
;
NodeIdType
default_node_
;
int32_t
num_walks_
;
// Number of walks per source. Default is 1
0
int32_t
num_walks_
;
// Number of walks per source. Default is 1
int32_t
num_workers_
;
// The number of worker threads. Default is 1
int32_t
num_workers_
;
// The number of worker threads. Default is 1
};
};
...
...
mindspore/dataset/engine/graphdata.py
浏览文件 @
14899a14
...
@@ -232,9 +232,10 @@ class GraphData:
...
@@ -232,9 +232,10 @@ class GraphData:
Args:
Args:
target_nodes (list[int]): Start node list in random walk
target_nodes (list[int]): Start node list in random walk
meta_path (list[int]): node type for each walk step
meta_path (list[int]): node type for each walk step
step_home_param (float): return hyper parameter in node2vec algorithm
step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0).
step_away_param (float): inout hyper parameter in node2vec algorithm
step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0).
default_node (int): default node if no more neighbors found
default_node (int, optional): default node if no more neighbors found (Default = -1).
A default value of -1 indicates that no node is given.
Returns:
Returns:
numpy.ndarray: array of nodes.
numpy.ndarray: array of nodes.
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
14899a14
...
@@ -1260,6 +1260,10 @@ def check_gnn_random_walk(method):
...
@@ -1260,6 +1260,10 @@ def check_gnn_random_walk(method):
# check meta_path; required argument
# check meta_path; required argument
check_gnn_list_or_ndarray
(
param_dict
.
get
(
"meta_path"
),
'meta_path'
)
check_gnn_list_or_ndarray
(
param_dict
.
get
(
"meta_path"
),
'meta_path'
)
check_type
(
param_dict
.
get
(
"step_home_param"
),
'step_home_param'
,
float
)
check_type
(
param_dict
.
get
(
"step_away_param"
),
'step_away_param'
,
float
)
check_type
(
param_dict
.
get
(
"default_node"
),
'default_node'
,
int
)
return
method
(
*
args
,
**
kwargs
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
return
new_method
...
...
tests/ut/cpp/dataset/gnn_graph_test.cc
浏览文件 @
14899a14
...
@@ -247,4 +247,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
...
@@ -247,4 +247,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
s
=
graph
.
RandomWalk
(
node_list
,
meta_path
,
2.0
,
0.5
,
-
1
,
&
walk_path
);
s
=
graph
.
RandomWalk
(
node_list
,
meta_path
,
2.0
,
0.5
,
-
1
,
&
walk_path
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
walk_path
->
shape
().
ToString
()
==
"<33,60>"
);
EXPECT_TRUE
(
walk_path
->
shape
().
ToString
()
==
"<33,60>"
);
}
}
\ No newline at end of file
TEST_F
(
MindDataTestGNNGraph
,
TestRandomWalkDefaults
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/sns"
;
Graph
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
MetaInfo
meta_info
;
s
=
graph
.
GetMetaInfo
(
&
meta_info
);
EXPECT_TRUE
(
s
.
IsOk
());
std
::
shared_ptr
<
Tensor
>
nodes
;
s
=
graph
.
GetAllNodes
(
meta_info
.
node_type
[
0
],
&
nodes
);
EXPECT_TRUE
(
s
.
IsOk
());
std
::
vector
<
NodeIdType
>
node_list
;
for
(
auto
itr
=
nodes
->
begin
<
NodeIdType
>
();
itr
!=
nodes
->
end
<
NodeIdType
>
();
++
itr
)
{
node_list
.
push_back
(
*
itr
);
}
print_int_vec
(
node_list
,
"node list "
);
std
::
vector
<
NodeType
>
meta_path
(
59
,
1
);
std
::
shared_ptr
<
Tensor
>
walk_path
;
s
=
graph
.
RandomWalk
(
node_list
,
meta_path
,
1.0
,
1.0
,
-
1
,
&
walk_path
);
EXPECT_TRUE
(
s
.
IsOk
());
EXPECT_TRUE
(
walk_path
->
shape
().
ToString
()
==
"<33,60>"
);
}
tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc
浏览文件 @
14899a14
...
@@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) {
...
@@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) {
auto
decode_and_crop
=
static_cast
<
RandomCropAndResizeOp
>
(
crop_and_decode_copy
);
auto
decode_and_crop
=
static_cast
<
RandomCropAndResizeOp
>
(
crop_and_decode_copy
);
EXPECT_TRUE
(
crop_and_decode
.
OneToOne
());
EXPECT_TRUE
(
crop_and_decode
.
OneToOne
());
GlobalContext
::
config_manager
()
->
set_seed
(
42
);
GlobalContext
::
config_manager
()
->
set_seed
(
42
);
for
(
int
k
=
0
;
k
<
10
0
;
k
++
)
{
for
(
int
k
=
0
;
k
<
10
;
k
++
)
{
(
void
)
crop_and_decode
.
Compute
(
raw_input_tensor_
,
&
crop_and_decode_output
);
(
void
)
crop_and_decode
.
Compute
(
raw_input_tensor_
,
&
crop_and_decode_output
);
(
void
)
decode_and_crop
.
Compute
(
input_tensor_
,
&
decode_and_crop_output
);
(
void
)
decode_and_crop
.
Compute
(
input_tensor_
,
&
decode_and_crop_output
);
cv
::
Mat
output1
=
CVTensor
::
AsCVTensor
(
crop_and_decode_output
)
->
mat
().
clone
();
cv
::
Mat
output1
=
CVTensor
::
AsCVTensor
(
crop_and_decode_output
)
->
mat
().
clone
();
...
@@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) {
...
@@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) {
int
mse_sum
,
m1
,
m2
,
count
;
int
mse_sum
,
m1
,
m2
,
count
;
double
mse
;
double
mse
;
for
(
int
k
=
0
;
k
<
10
0
;
++
k
)
{
for
(
int
k
=
0
;
k
<
10
;
++
k
)
{
mse_sum
=
0
;
mse_sum
=
0
;
count
=
0
;
count
=
0
;
for
(
auto
i
=
0
;
i
<
10
0
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
10
;
i
++
)
{
scale
=
rd_scale
(
rd
);
scale
=
rd_scale
(
rd
);
aspect
=
rd_aspect
(
rd
);
aspect
=
rd_aspect
(
rd
);
crop_width
=
std
::
round
(
std
::
sqrt
(
h
*
w
*
scale
/
aspect
));
crop_width
=
std
::
round
(
std
::
sqrt
(
h
*
w
*
scale
/
aspect
));
...
...
tests/ut/python/dataset/test_graphdata.py
浏览文件 @
14899a14
...
@@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
...
@@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
def
test_graphdata_getfullneighbor
():
def
test_graphdata_getfullneighbor
():
"""
Test get all neighbors
"""
logger
.
info
(
'test get all neighbors.
\n
'
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
nodes
=
g
.
get_all_nodes
(
1
)
nodes
=
g
.
get_all_nodes
(
1
)
assert
len
(
nodes
)
==
10
assert
len
(
nodes
)
==
10
...
@@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor():
...
@@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor():
def
test_graphdata_getnodefeature_input_check
():
def
test_graphdata_getnodefeature_input_check
():
"""
Test get node feature input check
"""
logger
.
info
(
'test getnodefeature input check.
\n
'
)
g
=
ds
.
GraphData
(
DATASET_FILE
)
g
=
ds
.
GraphData
(
DATASET_FILE
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
input_list
=
[
1
,
[
1
,
1
]]
input_list
=
[
1
,
[
1
,
1
]]
...
@@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check():
...
@@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check():
def
test_graphdata_getsampledneighbors
():
def
test_graphdata_getsampledneighbors
():
"""
Test sampled neighbors
"""
logger
.
info
(
'test get sampled neighbors.
\n
'
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
1
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
1
)
edges
=
g
.
get_all_edges
(
0
)
edges
=
g
.
get_all_edges
(
0
)
nodes
=
g
.
get_nodes_from_edges
(
edges
)
nodes
=
g
.
get_nodes_from_edges
(
edges
)
...
@@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors():
...
@@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors():
def
test_graphdata_getnegsampledneighbors
():
def
test_graphdata_getnegsampledneighbors
():
"""
Test neg sampled neighbors
"""
logger
.
info
(
'test get negative sampled neighbors.
\n
'
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
nodes
=
g
.
get_all_nodes
(
1
)
nodes
=
g
.
get_all_nodes
(
1
)
assert
len
(
nodes
)
==
10
assert
len
(
nodes
)
==
10
...
@@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors():
...
@@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors():
def
test_graphdata_graphinfo
():
def
test_graphdata_graphinfo
():
"""
Test graph info
"""
logger
.
info
(
'test graph info.
\n
'
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
2
)
graph_info
=
g
.
graph_info
()
graph_info
=
g
.
graph_info
()
assert
graph_info
[
'node_type'
]
==
[
1
,
2
]
assert
graph_info
[
'node_type'
]
==
[
1
,
2
]
...
@@ -155,6 +175,10 @@ class GNNGraphDataset():
...
@@ -155,6 +175,10 @@ class GNNGraphDataset():
def
test_graphdata_generatordataset
():
def
test_graphdata_generatordataset
():
"""
Test generator dataset
"""
logger
.
info
(
'test generator dataset.
\n
'
)
g
=
ds
.
GraphData
(
DATASET_FILE
)
g
=
ds
.
GraphData
(
DATASET_FILE
)
batch_num
=
2
batch_num
=
2
edge_num
=
g
.
graph_info
()[
'edge_num'
][
0
]
edge_num
=
g
.
graph_info
()[
'edge_num'
][
0
]
...
@@ -173,7 +197,11 @@ def test_graphdata_generatordataset():
...
@@ -173,7 +197,11 @@ def test_graphdata_generatordataset():
assert
i
==
40
assert
i
==
40
def
test_graphdata_randomwalk
():
def
test_graphdata_randomwalkdefault
():
"""
Test random walk defaults
"""
logger
.
info
(
'test randomwalk with default parameters.
\n
'
)
g
=
ds
.
GraphData
(
SOCIAL_DATA_FILE
,
1
)
g
=
ds
.
GraphData
(
SOCIAL_DATA_FILE
,
1
)
nodes
=
g
.
get_all_nodes
(
1
)
nodes
=
g
.
get_all_nodes
(
1
)
print
(
len
(
nodes
))
print
(
len
(
nodes
))
...
@@ -184,18 +212,27 @@ def test_graphdata_randomwalk():
...
@@ -184,18 +212,27 @@ def test_graphdata_randomwalk():
assert
walks
.
shape
==
(
33
,
40
)
assert
walks
.
shape
==
(
33
,
40
)
def
test_graphdata_randomwalk
():
"""
Test random walk
"""
logger
.
info
(
'test random walk with given parameters.
\n
'
)
g
=
ds
.
GraphData
(
SOCIAL_DATA_FILE
,
1
)
nodes
=
g
.
get_all_nodes
(
1
)
print
(
len
(
nodes
))
assert
len
(
nodes
)
==
33
meta_path
=
[
1
for
_
in
range
(
39
)]
walks
=
g
.
random_walk
(
nodes
,
meta_path
,
2.0
,
0.5
,
-
1
)
assert
walks
.
shape
==
(
33
,
40
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_graphdata_getfullneighbor
()
test_graphdata_getfullneighbor
()
logger
.
info
(
'test_graphdata_getfullneighbor Ended.
\n
'
)
test_graphdata_getnodefeature_input_check
()
test_graphdata_getnodefeature_input_check
()
logger
.
info
(
'test_graphdata_getnodefeature_input_check Ended.
\n
'
)
test_graphdata_getsampledneighbors
()
test_graphdata_getsampledneighbors
()
logger
.
info
(
'test_graphdata_getsampledneighbors Ended.
\n
'
)
test_graphdata_getnegsampledneighbors
()
test_graphdata_getnegsampledneighbors
()
logger
.
info
(
'test_graphdata_getnegsampledneighbors Ended.
\n
'
)
test_graphdata_graphinfo
()
test_graphdata_graphinfo
()
logger
.
info
(
'test_graphdata_graphinfo Ended.
\n
'
)
test_graphdata_generatordataset
()
test_graphdata_generatordataset
()
logger
.
info
(
'test_graphdata_generatordataset Ended.
\n
'
)
test_graphdata_randomwalkdefault
(
)
test_graphdata_randomwalk
()
test_graphdata_randomwalk
()
logger
.
info
(
'test_graphdata_randomwalk Ended.
\n
'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录