diff --git a/example/graph_to_mindrecord/sns/__init__.py b/example/graph_to_mindrecord/sns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/graph_to_mindrecord/sns/mr_api.py b/example/graph_to_mindrecord/sns/mr_api.py new file mode 100644 index 0000000000000000000000000000000000000000..4e01441601d1d544c9d39cd3e3a3948c7975f3d8 --- /dev/null +++ b/example/graph_to_mindrecord/sns/mr_api.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord GNN writer. +""" +social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], + [348, 336], [348, 337], [348, 338], [348, 340], [348, 341], + [348, 342], [348, 343], [348, 344], [348, 345], [348, 346], + [348, 347], [347, 351], [347, 327], [347, 329], [347, 331], + [347, 335], [347, 341], [347, 345], [347, 346], [346, 335], + [346, 340], [346, 339], [346, 349], [346, 353], [346, 354], + [346, 341], [346, 345], [345, 335], [345, 336], [345, 341], + [344, 338], [344, 342], [343, 332], [343, 338], [343, 342], + [342, 332], [340, 349], [334, 349], [333, 349], [330, 349], + [328, 349], [359, 349], [358, 352], [358, 349], [358, 354], + [358, 356], [357, 350], [357, 354], [357, 356], [356, 350], + [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] + +# profile: (num_features, feature_data_types, feature_shapes) +node_profile = (0, [], []) +edge_profile = (0, [], []) + + +def yield_nodes(task_id=0): + """ + Generate node data + + Yields: + data (dict): data row which is dict. + """ + print("Node task is {}".format(task_id)) + node_list = [] + for edge in social_data: + src, dst = edge + if src not in node_list: + node_list.append(src) + if dst not in node_list: + node_list.append(dst) + node_list.sort() + print(node_list) + for node_id in node_list: + node = {'id': node_id, 'type': 1} + yield node + + +def yield_edges(task_id=0): + """ + Generate edge data + + Yields: + data (dict): data row which is dict. + """ + print("Edge task is {}".format(task_id)) + line_count = 0 + for undirected_edge in social_data: + line_count += 1 + edge = { + 'id': line_count, + 'src_id': undirected_edge[0], + 'dst_id': undirected_edge[1], + 'type': 1} + yield edge + line_count += 1 + edge = { + 'id': line_count, + 'src_id': undirected_edge[1], + 'dst_id': undirected_edge[0], + 'type': 1} + yield edge diff --git a/example/graph_to_mindrecord/write_sns.sh b/example/graph_to_mindrecord/write_sns.sh new file mode 100644 index 0000000000000000000000000000000000000000..f564ddc8ffcd7a046bda596937d2dfedb69d438d --- /dev/null +++ b/example/graph_to_mindrecord/write_sns.sh @@ -0,0 +1,10 @@ +#!/bin/bash +MINDRECORD_PATH=/tmp/sns + +rm -f $MINDRECORD_PATH/* + +python writer.py --mindrecord_script sns \ +--mindrecord_file "$MINDRECORD_PATH/sns" \ +--mindrecord_partitions 1 \ +--mindrecord_header_size_by_bit 14 \ +--mindrecord_page_size_by_bit 15 diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 493e12e8f5dc86db2d09cf87f5d5a2b2159065a1..9e6940c5a36a8cf69515539846b98dcf22a3f1a2 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -656,9 +656,16 @@ void bindGraphData(py::module *m) { THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); return out.getRow(); }) - .def("graph_info", [](gnn::Graph &g) { - py::dict out; - THROW_IF_ERROR(g.GraphInfo(&out)); + .def("graph_info", + [](gnn::Graph &g) { + py::dict out; + THROW_IF_ERROR(g.GraphInfo(&out)); + return out; + }) + .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, + float step_home_param, float step_away_param, gnn::NodeIdType default_node) { + std::shared_ptr out; + THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); return out; }); } diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc index 15222d3b2388dae002382b36997d3e0e40760e49..10176573973d4aac581740c7cabff79e2cf236cd 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -29,7 +29,7 @@ namespace dataset { namespace gnn { Graph::Graph(std::string dataset_file, int32_t num_workers) - : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) { + : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { rnd_.seed(GetSeed()); MS_LOG(INFO) << "num_workers:" << num_workers; } @@ -240,8 +240,13 @@ Status Graph::GetNegSampledNeighbors(const std::vector &node_list, N return Status::OK(); } -Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, - float q, NodeIdType default_node, std::shared_ptr *out) { +Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) { + RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); + std::vector> walks; + RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); + RETURN_IF_NOT_OK(CreateTensorByVector({walks}, DataType(DataType::DE_INT32), out)); return Status::OK(); } @@ -386,6 +391,195 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { return Status::OK(); } +Graph::RandomWalkBase::RandomWalkBase(Graph *graph) + : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} + +Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, const NodeIdType default_node, + int32_t num_walks, int32_t num_workers) { + node_list_ = node_list; + if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { + std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + + ". The size of input path is " + std::to_string(meta_path.size()); + RETURN_STATUS_UNEXPECTED(err_msg); + } + meta_path_ = meta_path; + if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { + std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + + std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + + ", step_away_param: " + std::to_string(step_away_param); + RETURN_STATUS_UNEXPECTED(err_msg); + } + step_home_param_ = step_home_param; + step_away_param_ = step_away_param; + default_node_ = default_node; + num_walks_ = num_walks; + num_workers_ = num_workers; + return Status::OK(); +} + +Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { + // Simulate a random walk starting from start node. + auto walk = std::vector(1, start_node); // walk is an vector + // walk simulate + while (walk.size() - 1 < meta_path_.size()) { + // current nodE + auto cur_node_id = walk.back(); + std::shared_ptr cur_node; + RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); + + // current neighbors + std::vector cur_neighbors; + RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); + std::sort(cur_neighbors.begin(), cur_neighbors.end()); + + // break if no neighbors + if (cur_neighbors.empty()) { + break; + } + + // walk by the fist node, then by the previous 2 nodes + std::shared_ptr stochastic_index; + if (walk.size() == 1) { + RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); + } else { + NodeIdType prev_node_id = walk[walk.size() - 2]; + RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); + } + NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; + walk.push_back(next_node_id); + } + + while (walk.size() - 1 < meta_path_.size()) { + walk.push_back(default_node_); + } + + *walk_path = std::move(walk); + return Status::OK(); +} + +Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { + // Repeatedly simulate random walks from each node + std::vector permutation(node_list_.size()); + std::iota(permutation.begin(), permutation.end(), 0); + for (int32_t i = 0; i < num_walks_; i++) { + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed)); + for (const auto &i_perm : permutation) { + std::vector walk; + RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk)); + walks->push_back(walk); + } + } + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability) { + // Generate alias nodes + std::shared_ptr node; + graph_->GetNodeByNodeId(node_id, &node); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); + std::sort(neighbors.begin(), neighbors.end()); + auto non_normalized_probability = std::vector(neighbors.size(), 1.0); + *node_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability) { + // Get the alias edge setup lists for a given edge. + std::shared_ptr src_node; + graph_->GetNodeByNodeId(src, &src_node); + std::vector src_neighbors; + RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); + + std::shared_ptr dst_node; + graph_->GetNodeByNodeId(dst, &dst_node); + std::vector dst_neighbors; + RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); + + std::sort(dst_neighbors.begin(), dst_neighbors.end()); + std::vector non_normalized_probability; + for (const auto &dst_nbr : dst_neighbors) { + if (dst_nbr == src) { + non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + continue; + } + auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); + if (it != src_neighbors.end()) { + // stay close, this node connect both src and dst + non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] + } else { + // step far away + non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + } + } + + *edge_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { + uint32_t K = probability.size(); + std::vector switch_to_large_index(K, 0); + std::vector weight(K, .0); + std::vector smaller; + std::vector larger; + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); + float accumulate_threshold = 0.0; + for (uint32_t i = 0; i < K; i++) { + float threshold_one = distribution(random_device); + accumulate_threshold += threshold_one; + weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; + weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); + } + + while ((!smaller.empty()) && (!larger.empty())) { + uint32_t small = smaller.back(); + smaller.pop_back(); + uint32_t large = larger.back(); + larger.pop_back(); + switch_to_large_index[small] = large; + weight[large] = weight[large] + weight[small] - 1.0; + weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); + } + return StochasticIndex(switch_to_large_index, weight); +} + +uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { + auto switch_to_large_index = stochastic_index.first; + auto weight = stochastic_index.second; + const uint32_t size_of_index = switch_to_large_index.size(); + + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(0.0, 1.0); + + // Generate random integer between [0, K) + uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); + + if (distribution(random_device) < weight[random_idx]) { + return random_idx; + } + return switch_to_large_index[random_idx]; +} + +template +std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { + float sum_probability = + 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); + if (sum_probability < kGnnEpsilon) { + sum_probability = 1.0; + } + std::vector normalized_probability; + std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), + std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); + return normalized_probability; +} } // namespace gnn } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h index 13de3e524b2e4bf5ddd186843911365a862da272..ea1036305361556a96650163ccda0b29167fe3fa 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -16,12 +16,14 @@ #ifndef DATASET_ENGINE_GNN_GRAPH_H_ #define DATASET_ENGINE_GNN_GRAPH_H_ +#include #include #include #include #include #include #include +#include #include "dataset/core/tensor.h" #include "dataset/core/tensor_row.h" @@ -35,6 +37,10 @@ namespace mindspore { namespace dataset { namespace gnn { +const float kGnnEpsilon = 0.0001; +const uint32_t kMaxNumWalks = 80; +using StochasticIndex = std::pair, std::vector>; + struct MetaInfo { std::vector node_type; std::vector edge_type; @@ -99,8 +105,17 @@ class Graph { Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, NodeType neg_neighbor_type, std::shared_ptr *out); - Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, float q, - NodeIdType default_node, std::shared_ptr *out); + // Node2vec random walk. + // @param std::vector node_list - List of nodes + // @param std::vector meta_path - node type of each step + // @param float step_home_param - return hyper parameter in node2vec algorithm + // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param NodeIdType default_node - default node id + // @param std::shared_ptr *out - Returned nodes id in walk path + // @return Status - The error code return + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out); // Get the feature of a node // @param std::shared_ptr nodes - List of nodes @@ -131,6 +146,45 @@ class Graph { Status Init(); private: + class RandomWalkBase { + public: + explicit RandomWalkBase(Graph *graph); + + Status Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, + int32_t num_walks = 1, int32_t num_workers = 1); + + ~RandomWalkBase() = default; + + Status SimulateWalk(std::vector> *walks); + + private: + Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); + + Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability); + + Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability); + + static StochasticIndex GenerateProbability(const std::vector &probability); + + static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); + + template + std::vector Normalize(const std::vector &non_normalized_probability); + + Graph *graph_; + std::vector node_list_; + std::vector meta_path_; + float step_home_param_; // Return hyper parameter. Default is 1.0 + float step_away_param_; // Inout hyper parameter. Default is 1.0 + NodeIdType default_node_; + + int32_t num_walks_; // Number of walks per source. Default is 10 + int32_t num_workers_; // The number of worker threads. Default is 1 + }; + // Load graph data from mindrecord file // @return Status - The error code return Status LoadNodeAndEdge(); @@ -175,6 +229,7 @@ class Graph { std::string dataset_file_; int32_t num_workers_; // The number of worker threads std::mt19937 rnd_; + RandomWalkBase random_walk_; std::unordered_map> node_type_map_; std::unordered_map> node_id_map_; diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc index e091a52faa1183176fb9aeaf4c12dafb7bb1f9b1..c829f8e8caf79389b1c267b25372e32ea3a89589 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc @@ -39,17 +39,25 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr } } -Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) { +Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, bool exclude_itself) { std::vector neighbors; auto itr = neighbor_nodes_.find(neighbor_type); if (itr != neighbor_nodes_.end()) { - neighbors.resize(itr->second.size() + 1); - neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, - [](const std::shared_ptr node) { return node->id(); }); + if (exclude_itself) { + neighbors.resize(itr->second.size()); + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), + [](const std::shared_ptr node) { return node->id(); }); + } else { + neighbors.resize(itr->second.size() + 1); + neighbors[0] = id_; + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + [](const std::shared_ptr node) { return node->id(); }); + } } else { MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; - neighbors.emplace_back(id_); + if (!exclude_itself) { + neighbors.emplace_back(id_); + } } *out_neighbors = std::move(neighbors); return Status::OK(); diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/dataset/engine/gnn/local_node.h index b9b007c420ec2263efb91c0c4e63f01dadf3936c..bc069d073fda92ecaa0a8a130758411e57ec1d17 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.h @@ -47,7 +47,8 @@ class LocalNode : public Node { // @param NodeType neighbor_type - type of neighbor // @param std::vector *out_neighbors - Returned neighbors id // @return Status - The error code return - Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) override; + Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) override; // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor diff --git a/mindspore/ccsrc/dataset/engine/gnn/node.h b/mindspore/ccsrc/dataset/engine/gnn/node.h index f0136e92d7b308b70c280a1777fdf9f9f969594e..282f85679719f42386abf67fff7dad0c6b171d7b 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/dataset/engine/gnn/node.h @@ -56,7 +56,8 @@ class Node { // @param NodeType neighbor_type - type of neighbor // @param std::vector *out_neighbors - Returned neighbors id // @return Status - The error code return - virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) = 0; + virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) = 0; // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index e6ff22dd0d3dc17ed19aa11a4bf513b71877aae1..838dd53f0a0c58a16e003fc16cd3e10a2507dc27 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -22,7 +22,7 @@ from mindspore._c_dataengine import Tensor from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ - check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature + check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk class GraphData: @@ -148,7 +148,8 @@ class GraphData: TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray. """ - return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array() + return self._graph.get_sampled_neighbors( + node_list, neighbor_nums, neighbor_types).as_array() @check_gnn_get_neg_sampled_neighbors def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): @@ -174,7 +175,8 @@ class GraphData: TypeError: If `neg_neighbor_num` is not integer. TypeError: If `neg_neighbor_type` is not integer. """ - return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array() + return self._graph.get_neg_sampled_neighbors( + node_list, neg_neighbor_num, neg_neighbor_type).as_array() @check_gnn_get_node_feature def get_node_feature(self, node_list, feature_types): @@ -200,7 +202,10 @@ class GraphData: """ if isinstance(node_list, list): node_list = np.array(node_list, dtype=np.int32) - return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] + return [ + t.as_array() for t in self._graph.get_node_feature( + Tensor(node_list), + feature_types)] def graph_info(self): """ @@ -212,3 +217,36 @@ class GraphData: node_feature_type and edge_feature_type. """ return self._graph.graph_info() + + @check_gnn_random_walk + def random_walk( + self, + target_nodes, + meta_path, + step_home_param=1.0, + step_away_param=1.0, + default_node=-1): + """ + Random walk in nodes. + + Args: + target_nodes (list[int]): Start node list in random walk + meta_path (list[int]): node type for each walk step + step_home_param (float): return hyper parameter in node2vec algorithm + step_away_param (float): inout hyper parameter in node2vec algorithm + default_node (int): default node if no more neighbors found + + Returns: + numpy.ndarray: array of nodes. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1]) + + Raises: + TypeError: If `target_nodes` is not list or ndarray. + TypeError: If `meta_path` is not list or ndarray. + """ + return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, + default_node).as_array() diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4c8388ea639f40951f66a31a4614fe427f6ee2e2..d32a00e1752731825f542ad6d655cd465b700ba4 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1395,6 +1395,24 @@ def check_gnn_get_neg_sampled_neighbors(method): return new_method +def check_gnn_random_walk(method): + """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_list; required argument + check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes') + + # check meta_path; required argument + check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') + + return method(*args, **kwargs) + + return new_method + + def check_aligned_list(param, param_name, membor_type): """Check whether the structure of each member of the list is the same.""" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 7c644a3ae7aac8f6f0e063c1b96d927dee550a39..ce2aca4ffd0ab713933012c6f57e505f49eba874 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -27,6 +27,13 @@ using namespace mindspore::dataset; using namespace mindspore::dataset::gnn; +#define print_int_vec(_i, _str) \ + do { \ + std::stringstream ss; \ + std::copy(_i.begin(), _i.end(), std::ostream_iterator(ss, " ")); \ + MS_LOG(INFO) << _str << " " << ss.str(); \ + } while (false) + class MindDataTestGNNGraph : public UT::Common { protected: MindDataTestGNNGraph() = default; @@ -195,3 +202,29 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); } + +TEST_F(MindDataTestGNNGraph, TestRandomWalk) { + std::string path = "data/mindrecord/testGraphData/sns"; + Graph graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); + EXPECT_TRUE(s.IsOk()); + + std::shared_ptr nodes; + s = graph.GetAllNodes(meta_info.node_type[0], &nodes); + EXPECT_TRUE(s.IsOk()); + std::vector node_list; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + node_list.push_back(*itr); + } + + print_int_vec(node_list, "node list "); + std::vector meta_path(59, 1); + std::shared_ptr walk_path; + s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); +} \ No newline at end of file diff --git a/tests/ut/data/mindrecord/testGraphData/sns b/tests/ut/data/mindrecord/testGraphData/sns new file mode 100644 index 0000000000000000000000000000000000000000..37a2c3dd30b7344083968973520f83702ec6266b Binary files /dev/null and b/tests/ut/data/mindrecord/testGraphData/sns differ diff --git a/tests/ut/data/mindrecord/testGraphData/sns.db b/tests/ut/data/mindrecord/testGraphData/sns.db new file mode 100644 index 0000000000000000000000000000000000000000..14d0b4f6b95cb242cf71821ff03f4db8c2c431bf Binary files /dev/null and b/tests/ut/data/mindrecord/testGraphData/sns.db differ diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 9b4ff66ac177d8d0b3b183cf3a2a781ee00a718e..408333662330fb3eb70d276f172198e548da3732 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -19,6 +19,7 @@ import mindspore.dataset as ds from mindspore import log as logger DATASET_FILE = "../data/mindrecord/testGraphData/testdata" +SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" def test_graphdata_getfullneighbor(): @@ -172,6 +173,17 @@ def test_graphdata_generatordataset(): assert i == 40 +def test_graphdata_randomwalk(): + g = ds.GraphData(SOCIAL_DATA_FILE, 1) + nodes = g.get_all_nodes(1) + print(len(nodes)) + assert len(nodes) == 33 + + meta_path = [1 for _ in range(39)] + walks = g.random_walk(nodes, meta_path) + assert walks.shape == (33, 40) + + if __name__ == '__main__': test_graphdata_getfullneighbor() logger.info('test_graphdata_getfullneighbor Ended.\n') @@ -185,3 +197,5 @@ if __name__ == '__main__': logger.info('test_graphdata_graphinfo Ended.\n') test_graphdata_generatordataset() logger.info('test_graphdata_generatordataset Ended.\n') + test_graphdata_randomwalk() + logger.info('test_graphdata_randomwalk Ended.\n')