diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 0ae64db6713ea27f0fe2ecc4bb5edffee153358f..86c98406a44d20618e1b744067e7f072a0d1e784 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -820,6 +820,12 @@ void bindGraphData(py::module *m) { THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); return out.getRow(); }) + .def("get_edge_feature", + [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); + return out.getRow(); + }) .def("graph_info", [](gnn::Graph &g) { py::dict out; diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc index b3a8aed8f583cdb5d19e74f7784595970c7a87ea..aa5abd41339fbbeafcd90dfd27d64d88089c97d8 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -125,13 +125,8 @@ Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::s Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out) { - if (node_list.empty()) { - RETURN_STATUS_UNEXPECTED("Input node_list is empty."); - } - if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); std::vector> neighbors; size_t max_neighbor_num = 0; @@ -161,6 +156,14 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) { return Status::OK(); } +Status Graph::CheckNeighborType(NodeType neighbor_type) { + if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { + std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + Status Graph::GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, const std::vector &neighbor_types, std::shared_ptr *out) { @@ -171,10 +174,7 @@ Status Graph::GetSampledNeighbors(const std::vector &node_list, RETURN_IF_NOT_OK(CheckSamplesNum(num)); } for (const auto &type : neighbor_types) { - if (node_type_map_.find(type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(type); - RETURN_STATUS_UNEXPECTED(err_msg); - } + RETURN_IF_NOT_OK(CheckNeighborType(type)); } std::vector> neighbors_vec(node_list.size()); for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { @@ -228,44 +228,36 @@ Status Graph::GetNegSampledNeighbors(const std::vector &node_list, N NodeType neg_neighbor_type, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); - if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } + RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); - std::vector> neighbors_vec; - neighbors_vec.resize(node_list.size()); + std::vector> neg_neighbors_vec; + neg_neighbors_vec.resize(node_list.size()); for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { std::shared_ptr node; RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); std::vector neighbors; RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); - std::unordered_set exclude_node; + std::unordered_set exclude_nodes; std::transform(neighbors.begin(), neighbors.end(), - std::insert_iterator>(exclude_node, exclude_node.begin()), + std::insert_iterator>(exclude_nodes, exclude_nodes.begin()), [](const NodeIdType node) { return node; }); - auto itr = node_type_map_.find(neg_neighbor_type); - if (itr == node_type_map_.end()) { - std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); + const std::vector &all_nodes = node_type_map_[neg_neighbor_type]; + neg_neighbors_vec[node_idx].emplace_back(node->id()); + if (all_nodes.size() > exclude_nodes.size()) { + while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { + RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), + &neg_neighbors_vec[node_idx])); + } } else { - neighbors_vec[node_idx].emplace_back(node->id()); - if (itr->second.size() > exclude_node.size()) { - while (neighbors_vec[node_idx].size() < samples_num + 1) { - RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(), - &neighbors_vec[node_idx])); - } - } else { - MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() - << " neg_neighbor_type:" << neg_neighbor_type; - // If there are no negative neighbors, they are filled with kDefaultNodeId - for (int32_t i = 0; i < samples_num; ++i) { - neighbors_vec[node_idx].emplace_back(kDefaultNodeId); - } + MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() + << " neg_neighbor_type:" << neg_neighbor_type; + // If there are no negative neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); } } } - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); + RETURN_IF_NOT_OK(CreateTensorByVector(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); return Status::OK(); } @@ -280,8 +272,19 @@ Status Graph::RandomWalk(const std::vector &node_list, const std::ve } Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = default_feature_map_.find(feature_type); - if (itr == default_feature_map_.end()) { + auto itr = default_node_feature_map_.find(feature_type); + if (itr == default_node_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_edge_feature_map_.find(feature_type); + if (itr == default_edge_feature_map_.end()) { std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); RETURN_STATUS_UNEXPECTED(err_msg); } else { @@ -295,7 +298,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve if (!nodes || nodes->Size() == 0) { RETURN_STATUS_UNEXPECTED("Input nodes is empty"); } - CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); TensorRow tensors; for (const auto &f_type : feature_types) { std::shared_ptr default_feature; @@ -340,6 +343,45 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, TensorRow *out) { + if (!edges || edges->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input edges is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + TensorRow tensors; + for (const auto &f_type : feature_types) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetEdgeDefaultFeature(f_type, &default_feature)); + + TensorShape shape(default_feature->Value()->shape()); + auto shape_vec = edges->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + + dsize_t index = 0; + for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { + std::shared_ptr edge; + RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); + std::shared_ptr feature; + if (!edge->GetFeatures(f_type, &feature).IsOk()) { + feature = default_feature; + } + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); + index++; + } + + TensorShape reshape(edges->shape()); + for (auto s : default_feature->Value()->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + tensors.push_back(fea_tensor); + } + *out = std::move(tensors); return Status::OK(); } @@ -405,7 +447,8 @@ Status Graph::LoadNodeAndEdge() { RETURN_IF_NOT_OK(gl.InitAndLoad()); // get all maps RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, - &node_feature_map_, &edge_feature_map_, &default_feature_map_)); + &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, + &default_edge_feature_map_)); return Status::OK(); } @@ -420,18 +463,33 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { return Status::OK(); } +Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { + auto itr = edge_id_map_.find(id); + if (itr == edge_id_map_.end()) { + std::string err_msg = "Invalid edge id:" + std::to_string(id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *edge = itr->second; + } + return Status::OK(); +} + Graph::RandomWalkBase::RandomWalkBase(Graph *graph) : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, float step_home_param, float step_away_param, const NodeIdType default_node, int32_t num_walks, int32_t num_workers) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); node_list_ = node_list; if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + ". The size of input path is " + std::to_string(meta_path.size()); RETURN_STATUS_UNEXPECTED(err_msg); } + for (const auto &type : meta_path) { + RETURN_IF_NOT_OK(graph_->CheckNeighborType(type)); + } meta_path_ = meta_path; if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + @@ -500,15 +558,10 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve } Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { - // Repeatedly simulate random walks from each node - std::vector permutation(node_list_.size()); - std::iota(permutation.begin(), permutation.end(), 0); for (int32_t i = 0; i < num_walks_; i++) { - unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); - std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed)); - for (const auto &i_perm : permutation) { + for (const auto &node : node_list_) { std::vector walk; - RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk)); + RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); walks->push_back(walk); } } diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h index 68bdfcc9dc96b56e6962b9c353bbc94c52542ca7..426903829436ccb4af197f7792bdb46804e3037a 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -211,12 +211,24 @@ class Graph { // @return Status - The error code return Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + // Get the default feature of a edge + // @param FeatureType feature_type - + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + // Find node object using node id // @param NodeIdType id - // @param std::shared_ptr *node - Returned node object // @return Status - The error code return Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); + // Find edge object using edge id + // @param EdgeIdType id - + // @param std::shared_ptr *edge - Returned edge object + // @return Status - The error code return + Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge); + // Negative sampling // @param std::vector &input_data - The data set to be sampled // @param std::unordered_set &exclude_data - Data to be excluded @@ -228,6 +240,8 @@ class Graph { Status CheckSamplesNum(NodeIdType samples_num); + Status CheckNeighborType(NodeType neighbor_type); + std::string dataset_file_; int32_t num_workers_; // The number of worker threads std::mt19937 rnd_; @@ -242,7 +256,8 @@ class Graph { std::unordered_map> node_feature_map_; std::unordered_map> edge_feature_map_; - std::unordered_map> default_feature_map_; + std::unordered_map> default_node_feature_map_; + std::unordered_map> default_edge_feature_map_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc index 6504d088bf21f04f5d6c0f46241a9b5c28904013..f3374954b6b2ee48273da5026ba0d240fae0e16c 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc @@ -41,7 +41,8 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, - EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) { + EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, + DefaultEdgeFeatureMap *default_edge_feature_map) { for (std::deque> &dq : n_deques_) { while (dq.empty() == false) { std::shared_ptr node_ptr = dq.front(); @@ -70,7 +71,7 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); - MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map); + MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); return Status::OK(); } @@ -81,7 +82,8 @@ Status GraphLoader::InitAndLoad() { e_deques_.resize(num_workers_); n_feature_maps_.resize(num_workers_); e_feature_maps_.resize(num_workers_); - default_feature_maps_.resize(num_workers_); + default_node_feature_maps_.resize(num_workers_); + default_edge_feature_maps_.resize(num_workers_); TaskGroup vg; shard_reader_ = std::make_unique(); @@ -109,7 +111,7 @@ Status GraphLoader::InitAndLoad() { Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrecord::json &col_jsn, std::shared_ptr *node, NodeFeatureMap *feature_map, - DefaultFeatureMap *default_feature) { + DefaultNodeFeatureMap *default_feature) { NodeIdType node_id = col_jsn["first_id"]; NodeType node_type = static_cast(col_jsn["type"]); (*node) = std::make_shared(node_id, node_type); @@ -133,7 +135,7 @@ Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrec Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrecord::json &col_jsn, std::shared_ptr *edge, EdgeFeatureMap *feature_map, - DefaultFeatureMap *default_feature) { + DefaultEdgeFeatureMap *default_feature) { EdgeIdType edge_id = col_jsn["first_id"]; EdgeType edge_type = static_cast(col_jsn["type"]); NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; @@ -214,13 +216,13 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { std::string attr = col_jsn["attribute"]; if (attr == "n") { std::shared_ptr node_ptr; - RETURN_IF_NOT_OK( - LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); + RETURN_IF_NOT_OK(LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), + &default_node_feature_maps_[worker_id])); n_deques_[worker_id].emplace_back(node_ptr); } else if (attr == "e") { std::shared_ptr edge_ptr; - RETURN_IF_NOT_OK( - LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); + RETURN_IF_NOT_OK(LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), + &default_edge_feature_maps_[worker_id])); e_deques_[worker_id].emplace_back(edge_ptr); } else { MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; @@ -233,7 +235,8 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { } void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, - DefaultFeatureMap *default_feature_map) { + DefaultNodeFeatureMap *default_node_feature_map, + DefaultEdgeFeatureMap *default_edge_feature_map) { for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { for (auto &m : n_feature_maps_[wkr_id]) { for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); @@ -241,8 +244,11 @@ void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap for (auto &m : e_feature_maps_[wkr_id]) { for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); } - for (auto &m : default_feature_maps_[wkr_id]) { - (*default_feature_map)[m.first] = m.second; + for (auto &m : default_node_feature_maps_[wkr_id]) { + (*default_node_feature_map)[m.first] = m.second; + } + for (auto &m : default_edge_feature_maps_[wkr_id]) { + (*default_edge_feature_map)[m.first] = m.second; } } n_feature_maps_.clear(); diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h index 0ad54bae6d80c78272b4920875fc534255477d88..141816d633809448c0551a03c7aef1f72a81937c 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h @@ -43,7 +43,8 @@ using NodeTypeMap = std::unordered_map>; using EdgeTypeMap = std::unordered_map>; using NodeFeatureMap = std::unordered_map>; using EdgeFeatureMap = std::unordered_map>; -using DefaultFeatureMap = std::unordered_map>; +using DefaultNodeFeatureMap = std::unordered_map>; +using DefaultEdgeFeatureMap = std::unordered_map>; // this class interfaces with the underlying storage format (mindrecord) // it returns raw nodes and edges via GetNodesAndEdges @@ -63,7 +64,7 @@ class GraphLoader { // random order. src_node and dst_node in Edge are node_id only with -1 as type. // features attached to each node and edge are expected to be filled correctly Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, - DefaultFeatureMap *); + DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); private: // @@ -77,19 +78,19 @@ class GraphLoader { // @param mindrecord::json &jsn - contains raw data // @param std::shared_ptr *node - return value // @param NodeFeatureMap *feature_map - - // @param DefaultFeatureMap *default_feature - + // @param DefaultNodeFeatureMap *default_feature - // @return Status - the status code Status LoadNode(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *node, - NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature); + NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature); // @param std::vector &blob - contains data in blob field in mindrecord // @param mindrecord::json &jsn - contains raw data // @param std::shared_ptr *edge - return value, the edge ptr, edge is not yet connected // @param FeatureMap *feature_map - // @param DefaultFeatureMap *default_feature - + // @param DefaultEdgeFeatureMap *default_feature - // @return Status - the status code Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, - EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature); + EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); // @param std::string key - column name // @param std::vector &blob - contains data in blob field in mindrecord @@ -108,7 +109,7 @@ class GraphLoader { std::shared_ptr *tensor); // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 - void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *); + void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); const int32_t num_workers_; std::atomic_int row_id_; @@ -118,7 +119,8 @@ class GraphLoader { std::vector>> e_deques_; std::vector n_feature_maps_; std::vector e_feature_maps_; - std::vector default_feature_maps_; + std::vector default_node_feature_maps_; + std::vector default_edge_feature_maps_; const std::vector keys_; }; } // namespace gnn diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 5a9506080a7a615234cd24b624b6e314f77bd7fa..81314b4373456919d1e195e70dbf905eaec565b3 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -22,7 +22,8 @@ from mindspore._c_dataengine import Tensor from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ - check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk + check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \ + check_gnn_random_walk class GraphData: @@ -127,7 +128,13 @@ class GraphData: @check_gnn_get_sampled_neighbors def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): """ - Get sampled neighbor information, maximum support 6-hop sampling. + Get sampled neighbor information. + + The api supports multi-hop neighbor sampling. That is, the previous sampling result is used as the input of + next-hop sampling. A maximum of 6-hop are allowed. + + The sampling result is tiled into a list in the format of [input node, 1-hop sampling result, + 2-hop samling result ...] Args: node_list (list or numpy.ndarray): The given list of nodes. @@ -207,6 +214,35 @@ class GraphData: Tensor(node_list), feature_types)] + @check_gnn_get_edge_feature + def get_edge_feature(self, edge_list, feature_types): + """ + Get `feature_types` feature of the edges in `edge_list`. + + Args: + edge_list (list or numpy.ndarray): The given list of edges. + feature_types (list or ndarray): The given list of feature types. + + Returns: + numpy.ndarray: array of features. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> edges = data_graph.get_all_edges(0) + >>> features = data_graph.get_edge_feature(edges, [1]) + + Raises: + TypeError: If `edge_list` is not list or ndarray. + TypeError: If `feature_types` is not list or ndarray. + """ + if isinstance(edge_list, list): + edge_list = np.array(edge_list, dtype=np.int32) + return [ + t.as_array() for t in self._graph.get_edge_feature( + Tensor(edge_list), + feature_types)] + def graph_info(self): """ Get the meta information of the graph, including the number of nodes, the type of nodes, diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ab7cc6ac54387060c1cdcde3aaa801beace06c4b..f3b79f9db7f117aed6aedef42e1d57e12bd7a9e8 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -797,7 +797,7 @@ def check_gnn_graphdata(method): check_file(dataset_file) if num_parallel_workers is not None: - type_check(num_parallel_workers, (int,), "num_parallel_workers") + check_num_parallel_workers(num_parallel_workers) return method(self, *args, **kwargs) return new_method @@ -970,6 +970,28 @@ def check_gnn_get_node_feature(method): return new_method +def check_gnn_get_edge_feature(method): + """A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs) + + type_check(edge_list, (list, np.ndarray), "edge_list") + if isinstance(edge_list, list): + check_aligned_list(edge_list, 'edge_list', int) + elif isinstance(edge_list, np.ndarray): + if not edge_list.dtype == np.int32: + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( + edge_list, edge_list.dtype)) + + check_gnn_list_or_ndarray(feature_types, 'feature_types') + + return method(self, *args, **kwargs) + + return new_method + + def check_numpyslicesdataset(method): """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 96cbcb0c7db83fcc859fe3c0d64c362d37b9ab1f..584fde5cefde88c8ed80b83210bb2fbbe29c9e47 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -49,9 +49,10 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) { EdgeTypeMap e_type_map; NodeFeatureMap n_feature_map; EdgeFeatureMap e_feature_map; - DefaultFeatureMap default_feature_map; + DefaultNodeFeatureMap default_node_feature_map; + DefaultEdgeFeatureMap default_edge_feature_map; EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, - &default_feature_map) + &default_node_feature_map, &default_edge_feature_map) .IsOk()); EXPECT_EQ(n_id_map.size(), 20); EXPECT_EQ(e_id_map.size(), 40); @@ -119,6 +120,17 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { std::transform(edges->begin(), edges->end(), edge_list.begin(), [](const EdgeIdType edge) { return edge; }); + TensorRow edge_features; + s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(edge_features[0]->ToString() == + "Tensor (shape: <40>, Type: int32)\n" + "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]"); + EXPECT_TRUE(edge_features[1]->ToString() == + "Tensor (shape: <40>, Type: float32)\n" + "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2." + "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]"); + std::shared_ptr nodes; s = graph.GetNodesFromEdges(edge_list, &nodes); EXPECT_TRUE(s.IsOk()); diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata index e206469ac693d2f0073caaf3293ec3c0dde8be74..52359734692a4652a40c5a49d8b43c32f58e856c 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata and b/tests/ut/data/mindrecord/testGraphData/testdata differ diff --git a/tests/ut/data/mindrecord/testGraphData/testdata.db b/tests/ut/data/mindrecord/testGraphData/testdata.db index 541da0e998e6385f79902bcf281ac325baf2f85d..0f022589f4c0c123c8a38fdb69b8abbcbaaa9668 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata.db and b/tests/ut/data/mindrecord/testGraphData/testdata.db differ diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index abcc643cc927c89e597c69108c3ec1972a7462e9..0f78cfd03a8b681340dcb7d3a4ae100ee1c7571f 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -125,7 +125,7 @@ def test_graphdata_graphinfo(): assert graph_info['node_num'] == {1: 10, 2: 10} assert graph_info['edge_num'] == {0: 40} assert graph_info['node_feature_type'] == [1, 2, 3, 4] - assert graph_info['edge_feature_type'] == [] + assert graph_info['edge_feature_type'] == [1, 2] class RandomBatchedSampler(ds.Sampler): @@ -204,7 +204,6 @@ def test_graphdata_randomwalkdefault(): logger.info('test randomwalk with default parameters.\n') g = ds.GraphData(SOCIAL_DATA_FILE, 1) nodes = g.get_all_nodes(1) - print(len(nodes)) assert len(nodes) == 33 meta_path = [1 for _ in range(39)] @@ -219,7 +218,6 @@ def test_graphdata_randomwalk(): logger.info('test random walk with given parameters.\n') g = ds.GraphData(SOCIAL_DATA_FILE, 1) nodes = g.get_all_nodes(1) - print(len(nodes)) assert len(nodes) == 33 meta_path = [1 for _ in range(39)] @@ -227,6 +225,18 @@ def test_graphdata_randomwalk(): assert walks.shape == (33, 40) +def test_graphdata_getedgefeature(): + """ + Test get edge feature + """ + logger.info('test get_edge_feature.\n') + g = ds.GraphData(DATASET_FILE) + edges = g.get_all_edges(0) + features = g.get_edge_feature(edges, [1, 2]) + assert features[0].shape == (40,) + assert features[1].shape == (40,) + + if __name__ == '__main__': test_graphdata_getfullneighbor() test_graphdata_getnodefeature_input_check() @@ -236,3 +246,4 @@ if __name__ == '__main__': test_graphdata_generatordataset() test_graphdata_randomwalkdefault() test_graphdata_randomwalk() + test_graphdata_getedgefeature()