提交 14899a14 编写于 作者: T tony_liu2

fix gnn random walk pr 1977 comments

add fix to random resize decode crop test case

fix pylint issues
上级 9991df86
......@@ -439,6 +439,18 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co
", step_away_param: " + std::to_string(step_away_param);
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (default_node < -1) {
std::string err_msg = "Failed, default_node required to be greater or equal to -1.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (num_walks <= 0) {
std::string err_msg = "Failed, num_walks parameter required to be greater than 0";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (num_workers <= 0) {
std::string err_msg = "Failed, num_workers parameter required to be greater than 0";
RETURN_STATUS_UNEXPECTED(err_msg);
}
step_home_param_ = step_home_param;
step_away_param_ = step_away_param;
default_node_ = default_node;
......
......@@ -181,7 +181,7 @@ class Graph {
float step_away_param_; // Inout hyper parameter. Default is 1.0
NodeIdType default_node_;
int32_t num_walks_; // Number of walks per source. Default is 10
int32_t num_walks_; // Number of walks per source. Default is 1
int32_t num_workers_; // The number of worker threads. Default is 1
};
......
......@@ -232,9 +232,10 @@ class GraphData:
Args:
target_nodes (list[int]): Start node list in random walk
meta_path (list[int]): node type for each walk step
step_home_param (float): return hyper parameter in node2vec algorithm
step_away_param (float): inout hyper parameter in node2vec algorithm
default_node (int): default node if no more neighbors found
step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0).
step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0).
default_node (int, optional): default node if no more neighbors found (Default = -1).
A default value of -1 indicates that no node is given.
Returns:
numpy.ndarray: array of nodes.
......
......@@ -1260,6 +1260,10 @@ def check_gnn_random_walk(method):
# check meta_path; required argument
check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path')
check_type(param_dict.get("step_home_param"), 'step_home_param', float)
check_type(param_dict.get("step_away_param"), 'step_away_param', float)
check_type(param_dict.get("default_node"), 'default_node', int)
return method(*args, **kwargs)
return new_method
......
......@@ -247,4 +247,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
}
\ No newline at end of file
}
TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());
std::shared_ptr<Tensor> nodes;
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
node_list.push_back(*itr);
}
print_int_vec(node_list, "node list ");
std::vector<NodeType> meta_path(59, 1);
std::shared_ptr<Tensor> walk_path;
s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
}
......@@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) {
auto decode_and_crop = static_cast<RandomCropAndResizeOp>(crop_and_decode_copy);
EXPECT_TRUE(crop_and_decode.OneToOne());
GlobalContext::config_manager()->set_seed(42);
for (int k = 0; k < 100; k++) {
for (int k = 0; k < 10; k++) {
(void)crop_and_decode.Compute(raw_input_tensor_, &crop_and_decode_output);
(void)decode_and_crop.Compute(input_tensor_, &decode_and_crop_output);
cv::Mat output1 = CVTensor::AsCVTensor(crop_and_decode_output)->mat().clone();
......@@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) {
int mse_sum, m1, m2, count;
double mse;
for (int k = 0; k < 100; ++k) {
for (int k = 0; k < 10; ++k) {
mse_sum = 0;
count = 0;
for (auto i = 0; i < 100; i++) {
for (auto i = 0; i < 10; i++) {
scale = rd_scale(rd);
aspect = rd_aspect(rd);
crop_width = std::round(std::sqrt(h * w * scale / aspect));
......
......@@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
def test_graphdata_getfullneighbor():
"""
Test get all neighbors
"""
logger.info('test get all neighbors.\n')
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
......@@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor():
def test_graphdata_getnodefeature_input_check():
"""
Test get node feature input check
"""
logger.info('test getnodefeature input check.\n')
g = ds.GraphData(DATASET_FILE)
with pytest.raises(TypeError):
input_list = [1, [1, 1]]
......@@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check():
def test_graphdata_getsampledneighbors():
"""
Test sampled neighbors
"""
logger.info('test get sampled neighbors.\n')
g = ds.GraphData(DATASET_FILE, 1)
edges = g.get_all_edges(0)
nodes = g.get_nodes_from_edges(edges)
......@@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors():
def test_graphdata_getnegsampledneighbors():
"""
Test neg sampled neighbors
"""
logger.info('test get negative sampled neighbors.\n')
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
......@@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors():
def test_graphdata_graphinfo():
"""
Test graph info
"""
logger.info('test graph info.\n')
g = ds.GraphData(DATASET_FILE, 2)
graph_info = g.graph_info()
assert graph_info['node_type'] == [1, 2]
......@@ -155,6 +175,10 @@ class GNNGraphDataset():
def test_graphdata_generatordataset():
"""
Test generator dataset
"""
logger.info('test generator dataset.\n')
g = ds.GraphData(DATASET_FILE)
batch_num = 2
edge_num = g.graph_info()['edge_num'][0]
......@@ -173,7 +197,11 @@ def test_graphdata_generatordataset():
assert i == 40
def test_graphdata_randomwalk():
def test_graphdata_randomwalkdefault():
"""
Test random walk defaults
"""
logger.info('test randomwalk with default parameters.\n')
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
......@@ -184,18 +212,27 @@ def test_graphdata_randomwalk():
assert walks.shape == (33, 40)
def test_graphdata_randomwalk():
"""
Test random walk
"""
logger.info('test random walk with given parameters.\n')
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
assert len(nodes) == 33
meta_path = [1 for _ in range(39)]
walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1)
assert walks.shape == (33, 40)
if __name__ == '__main__':
test_graphdata_getfullneighbor()
logger.info('test_graphdata_getfullneighbor Ended.\n')
test_graphdata_getnodefeature_input_check()
logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
test_graphdata_getsampledneighbors()
logger.info('test_graphdata_getsampledneighbors Ended.\n')
test_graphdata_getnegsampledneighbors()
logger.info('test_graphdata_getnegsampledneighbors Ended.\n')
test_graphdata_graphinfo()
logger.info('test_graphdata_graphinfo Ended.\n')
test_graphdata_generatordataset()
logger.info('test_graphdata_generatordataset Ended.\n')
test_graphdata_randomwalkdefault()
test_graphdata_randomwalk()
logger.info('test_graphdata_randomwalk Ended.\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册