提交 7a046a1d 编写于 作者: H heleiwang

support get_edge_feature

上级 87722b9e
......@@ -820,6 +820,12 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow();
})
.def("get_edge_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
return out.getRow();
})
.def("graph_info",
[](gnn::Graph &g) {
py::dict out;
......
......@@ -125,13 +125,8 @@ Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::s
Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) {
if (node_list.empty()) {
RETURN_STATUS_UNEXPECTED("Input node_list is empty.");
}
if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type));
std::vector<std::vector<NodeIdType>> neighbors;
size_t max_neighbor_num = 0;
......@@ -161,6 +156,14 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) {
return Status::OK();
}
Status Graph::CheckNeighborType(NodeType neighbor_type) {
if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
......@@ -171,10 +174,7 @@ Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
RETURN_IF_NOT_OK(CheckSamplesNum(num));
}
for (const auto &type : neighbor_types) {
if (node_type_map_.find(type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
RETURN_IF_NOT_OK(CheckNeighborType(type));
}
std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size());
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
......@@ -228,44 +228,36 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
std::vector<std::vector<NodeIdType>> neighbors_vec;
neighbors_vec.resize(node_list.size());
std::vector<std::vector<NodeIdType>> neg_neighbors_vec;
neg_neighbors_vec.resize(node_list.size());
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
std::shared_ptr<Node> node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node));
std::vector<NodeIdType> neighbors;
RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors));
std::unordered_set<NodeIdType> exclude_node;
std::unordered_set<NodeIdType> exclude_nodes;
std::transform(neighbors.begin(), neighbors.end(),
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_node, exclude_node.begin()),
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()),
[](const NodeIdType node) { return node; });
auto itr = node_type_map_.find(neg_neighbor_type);
if (itr == node_type_map_.end()) {
std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type);
RETURN_STATUS_UNEXPECTED(err_msg);
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
neg_neighbors_vec[node_idx].emplace_back(node->id());
if (all_nodes.size() > exclude_nodes.size()) {
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) {
RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(),
&neg_neighbors_vec[node_idx]));
}
} else {
neighbors_vec[node_idx].emplace_back(node->id());
if (itr->second.size() > exclude_node.size()) {
while (neighbors_vec[node_idx].size() < samples_num + 1) {
RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(),
&neighbors_vec[node_idx]));
}
} else {
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
<< " neg_neighbor_type:" << neg_neighbor_type;
// If there are no negative neighbors, they are filled with kDefaultNodeId
for (int32_t i = 0; i < samples_num; ++i) {
neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
}
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
<< " neg_neighbor_type:" << neg_neighbor_type;
// If there are no negative neighbors, they are filled with kDefaultNodeId
for (int32_t i = 0; i < samples_num; ++i) {
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
}
}
}
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out));
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out));
return Status::OK();
}
......@@ -280,8 +272,19 @@ Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::ve
}
Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = default_feature_map_.find(feature_type);
if (itr == default_feature_map_.end()) {
auto itr = default_node_feature_map_.find(feature_type);
if (itr == default_node_feature_map_.end()) {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
*out_feature = itr->second;
}
return Status::OK();
}
Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = default_edge_feature_map_.find(feature_type);
if (itr == default_edge_feature_map_.end()) {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
......@@ -295,7 +298,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
if (!nodes || nodes->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
}
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty");
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
TensorRow tensors;
for (const auto &f_type : feature_types) {
std::shared_ptr<Feature> default_feature;
......@@ -340,6 +343,45 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) {
if (!edges || edges->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input edges is empty");
}
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
TensorRow tensors;
for (const auto &f_type : feature_types) {
std::shared_ptr<Feature> default_feature;
// If no feature can be obtained, fill in the default value
RETURN_IF_NOT_OK(GetEdgeDefaultFeature(f_type, &default_feature));
TensorShape shape(default_feature->Value()->shape());
auto shape_vec = edges->shape().AsVector();
dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
shape = shape.PrependDim(size);
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(
Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr));
dsize_t index = 0;
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
std::shared_ptr<Edge> edge;
RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge));
std::shared_ptr<Feature> feature;
if (!edge->GetFeatures(f_type, &feature).IsOk()) {
feature = default_feature;
}
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value()));
index++;
}
TensorShape reshape(edges->shape());
for (auto s : default_feature->Value()->shape().AsVector()) {
reshape = reshape.AppendDim(s);
}
RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
fea_tensor->Squeeze();
tensors.push_back(fea_tensor);
}
*out = std::move(tensors);
return Status::OK();
}
......@@ -405,7 +447,8 @@ Status Graph::LoadNodeAndEdge() {
RETURN_IF_NOT_OK(gl.InitAndLoad());
// get all maps
RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_,
&node_feature_map_, &edge_feature_map_, &default_feature_map_));
&node_feature_map_, &edge_feature_map_, &default_node_feature_map_,
&default_edge_feature_map_));
return Status::OK();
}
......@@ -420,18 +463,33 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
return Status::OK();
}
Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) {
auto itr = edge_id_map_.find(id);
if (itr == edge_id_map_.end()) {
std::string err_msg = "Invalid edge id:" + std::to_string(id);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
*edge = itr->second;
}
return Status::OK();
}
Graph::RandomWalkBase::RandomWalkBase(Graph *graph)
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, const NodeIdType default_node,
int32_t num_walks, int32_t num_workers) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
node_list_ = node_list;
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) +
". The size of input path is " + std::to_string(meta_path.size());
RETURN_STATUS_UNEXPECTED(err_msg);
}
for (const auto &type : meta_path) {
RETURN_IF_NOT_OK(graph_->CheckNeighborType(type));
}
meta_path_ = meta_path;
if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) {
std::string err_msg = "Failed, step_home_param and step_away_param required greater than " +
......@@ -500,15 +558,10 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve
}
Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
// Repeatedly simulate random walks from each node
std::vector<uint32_t> permutation(node_list_.size());
std::iota(permutation.begin(), permutation.end(), 0);
for (int32_t i = 0; i < num_walks_; i++) {
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed));
for (const auto &i_perm : permutation) {
for (const auto &node : node_list_) {
std::vector<NodeIdType> walk;
RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk));
RETURN_IF_NOT_OK(Node2vecWalk(node, &walk));
walks->push_back(walk);
}
}
......
......@@ -211,12 +211,24 @@ class Graph {
// @return Status - The error code return
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
// Get the default feature of a edge
// @param FeatureType feature_type -
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
// Find node object using node id
// @param NodeIdType id -
// @param std::shared_ptr<Node> *node - Returned node object
// @return Status - The error code return
Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);
// Find edge object using edge id
// @param EdgeIdType id -
// @param std::shared_ptr<Node> *edge - Returned edge object
// @return Status - The error code return
Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge);
// Negative sampling
// @param std::vector<NodeIdType> &input_data - The data set to be sampled
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
......@@ -228,6 +240,8 @@ class Graph {
Status CheckSamplesNum(NodeIdType samples_num);
Status CheckNeighborType(NodeType neighbor_type);
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_;
......@@ -242,7 +256,8 @@ class Graph {
std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_;
std::unordered_map<EdgeType, std::unordered_set<FeatureType>> edge_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_node_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_edge_feature_map_;
};
} // namespace gnn
} // namespace dataset
......
......@@ -41,7 +41,8 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) {
EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map,
DefaultEdgeFeatureMap *default_edge_feature_map) {
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
while (dq.empty() == false) {
std::shared_ptr<Node> node_ptr = dq.front();
......@@ -70,7 +71,7 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N
for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map);
MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map);
return Status::OK();
}
......@@ -81,7 +82,8 @@ Status GraphLoader::InitAndLoad() {
e_deques_.resize(num_workers_);
n_feature_maps_.resize(num_workers_);
e_feature_maps_.resize(num_workers_);
default_feature_maps_.resize(num_workers_);
default_node_feature_maps_.resize(num_workers_);
default_edge_feature_maps_.resize(num_workers_);
TaskGroup vg;
shard_reader_ = std::make_unique<ShardReader>();
......@@ -109,7 +111,7 @@ Status GraphLoader::InitAndLoad() {
Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
std::shared_ptr<Node> *node, NodeFeatureMap *feature_map,
DefaultFeatureMap *default_feature) {
DefaultNodeFeatureMap *default_feature) {
NodeIdType node_id = col_jsn["first_id"];
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
(*node) = std::make_shared<LocalNode>(node_id, node_type);
......@@ -133,7 +135,7 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map,
DefaultFeatureMap *default_feature) {
DefaultEdgeFeatureMap *default_feature) {
EdgeIdType edge_id = col_jsn["first_id"];
EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
......@@ -214,13 +216,13 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) {
std::string attr = col_jsn["attribute"];
if (attr == "n") {
std::shared_ptr<Node> node_ptr;
RETURN_IF_NOT_OK(
LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
RETURN_IF_NOT_OK(LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]),
&default_node_feature_maps_[worker_id]));
n_deques_[worker_id].emplace_back(node_ptr);
} else if (attr == "e") {
std::shared_ptr<Edge> edge_ptr;
RETURN_IF_NOT_OK(
LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
RETURN_IF_NOT_OK(LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]),
&default_edge_feature_maps_[worker_id]));
e_deques_[worker_id].emplace_back(edge_ptr);
} else {
MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node.";
......@@ -233,7 +235,8 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) {
}
void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
DefaultFeatureMap *default_feature_map) {
DefaultNodeFeatureMap *default_node_feature_map,
DefaultEdgeFeatureMap *default_edge_feature_map) {
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
for (auto &m : n_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
......@@ -241,8 +244,11 @@ void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap
for (auto &m : e_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
}
for (auto &m : default_feature_maps_[wkr_id]) {
(*default_feature_map)[m.first] = m.second;
for (auto &m : default_node_feature_maps_[wkr_id]) {
(*default_node_feature_map)[m.first] = m.second;
}
for (auto &m : default_edge_feature_maps_[wkr_id]) {
(*default_edge_feature_map)[m.first] = m.second;
}
}
n_feature_maps_.clear();
......
......@@ -43,7 +43,8 @@ using NodeTypeMap = std::unordered_map<NodeType, std::vector<NodeIdType>>;
using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>;
using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>;
using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>;
using DefaultFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
// this class interfaces with the underlying storage format (mindrecord)
// it returns raw nodes and edges via GetNodesAndEdges
......@@ -63,7 +64,7 @@ class GraphLoader {
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
// features attached to each node and edge are expected to be filled correctly
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *,
DefaultFeatureMap *);
DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
private:
//
......@@ -77,19 +78,19 @@ class GraphLoader {
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Node> *node - return value
// @param NodeFeatureMap *feature_map -
// @param DefaultFeatureMap *default_feature -
// @param DefaultNodeFeatureMap *default_feature -
// @return Status - the status code
Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node,
NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature);
NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature);
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected
// @param FeatureMap *feature_map
// @param DefaultFeatureMap *default_feature -
// @param DefaultEdgeFeatureMap *default_feature -
// @return Status - the status code
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature);
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
......@@ -108,7 +109,7 @@ class GraphLoader {
std::shared_ptr<Tensor> *tensor);
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *);
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
const int32_t num_workers_;
std::atomic_int row_id_;
......@@ -118,7 +119,8 @@ class GraphLoader {
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
std::vector<NodeFeatureMap> n_feature_maps_;
std::vector<EdgeFeatureMap> e_feature_maps_;
std::vector<DefaultFeatureMap> default_feature_maps_;
std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
const std::vector<std::string> keys_;
};
} // namespace gnn
......
......@@ -22,7 +22,8 @@ from mindspore._c_dataengine import Tensor
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \
check_gnn_random_walk
class GraphData:
......@@ -127,7 +128,13 @@ class GraphData:
@check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
"""
Get sampled neighbor information, maximum support 6-hop sampling.
Get sampled neighbor information.
The api supports multi-hop neighbor sampling. That is, the previous sampling result is used as the input of
next-hop sampling. A maximum of 6-hop are allowed.
The sampling result is tiled into a list in the format of [input node, 1-hop sampling result,
2-hop samling result ...]
Args:
node_list (list or numpy.ndarray): The given list of nodes.
......@@ -207,6 +214,35 @@ class GraphData:
Tensor(node_list),
feature_types)]
@check_gnn_get_edge_feature
def get_edge_feature(self, edge_list, feature_types):
"""
Get `feature_types` feature of the edges in `edge_list`.
Args:
edge_list (list or numpy.ndarray): The given list of edges.
feature_types (list or ndarray): The given list of feature types.
Returns:
numpy.ndarray: array of features.
Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> edges = data_graph.get_all_edges(0)
>>> features = data_graph.get_edge_feature(edges, [1])
Raises:
TypeError: If `edge_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray.
"""
if isinstance(edge_list, list):
edge_list = np.array(edge_list, dtype=np.int32)
return [
t.as_array() for t in self._graph.get_edge_feature(
Tensor(edge_list),
feature_types)]
def graph_info(self):
"""
Get the meta information of the graph, including the number of nodes, the type of nodes,
......
......@@ -797,7 +797,7 @@ def check_gnn_graphdata(method):
check_file(dataset_file)
if num_parallel_workers is not None:
type_check(num_parallel_workers, (int,), "num_parallel_workers")
check_num_parallel_workers(num_parallel_workers)
return method(self, *args, **kwargs)
return new_method
......@@ -970,6 +970,28 @@ def check_gnn_get_node_feature(method):
return new_method
def check_gnn_get_edge_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
type_check(edge_list, (list, np.ndarray), "edge_list")
if isinstance(edge_list, list):
check_aligned_list(edge_list, 'edge_list', int)
elif isinstance(edge_list, np.ndarray):
if not edge_list.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
edge_list, edge_list.dtype))
check_gnn_list_or_ndarray(feature_types, 'feature_types')
return method(self, *args, **kwargs)
return new_method
def check_numpyslicesdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
......
......@@ -49,9 +49,10 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
EdgeTypeMap e_type_map;
NodeFeatureMap n_feature_map;
EdgeFeatureMap e_feature_map;
DefaultFeatureMap default_feature_map;
DefaultNodeFeatureMap default_node_feature_map;
DefaultEdgeFeatureMap default_edge_feature_map;
EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map,
&default_feature_map)
&default_node_feature_map, &default_edge_feature_map)
.IsOk());
EXPECT_EQ(n_id_map.size(), 20);
EXPECT_EQ(e_id_map.size(), 40);
......@@ -119,6 +120,17 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(),
[](const EdgeIdType edge) { return edge; });
TensorRow edge_features;
s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(edge_features[0]->ToString() ==
"Tensor (shape: <40>, Type: int32)\n"
"[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]");
EXPECT_TRUE(edge_features[1]->ToString() ==
"Tensor (shape: <40>, Type: float32)\n"
"[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2."
"7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]");
std::shared_ptr<Tensor> nodes;
s = graph.GetNodesFromEdges(edge_list, &nodes);
EXPECT_TRUE(s.IsOk());
......
......@@ -125,7 +125,7 @@ def test_graphdata_graphinfo():
assert graph_info['node_num'] == {1: 10, 2: 10}
assert graph_info['edge_num'] == {0: 40}
assert graph_info['node_feature_type'] == [1, 2, 3, 4]
assert graph_info['edge_feature_type'] == []
assert graph_info['edge_feature_type'] == [1, 2]
class RandomBatchedSampler(ds.Sampler):
......@@ -204,7 +204,6 @@ def test_graphdata_randomwalkdefault():
logger.info('test randomwalk with default parameters.\n')
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
assert len(nodes) == 33
meta_path = [1 for _ in range(39)]
......@@ -219,7 +218,6 @@ def test_graphdata_randomwalk():
logger.info('test random walk with given parameters.\n')
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
assert len(nodes) == 33
meta_path = [1 for _ in range(39)]
......@@ -227,6 +225,18 @@ def test_graphdata_randomwalk():
assert walks.shape == (33, 40)
def test_graphdata_getedgefeature():
"""
Test get edge feature
"""
logger.info('test get_edge_feature.\n')
g = ds.GraphData(DATASET_FILE)
edges = g.get_all_edges(0)
features = g.get_edge_feature(edges, [1, 2])
assert features[0].shape == (40,)
assert features[1].shape == (40,)
if __name__ == '__main__':
test_graphdata_getfullneighbor()
test_graphdata_getnodefeature_input_check()
......@@ -236,3 +246,4 @@ if __name__ == '__main__':
test_graphdata_generatordataset()
test_graphdata_randomwalkdefault()
test_graphdata_randomwalk()
test_graphdata_getedgefeature()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册