提交 625f2421 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1977 random walk in gnn node2vec

Merge pull request !1977 from JonathanY/randomwalk
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord GNN writer.
"""
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
[348, 347], [347, 351], [347, 327], [347, 329], [347, 331],
[347, 335], [347, 341], [347, 345], [347, 346], [346, 335],
[346, 340], [346, 339], [346, 349], [346, 353], [346, 354],
[346, 341], [346, 345], [345, 335], [345, 336], [345, 341],
[344, 338], [344, 342], [343, 332], [343, 338], [343, 342],
[342, 332], [340, 349], [334, 349], [333, 349], [330, 349],
[328, 349], [359, 349], [358, 352], [358, 349], [358, 354],
[358, 356], [357, 350], [357, 354], [357, 356], [356, 350],
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (0, [], [])
edge_profile = (0, [], [])
def yield_nodes(task_id=0):
"""
Generate node data
Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
node_list = []
for edge in social_data:
src, dst = edge
if src not in node_list:
node_list.append(src)
if dst not in node_list:
node_list.append(dst)
node_list.sort()
print(node_list)
for node_id in node_list:
node = {'id': node_id, 'type': 1}
yield node
def yield_edges(task_id=0):
"""
Generate edge data
Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
line_count = 0
for undirected_edge in social_data:
line_count += 1
edge = {
'id': line_count,
'src_id': undirected_edge[0],
'dst_id': undirected_edge[1],
'type': 1}
yield edge
line_count += 1
edge = {
'id': line_count,
'src_id': undirected_edge[1],
'dst_id': undirected_edge[0],
'type': 1}
yield edge
#!/bin/bash
MINDRECORD_PATH=/tmp/sns
rm -f $MINDRECORD_PATH/*
python writer.py --mindrecord_script sns \
--mindrecord_file "$MINDRECORD_PATH/sns" \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 14 \
--mindrecord_page_size_by_bit 15
......@@ -656,9 +656,16 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow();
})
.def("graph_info", [](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
.def("graph_info",
[](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
})
.def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out;
});
}
......
......@@ -29,7 +29,7 @@ namespace dataset {
namespace gnn {
Graph::Graph(std::string dataset_file, int32_t num_workers)
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) {
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) {
rnd_.seed(GetSeed());
MS_LOG(INFO) << "num_workers:" << num_workers;
}
......@@ -240,8 +240,13 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N
return Status::OK();
}
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p,
float q, NodeIdType default_node, std::shared_ptr<Tensor> *out) {
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) {
RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
std::vector<std::vector<NodeIdType>> walks;
RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({walks}, DataType(DataType::DE_INT32), out));
return Status::OK();
}
......@@ -386,6 +391,195 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
return Status::OK();
}
Graph::RandomWalkBase::RandomWalkBase(Graph *graph)
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, const NodeIdType default_node,
int32_t num_walks, int32_t num_workers) {
node_list_ = node_list;
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) +
". The size of input path is " + std::to_string(meta_path.size());
RETURN_STATUS_UNEXPECTED(err_msg);
}
meta_path_ = meta_path;
if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) {
std::string err_msg = "Failed, step_home_param and step_away_param required greater than " +
std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) +
", step_away_param: " + std::to_string(step_away_param);
RETURN_STATUS_UNEXPECTED(err_msg);
}
step_home_param_ = step_home_param;
step_away_param_ = step_away_param;
default_node_ = default_node;
num_walks_ = num_walks;
num_workers_ = num_workers;
return Status::OK();
}
Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
// Simulate a random walk starting from start node.
auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
// walk simulate
while (walk.size() - 1 < meta_path_.size()) {
// current nodE
auto cur_node_id = walk.back();
std::shared_ptr<Node> cur_node;
RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node));
// current neighbors
std::vector<NodeIdType> cur_neighbors;
RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true));
std::sort(cur_neighbors.begin(), cur_neighbors.end());
// break if no neighbors
if (cur_neighbors.empty()) {
break;
}
// walk by the fist node, then by the previous 2 nodes
std::shared_ptr<StochasticIndex> stochastic_index;
if (walk.size() == 1) {
RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index));
} else {
NodeIdType prev_node_id = walk[walk.size() - 2];
RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index));
}
NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)];
walk.push_back(next_node_id);
}
while (walk.size() - 1 < meta_path_.size()) {
walk.push_back(default_node_);
}
*walk_path = std::move(walk);
return Status::OK();
}
Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
// Repeatedly simulate random walks from each node
std::vector<uint32_t> permutation(node_list_.size());
std::iota(permutation.begin(), permutation.end(), 0);
for (int32_t i = 0; i < num_walks_; i++) {
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed));
for (const auto &i_perm : permutation) {
std::vector<NodeIdType> walk;
RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk));
walks->push_back(walk);
}
}
return Status::OK();
}
Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
std::shared_ptr<StochasticIndex> *node_probability) {
// Generate alias nodes
std::shared_ptr<Node> node;
graph_->GetNodeByNodeId(node_id, &node);
std::vector<NodeIdType> neighbors;
RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true));
std::sort(neighbors.begin(), neighbors.end());
auto non_normalized_probability = std::vector<float>(neighbors.size(), 1.0);
*node_probability =
std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
return Status::OK();
}
Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
std::shared_ptr<StochasticIndex> *edge_probability) {
// Get the alias edge setup lists for a given edge.
std::shared_ptr<Node> src_node;
graph_->GetNodeByNodeId(src, &src_node);
std::vector<NodeIdType> src_neighbors;
RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true));
std::shared_ptr<Node> dst_node;
graph_->GetNodeByNodeId(dst, &dst_node);
std::vector<NodeIdType> dst_neighbors;
RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true));
std::sort(dst_neighbors.begin(), dst_neighbors.end());
std::vector<float> non_normalized_probability;
for (const auto &dst_nbr : dst_neighbors) {
if (dst_nbr == src) {
non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
continue;
}
auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr);
if (it != src_neighbors.end()) {
// stay close, this node connect both src and dst
non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight']
} else {
// step far away
non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
}
}
*edge_probability =
std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
return Status::OK();
}
StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
uint32_t K = probability.size();
std::vector<int32_t> switch_to_large_index(K, 0);
std::vector<float> weight(K, .0);
std::vector<int32_t> smaller;
std::vector<int32_t> larger;
auto random_device = GetRandomDevice();
std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon);
float accumulate_threshold = 0.0;
for (uint32_t i = 0; i < K; i++) {
float threshold_one = distribution(random_device);
accumulate_threshold += threshold_one;
weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold;
weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i);
}
while ((!smaller.empty()) && (!larger.empty())) {
uint32_t small = smaller.back();
smaller.pop_back();
uint32_t large = larger.back();
larger.pop_back();
switch_to_large_index[small] = large;
weight[large] = weight[large] + weight[small] - 1.0;
weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large);
}
return StochasticIndex(switch_to_large_index, weight);
}
uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
auto switch_to_large_index = stochastic_index.first;
auto weight = stochastic_index.second;
const uint32_t size_of_index = switch_to_large_index.size();
auto random_device = GetRandomDevice();
std::uniform_real_distribution<> distribution(0.0, 1.0);
// Generate random integer between [0, K)
uint32_t random_idx = std::floor(distribution(random_device) * size_of_index);
if (distribution(random_device) < weight[random_idx]) {
return random_idx;
}
return switch_to_large_index[random_idx];
}
template <typename T>
std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
float sum_probability =
1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
if (sum_probability < kGnnEpsilon) {
sum_probability = 1.0;
}
std::vector<float> normalized_probability;
std::transform(non_normalized_probability.begin(), non_normalized_probability.end(),
std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; });
return normalized_probability;
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore
......@@ -16,12 +16,14 @@
#ifndef DATASET_ENGINE_GNN_GRAPH_H_
#define DATASET_ENGINE_GNN_GRAPH_H_
#include <algorithm>
#include <memory>
#include <string>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <utility>
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_row.h"
......@@ -35,6 +37,10 @@ namespace mindspore {
namespace dataset {
namespace gnn {
const float kGnnEpsilon = 0.0001;
const uint32_t kMaxNumWalks = 80;
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
struct MetaInfo {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
......@@ -99,8 +105,17 @@ class Graph {
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
NodeIdType default_node, std::shared_ptr<Tensor> *out);
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out);
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
......@@ -131,6 +146,45 @@ class Graph {
Status Init();
private:
class RandomWalkBase {
public:
explicit RandomWalkBase(Graph *graph);
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
int32_t num_walks = 1, int32_t num_workers = 1);
~RandomWalkBase() = default;
Status SimulateWalk(std::vector<std::vector<NodeIdType>> *walks);
private:
Status Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path);
Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
std::shared_ptr<StochasticIndex> *node_probability);
Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
std::shared_ptr<StochasticIndex> *edge_probability);
static StochasticIndex GenerateProbability(const std::vector<float> &probability);
static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index);
template <typename T>
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
Graph *graph_;
std::vector<NodeIdType> node_list_;
std::vector<NodeType> meta_path_;
float step_home_param_; // Return hyper parameter. Default is 1.0
float step_away_param_; // Inout hyper parameter. Default is 1.0
NodeIdType default_node_;
int32_t num_walks_; // Number of walks per source. Default is 10
int32_t num_workers_; // The number of worker threads. Default is 1
};
// Load graph data from mindrecord file
// @return Status - The error code return
Status LoadNodeAndEdge();
......@@ -175,6 +229,7 @@ class Graph {
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_;
RandomWalkBase random_walk_;
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
......
......@@ -39,17 +39,25 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
}
}
Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) {
Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, bool exclude_itself) {
std::vector<NodeIdType> neighbors;
auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) {
neighbors.resize(itr->second.size() + 1);
neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); });
if (exclude_itself) {
neighbors.resize(itr->second.size());
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(),
[](const std::shared_ptr<Node> node) { return node->id(); });
} else {
neighbors.resize(itr->second.size() + 1);
neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); });
}
} else {
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
neighbors.emplace_back(id_);
if (!exclude_itself) {
neighbors.emplace_back(id_);
}
}
*out_neighbors = std::move(neighbors);
return Status::OK();
......
......@@ -47,7 +47,8 @@ class LocalNode : public Node {
// @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override;
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) override;
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
......
......@@ -56,7 +56,8 @@ class Node {
// @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0;
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) = 0;
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
......
......@@ -22,7 +22,7 @@ from mindspore._c_dataengine import Tensor
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk
class GraphData:
......@@ -148,7 +148,8 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray.
"""
return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array()
return self._graph.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array()
@check_gnn_get_neg_sampled_neighbors
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
......@@ -174,7 +175,8 @@ class GraphData:
TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer.
"""
return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array()
return self._graph.get_neg_sampled_neighbors(
node_list, neg_neighbor_num, neg_neighbor_type).as_array()
@check_gnn_get_node_feature
def get_node_feature(self, node_list, feature_types):
......@@ -200,7 +202,10 @@ class GraphData:
"""
if isinstance(node_list, list):
node_list = np.array(node_list, dtype=np.int32)
return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]
return [
t.as_array() for t in self._graph.get_node_feature(
Tensor(node_list),
feature_types)]
def graph_info(self):
"""
......@@ -212,3 +217,36 @@ class GraphData:
node_feature_type and edge_feature_type.
"""
return self._graph.graph_info()
@check_gnn_random_walk
def random_walk(
self,
target_nodes,
meta_path,
step_home_param=1.0,
step_away_param=1.0,
default_node=-1):
"""
Random walk in nodes.
Args:
target_nodes (list[int]): Start node list in random walk
meta_path (list[int]): node type for each walk step
step_home_param (float): return hyper parameter in node2vec algorithm
step_away_param (float): inout hyper parameter in node2vec algorithm
default_node (int): default node if no more neighbors found
Returns:
numpy.ndarray: array of nodes.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1])
Raises:
TypeError: If `target_nodes` is not list or ndarray.
TypeError: If `meta_path` is not list or ndarray.
"""
return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array()
......@@ -1395,6 +1395,24 @@ def check_gnn_get_neg_sampled_neighbors(method):
return new_method
def check_gnn_random_walk(method):
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check node_list; required argument
check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes')
# check meta_path; required argument
check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path')
return method(*args, **kwargs)
return new_method
def check_aligned_list(param, param_name, membor_type):
"""Check whether the structure of each member of the list is the same."""
......
......@@ -27,6 +27,13 @@
using namespace mindspore::dataset;
using namespace mindspore::dataset::gnn;
#define print_int_vec(_i, _str) \
do { \
std::stringstream ss; \
std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \
MS_LOG(INFO) << _str << " " << ss.str(); \
} while (false)
class MindDataTestGNNGraph : public UT::Common {
protected:
MindDataTestGNNGraph() = default;
......@@ -195,3 +202,29 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
}
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());
std::shared_ptr<Tensor> nodes;
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
node_list.push_back(*itr);
}
print_int_vec(node_list, "node list ");
std::vector<NodeType> meta_path(59, 1);
std::shared_ptr<Tensor> walk_path;
s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
}
\ No newline at end of file
......@@ -19,6 +19,7 @@ import mindspore.dataset as ds
from mindspore import log as logger
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
def test_graphdata_getfullneighbor():
......@@ -172,6 +173,17 @@ def test_graphdata_generatordataset():
assert i == 40
def test_graphdata_randomwalk():
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
assert len(nodes) == 33
meta_path = [1 for _ in range(39)]
walks = g.random_walk(nodes, meta_path)
assert walks.shape == (33, 40)
if __name__ == '__main__':
test_graphdata_getfullneighbor()
logger.info('test_graphdata_getfullneighbor Ended.\n')
......@@ -185,3 +197,5 @@ if __name__ == '__main__':
logger.info('test_graphdata_graphinfo Ended.\n')
test_graphdata_generatordataset()
logger.info('test_graphdata_generatordataset Ended.\n')
test_graphdata_randomwalk()
logger.info('test_graphdata_randomwalk Ended.\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册