diff --git a/example/graph_to_mindrecord/write_citeseer.sh b/example/graph_to_mindrecord/write_citeseer.sh index 0d5093f18d236f03b6893b6eef1a6a58e47e85c3..33235372fa58398c69fd2d7917bd54f8ab29cd3c 100644 --- a/example/graph_to_mindrecord/write_citeseer.sh +++ b/example/graph_to_mindrecord/write_citeseer.sh @@ -1,9 +1,12 @@ #!/bin/bash -rm /tmp/citeseer/mindrecord/* +SRC_PATH=/tmp/citeseer/dataset +MINDRECORD_PATH=/tmp/citeseer/mindrecord + +rm -f $MINDRECORD_PATH/* python writer.py --mindrecord_script citeseer \ ---mindrecord_file "/tmp/citeseer/mindrecord/citeseer_mr" \ +--mindrecord_file "$MINDRECORD_PATH/citeseer_mr" \ --mindrecord_partitions 1 \ --mindrecord_header_size_by_bit 18 \ --mindrecord_page_size_by_bit 20 \ ---graph_api_args "/tmp/citeseer/dataset/citeseer.content:/tmp/citeseer/dataset/citeseer.cites" +--graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites" diff --git a/example/graph_to_mindrecord/write_cora.sh b/example/graph_to_mindrecord/write_cora.sh index 6ba321ef03597a2850d08fa1e5a1f34b8279dd01..84ccf34f5e714ce15bd9cbc8638820dbe0d8de57 100644 --- a/example/graph_to_mindrecord/write_cora.sh +++ b/example/graph_to_mindrecord/write_cora.sh @@ -1,9 +1,12 @@ #!/bin/bash -rm /tmp/cora/mindrecord/* +SRC_PATH=/tmp/cora/dataset +MINDRECORD_PATH=/tmp/cora/mindrecord + +rm -f $MINDRECORD_PATH/* python writer.py --mindrecord_script cora \ ---mindrecord_file "/tmp/cora/mindrecord/cora_mr" \ +--mindrecord_file "$MINDRECORD_PATH/cora_mr" \ --mindrecord_partitions 1 \ --mindrecord_header_size_by_bit 18 \ --mindrecord_page_size_by_bit 20 \ ---graph_api_args "/tmp/cora/dataset/cora_content.csv:/tmp/cora/dataset/cora_cites.csv" +--graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv" diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc index 9dcca72339659adf360dd3619c9479ce44d42a91..a0278761229543b1e2818e58539aad686b2e9272 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -51,7 +51,7 @@ Status Graph::CreateTensorByVector(const std::vector> &data, Data RETURN_STATUS_UNEXPECTED("Data type not compatible"); } if (data.empty()) { - RETURN_STATUS_UNEXPECTED("Input data is emply"); + RETURN_STATUS_UNEXPECTED("Input data is empty"); } std::shared_ptr tensor; size_t m = data.size(); @@ -74,7 +74,7 @@ Status Graph::CreateTensorByVector(const std::vector> &data, Data template Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { if (!data || data->empty()) { - RETURN_STATUS_UNEXPECTED("Input data is emply"); + RETURN_STATUS_UNEXPECTED("Input data is empty"); } for (std::vector &vec : *data) { size_t size = vec.size(); @@ -93,6 +93,9 @@ Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr< Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out) { + if (node_list.empty()) { + RETURN_STATUS_UNEXPECTED("Input node_list is empty."); + } if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); RETURN_STATUS_UNEXPECTED(err_msg); @@ -147,7 +150,7 @@ Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr &nodes, const std::vector &feature_types, TensorRow *out) { if (!nodes || nodes->Size() == 0) { - RETURN_STATUS_UNEXPECTED("Inpude nodes is empty"); + RETURN_STATUS_UNEXPECTED("Input nodes is empty"); } TensorRow tensors; for (auto f_type : feature_types) { diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h index 027ba53aeb1ee5ce5b2a575a18e4947d5d0d3ac4..3dd644480792e7426b8de8a5381aea3053eb7e5a 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -156,7 +156,7 @@ class Graph { std::unordered_map> edge_id_map_; std::unordered_map> node_feature_map_; - std::unordered_map> edge_feature_map_; + std::unordered_map> edge_feature_map_; std::unordered_map> default_feature_map_; }; diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 9f70204e2f726d50b0083387c84da73d723a0d82..23f8dbda6a5e442ce7430c1c50c3069d89651815 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -78,7 +78,7 @@ class GraphData: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) - >>> neighbors = data_graph.get_all_neighbors(nodes[0], 0) + >>> neighbors = data_graph.get_all_neighbors(nodes, 0) Raises: TypeError: If `node_list` is not list or ndarray. @@ -102,7 +102,7 @@ class GraphData: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) - >>> features = data_graph.get_node_feature(nodes[0], [1]) + >>> features = data_graph.get_node_feature(nodes, [1]) Raises: TypeError: If `node_list` is not list or ndarray.