提交 256dccc6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4498 Gnn data processing supports distributed scenarios

Merge pull request !4498 from heleiwang/gnn_distributed
......@@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
SET(MS_BUILD_GRPC 0)
if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES)
SET(MS_BUILD_GRPC 1)
endif()
if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
SET(MS_BUILD_GRPC 1)
endif()
if ("${MS_BUILD_GRPC}")
# build dependencies of gRPC
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake)
......
......@@ -83,6 +83,7 @@ endif()
if (ENABLE_TDTQUE)
add_dependencies(engine-tdt core)
endif ()
################### Create _c_dataengine Library ######################
set(submodules
$<TARGET_OBJECTS:core>
......@@ -182,3 +183,7 @@ else()
set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON)
endif ()
endif()
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
endif()
\ No newline at end of file
......@@ -18,83 +18,103 @@
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/graph_data_client.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
Graph, 0, ([](const py::module *m) {
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
.def(py::init([](std::string dataset_file, int32_t num_workers) {
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
THROW_IF_ERROR(g_out->Init());
return g_out;
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode,
const std::string &hostname, int32_t port) {
std::shared_ptr<gnn::GraphData> out;
if (working_mode == "local") {
out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers);
} else if (working_mode == "client") {
out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port);
}
THROW_IF_ERROR(out->Init());
return out;
}))
.def("get_all_nodes",
[](gnn::Graph &g, gnn::NodeType node_type) {
[](gnn::GraphData &g, gnn::NodeType node_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
return out;
})
.def("get_all_edges",
[](gnn::Graph &g, gnn::EdgeType edge_type) {
[](gnn::GraphData &g, gnn::EdgeType edge_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
return out;
})
.def("get_nodes_from_edges",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out;
})
.def("get_all_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out;
})
.def("get_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
return out;
})
.def("get_neg_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out;
})
.def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
[](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow();
})
.def("get_edge_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
[](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
return out.getRow();
})
.def("graph_info",
[](gnn::Graph &g) {
[](gnn::GraphData &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
})
.def("random_walk",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out;
});
})
.def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); });
(void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer")
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
int32_t client_num, bool auto_shutdown) {
std::shared_ptr<gnn::GraphDataServer> out;
out =
std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown);
THROW_IF_ERROR(out->Init());
return out;
}))
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); });
}));
} // namespace dataset
......
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-gnn OBJECT
graph.cc
set(DATASET_ENGINE_GNN_SRC_FILES
graph_data_impl.cc
graph_data_client.cc
graph_data_server.cc
graph_loader.cc
graph_feature_parser.cc
local_node.cc
local_edge.cc
feature.cc
)
)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES})
else()
set(DATASET_ENGINE_GNN_SRC_FILES
${DATASET_ENGINE_GNN_SRC_FILES}
tensor_proto.cc
grpc_async_server.cc
graph_data_service_impl.cc
graph_shared_memory.cc)
ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto")
ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto")
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS})
add_dependencies(engine-gnn mindspore::protobuf)
endif()
......@@ -19,7 +19,8 @@ namespace mindspore {
namespace dataset {
namespace gnn {
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {}
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory)
: type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {}
} // namespace gnn
} // namespace dataset
......
......@@ -31,7 +31,7 @@ class Feature {
// Constructor
// @param FeatureType type_name - feature type
// @param std::shared_ptr<Tensor> value - feature value
Feature(FeatureType type_name, std::shared_ptr<Tensor> value);
Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory = false);
~Feature() = default;
......@@ -45,6 +45,7 @@ class Feature {
private:
FeatureType type_name_;
std::shared_ptr<Tensor> value_;
bool is_shared_memory_;
};
} // namespace gnn
} // namespace dataset
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
import "gnn_tensor.proto";
message GnnClientRegisterRequestPb {
int32 pid = 1;
}
message GnnFeatureInfoPb {
int32 type = 1;
TensorPb feature = 2;
}
message GnnClientRegisterResponsePb {
string error_msg = 1;
string data_schema = 2;
int64 shared_memory_key = 3;
int64 shared_memory_size = 4;
repeated GnnFeatureInfoPb default_node_feature = 5;
repeated GnnFeatureInfoPb default_edge_feature = 6;
}
message GnnClientUnRegisterRequestPb {
int32 pid = 1;
}
message GnnClientUnRegisterResponsePb {
string error_msg = 1;
}
enum GnnOpName {
GET_ALL_NODES = 0;
GET_ALL_EDGES = 1;
GET_NODES_FROM_EDGES = 2;
GET_ALL_NEIGHBORS = 3;
GET_SAMPLED_NEIGHBORS = 4;
GET_NEG_SAMPLED_NEIGHBORS = 5;
RANDOM_WALK = 6;
GET_NODE_FEATURE = 7;
GET_EDGE_FEATURE = 8;
}
message GnnRandomWalkPb {
float p = 1;
float q = 2;
int32 default_id = 3;
}
message GnnGraphDataRequestPb {
GnnOpName op_name = 1;
repeated int32 id = 2; // node id or edge id
repeated int32 type = 3; //node type or edge type or neighbor type or feature type
repeated int32 number = 4; // samples number
TensorPb id_tensor = 5; // input ids ,node id or edge id
GnnRandomWalkPb random_walk = 6;
}
message GnnGraphDataResponsePb {
string error_msg = 1;
repeated TensorPb result_data = 2;
}
message GnnMetaInfoRequestPb {
}
message GnnNodeEdgeInfoPb {
int32 type = 1;
int32 num = 2;
}
message GnnMetaInfoResponsePb {
string error_msg = 1;
repeated GnnNodeEdgeInfoPb node_info = 2;
repeated GnnNodeEdgeInfoPb edge_info = 3;
repeated int32 node_feature_type = 4;
repeated int32 edge_feature_type = 5;
}
service GnnGraphData {
rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb);
rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb);
rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb);
rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb);
}
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
enum DataTypePb {
DE_PB_UNKNOWN = 0;
DE_PB_BOOL = 1;
DE_PB_INT8 = 2;
DE_PB_UINT8 = 3;
DE_PB_INT16 = 4;
DE_PB_UINT16 = 5;
DE_PB_INT32 = 6;
DE_PB_UINT32 = 7;
DE_PB_INT64 = 8;
DE_PB_UINT64 = 9;
DE_PB_FLOAT16 = 10;
DE_PB_FLOAT32 = 11;
DE_PB_FLOAT64 = 12;
DE_PB_STRING = 13;
}
message TensorPb {
repeated int64 dims = 1; // tensor shape info
DataTypePb tensor_type = 2; // tensor content data type
bytes data = 3; // tensor data
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
struct MetaInfo {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
std::map<NodeType, NodeIdType> node_num;
std::map<EdgeType, EdgeIdType> edge_num;
std::vector<FeatureType> node_feature_type;
std::vector<FeatureType> edge_feature_type;
};
class GraphData {
public:
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) = 0;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) = 0;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
// Return meta information to python layer
virtual Status GraphInfo(py::dict *out) = 0;
virtual Status Init() = 0;
virtual Status Stop() = 0;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#include <algorithm>
#include <memory>
#include <string>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <utility>
#if !defined(_WIN32) && !defined(_WIN64)
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
#endif
#include "minddata/dataset/engine/gnn/graph_data.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataClient : public GraphData {
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
~GraphDataClient();
Status Init() override;
Status Stop() override;
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
// Return meta information to python layer
Status GraphInfo(py::dict *out) override;
private:
#if !defined(_WIN32) && !defined(_WIN64)
Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
std::shared_ptr<Tensor> *out);
Status RegisterToServer();
Status UnRegisterToServer();
Status InitFeatureParser();
Status CheckPid() {
CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(),
"Multi-process mode is not supported, please change to use multi-thread");
return Status::OK();
}
#endif
std::string dataset_file_;
std::string host_;
int32_t port_;
int32_t pid_;
mindrecord::json data_schema_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GnnGraphData::Stub> stub_;
key_t shared_memory_key_;
int64_t shared_memory_size_;
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
#endif
bool registered_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
......@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
#include <algorithm>
#include <memory>
......@@ -25,13 +25,11 @@
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/graph_data.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
namespace mindspore {
namespace dataset {
......@@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001;
const uint32_t kMaxNumWalks = 80;
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
struct MetaInfo {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
std::map<NodeType, NodeIdType> node_num;
std::map<EdgeType, EdgeIdType> edge_num;
std::vector<FeatureType> node_feature_type;
std::vector<FeatureType> edge_feature_type;
};
class Graph {
class GraphDataImpl : public GraphData {
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
Graph(std::string dataset_file, int32_t num_workers);
GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false);
~Graph() = default;
~GraphDataImpl();
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out);
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out);
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out);
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
......@@ -85,7 +74,7 @@ class Graph {
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out);
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
......@@ -94,7 +83,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
......@@ -103,7 +92,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
......@@ -115,7 +104,7 @@ class Graph {
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out);
std::shared_ptr<Tensor> *out) override;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
......@@ -124,16 +113,22 @@ class Graph {
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out);
TensorRow *out) override;
Status GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type,
std::shared_ptr<Tensor> *out);
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edget - List of edges
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types,
TensorRow *out);
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
std::shared_ptr<Tensor> *out);
// Get meta information of graph
// @param MetaInfo *meta_info - Returned meta information
......@@ -142,15 +137,34 @@ class Graph {
#ifdef ENABLE_PYTHON
// Return meta information to python layer
Status GraphInfo(py::dict *out);
Status GraphInfo(py::dict *out) override;
#endif
Status Init();
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() {
return &default_node_feature_map_;
}
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() {
return &default_edge_feature_map_;
}
Status Init() override;
Status Stop() override { return Status::OK(); }
std::string GetDataSchema() { return data_schema_.dump(); }
#if !defined(_WIN32) && !defined(_WIN64)
key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); }
int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); }
#endif
private:
friend class GraphLoader;
class RandomWalkBase {
public:
explicit RandomWalkBase(Graph *graph);
explicit RandomWalkBase(GraphDataImpl *graph);
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
......@@ -176,7 +190,7 @@ class Graph {
template <typename T>
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
Graph *graph_;
GraphDataImpl *graph_;
std::vector<NodeIdType> node_list_;
std::vector<NodeType> meta_path_;
float step_home_param_; // Return hyper parameter. Default is 1.0
......@@ -248,7 +262,11 @@ class Graph {
int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_;
RandomWalkBase random_walk_;
mindrecord::json data_schema_;
bool server_mode_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
#endif
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
......@@ -264,4 +282,4 @@ class Graph {
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_server.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <utility>
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
namespace gnn {
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
int32_t port, int32_t client_num, bool auto_shutdown)
: dataset_file_(dataset_file),
num_workers_(num_workers),
client_num_(client_num),
max_connected_client_num_(0),
auto_shutdown_(auto_shutdown),
state_(kGdsUninit) {
tg_ = std::make_unique<TaskGroup>();
graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
#if !defined(_WIN32) && !defined(_WIN64)
service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
#endif
}
Status GraphDataServer::Init() {
#if defined(_WIN32) || defined(_WIN64)
RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
#else
set_state(kGdsInitializing);
RETURN_IF_NOT_OK(async_server_->Run());
// RETURN_IF_NOT_OK(InitGraphDataImpl());
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this)));
for (int32_t i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(
tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
}
if (auto_shutdown_) {
RETURN_IF_NOT_OK(
tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
}
return Status::OK();
#endif
}
Status GraphDataServer::InitGraphDataImpl() {
TaskManager::FindMe()->Post();
Status s = graph_data_impl_->Init();
if (s.IsOk()) {
set_state(kGdsRunning);
} else {
(void)Stop();
}
return s;
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphDataServer::StartAsyncRpcService() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(async_server_->HandleRequest());
return Status::OK();
}
#endif
Status GraphDataServer::JudgeAutoShutdownServer() {
TaskManager::FindMe()->Post();
while (true) {
if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) {
MS_LOG(INFO) << "All clients have been unregister, automatically exit the server.";
RETURN_IF_NOT_OK(Stop());
break;
}
if (state_ == kGdsStopped) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
return Status::OK();
}
Status GraphDataServer::Stop() {
#if !defined(_WIN32) && !defined(_WIN64)
async_server_->Stop();
#endif
set_state(kGdsStopped);
graph_data_impl_.reset();
return Status::OK();
}
Status GraphDataServer::ClientRegister(int32_t pid) {
std::unique_lock<std::mutex> lck(mutex_);
MS_LOG(INFO) << "client register pid:" << std::to_string(pid);
client_pid_.emplace(pid);
if (client_pid_.size() > max_connected_client_num_) {
max_connected_client_num_ = client_pid_.size();
}
return Status::OK();
}
Status GraphDataServer::ClientUnRegister(int32_t pid) {
std::unique_lock<std::mutex> lck(mutex_);
auto itr = client_pid_.find(pid);
if (itr != client_pid_.end()) {
client_pid_.erase(itr);
MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid);
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>
#if !defined(_WIN32) && !defined(_WIN64)
#include "grpcpp/grpcpp.h"
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#endif
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataImpl;
class GraphDataServer {
public:
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
int32_t client_num, bool auto_shutdown);
~GraphDataServer() = default;
Status Init();
Status Stop();
Status ClientRegister(int32_t pid);
Status ClientUnRegister(int32_t pid);
enum ServerState state() { return state_; }
bool IsStoped() {
if (state_ == kGdsStopped) {
return true;
} else {
return false;
}
}
private:
void set_state(enum ServerState state) { state_ = state; }
Status InitGraphDataImpl();
#if !defined(_WIN32) && !defined(_WIN64)
Status StartAsyncRpcService();
#endif
Status JudgeAutoShutdownServer();
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
int32_t client_num_;
int32_t max_connected_client_num_;
bool auto_shutdown_;
enum ServerState state_;
std::unique_ptr<TaskGroup> tg_; // Class for worker management
std::unique_ptr<GraphDataImpl> graph_data_impl_;
std::unordered_set<int32_t> client_pid_;
std::mutex mutex_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GraphDataServiceImpl> service_impl_;
std::unique_ptr<GrpcAsyncServer> async_server_;
#endif
};
#if !defined(_WIN32) && !defined(_WIN64)
class UntypedCall {
public:
virtual ~UntypedCall() {}
virtual Status operator()() = 0;
};
template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage>
class CallData : public UntypedCall {
public:
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *,
grpc::ServerAsyncResponseWriter<ResponseMessage> *,
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *,
ResponseMessage *);
CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function)
: status_(STATE::CREATE),
service_impl_(service_impl),
async_service_(async_service),
cq_(cq),
enqueue_function_(enqueue_function),
handle_request_function_(handle_request_function),
responder_(&ctx_) {}
~CallData() = default;
static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>(
service_impl, async_service, cq, enqueue_function, handle_request_function);
RETURN_IF_NOT_OK((*call)());
return Status::OK();
}
Status operator()() {
if (status_ == STATE::CREATE) {
status_ = STATE::PROCESS;
(async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this);
} else if (status_ == STATE::PROCESS) {
EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_);
status_ = STATE::FINISH;
// new CallData(service_, cq_, this->s_type_);
grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_);
responder_.Finish(response_, s, this);
} else {
GPR_ASSERT(status_ == STATE::FINISH);
delete this;
}
return Status::OK();
}
private:
STATE status_;
ServiceImpl *service_impl_;
AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
EnqueueFunction enqueue_function_;
HandleRequestFunction handle_request_function_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
RequestMessage request_;
ResponseMessage response_;
};
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \
do { \
Status s = \
CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \
&gnn::GraphDataServiceImpl::method); \
RETURN_IF_NOT_OK(s); \
} while (0)
class GraphDataGrpcServer : public GrpcAsyncServer {
public:
GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl)
: GrpcAsyncServer(host, port), service_impl_(service_impl) {}
Status RegisterService(grpc::ServerBuilder *builder) {
builder->RegisterService(&svc_);
return Status::OK();
}
Status EnqueueRequest() {
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb,
GnnClientRegisterResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb,
GnnClientUnRegisterResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb);
return Status::OK();
}
Status ProcessRequest(void *tag) {
auto rq = static_cast<UntypedCall *>(tag);
RETURN_IF_NOT_OK((*rq)());
return Status::OK();
}
private:
GraphDataServiceImpl *service_impl_;
GnnGraphData::AsyncService svc_;
};
#endif
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using pFunction = Status (GraphDataServiceImpl::*)(const GnnGraphDataRequestPb *, GnnGraphDataResponsePb *);
static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = {
{GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes},
{GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges},
{GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges},
{GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors},
{GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors},
{GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors},
{RANDOM_WALK, &GraphDataServiceImpl::RandomWalk},
{GET_NODE_FEATURE, &GraphDataServiceImpl::GetNodeFeature},
{GET_EDGE_FEATURE, &GraphDataServiceImpl::GetEdgeFeature}};
GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl)
: server_(server), graph_data_impl_(graph_data_impl) {}
Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) {
const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures();
for (const auto feature : *default_node_features) {
GnnFeatureInfoPb *feature_info = response->add_default_node_feature();
feature_info->set_type(feature.first);
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
}
const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures();
for (const auto feature : *default_edge_features) {
GnnFeatureInfoPb *feature_info = response->add_default_edge_feature();
feature_info->set_type(feature.first);
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
}
return Status::OK();
}
grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context,
const GnnClientRegisterRequestPb *request,
GnnClientRegisterResponsePb *response) {
Status s = server_->ClientRegister(request->pid());
if (s.IsOk()) {
switch (server_->state()) {
case GraphDataServer::kGdsUninit:
case GraphDataServer::kGdsInitializing:
response->set_error_msg("Initializing");
break;
case GraphDataServer::kGdsRunning:
response->set_error_msg("Success");
response->set_data_schema(graph_data_impl_->GetDataSchema());
response->set_shared_memory_key(graph_data_impl_->GetSharedMemoryKey());
response->set_shared_memory_size(graph_data_impl_->GetSharedMemorySize());
s = FillDefaultFeature(response);
if (!s.IsOk()) {
response->set_error_msg(s.ToString());
}
break;
case GraphDataServer::kGdsStopped:
response->set_error_msg("Stoped");
break;
}
} else {
response->set_error_msg(s.ToString());
}
return ::grpc::Status::OK;
}
grpc::Status GraphDataServiceImpl::ClientUnRegister(grpc::ServerContext *context,
const GnnClientUnRegisterRequestPb *request,
GnnClientUnRegisterResponsePb *response) {
Status s = server_->ClientUnRegister(request->pid());
if (s.IsOk()) {
response->set_error_msg("Success");
} else {
response->set_error_msg(s.ToString());
}
return ::grpc::Status::OK;
}
grpc::Status GraphDataServiceImpl::GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response) {
// MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name();
Status s;
auto iter = g_get_graph_data_func_.find(request->op_name());
if (iter != g_get_graph_data_func_.end()) {
pFunction func = iter->second;
s = (this->*func)(request, response);
if (s.IsOk()) {
response->set_error_msg("Success");
} else {
response->set_error_msg(s.ToString());
}
} else {
response->set_error_msg("Invalid op name.");
}
// MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name();
return ::grpc::Status::OK;
}
grpc::Status GraphDataServiceImpl::GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
GnnMetaInfoResponsePb *response) {
MetaInfo meta_info;
Status s = graph_data_impl_->GetMetaInfo(&meta_info);
if (s.IsOk()) {
response->set_error_msg("Success");
for (const auto &type : meta_info.node_type) {
auto node_info = response->add_node_info();
node_info->set_type(static_cast<google::protobuf::int32>(type));
auto itr = meta_info.node_num.find(type);
if (itr != meta_info.node_num.end()) {
node_info->set_num(static_cast<google::protobuf::int32>(itr->second));
} else {
node_info->set_num(0);
}
}
for (const auto &type : meta_info.edge_type) {
auto edge_info = response->add_edge_info();
edge_info->set_type(static_cast<google::protobuf::int32>(type));
auto itr = meta_info.edge_num.find(type);
if (itr != meta_info.edge_num.end()) {
edge_info->set_num(static_cast<google::protobuf::int32>(itr->second));
} else {
edge_info->set_num(0);
}
}
for (const auto &type : meta_info.node_feature_type) {
response->add_node_feature_type(static_cast<google::protobuf::int32>(type));
}
for (const auto &type : meta_info.edge_feature_type) {
response->add_edge_feature_type(static_cast<google::protobuf::int32>(type));
}
} else {
response->set_error_msg(s.ToString());
}
return ::grpc::Status::OK;
}
Status GraphDataServiceImpl::GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNodes(static_cast<NodeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllEdges(static_cast<EdgeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input edge id is empty");
std::vector<EdgeIdType> edge_list;
edge_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), edge_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<EdgeIdType>(id); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetNodesFromEdges(edge_list, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() > 0, "The input neighbor number is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input neighbor type is empty");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::vector<NodeIdType> neighbor_nums;
neighbor_nums.resize(request->number().size());
std::transform(request->number().begin(), request->number().end(), neighbor_nums.begin(),
[](const google::protobuf::int32 num) { return static_cast<NodeIdType>(num); });
std::vector<NodeType> neighbor_types;
neighbor_types.resize(request->type().size());
std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(),
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetNegSampledNeighbors(const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() == 1, "The number of neighbor number is not 1");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of neighbor types is not 1");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetNegSampledNeighbors(node_list, static_cast<NodeIdType>(request->number()[0]),
static_cast<NodeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input meta path is empty");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::vector<NodeType> meta_path;
meta_path.resize(request->type().size());
std::transform(request->type().begin(), request->type().end(), meta_path.begin(),
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->RandomWalk(node_list, meta_path, request->random_walk().p(),
request->random_walk().q(), request->random_walk().default_id(),
&tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
std::shared_ptr<Tensor> nodes;
RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &nodes));
for (const auto &type : request->type()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetNodeFeatureSharedMemory(nodes, type, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
}
return Status::OK();
}
Status GraphDataServiceImpl::GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
std::shared_ptr<Tensor> edges;
RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &edges));
for (const auto &type : request->type()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetEdgeFeatureSharedMemory(edges, type, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#include <memory>
#include <string>
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataServer;
// class GraphDataServiceImpl : public GnnGraphData::Service {
class GraphDataServiceImpl {
public:
GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl);
~GraphDataServiceImpl() = default;
grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request,
GnnClientRegisterResponsePb *response);
grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request,
GnnClientUnRegisterResponsePb *response);
grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response);
grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
GnnMetaInfoResponsePb *response);
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
private:
Status FillDefaultFeature(GnnClientRegisterResponsePb *response);
GraphDataServer *server_;
GraphDataImpl *graph_data_impl_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#include <memory>
#include <utility>
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::MSRStatus;
GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) {
shard_column_ = std::make_unique<ShardColumn>(shard_column);
}
Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
std::shared_ptr<Tensor> *tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
data, tensor));
return Status::OK();
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
GraphSharedMemory *shared_memory,
std::shared_ptr<Tensor> *out_tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
auto fea_itr = tensor->begin<int64_t>();
int64_t offset = 0;
RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset));
*fea_itr = offset;
++fea_itr;
*fea_itr = n_bytes;
*out_tensor = std::move(tensor);
return Status::OK();
}
#endif
Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
std::vector<int32_t> *indices) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
for (int i = 0; i < n_bytes; i += col_type_size) {
int32_t feature_ind = -1;
if (col_type == mindrecord::ColumnInt32) {
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
} else if (col_type == mindrecord::ColumnInt64) {
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
} else {
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
}
if (feature_ind >= 0) indices->push_back(feature_ind);
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::ShardColumn;
class GraphFeatureParser {
public:
explicit GraphFeatureParser(const ShardColumn &shard_column);
~GraphFeatureParser() = default;
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor);
#if !defined(_WIN32) && !defined(_WIN64)
Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor);
#endif
private:
std::unique_ptr<ShardColumn> shard_column_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
......@@ -13,41 +13,42 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include <future>
#include <tuple>
#include <utility>
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/local_edge.h"
#include "minddata/dataset/engine/gnn/local_node.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/mindrecord/include/shard_error.h"
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::MSRStatus;
GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
: mr_path_(mr_filepath),
GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers, bool server_mode)
: graph_impl_(graph_impl),
mr_path_(mr_filepath),
num_workers_(num_workers),
row_id_(0),
shard_reader_(nullptr),
graph_feature_parser_(nullptr),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map,
DefaultEdgeFeatureMap *default_edge_feature_map) {
Status GraphLoader::GetNodesAndEdges() {
NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_;
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
while (dq.empty() == false) {
std::shared_ptr<Node> node_ptr = dq.front();
n_id_map->insert({node_ptr->id(), node_ptr});
(*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id());
dq.pop_front();
}
}
......@@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
(*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
dq.pop_front();
}
}
for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
for (auto &itr : graph_impl_->node_type_map_) itr.second.shrink_to_fit();
for (auto &itr : graph_impl_->edge_type_map_) itr.second.shrink_to_fit();
MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map);
MergeFeatureMaps();
return Status::OK();
}
......@@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() {
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
mindrecord::json schema = graph_impl_->data_schema_["schema"];
for (const std::string &key : keys_) {
if (schema.find(key) == schema.end()) {
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
}
}
if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64)
int64_t total_blob_size = 0;
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS,
"failed to get total blob size");
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_);
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
#endif
}
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(*shard_reader_->GetShardColumn());
// launching worker threads
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
......@@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
(*node) = std::make_shared<LocalNode>(node_id, node_type);
std::vector<int32_t> indices;
RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[node_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64)
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor_sm;
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory(
"node_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm));
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true)));
(*feature_map)[node_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor));
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
#endif
} else {
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor));
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[node_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
}
return Status::OK();
......@@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
std::vector<int32_t> indices;
RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[edge_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64)
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor_sm;
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory(
"edge_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm));
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true)));
(*feature_map)[edge_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor));
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
}
return Status::OK();
}
Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
data, tensor));
return Status::OK();
}
Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
for (int i = 0; i < n_bytes; i += col_type_size) {
int32_t feature_ind = -1;
if (col_type == mindrecord::ColumnInt32) {
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
} else if (col_type == mindrecord::ColumnInt64) {
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
} else {
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
#endif
} else {
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor));
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[edge_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
if (feature_ind >= 0) indices->push_back(feature_ind);
}
return Status::OK();
}
......@@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) {
return Status::OK();
}
void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
DefaultNodeFeatureMap *default_node_feature_map,
DefaultEdgeFeatureMap *default_edge_feature_map) {
void GraphLoader::MergeFeatureMaps() {
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
for (auto &m : n_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
for (auto &n : m.second) graph_impl_->node_feature_map_[m.first].insert(n);
}
for (auto &m : e_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
for (auto &n : m.second) graph_impl_->edge_feature_map_[m.first].insert(n);
}
for (auto &m : default_node_feature_maps_[wkr_id]) {
(*default_node_feature_map)[m.first] = m.second;
graph_impl_->default_node_feature_map_[m.first] = m.second;
}
for (auto &m : default_edge_feature_maps_[wkr_id]) {
(*default_edge_feature_map)[m.first] = m.second;
graph_impl_->default_edge_feature_map_[m.first] = m.second;
}
}
n_feature_maps_.clear();
......
......@@ -26,10 +26,13 @@
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_reader.h"
namespace mindspore {
......@@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy
using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
class GraphDataImpl;
// this class interfaces with the underlying storage format (mindrecord)
// it returns raw nodes and edges via GetNodesAndEdges
// it is then the responsibility of graph to construct itself based on the nodes and edges
// if needed, this class could become a base where each derived class handles a specific storage format
class GraphLoader {
public:
explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4);
GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false);
~GraphLoader() = default;
// Init mindrecord and load everything into memory multi-threaded
......@@ -63,8 +68,7 @@ class GraphLoader {
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
// features attached to each node and edge are expected to be filled correctly
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *,
DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
Status GetNodesAndEdges();
private:
//
......@@ -92,29 +96,15 @@ class GraphLoader {
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::shared_ptr<Tensor> *tensor);
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
void MergeFeatureMaps();
GraphDataImpl *graph_impl_;
std::string mr_path_;
const int32_t num_workers_;
std::atomic_int row_id_;
std::string mr_path_;
std::unique_ptr<ShardReader> shard_reader_;
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
std::vector<NodeFeatureMap> n_feature_maps_;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#include <string>
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
namespace gnn {
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key)
: memory_size_(memory_size),
memory_key_(memory_key),
memory_ptr_(nullptr),
memory_offset_(0),
is_new_create_(false) {
std::stringstream stream;
stream << std::hex << memory_key_;
memory_key_str_ = stream.str();
}
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file)
: mr_file_(mr_file),
memory_size_(memory_size),
memory_key_(-1),
memory_ptr_(nullptr),
memory_offset_(0),
is_new_create_(false) {}
GraphSharedMemory::~GraphSharedMemory() {
if (is_new_create_) {
(void)DeleteSharedMemory();
}
}
Status GraphSharedMemory::CreateSharedMemory() {
if (memory_key_ == -1) {
// ftok to generate unique key
memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId);
CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_);
std::stringstream stream;
stream << std::hex << memory_key_;
memory_key_str_ = stream.str();
}
int shmflg = (0666 | IPC_CREAT | IPC_EXCL);
Status s = SharedMemoryImpl(shmflg);
if (s.IsOk()) {
is_new_create_ = true;
MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_;
} else {
MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_;
shmflg = (0666 | IPC_CREAT);
s = SharedMemoryImpl(shmflg);
if (!s.IsOk()) {
RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_);
}
}
return Status::OK();
}
Status GraphSharedMemory::GetSharedMemory() {
int shmflg = 0;
RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg));
return Status::OK();
}
Status GraphSharedMemory::DeleteSharedMemory() {
int shmid = shmget(memory_key_, 0, 0);
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
int result = shmctl(shmid, IPC_RMID, 0);
CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_);
return Status::OK();
}
Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) {
// shmget returns an identifier in shmid
int shmid = shmget(memory_key_, memory_size_, shmflg);
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
// shmat to attach to shared memory
auto data = shmat(shmid, reinterpret_cast<void *>(0), 0);
CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_);
memory_ptr_ = reinterpret_cast<uint8_t *>(data);
return Status::OK();
}
Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) {
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid.");
std::lock_guard<std::mutex> lck(mutex_);
CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len),
"Insufficient shared memory space to insert data.");
if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) {
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
}
*offset = memory_offset_;
memory_offset_ += len;
return Status::OK();
}
Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) {
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid.");
CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space.");
CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset,
"get_data_len is too large, beyond the space of shared memory.");
if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) {
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#include <sys/ipc.h>
#include <sys/shm.h>
#include <mutex>
#include <string>
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
const int kGnnSharedMemoryId = 65;
class GraphSharedMemory {
public:
explicit GraphSharedMemory(int64_t memory_size, key_t memory_key);
explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file);
~GraphSharedMemory();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status CreateSharedMemory();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status GetSharedMemory();
Status DeleteSharedMemory();
Status InsertData(const uint8_t *data, int64_t len, int64_t *offset);
Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
key_t memory_key() { return memory_key_; }
int64_t memory_size() { return memory_size_; }
private:
Status SharedMemoryImpl(const int &shmflg);
std::string mr_file_;
int64_t memory_size_;
key_t memory_key_;
std::string memory_key_str_;
uint8_t *memory_ptr_;
int64_t memory_offset_;
std::mutex mutex_;
bool is_new_create_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#include <limits>
#include "minddata/dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {}
GrpcAsyncServer::~GrpcAsyncServer() { Stop(); }
Status GrpcAsyncServer::Run() {
std::string server_address = host_ + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
int port_tcpip = 0;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
RETURN_IF_NOT_OK(RegisterService(&builder));
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
errMsg += "Unable to bind to address " + server_address + ".";
}
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status GrpcAsyncServer::HandleRequest() {
bool success;
void *tag;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CallData
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
RETURN_IF_NOT_OK(EnqueueRequest());
while (cq_->Next(&tag, &success)) {
RETURN_IF_INTERRUPTED();
if (success) {
RETURN_IF_NOT_OK(ProcessRequest(tag));
} else {
MS_LOG(DEBUG) << "cq_->Next failed.";
}
}
return Status::OK();
}
void GrpcAsyncServer::Stop() {
if (server_) {
server_->Shutdown();
}
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
}
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "grpcpp/grpcpp.h"
#include "grpcpp/impl/codegen/async_unary_call.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief Async server base class
class GrpcAsyncServer {
public:
explicit GrpcAsyncServer(const std::string &host, int32_t port);
virtual ~GrpcAsyncServer();
/// \brief Brings up gRPC server
/// \return none
Status Run();
/// \brief Entry function to handle async server request
Status HandleRequest();
void Stop();
virtual Status RegisterService(grpc::ServerBuilder *builder) = 0;
virtual Status EnqueueRequest() = 0;
virtual Status ProcessRequest(void *tag) = 0;
protected:
int32_t port_;
std::string host_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
......@@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
return Status::OK();
}
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore
......@@ -20,10 +20,10 @@
#include <unordered_map>
#include <utility>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
......
......@@ -20,9 +20,9 @@
#include <unordered_map>
#include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
......
......@@ -20,8 +20,8 @@
#include <unordered_map>
#include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#include <algorithm>
#include <utility>
#include <unordered_map>
namespace mindspore {
namespace dataset {
const std::unordered_map<DataTypePb, DataType::Type> g_pb2datatype_map{
{DataTypePb::DE_PB_UNKNOWN, DataType::DE_UNKNOWN}, {DataTypePb::DE_PB_BOOL, DataType::DE_BOOL},
{DataTypePb::DE_PB_INT8, DataType::DE_INT8}, {DataTypePb::DE_PB_UINT8, DataType::DE_UINT8},
{DataTypePb::DE_PB_INT16, DataType::DE_INT16}, {DataTypePb::DE_PB_UINT16, DataType::DE_UINT16},
{DataTypePb::DE_PB_INT32, DataType::DE_INT32}, {DataTypePb::DE_PB_UINT32, DataType::DE_UINT32},
{DataTypePb::DE_PB_INT64, DataType::DE_INT64}, {DataTypePb::DE_PB_UINT64, DataType::DE_UINT64},
{DataTypePb::DE_PB_FLOAT16, DataType::DE_FLOAT16}, {DataTypePb::DE_PB_FLOAT32, DataType::DE_FLOAT32},
{DataTypePb::DE_PB_FLOAT64, DataType::DE_FLOAT64}, {DataTypePb::DE_PB_STRING, DataType::DE_STRING},
};
const std::unordered_map<DataType::Type, DataTypePb> g_datatype2pb_map{
{DataType::DE_UNKNOWN, DataTypePb::DE_PB_UNKNOWN}, {DataType::DE_BOOL, DataTypePb::DE_PB_BOOL},
{DataType::DE_INT8, DataTypePb::DE_PB_INT8}, {DataType::DE_UINT8, DataTypePb::DE_PB_UINT8},
{DataType::DE_INT16, DataTypePb::DE_PB_INT16}, {DataType::DE_UINT16, DataTypePb::DE_PB_UINT16},
{DataType::DE_INT32, DataTypePb::DE_PB_INT32}, {DataType::DE_UINT32, DataTypePb::DE_PB_UINT32},
{DataType::DE_INT64, DataTypePb::DE_PB_INT64}, {DataType::DE_UINT64, DataTypePb::DE_PB_UINT64},
{DataType::DE_FLOAT16, DataTypePb::DE_PB_FLOAT16}, {DataType::DE_FLOAT32, DataTypePb::DE_PB_FLOAT32},
{DataType::DE_FLOAT64, DataTypePb::DE_PB_FLOAT64}, {DataType::DE_STRING, DataTypePb::DE_PB_STRING},
};
Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer");
CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer");
std::vector<dsize_t> shape = tensor->shape().AsVector();
for (auto dim : shape) {
tensor_pb->add_dims(static_cast<google::protobuf::int64>(dim));
}
auto iter = g_datatype2pb_map.find(tensor->type().value());
if (iter == g_datatype2pb_map.end()) {
RETURN_STATUS_UNEXPECTED("Invalid tensor type: " + tensor->type().ToString());
}
tensor_pb->set_tensor_type(iter->second);
tensor_pb->set_data(tensor->GetBuffer(), tensor->SizeInBytes());
return Status::OK();
}
Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer");
CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer");
std::vector<dsize_t> shape;
shape.resize(tensor_pb->dims().size());
std::transform(tensor_pb->dims().begin(), tensor_pb->dims().end(), shape.begin(),
[](const google::protobuf::int64 dim) { return static_cast<dsize_t>(dim); });
auto iter = g_pb2datatype_map.find(tensor_pb->tensor_type());
if (iter == g_pb2datatype_map.end()) {
RETURN_STATUS_UNEXPECTED("Invalid Tensor_pb type: " + std::to_string(tensor_pb->tensor_type()));
}
DataType::Type type = iter->second;
std::shared_ptr<Tensor> tensor_out;
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(shape), DataType(type),
reinterpret_cast<const unsigned char *>(tensor_pb->data().data()),
tensor_pb->data().size(), &tensor_out));
*tensor = std::move(tensor_out);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
#include <deque>
#include <memory>
#include <vector>
#include "proto/gnn_tensor.pb.h"
#include "minddata/dataset/core/tensor.h"
namespace mindspore {
namespace dataset {
Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb);
Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
......@@ -61,6 +61,7 @@ const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
class ShardColumn {
public:
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
explicit ShardColumn(const json &schema_json, bool compress_integer = true);
~ShardColumn() = default;
......@@ -72,23 +73,29 @@ class ShardColumn {
std::vector<int64_t> *column_shape);
/// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob);
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size);
/// \brief check if blob compressed
bool CheckCompressBlob() const { return has_compress_blob_; }
/// \brief getter
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
/// \brief getter
std::vector<std::string> GetColumnName() { return column_name_; }
/// \brief getter
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
/// \brief getter
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
/// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes);
/// \brief get column type
std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name,
ColumnDataType *column_data_type,
uint64_t *column_data_type_size,
......@@ -99,6 +106,9 @@ class ShardColumn {
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
private:
/// \brief intialization
void Init(const json &schema_json, bool compress_integer = true);
/// \brief get float value from json
template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
......
......@@ -65,6 +65,11 @@ class ShardHeader {
/// \return the Statistic
std::vector<std::shared_ptr<Statistics>> GetStatistics();
/// \brief add the statistic and save it
/// \param[in] statistic info of slim size
/// \return null
int64_t GetSlimSizeStatistic(const json &slim_size_json);
/// \brief get the fields of the index
/// \return the fields of the index
std::vector<std::pair<uint64_t, std::string>> GetFields();
......@@ -114,10 +119,14 @@ class ShardHeader {
uint64_t GetPageSize() const { return page_size_; }
uint64_t GetCompressionSize() const { return compression_size_; }
void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; }
std::vector<std::string> SerializeHeader();
MSRStatus PagesToFile(const std::string dump_file_name);
......@@ -177,6 +186,7 @@ class ShardHeader {
uint32_t shard_count_;
uint64_t header_size_;
uint64_t page_size_;
uint64_t compression_size_;
std::shared_ptr<Index> index_;
std::vector<std::string> shard_addresses_;
......
......@@ -209,6 +209,9 @@ class ShardReader {
/// \brief get all classes
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
/// \brief get the size of blob data
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
protected:
/// \brief sqlite call back function
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
......@@ -323,6 +326,7 @@ class ShardReader {
const std::string kThreadName = "THRD_ITER_"; // prefix of thread name
std::vector<std::thread> thread_set_; // thread list
int num_rows_; // number of rows
int64_t total_blob_size_; // total size of blob data
std::mutex mtx_delivery_; // locker for delivery
std::condition_variable cv_delivery_; // conditional variable for delivery
std::condition_variable cv_iterator_; // conditional variable for iterator
......
......@@ -257,6 +257,7 @@ class ShardWriter {
std::mutex check_mutex_; // mutex for data check
std::atomic<bool> flag_{false};
std::atomic<int64_t> compression_size_;
};
} // namespace mindrecord
} // namespace mindspore
......
......@@ -43,6 +43,7 @@ ShardReader::ShardReader() {
page_size_ = 0;
header_size_ = 0;
num_rows_ = 0;
total_blob_size_ = 0;
num_padded_ = 0;
}
......@@ -55,9 +56,11 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
return {FAILED, {}};
}
auto header = ret.second;
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"version", header["version"]}, {"index_fields", header["index_fields"]},
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
uint64_t compression_size = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"compression_size", compression_size}, {"version", header["version"]},
{"index_fields", header["index_fields"]}, {"schema", header["schema"]},
{"blob_fields", header["blob_fields"]}};
return {SUCCESS, header["shard_addresses"]};
}
......@@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
for (const auto &rg : row_group_summary) {
num_rows_ += std::get<3>(rg);
}
auto disk_size = page_size_ * row_group_summary.size();
auto compression_size = shard_header_->GetCompressionSize();
total_blob_size_ = disk_size + compression_size;
MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , addtional uncompression: " << compression_size
<< " , Total: " << total_blob_size_;
MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully.";
......@@ -272,6 +280,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
return row_group_summary;
}
MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
*total_blob_size = total_blob_size_;
return SUCCESS;
}
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
std::shared_ptr<std::fstream> fs,
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
......
......@@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore {
namespace mindrecord {
ShardWriter::ShardWriter()
: shard_count_(1),
header_size_(kDefaultHeaderSize),
page_size_(kDefaultPageSize),
row_count_(0),
schema_count_(1) {}
: shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
compression_size_ = 0;
}
ShardWriter::~ShardWriter() {
for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
......@@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (ret == FAILED) {
return FAILED;
}
compression_size_ = shard_header_->GetCompressionSize();
ret = Open(real_addresses, true);
if (ret == FAILED) {
MS_LOG(ERROR) << "Open file failed";
......@@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
// compress blob
if (shard_column_->CheckCompressBlob()) {
for (auto &blob : blob_data) {
blob = shard_column_->CompressBlob(blob);
int64_t compression_bytes = 0;
blob = shard_column_->CompressBlob(blob, &compression_bytes);
compression_size_ += compression_bytes;
}
}
......@@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() {
MS_LOG(ERROR) << "Shard header is null";
return FAILED;
}
int64_t compression_temp = compression_size_;
uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
shard_header_->SetCompressionSize(compression_size);
auto shard_header = shard_header_->SerializeHeader();
// Write header data to multi files
if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) {
......
......@@ -24,7 +24,15 @@ namespace mindspore {
namespace mindrecord {
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
auto first_schema = shard_header->GetSchemas()[0];
auto schema = first_schema->GetSchema()["schema"];
json schema_json = first_schema->GetSchema();
Init(schema_json, compress_integer);
}
ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); }
void ShardColumn::Init(const json &schema_json, bool compress_integer) {
auto schema = schema_json["schema"];
auto blob_fields = schema_json["blob_fields"];
bool has_integer_array = false;
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
......@@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool
column_name_id_[column_name_[i]] = i;
}
auto blob_fields = first_schema->GetBlobFields();
for (const auto &field : blob_fields) {
blob_column_.push_back(field);
}
......@@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
}
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) {
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) {
// Skip if no compress columns
*compression_size = 0;
if (!CheckCompressBlob()) return blob;
std::vector<uint8_t> dst_blob;
......@@ -295,7 +302,9 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
// Compress and return is blob has 1 column only
if (num_blob_column_ == 1) {
return CompressInt(blob, int_type);
dst_blob = CompressInt(blob, int_type);
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
return dst_blob;
}
// Just copy and continue if column dat type is not int32/int64
......@@ -319,6 +328,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
i_src += kInt64Len + num_bytes;
}
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
return dst_blob;
}
......
......@@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
std::atomic<bool> thread_status(false);
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
index_ = std::make_shared<Index>();
}
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
shard_count_ = headers.size();
......@@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l
ParseShardAddress(header["shard_addresses"]);
header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>();
compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
}
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
return FAILED;
......@@ -146,9 +149,12 @@ std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &fil
return {FAILED, json()};
}
json raw_header = ret.second;
uint64_t compression_size =
raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0;
json header = {{"shard_addresses", raw_header["shard_addresses"]},
{"header_size", raw_header["header_size"]},
{"page_size", raw_header["page_size"]},
{"compression_size", compression_size},
{"index_fields", raw_header["index_fields"]},
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
{"schema", raw_header["schema"][0]["schema"]},
......@@ -343,6 +349,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s += "\"index_fields\":" + index + ",";
s += "\"page\":" + pages[shardId] + ",";
s += "\"page_size\":" + std::to_string(page_size_) + ",";
s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
s += "\"schema\":" + schema + ",";
s += "\"shard_addresses\":" + address + ",";
s += "\"shard_id\":" + std::to_string(shardId) + ",";
......
......@@ -3085,20 +3085,22 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple([np.array(x, copy=False) for x in val])
def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
"""
indices = sampler.get_indices()
return _sampler_fn_mp(indices, dataset, num_worker)
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
"""
Multiprocessing generator function wrapper for mappable dataset with python sampler.
"""
indices = _fetch_py_sampler_indices(sampler, num_samples)
return _sampler_fn_mp(indices, dataset, num_worker)
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
def _fetch_py_sampler_indices(sampler, num_samples):
......@@ -3132,63 +3134,92 @@ def _fill_worker_indices(workers, indices, idx):
return idx
def _sampler_fn_mp(indices, dataset, num_worker):
class SamplerFn:
"""
Multiprocessing generator function wrapper master process.
Multiprocessing or multithread generator function wrapper master process.
"""
workers = []
# Event for end of epoch
eoe = multiprocessing.Event()
# Create workers
for _ in range(num_worker):
worker = _GeneratorWorker(dataset, eoe)
worker.daemon = True
workers.append(worker)
# Fill initial index queues
idx_cursor = 0
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
# Start all workers
for w in workers:
w.start()
# Fetch results
for i in range(len(indices)):
# Fetch result and put index
try:
result = workers[i % num_worker].get()
except queue.Empty:
raise Exception("Generator worker process timeout")
except KeyboardInterrupt:
for w in workers:
w.terminate()
def __init__(self, dataset, num_worker, multi_process):
self.workers = []
self.num_worker = num_worker
self.multi_process = multi_process
# Event for end of epoch
if multi_process is True:
self.eoe = multiprocessing.Event()
self.eof = multiprocessing.Event()
else:
self.eoe = threading.Event()
self.eof = threading.Event()
# Create workers
for _ in range(num_worker):
if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof)
else:
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
worker.daemon = True
self.workers.append(worker)
def process(self, indices):
"""
The main process, start the child process or child thread, and fill the index queue,
get the result from the result and return.
"""
# Fill initial index queues
idx_cursor = 0
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Start all workers
for w in self.workers:
w.start()
# Fetch results
for i in range(len(indices)):
# Fetch result and put index
try:
result = self.workers[i % self.num_worker].get()
except queue.Empty:
raise Exception("Generator worker process timeout")
except KeyboardInterrupt:
self.eof.set()
for w in self.workers:
w.terminate()
w.join()
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Set eoe event once all indices are sent
if idx_cursor == len(indices) and not self.eoe.is_set():
self.eoe.set()
yield tuple([np.array(x, copy=False) for x in result])
def __del__(self):
self.eoe.set()
self.eof.set()
if self.multi_process is False:
for w in self.workers:
w.join()
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
# Set eoe event once all indices are sent
if idx_cursor == len(indices) and not eoe.is_set():
eoe.set()
yield tuple([np.array(x, copy=False) for x in result])
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
"""
Multiprocessing generator worker process loop.
Multiprocessing or multithread generator worker process loop.
"""
while True:
# Fetch index, block
try:
idx = idx_queue.get()
idx = idx_queue.get(timeout=10)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty:
if eof.is_set() or eoe.is_set():
raise Exception("Generator worker receives queue.Empty")
continue
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
return
if eof.is_set():
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
......@@ -3197,17 +3228,42 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx
if eoe.is_set() and idx_queue.empty():
return
class _GeneratorWorker(multiprocessing.Process):
class _GeneratorWorkerMt(threading.Thread):
"""
Worker process for multithread Generator.
"""
def __init__(self, dataset, eoe, eof):
self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
def put(self, item):
"""
Put function for worker index queue. Never block. Raise queue.Full on failure.
"""
self.idx_queue.put_nowait(item)
def get(self):
"""
Get function for worker result queue. Block with timeout.
"""
return self.res_queue.get(timeout=10)
class _GeneratorWorkerMp(multiprocessing.Process):
"""
Worker process for multiprocess Generator.
"""
def __init__(self, dataset, eoe):
def __init__(self, dataset, eoe, eof):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
def put(self, item):
"""
......@@ -3219,7 +3275,7 @@ class _GeneratorWorker(multiprocessing.Process):
"""
Get function for worker result queue. Block with timeout.
"""
return self.res_queue.get()
return self.res_queue.get(timeout=10)
def __del__(self):
self.terminate()
......@@ -3282,6 +3338,8 @@ class GeneratorDataset(MappableDataset):
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=True).
Examples:
>>> import mindspore.dataset as ds
......@@ -3318,12 +3376,14 @@ class GeneratorDataset(MappableDataset):
@check_generatordataset
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True):
super().__init__(num_parallel_workers)
self.source = source
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.num_shards = num_shards
self.python_multiprocessing = python_multiprocessing
if column_names is not None and not isinstance(column_names, list):
column_names = [column_names]
......@@ -3405,12 +3465,16 @@ class GeneratorDataset(MappableDataset):
sampler_instance.set_num_rows(len(self.source))
sampler_instance.initialize()
if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers))
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
else:
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
else:
if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers))
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
else:
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
else:
......
......@@ -16,8 +16,11 @@
graphdata.py supports loading graph dataset for GNN network training,
and provides operations related to graph data.
"""
import atexit
import time
import numpy as np
from mindspore._c_dataengine import Graph
from mindspore._c_dataengine import GraphDataClient
from mindspore._c_dataengine import GraphDataServer
from mindspore._c_dataengine import Tensor
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
......@@ -34,14 +37,52 @@ class GraphData:
dataset_file (str): One of file names in dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
(default=None).
working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local').
- 'local', used in non-distributed training scenarios.
- 'client', used in distributed training scenarios, the client does not load data,
but obtains data from the server.
- 'server', used in distributed training scenarios, the server loads the data
and is available to the client.
hostname (str, optional): Valid when working_mode is set to 'client' or 'server',
set the hostname of the graph data server (default='127.0.0.1').
port (int, optional): Valid when working_mode is set to 'client' or 'server',
set the port of the graph data server, the range is 1024-65535 (default=50051).
num_client (int, optional): Valid when working_mode is set to 'server',
set the number of clients expected to connect, and the server will allocate corresponding
resources according to this parameter (default=1).
auto_shutdown (bool, optional): Valid when working_mode is set to 'server',
Control when all clients have connected and no client connected to the server,
automatically exit the server (default=True).
"""
@check_gnn_graphdata
def __init__(self, dataset_file, num_parallel_workers=None):
def __init__(self, dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051,
num_client=1, auto_shutdown=True):
self._dataset_file = dataset_file
self._working_mode = working_mode
if num_parallel_workers is None:
num_parallel_workers = 1
self._graph = Graph(dataset_file, num_parallel_workers)
def stop():
self._graph_data.stop()
atexit.register(stop)
if working_mode in ['local', 'client']:
self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port)
if working_mode == 'server':
self._graph_data = GraphDataServer(
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
try:
while self._graph_data.is_stoped() is not True:
time.sleep(1)
except KeyboardInterrupt:
# self._graph_data.stop()
raise Exception("Graph data server receives KeyboardInterrupt")
@check_gnn_get_all_nodes
def get_all_nodes(self, node_type):
......@@ -62,7 +103,9 @@ class GraphData:
Raises:
TypeError: If `node_type` is not integer.
"""
return self._graph.get_all_nodes(node_type).as_array()
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_all_nodes(node_type).as_array()
@check_gnn_get_all_edges
def get_all_edges(self, edge_type):
......@@ -83,7 +126,9 @@ class GraphData:
Raises:
TypeError: If `edge_type` is not integer.
"""
return self._graph.get_all_edges(edge_type).as_array()
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_all_edges(edge_type).as_array()
@check_gnn_get_nodes_from_edges
def get_nodes_from_edges(self, edge_list):
......@@ -99,7 +144,9 @@ class GraphData:
Raises:
TypeError: If `edge_list` is not list or ndarray.
"""
return self._graph.get_nodes_from_edges(edge_list).as_array()
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_nodes_from_edges(edge_list).as_array()
@check_gnn_get_all_neighbors
def get_all_neighbors(self, node_list, neighbor_type):
......@@ -123,7 +170,9 @@ class GraphData:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `neighbor_type` is not integer.
"""
return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
@check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
......@@ -155,7 +204,9 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray.
"""
return self._graph.get_sampled_neighbors(
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array()
@check_gnn_get_neg_sampled_neighbors
......@@ -182,7 +233,9 @@ class GraphData:
TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer.
"""
return self._graph.get_neg_sampled_neighbors(
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_neg_sampled_neighbors(
node_list, neg_neighbor_num, neg_neighbor_type).as_array()
@check_gnn_get_node_feature
......@@ -207,10 +260,12 @@ class GraphData:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray.
"""
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
if isinstance(node_list, list):
node_list = np.array(node_list, dtype=np.int32)
return [
t.as_array() for t in self._graph.get_node_feature(
t.as_array() for t in self._graph_data.get_node_feature(
Tensor(node_list),
feature_types)]
......@@ -236,10 +291,12 @@ class GraphData:
TypeError: If `edge_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray.
"""
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
if isinstance(edge_list, list):
edge_list = np.array(edge_list, dtype=np.int32)
return [
t.as_array() for t in self._graph.get_edge_feature(
t.as_array() for t in self._graph_data.get_edge_feature(
Tensor(edge_list),
feature_types)]
......@@ -252,7 +309,9 @@ class GraphData:
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type.
"""
return self._graph.graph_info()
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.graph_info()
@check_gnn_random_walk
def random_walk(
......@@ -285,5 +344,7 @@ class GraphData:
TypeError: If `target_nodes` is not list or ndarray.
TypeError: If `meta_path` is not list or ndarray.
"""
return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array()
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array()
......@@ -18,6 +18,7 @@ Built-in validators.
"""
import inspect as ins
import os
import re
from functools import wraps
import numpy as np
......@@ -912,16 +913,36 @@ def check_split(method):
return new_method
def check_hostname(hostname):
if len(hostname) > 255:
return False
if hostname[-1] == ".":
hostname = hostname[:-1] # strip exactly one dot from the right, if present
allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in hostname.split("."))
def check_gnn_graphdata(method):
"""check the input arguments of graphdata."""
@wraps(method)
def new_method(self, *args, **kwargs):
[dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
[dataset_file, num_parallel_workers, working_mode, hostname,
port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
check_file(dataset_file)
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
type_check(hostname, (str,), "hostname")
if check_hostname(hostname) is False:
raise ValueError("The hostname is illegal")
type_check(working_mode, (str,), "working_mode")
if working_mode not in {'local', 'client', 'server'}:
raise ValueError("Invalid working mode")
type_check(port, (int,), "port")
check_value(port, (1024, 65535), "port")
type_check(num_client, (int,), "num_client")
check_value(num_client, (1, 255), "num_client")
type_check(auto_shutdown, (bool,), "auto_shutdown")
return method(self, *args, **kwargs)
return new_method
......
......@@ -15,6 +15,7 @@
"""
User-defined API for MindRecord GNN writer.
"""
import numpy as np
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
......@@ -29,7 +30,7 @@ social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (0, [], [])
node_profile = (2, ["int64", "int32"], [[-1], [-1]])
edge_profile = (0, [], [])
......@@ -51,7 +52,9 @@ def yield_nodes(task_id=0):
node_list.sort()
print(node_list)
for node_id in node_list:
node = {'id': node_id, 'type': 1}
node = {'id': node_id, 'type': 1,
'feature_1': np.ones((5,), dtype=np.int64),
'feature_2': np.ones((10,), dtype=np.int32)}
yield node
......
......@@ -22,6 +22,7 @@
#include "gtest/gtest.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_loader.h"
using namespace mindspore::dataset;
......@@ -39,30 +40,9 @@ class MindDataTestGNNGraph : public UT::Common {
MindDataTestGNNGraph() = default;
};
TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
std::string path = "data/mindrecord/testGraphData/testdata";
GraphLoader gl(path, 4);
EXPECT_TRUE(gl.InitAndLoad().IsOk());
NodeIdMap n_id_map;
EdgeIdMap e_id_map;
NodeTypeMap n_type_map;
EdgeTypeMap e_type_map;
NodeFeatureMap n_feature_map;
EdgeFeatureMap e_feature_map;
DefaultNodeFeatureMap default_node_feature_map;
DefaultEdgeFeatureMap default_edge_feature_map;
EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map,
&default_node_feature_map, &default_edge_feature_map)
.IsOk());
EXPECT_EQ(n_id_map.size(), 20);
EXPECT_EQ(e_id_map.size(), 40);
EXPECT_EQ(n_type_map[2].size(), 10);
EXPECT_EQ(n_type_map[1].size(), 10);
}
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1);
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
......@@ -103,7 +83,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1);
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
......@@ -194,7 +174,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1);
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
......@@ -237,7 +217,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1);
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
......@@ -263,7 +243,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1);
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import random
import time
from multiprocessing import Process
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
def graphdata_startserver():
"""
start graphdata server
"""
logger.info('test start server.\n')
ds.GraphData(DATASET_FILE, 1, 'server')
class RandomBatchedSampler(ds.Sampler):
# RandomBatchedSampler generate random sequence without replacement in a batched manner
def __init__(self, index_range, num_edges_per_sample):
super().__init__()
self.index_range = index_range
self.num_edges_per_sample = num_edges_per_sample
def __iter__(self):
indices = [i+1 for i in range(self.index_range)]
# Reset random seed here if necessary
# random.seed(0)
random.shuffle(indices)
for i in range(0, self.index_range, self.num_edges_per_sample):
# Drop reminder
if i + self.num_edges_per_sample <= self.index_range:
yield indices[i: i + self.num_edges_per_sample]
class GNNGraphDataset():
def __init__(self, g, batch_num):
self.g = g
self.batch_num = batch_num
def __len__(self):
# Total sample size of GNN dataset
# In this case, the size should be total_num_edges/num_edges_per_sample
return self.g.graph_info()['edge_num'][0] // self.batch_num
def __getitem__(self, index):
# index will be a list of indices yielded from RandomBatchedSampler
# Fetch edges/nodes/samples/features based on indices
nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
nodes = nodes[:, 0]
neg_nodes = self.g.get_neg_sampled_neighbors(
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
2, 2], neighbor_types=[2, 1])
neg_nodes_neighbors = self.g.get_sampled_neighbors(
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
nodes_neighbors_features = self.g.get_node_feature(
node_list=nodes_neighbors, feature_types=[2, 3])
neg_neighbors_features = self.g.get_node_feature(
node_list=neg_nodes_neighbors, feature_types=[2, 3])
return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
def test_graphdata_distributed():
"""
Test distributed
"""
logger.info('test distributed.\n')
p1 = Process(target=graphdata_startserver)
p1.start()
time.sleep(2)
g = ds.GraphData(DATASET_FILE, 1, 'client')
nodes = g.get_all_nodes(1)
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
[1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
[0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
edges = g.get_all_edges(0)
assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
features = g.get_edge_feature(edges, [1, 2])
assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
batch_num = 2
edge_num = g.graph_info()['edge_num'][0]
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
python_multiprocessing=False)
dataset = dataset.repeat(2)
itr = dataset.create_dict_iterator()
i = 0
for data in itr:
assert data['neighbors'].shape == (2, 7)
assert data['neg_neighbors'].shape == (6, 7)
assert data['neighbors_features'].shape == (2, 7)
assert data['neg_neighbors_features'].shape == (6, 7)
i += 1
assert i == 40
if __name__ == '__main__':
test_graphdata_distributed()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册