Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
256dccc6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
256dccc6
编写于
8月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4498 Gnn data processing supports distributed scenarios
Merge pull request !4498 from heleiwang/gnn_distributed
上级
39c81daf
8ee4d8e9
变更
48
展开全部
隐藏空白更改
内联
并排
Showing
48 changed file
with
3202 addition
and
340 deletion
+3202
-340
cmake/mind_expression.cmake
cmake/mind_expression.cmake
+7
-0
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
+5
-0
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc
...ataset/api/python/bindings/dataset/engine/gnn/bindings.cc
+38
-18
mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt
+23
-3
mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc
mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h
mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h
+2
-1
mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto
...re/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto
+103
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto
mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto
+42
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h
+134
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc
...re/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc
+589
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h
...ore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h
+185
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc
...pore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc
+135
-54
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h
...spore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h
+56
-38
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc
...re/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc
+133
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h
...ore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h
+196
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc
...rc/minddata/dataset/engine/gnn/graph_data_service_impl.cc
+299
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h
...src/minddata/dataset/engine/gnn/graph_data_service_impl.h
+70
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc
...ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc
+106
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h
.../ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h
+67
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc
+100
-88
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h
+13
-23
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc
.../ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc
+134
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h
...e/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h
+72
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc
...re/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc
+82
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h
...ore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h
+59
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc
+1
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h
+1
-1
mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h
mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h
+1
-1
mindspore/ccsrc/minddata/dataset/engine/gnn/node.h
mindspore/ccsrc/minddata/dataset/engine/gnn/node.h
+1
-1
mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc
mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc
+84
-0
mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h
mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h
+36
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_column.h
mindspore/ccsrc/minddata/mindrecord/include/shard_column.h
+11
-1
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
+10
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
+4
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
+1
-0
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
+16
-3
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
+12
-6
mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc
mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc
+15
-5
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
+8
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+115
-51
mindspore/dataset/engine/graphdata.py
mindspore/dataset/engine/graphdata.py
+75
-14
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+23
-2
model_zoo/utils/graph_to_mindrecord/sns/mr_api.py
model_zoo/utils/graph_to_mindrecord/sns/mr_api.py
+5
-2
tests/ut/cpp/dataset/gnn_graph_test.cc
tests/ut/cpp/dataset/gnn_graph_test.cc
+6
-26
tests/ut/data/mindrecord/testGraphData/sns
tests/ut/data/mindrecord/testGraphData/sns
+0
-0
tests/ut/data/mindrecord/testGraphData/sns.db
tests/ut/data/mindrecord/testGraphData/sns.db
+0
-0
tests/ut/data/mindrecord/testGraphData/testdata
tests/ut/data/mindrecord/testGraphData/testdata
+0
-0
tests/ut/python/dataset/test_graphdata_distributed.py
tests/ut/python/dataset/test_graphdata_distributed.py
+125
-0
未找到文件。
cmake/mind_expression.cmake
浏览文件 @
256dccc6
...
...
@@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
include
(
${
CMAKE_SOURCE_DIR
}
/cmake/dependency_securec.cmake
)
include
(
${
CMAKE_SOURCE_DIR
}
/cmake/external_libs/protobuf.cmake
)
SET
(
MS_BUILD_GRPC 0
)
if
(
ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES
)
SET
(
MS_BUILD_GRPC 1
)
endif
()
if
(
ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES
"Windows"
)
SET
(
MS_BUILD_GRPC 1
)
endif
()
if
(
"
${
MS_BUILD_GRPC
}
"
)
# build dependencies of gRPC
include
(
${
CMAKE_SOURCE_DIR
}
/cmake/external_libs/absl.cmake
)
include
(
${
CMAKE_SOURCE_DIR
}
/cmake/external_libs/c-ares.cmake
)
...
...
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
浏览文件 @
256dccc6
...
...
@@ -83,6 +83,7 @@ endif()
if
(
ENABLE_TDTQUE
)
add_dependencies
(
engine-tdt core
)
endif
()
################### Create _c_dataengine Library ######################
set
(
submodules
$<TARGET_OBJECTS:core>
...
...
@@ -182,3 +183,7 @@ else()
set_target_properties
(
_c_dataengine PROPERTIES MACOSX_RPATH ON
)
endif
()
endif
()
if
(
NOT CMAKE_SYSTEM_NAME MATCHES
"Windows"
)
target_link_libraries
(
_c_dataengine PRIVATE mindspore::grpc++
)
endif
()
\ No newline at end of file
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc
浏览文件 @
256dccc6
...
...
@@ -18,83 +18,103 @@
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/graph_data_client.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace
mindspore
{
namespace
dataset
{
PYBIND_REGISTER
(
Graph
,
0
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
gnn
::
Graph
,
std
::
shared_ptr
<
gnn
::
Graph
>>
(
*
m
,
"Graph"
)
.
def
(
py
::
init
([](
std
::
string
dataset_file
,
int32_t
num_workers
)
{
std
::
shared_ptr
<
gnn
::
Graph
>
g_out
=
std
::
make_shared
<
gnn
::
Graph
>
(
dataset_file
,
num_workers
);
THROW_IF_ERROR
(
g_out
->
Init
());
return
g_out
;
(
void
)
py
::
class_
<
gnn
::
GraphData
,
std
::
shared_ptr
<
gnn
::
GraphData
>>
(
*
m
,
"GraphDataClient"
)
.
def
(
py
::
init
([](
const
std
::
string
&
dataset_file
,
int32_t
num_workers
,
const
std
::
string
&
working_mode
,
const
std
::
string
&
hostname
,
int32_t
port
)
{
std
::
shared_ptr
<
gnn
::
GraphData
>
out
;
if
(
working_mode
==
"local"
)
{
out
=
std
::
make_shared
<
gnn
::
GraphDataImpl
>
(
dataset_file
,
num_workers
);
}
else
if
(
working_mode
==
"client"
)
{
out
=
std
::
make_shared
<
gnn
::
GraphDataClient
>
(
dataset_file
,
hostname
,
port
);
}
THROW_IF_ERROR
(
out
->
Init
());
return
out
;
}))
.
def
(
"get_all_nodes"
,
[](
gnn
::
Graph
&
g
,
gnn
::
NodeType
node_type
)
{
[](
gnn
::
Graph
Data
&
g
,
gnn
::
NodeType
node_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetAllNodes
(
node_type
,
&
out
));
return
out
;
})
.
def
(
"get_all_edges"
,
[](
gnn
::
Graph
&
g
,
gnn
::
EdgeType
edge_type
)
{
[](
gnn
::
Graph
Data
&
g
,
gnn
::
EdgeType
edge_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetAllEdges
(
edge_type
,
&
out
));
return
out
;
})
.
def
(
"get_nodes_from_edges"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
edge_list
)
{
[](
gnn
::
Graph
Data
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
edge_list
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetNodesFromEdges
(
edge_list
,
&
out
));
return
out
;
})
.
def
(
"get_all_neighbors"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
gnn
::
NodeType
neighbor_type
)
{
[](
gnn
::
Graph
Data
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
gnn
::
NodeType
neighbor_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetAllNeighbors
(
node_list
,
neighbor_type
,
&
out
));
return
out
;
})
.
def
(
"get_sampled_neighbors"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
std
::
vector
<
gnn
::
NodeIdType
>
neighbor_nums
,
[](
gnn
::
Graph
Data
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
std
::
vector
<
gnn
::
NodeIdType
>
neighbor_nums
,
std
::
vector
<
gnn
::
NodeType
>
neighbor_types
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetSampledNeighbors
(
node_list
,
neighbor_nums
,
neighbor_types
,
&
out
));
return
out
;
})
.
def
(
"get_neg_sampled_neighbors"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
gnn
::
NodeIdType
neighbor_num
,
[](
gnn
::
Graph
Data
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
gnn
::
NodeIdType
neighbor_num
,
gnn
::
NodeType
neg_neighbor_type
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
GetNegSampledNeighbors
(
node_list
,
neighbor_num
,
neg_neighbor_type
,
&
out
));
return
out
;
})
.
def
(
"get_node_feature"
,
[](
gnn
::
Graph
&
g
,
std
::
shared_ptr
<
Tensor
>
node_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
[](
gnn
::
Graph
Data
&
g
,
std
::
shared_ptr
<
Tensor
>
node_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
TensorRow
out
;
THROW_IF_ERROR
(
g
.
GetNodeFeature
(
node_list
,
feature_types
,
&
out
));
return
out
.
getRow
();
})
.
def
(
"get_edge_feature"
,
[](
gnn
::
Graph
&
g
,
std
::
shared_ptr
<
Tensor
>
edge_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
[](
gnn
::
Graph
Data
&
g
,
std
::
shared_ptr
<
Tensor
>
edge_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
TensorRow
out
;
THROW_IF_ERROR
(
g
.
GetEdgeFeature
(
edge_list
,
feature_types
,
&
out
));
return
out
.
getRow
();
})
.
def
(
"graph_info"
,
[](
gnn
::
Graph
&
g
)
{
[](
gnn
::
Graph
Data
&
g
)
{
py
::
dict
out
;
THROW_IF_ERROR
(
g
.
GraphInfo
(
&
out
));
return
out
;
})
.
def
(
"random_walk"
,
[](
gnn
::
Graph
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
std
::
vector
<
gnn
::
NodeType
>
meta_path
,
[](
gnn
::
Graph
Data
&
g
,
std
::
vector
<
gnn
::
NodeIdType
>
node_list
,
std
::
vector
<
gnn
::
NodeType
>
meta_path
,
float
step_home_param
,
float
step_away_param
,
gnn
::
NodeIdType
default_node
)
{
std
::
shared_ptr
<
Tensor
>
out
;
THROW_IF_ERROR
(
g
.
RandomWalk
(
node_list
,
meta_path
,
step_home_param
,
step_away_param
,
default_node
,
&
out
));
return
out
;
});
})
.
def
(
"stop"
,
[](
gnn
::
GraphData
&
g
)
{
THROW_IF_ERROR
(
g
.
Stop
());
});
(
void
)
py
::
class_
<
gnn
::
GraphDataServer
,
std
::
shared_ptr
<
gnn
::
GraphDataServer
>>
(
*
m
,
"GraphDataServer"
)
.
def
(
py
::
init
([](
const
std
::
string
&
dataset_file
,
int32_t
num_workers
,
const
std
::
string
&
hostname
,
int32_t
port
,
int32_t
client_num
,
bool
auto_shutdown
)
{
std
::
shared_ptr
<
gnn
::
GraphDataServer
>
out
;
out
=
std
::
make_shared
<
gnn
::
GraphDataServer
>
(
dataset_file
,
num_workers
,
hostname
,
port
,
client_num
,
auto_shutdown
);
THROW_IF_ERROR
(
out
->
Init
());
return
out
;
}))
.
def
(
"stop"
,
[](
gnn
::
GraphDataServer
&
g
)
{
THROW_IF_ERROR
(
g
.
Stop
());
})
.
def
(
"is_stoped"
,
[](
gnn
::
GraphDataServer
&
g
)
{
return
g
.
IsStoped
();
});
}));
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt
浏览文件 @
256dccc6
file
(
GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
engine-gnn OBJECT
graph.cc
set
(
DATASET_ENGINE_GNN_SRC_FILES
graph_data_impl.cc
graph_data_client.cc
graph_data_server.cc
graph_loader.cc
graph_feature_parser.cc
local_node.cc
local_edge.cc
feature.cc
)
)
if
(
CMAKE_SYSTEM_NAME MATCHES
"Windows"
)
add_library
(
engine-gnn OBJECT
${
DATASET_ENGINE_GNN_SRC_FILES
}
)
else
()
set
(
DATASET_ENGINE_GNN_SRC_FILES
${
DATASET_ENGINE_GNN_SRC_FILES
}
tensor_proto.cc
grpc_async_server.cc
graph_data_service_impl.cc
graph_shared_memory.cc
)
ms_protobuf_generate
(
TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS
"gnn_tensor.proto"
)
ms_grpc_generate
(
GNN_PROTO_SRCS GNN_PROTO_HDRS
"gnn_graph_data.proto"
)
add_library
(
engine-gnn OBJECT
${
DATASET_ENGINE_GNN_SRC_FILES
}
${
TENSOR_PROTO_SRCS
}
${
GNN_PROTO_SRCS
}
)
add_dependencies
(
engine-gnn mindspore::protobuf
)
endif
()
mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc
浏览文件 @
256dccc6
...
...
@@ -19,7 +19,8 @@ namespace mindspore {
namespace
dataset
{
namespace
gnn
{
Feature
::
Feature
(
FeatureType
type_name
,
std
::
shared_ptr
<
Tensor
>
value
)
:
type_name_
(
type_name
),
value_
(
value
)
{}
Feature
::
Feature
(
FeatureType
type_name
,
std
::
shared_ptr
<
Tensor
>
value
,
bool
is_shared_memory
)
:
type_name_
(
type_name
),
value_
(
value
),
is_shared_memory_
(
is_shared_memory
)
{}
}
// namespace gnn
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h
浏览文件 @
256dccc6
...
...
@@ -31,7 +31,7 @@ class Feature {
// Constructor
// @param FeatureType type_name - feature type
// @param std::shared_ptr<Tensor> value - feature value
Feature
(
FeatureType
type_name
,
std
::
shared_ptr
<
Tensor
>
value
);
Feature
(
FeatureType
type_name
,
std
::
shared_ptr
<
Tensor
>
value
,
bool
is_shared_memory
=
false
);
~
Feature
()
=
default
;
...
...
@@ -45,6 +45,7 @@ class Feature {
private:
FeatureType
type_name_
;
std
::
shared_ptr
<
Tensor
>
value_
;
bool
is_shared_memory_
;
};
}
// namespace gnn
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax
=
"proto3"
;
package
mindspore
.
dataset
;
import
"gnn_tensor.proto"
;
message
GnnClientRegisterRequestPb
{
int32
pid
=
1
;
}
message
GnnFeatureInfoPb
{
int32
type
=
1
;
TensorPb
feature
=
2
;
}
message
GnnClientRegisterResponsePb
{
string
error_msg
=
1
;
string
data_schema
=
2
;
int64
shared_memory_key
=
3
;
int64
shared_memory_size
=
4
;
repeated
GnnFeatureInfoPb
default_node_feature
=
5
;
repeated
GnnFeatureInfoPb
default_edge_feature
=
6
;
}
message
GnnClientUnRegisterRequestPb
{
int32
pid
=
1
;
}
message
GnnClientUnRegisterResponsePb
{
string
error_msg
=
1
;
}
enum
GnnOpName
{
GET_ALL_NODES
=
0
;
GET_ALL_EDGES
=
1
;
GET_NODES_FROM_EDGES
=
2
;
GET_ALL_NEIGHBORS
=
3
;
GET_SAMPLED_NEIGHBORS
=
4
;
GET_NEG_SAMPLED_NEIGHBORS
=
5
;
RANDOM_WALK
=
6
;
GET_NODE_FEATURE
=
7
;
GET_EDGE_FEATURE
=
8
;
}
message
GnnRandomWalkPb
{
float
p
=
1
;
float
q
=
2
;
int32
default_id
=
3
;
}
message
GnnGraphDataRequestPb
{
GnnOpName
op_name
=
1
;
repeated
int32
id
=
2
;
// node id or edge id
repeated
int32
type
=
3
;
//node type or edge type or neighbor type or feature type
repeated
int32
number
=
4
;
// samples number
TensorPb
id_tensor
=
5
;
// input ids ,node id or edge id
GnnRandomWalkPb
random_walk
=
6
;
}
message
GnnGraphDataResponsePb
{
string
error_msg
=
1
;
repeated
TensorPb
result_data
=
2
;
}
message
GnnMetaInfoRequestPb
{
}
message
GnnNodeEdgeInfoPb
{
int32
type
=
1
;
int32
num
=
2
;
}
message
GnnMetaInfoResponsePb
{
string
error_msg
=
1
;
repeated
GnnNodeEdgeInfoPb
node_info
=
2
;
repeated
GnnNodeEdgeInfoPb
edge_info
=
3
;
repeated
int32
node_feature_type
=
4
;
repeated
int32
edge_feature_type
=
5
;
}
service
GnnGraphData
{
rpc
ClientRegister
(
GnnClientRegisterRequestPb
)
returns
(
GnnClientRegisterResponsePb
);
rpc
ClientUnRegister
(
GnnClientUnRegisterRequestPb
)
returns
(
GnnClientUnRegisterResponsePb
);
rpc
GetGraphData
(
GnnGraphDataRequestPb
)
returns
(
GnnGraphDataResponsePb
);
rpc
GetMetaInfo
(
GnnMetaInfoRequestPb
)
returns
(
GnnMetaInfoResponsePb
);
}
mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax
=
"proto3"
;
package
mindspore
.
dataset
;
enum
DataTypePb
{
DE_PB_UNKNOWN
=
0
;
DE_PB_BOOL
=
1
;
DE_PB_INT8
=
2
;
DE_PB_UINT8
=
3
;
DE_PB_INT16
=
4
;
DE_PB_UINT16
=
5
;
DE_PB_INT32
=
6
;
DE_PB_UINT32
=
7
;
DE_PB_INT64
=
8
;
DE_PB_UINT64
=
9
;
DE_PB_FLOAT16
=
10
;
DE_PB_FLOAT32
=
11
;
DE_PB_FLOAT64
=
12
;
DE_PB_STRING
=
13
;
}
message
TensorPb
{
repeated
int64
dims
=
1
;
// tensor shape info
DataTypePb
tensor_type
=
2
;
// tensor content data type
bytes
data
=
3
;
// tensor data
}
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
struct
MetaInfo
{
std
::
vector
<
NodeType
>
node_type
;
std
::
vector
<
EdgeType
>
edge_type
;
std
::
map
<
NodeType
,
NodeIdType
>
node_num
;
std
::
map
<
EdgeType
,
EdgeIdType
>
edge_num
;
std
::
vector
<
FeatureType
>
node_feature_type
;
std
::
vector
<
FeatureType
>
edge_feature_type
;
};
class
GraphData
{
public:
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
virtual
Status
GetAllNodes
(
NodeType
node_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
virtual
Status
GetAllEdges
(
EdgeType
edge_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
virtual
Status
GetNodesFromEdges
(
const
std
::
vector
<
EdgeIdType
>
&
edge_list
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
virtual
Status
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
virtual
Status
GetSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
virtual
Status
GetNegSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
virtual
Status
RandomWalk
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
meta_path
,
float
step_home_param
,
float
step_away_param
,
NodeIdType
default_node
,
std
::
shared_ptr
<
Tensor
>
*
out
)
=
0
;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
virtual
Status
GetNodeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
nodes
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
)
=
0
;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
virtual
Status
GetEdgeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
edges
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
)
=
0
;
// Return meta information to python layer
virtual
Status
GraphInfo
(
py
::
dict
*
out
)
=
0
;
virtual
Status
Init
()
=
0
;
virtual
Status
Stop
()
=
0
;
};
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc
0 → 100644
浏览文件 @
256dccc6
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#include <algorithm>
#include <memory>
#include <string>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <utility>
#if !defined(_WIN32) && !defined(_WIN64)
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
#endif
#include "minddata/dataset/engine/gnn/graph_data.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
class
GraphDataClient
:
public
GraphData
{
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
GraphDataClient
(
const
std
::
string
&
dataset_file
,
const
std
::
string
&
hostname
,
int32_t
port
);
~
GraphDataClient
();
Status
Init
()
override
;
Status
Stop
()
override
;
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status
GetAllNodes
(
NodeType
node_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status
GetAllEdges
(
EdgeType
edge_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status
GetNodesFromEdges
(
const
std
::
vector
<
EdgeIdType
>
&
edge_list
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status
GetSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status
GetNegSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
Status
RandomWalk
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
meta_path
,
float
step_home_param
,
float
step_away_param
,
NodeIdType
default_node
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status
GetNodeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
nodes
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
)
override
;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status
GetEdgeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
edges
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
)
override
;
// Return meta information to python layer
Status
GraphInfo
(
py
::
dict
*
out
)
override
;
private:
#if !defined(_WIN32) && !defined(_WIN64)
Status
ParseNodeFeatureFromMemory
(
const
std
::
shared_ptr
<
Tensor
>
&
nodes
,
FeatureType
feature_type
,
const
std
::
shared_ptr
<
Tensor
>
&
memory_tensor
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
ParseEdgeFeatureFromMemory
(
const
std
::
shared_ptr
<
Tensor
>
&
edges
,
FeatureType
feature_type
,
const
std
::
shared_ptr
<
Tensor
>
&
memory_tensor
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
GetNodeDefaultFeature
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Tensor
>
*
out_feature
);
Status
GetEdgeDefaultFeature
(
FeatureType
feature_type
,
std
::
shared_ptr
<
Tensor
>
*
out_feature
);
Status
GetGraphData
(
const
GnnGraphDataRequestPb
&
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetGraphDataTensor
(
const
GnnGraphDataRequestPb
&
request
,
GnnGraphDataResponsePb
*
response
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
RegisterToServer
();
Status
UnRegisterToServer
();
Status
InitFeatureParser
();
Status
CheckPid
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
pid_
==
getpid
(),
"Multi-process mode is not supported, please change to use multi-thread"
);
return
Status
::
OK
();
}
#endif
std
::
string
dataset_file_
;
std
::
string
host_
;
int32_t
port_
;
int32_t
pid_
;
mindrecord
::
json
data_schema_
;
#if !defined(_WIN32) && !defined(_WIN64)
std
::
unique_ptr
<
GnnGraphData
::
Stub
>
stub_
;
key_t
shared_memory_key_
;
int64_t
shared_memory_size_
;
std
::
unique_ptr
<
GraphFeatureParser
>
graph_feature_parser_
;
std
::
unique_ptr
<
GraphSharedMemory
>
graph_shared_memory_
;
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Tensor
>>
default_node_feature_map_
;
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Tensor
>>
default_edge_feature_map_
;
#endif
bool
registered_
;
};
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc
→
mindspore/ccsrc/minddata/dataset/engine/gnn/graph
_data_impl
.cc
浏览文件 @
256dccc6
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h
→
mindspore/ccsrc/minddata/dataset/engine/gnn/graph
_data_impl
.h
浏览文件 @
256dccc6
...
...
@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_
DATA_IMPL_
H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_
DATA_IMPL_
H_
#include <algorithm>
#include <memory>
...
...
@@ -25,13 +25,11 @@
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/graph_data.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001;
const
uint32_t
kMaxNumWalks
=
80
;
using
StochasticIndex
=
std
::
pair
<
std
::
vector
<
int32_t
>
,
std
::
vector
<
float
>>
;
struct
MetaInfo
{
std
::
vector
<
NodeType
>
node_type
;
std
::
vector
<
EdgeType
>
edge_type
;
std
::
map
<
NodeType
,
NodeIdType
>
node_num
;
std
::
map
<
EdgeType
,
EdgeIdType
>
edge_num
;
std
::
vector
<
FeatureType
>
node_feature_type
;
std
::
vector
<
FeatureType
>
edge_feature_type
;
};
class
Graph
{
class
GraphDataImpl
:
public
GraphData
{
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
Graph
(
std
::
string
dataset_file
,
int32_t
num_workers
);
Graph
DataImpl
(
std
::
string
dataset_file
,
int32_t
num_workers
,
bool
server_mode
=
false
);
~
Graph
()
=
default
;
~
Graph
DataImpl
()
;
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status
GetAllNodes
(
NodeType
node_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
GetAllNodes
(
NodeType
node_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status
GetAllEdges
(
EdgeType
edge_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
GetAllEdges
(
EdgeType
edge_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status
GetNodesFromEdges
(
const
std
::
vector
<
EdgeIdType
>
&
edge_list
,
std
::
shared_ptr
<
Tensor
>
*
out
);
Status
GetNodesFromEdges
(
const
std
::
vector
<
EdgeIdType
>
&
edge_list
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
...
...
@@ -85,7 +74,7 @@ class Graph {
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status
GetAllNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeType
neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
...
...
@@ -94,7 +83,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status
GetSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeIdType
>
&
neighbor_nums
,
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
);
const
std
::
vector
<
NodeType
>
&
neighbor_types
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
...
...
@@ -103,7 +92,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status
GetNegSampledNeighbors
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
NodeIdType
samples_num
,
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
NodeType
neg_neighbor_type
,
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
...
...
@@ -115,7 +104,7 @@ class Graph {
// @return Status - The error code return
Status
RandomWalk
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
meta_path
,
float
step_home_param
,
float
step_away_param
,
NodeIdType
default_node
,
std
::
shared_ptr
<
Tensor
>
*
out
);
std
::
shared_ptr
<
Tensor
>
*
out
)
override
;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
...
...
@@ -124,16 +113,22 @@ class Graph {
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status
GetNodeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
nodes
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
);
TensorRow
*
out
)
override
;
Status
GetNodeFeatureSharedMemory
(
const
std
::
shared_ptr
<
Tensor
>
&
nodes
,
FeatureType
type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edge
t
- List of edges
// @param std::shared_ptr<Tensor> edge
s
- List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status
GetEdgeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
edget
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
);
Status
GetEdgeFeature
(
const
std
::
shared_ptr
<
Tensor
>
&
edges
,
const
std
::
vector
<
FeatureType
>
&
feature_types
,
TensorRow
*
out
)
override
;
Status
GetEdgeFeatureSharedMemory
(
const
std
::
shared_ptr
<
Tensor
>
&
edges
,
FeatureType
type
,
std
::
shared_ptr
<
Tensor
>
*
out
);
// Get meta information of graph
// @param MetaInfo *meta_info - Returned meta information
...
...
@@ -142,15 +137,34 @@ class Graph {
#ifdef ENABLE_PYTHON
// Return meta information to python layer
Status
GraphInfo
(
py
::
dict
*
out
);
Status
GraphInfo
(
py
::
dict
*
out
)
override
;
#endif
Status
Init
();
const
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Feature
>>
*
GetAllDefaultNodeFeatures
()
{
return
&
default_node_feature_map_
;
}
const
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Feature
>>
*
GetAllDefaultEdgeFeatures
()
{
return
&
default_edge_feature_map_
;
}
Status
Init
()
override
;
Status
Stop
()
override
{
return
Status
::
OK
();
}
std
::
string
GetDataSchema
()
{
return
data_schema_
.
dump
();
}
#if !defined(_WIN32) && !defined(_WIN64)
key_t
GetSharedMemoryKey
()
{
return
graph_shared_memory_
->
memory_key
();
}
int64_t
GetSharedMemorySize
()
{
return
graph_shared_memory_
->
memory_size
();
}
#endif
private:
friend
class
GraphLoader
;
class
RandomWalkBase
{
public:
explicit
RandomWalkBase
(
Graph
*
graph
);
explicit
RandomWalkBase
(
Graph
DataImpl
*
graph
);
Status
Build
(
const
std
::
vector
<
NodeIdType
>
&
node_list
,
const
std
::
vector
<
NodeType
>
&
meta_path
,
float
step_home_param
=
1.0
,
float
step_away_param
=
1.0
,
NodeIdType
default_node
=
-
1
,
...
...
@@ -176,7 +190,7 @@ class Graph {
template
<
typename
T
>
std
::
vector
<
float
>
Normalize
(
const
std
::
vector
<
T
>
&
non_normalized_probability
);
Graph
*
graph_
;
Graph
DataImpl
*
graph_
;
std
::
vector
<
NodeIdType
>
node_list_
;
std
::
vector
<
NodeType
>
meta_path_
;
float
step_home_param_
;
// Return hyper parameter. Default is 1.0
...
...
@@ -248,7 +262,11 @@ class Graph {
int32_t
num_workers_
;
// The number of worker threads
std
::
mt19937
rnd_
;
RandomWalkBase
random_walk_
;
mindrecord
::
json
data_schema_
;
bool
server_mode_
;
#if !defined(_WIN32) && !defined(_WIN64)
std
::
unique_ptr
<
GraphSharedMemory
>
graph_shared_memory_
;
#endif
std
::
unordered_map
<
NodeType
,
std
::
vector
<
NodeIdType
>>
node_type_map_
;
std
::
unordered_map
<
NodeIdType
,
std
::
shared_ptr
<
Node
>>
node_id_map_
;
...
...
@@ -264,4 +282,4 @@ class Graph {
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_
DATA_IMPL_
H_
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_server.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <utility>
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/util/random.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
GraphDataServer
::
GraphDataServer
(
const
std
::
string
&
dataset_file
,
int32_t
num_workers
,
const
std
::
string
&
hostname
,
int32_t
port
,
int32_t
client_num
,
bool
auto_shutdown
)
:
dataset_file_
(
dataset_file
),
num_workers_
(
num_workers
),
client_num_
(
client_num
),
max_connected_client_num_
(
0
),
auto_shutdown_
(
auto_shutdown
),
state_
(
kGdsUninit
)
{
tg_
=
std
::
make_unique
<
TaskGroup
>
();
graph_data_impl_
=
std
::
make_unique
<
GraphDataImpl
>
(
dataset_file
,
num_workers
,
true
);
#if !defined(_WIN32) && !defined(_WIN64)
service_impl_
=
std
::
make_unique
<
GraphDataServiceImpl
>
(
this
,
graph_data_impl_
.
get
());
async_server_
=
std
::
make_unique
<
GraphDataGrpcServer
>
(
hostname
,
port
,
service_impl_
.
get
());
#endif
}
Status
GraphDataServer
::
Init
()
{
#if defined(_WIN32) || defined(_WIN64)
RETURN_STATUS_UNEXPECTED
(
"Graph data server is not supported in Windows OS"
);
#else
set_state
(
kGdsInitializing
);
RETURN_IF_NOT_OK
(
async_server_
->
Run
());
// RETURN_IF_NOT_OK(InitGraphDataImpl());
RETURN_IF_NOT_OK
(
tg_
->
CreateAsyncTask
(
"init graph data impl"
,
std
::
bind
(
&
GraphDataServer
::
InitGraphDataImpl
,
this
)));
for
(
int32_t
i
=
0
;
i
<
num_workers_
;
++
i
)
{
RETURN_IF_NOT_OK
(
tg_
->
CreateAsyncTask
(
"start async rpc service"
,
std
::
bind
(
&
GraphDataServer
::
StartAsyncRpcService
,
this
)));
}
if
(
auto_shutdown_
)
{
RETURN_IF_NOT_OK
(
tg_
->
CreateAsyncTask
(
"judge auto shutdown server"
,
std
::
bind
(
&
GraphDataServer
::
JudgeAutoShutdownServer
,
this
)));
}
return
Status
::
OK
();
#endif
}
Status
GraphDataServer
::
InitGraphDataImpl
()
{
TaskManager
::
FindMe
()
->
Post
();
Status
s
=
graph_data_impl_
->
Init
();
if
(
s
.
IsOk
())
{
set_state
(
kGdsRunning
);
}
else
{
(
void
)
Stop
();
}
return
s
;
}
#if !defined(_WIN32) && !defined(_WIN64)
Status
GraphDataServer
::
StartAsyncRpcService
()
{
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
async_server_
->
HandleRequest
());
return
Status
::
OK
();
}
#endif
Status
GraphDataServer
::
JudgeAutoShutdownServer
()
{
TaskManager
::
FindMe
()
->
Post
();
while
(
true
)
{
if
(
auto_shutdown_
&&
(
max_connected_client_num_
>=
client_num_
)
&&
(
client_pid_
.
size
()
==
0
))
{
MS_LOG
(
INFO
)
<<
"All clients have been unregister, automatically exit the server."
;
RETURN_IF_NOT_OK
(
Stop
());
break
;
}
if
(
state_
==
kGdsStopped
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
return
Status
::
OK
();
}
Status
GraphDataServer
::
Stop
()
{
#if !defined(_WIN32) && !defined(_WIN64)
async_server_
->
Stop
();
#endif
set_state
(
kGdsStopped
);
graph_data_impl_
.
reset
();
return
Status
::
OK
();
}
Status
GraphDataServer
::
ClientRegister
(
int32_t
pid
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
MS_LOG
(
INFO
)
<<
"client register pid:"
<<
std
::
to_string
(
pid
);
client_pid_
.
emplace
(
pid
);
if
(
client_pid_
.
size
()
>
max_connected_client_num_
)
{
max_connected_client_num_
=
client_pid_
.
size
();
}
return
Status
::
OK
();
}
Status
GraphDataServer
::
ClientUnRegister
(
int32_t
pid
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
auto
itr
=
client_pid_
.
find
(
pid
);
if
(
itr
!=
client_pid_
.
end
())
{
client_pid_
.
erase
(
itr
);
MS_LOG
(
INFO
)
<<
"client unregister pid:"
<<
std
::
to_string
(
pid
);
}
return
Status
::
OK
();
}
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>
#if !defined(_WIN32) && !defined(_WIN64)
#include "grpcpp/grpcpp.h"
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#endif
#include "minddata/dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
class
GraphDataImpl
;
class
GraphDataServer
{
public:
enum
ServerState
{
kGdsUninit
=
0
,
kGdsInitializing
,
kGdsRunning
,
kGdsStopped
};
GraphDataServer
(
const
std
::
string
&
dataset_file
,
int32_t
num_workers
,
const
std
::
string
&
hostname
,
int32_t
port
,
int32_t
client_num
,
bool
auto_shutdown
);
~
GraphDataServer
()
=
default
;
Status
Init
();
Status
Stop
();
Status
ClientRegister
(
int32_t
pid
);
Status
ClientUnRegister
(
int32_t
pid
);
enum
ServerState
state
()
{
return
state_
;
}
bool
IsStoped
()
{
if
(
state_
==
kGdsStopped
)
{
return
true
;
}
else
{
return
false
;
}
}
private:
void
set_state
(
enum
ServerState
state
)
{
state_
=
state
;
}
Status
InitGraphDataImpl
();
#if !defined(_WIN32) && !defined(_WIN64)
Status
StartAsyncRpcService
();
#endif
Status
JudgeAutoShutdownServer
();
std
::
string
dataset_file_
;
int32_t
num_workers_
;
// The number of worker threads
int32_t
client_num_
;
int32_t
max_connected_client_num_
;
bool
auto_shutdown_
;
enum
ServerState
state_
;
std
::
unique_ptr
<
TaskGroup
>
tg_
;
// Class for worker management
std
::
unique_ptr
<
GraphDataImpl
>
graph_data_impl_
;
std
::
unordered_set
<
int32_t
>
client_pid_
;
std
::
mutex
mutex_
;
#if !defined(_WIN32) && !defined(_WIN64)
std
::
unique_ptr
<
GraphDataServiceImpl
>
service_impl_
;
std
::
unique_ptr
<
GrpcAsyncServer
>
async_server_
;
#endif
};
#if !defined(_WIN32) && !defined(_WIN64)
class
UntypedCall
{
public:
virtual
~
UntypedCall
()
{}
virtual
Status
operator
()()
=
0
;
};
template
<
class
ServiceImpl
,
class
AsyncService
,
class
RequestMessage
,
class
ResponseMessage
>
class
CallData
:
public
UntypedCall
{
public:
enum
class
STATE
:
int8_t
{
CREATE
=
1
,
PROCESS
=
2
,
FINISH
=
3
};
using
EnqueueFunction
=
void
(
AsyncService
::*
)(
grpc
::
ServerContext
*
,
RequestMessage
*
,
grpc
::
ServerAsyncResponseWriter
<
ResponseMessage
>
*
,
grpc
::
CompletionQueue
*
,
grpc
::
ServerCompletionQueue
*
,
void
*
);
using
HandleRequestFunction
=
grpc
::
Status
(
ServiceImpl
::*
)(
grpc
::
ServerContext
*
,
const
RequestMessage
*
,
ResponseMessage
*
);
CallData
(
ServiceImpl
*
service_impl
,
AsyncService
*
async_service
,
grpc
::
ServerCompletionQueue
*
cq
,
EnqueueFunction
enqueue_function
,
HandleRequestFunction
handle_request_function
)
:
status_
(
STATE
::
CREATE
),
service_impl_
(
service_impl
),
async_service_
(
async_service
),
cq_
(
cq
),
enqueue_function_
(
enqueue_function
),
handle_request_function_
(
handle_request_function
),
responder_
(
&
ctx_
)
{}
~
CallData
()
=
default
;
static
Status
EnqueueRequest
(
ServiceImpl
*
service_impl
,
AsyncService
*
async_service
,
grpc
::
ServerCompletionQueue
*
cq
,
EnqueueFunction
enqueue_function
,
HandleRequestFunction
handle_request_function
)
{
auto
call
=
new
CallData
<
ServiceImpl
,
AsyncService
,
RequestMessage
,
ResponseMessage
>
(
service_impl
,
async_service
,
cq
,
enqueue_function
,
handle_request_function
);
RETURN_IF_NOT_OK
((
*
call
)());
return
Status
::
OK
();
}
Status
operator
()()
{
if
(
status_
==
STATE
::
CREATE
)
{
status_
=
STATE
::
PROCESS
;
(
async_service_
->*
enqueue_function_
)(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
else
if
(
status_
==
STATE
::
PROCESS
)
{
EnqueueRequest
(
service_impl_
,
async_service_
,
cq_
,
enqueue_function_
,
handle_request_function_
);
status_
=
STATE
::
FINISH
;
// new CallData(service_, cq_, this->s_type_);
grpc
::
Status
s
=
(
service_impl_
->*
handle_request_function_
)(
&
ctx_
,
&
request_
,
&
response_
);
responder_
.
Finish
(
response_
,
s
,
this
);
}
else
{
GPR_ASSERT
(
status_
==
STATE
::
FINISH
);
delete
this
;
}
return
Status
::
OK
();
}
private:
STATE
status_
;
ServiceImpl
*
service_impl_
;
AsyncService
*
async_service_
;
grpc
::
ServerCompletionQueue
*
cq_
;
EnqueueFunction
enqueue_function_
;
HandleRequestFunction
handle_request_function_
;
grpc
::
ServerContext
ctx_
;
grpc
::
ServerAsyncResponseWriter
<
ResponseMessage
>
responder_
;
RequestMessage
request_
;
ResponseMessage
response_
;
};
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \
do { \
Status s = \
CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \
&gnn::GraphDataServiceImpl::method); \
RETURN_IF_NOT_OK(s); \
} while (0)
class
GraphDataGrpcServer
:
public
GrpcAsyncServer
{
public:
GraphDataGrpcServer
(
const
std
::
string
&
host
,
int32_t
port
,
GraphDataServiceImpl
*
service_impl
)
:
GrpcAsyncServer
(
host
,
port
),
service_impl_
(
service_impl
)
{}
Status
RegisterService
(
grpc
::
ServerBuilder
*
builder
)
{
builder
->
RegisterService
(
&
svc_
);
return
Status
::
OK
();
}
Status
EnqueueRequest
()
{
ENQUEUE_REQUEST
(
service_impl_
,
&
svc_
,
cq_
.
get
(),
ClientRegister
,
GnnClientRegisterRequestPb
,
GnnClientRegisterResponsePb
);
ENQUEUE_REQUEST
(
service_impl_
,
&
svc_
,
cq_
.
get
(),
ClientUnRegister
,
GnnClientUnRegisterRequestPb
,
GnnClientUnRegisterResponsePb
);
ENQUEUE_REQUEST
(
service_impl_
,
&
svc_
,
cq_
.
get
(),
GetGraphData
,
GnnGraphDataRequestPb
,
GnnGraphDataResponsePb
);
ENQUEUE_REQUEST
(
service_impl_
,
&
svc_
,
cq_
.
get
(),
GetMetaInfo
,
GnnMetaInfoRequestPb
,
GnnMetaInfoResponsePb
);
return
Status
::
OK
();
}
Status
ProcessRequest
(
void
*
tag
)
{
auto
rq
=
static_cast
<
UntypedCall
*>
(
tag
);
RETURN_IF_NOT_OK
((
*
rq
)());
return
Status
::
OK
();
}
private:
GraphDataServiceImpl
*
service_impl_
;
GnnGraphData
::
AsyncService
svc_
;
};
#endif
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
using
pFunction
=
Status
(
GraphDataServiceImpl
::*
)(
const
GnnGraphDataRequestPb
*
,
GnnGraphDataResponsePb
*
);
static
std
::
unordered_map
<
uint32_t
,
pFunction
>
g_get_graph_data_func_
=
{
{
GET_ALL_NODES
,
&
GraphDataServiceImpl
::
GetAllNodes
},
{
GET_ALL_EDGES
,
&
GraphDataServiceImpl
::
GetAllEdges
},
{
GET_NODES_FROM_EDGES
,
&
GraphDataServiceImpl
::
GetNodesFromEdges
},
{
GET_ALL_NEIGHBORS
,
&
GraphDataServiceImpl
::
GetAllNeighbors
},
{
GET_SAMPLED_NEIGHBORS
,
&
GraphDataServiceImpl
::
GetSampledNeighbors
},
{
GET_NEG_SAMPLED_NEIGHBORS
,
&
GraphDataServiceImpl
::
GetNegSampledNeighbors
},
{
RANDOM_WALK
,
&
GraphDataServiceImpl
::
RandomWalk
},
{
GET_NODE_FEATURE
,
&
GraphDataServiceImpl
::
GetNodeFeature
},
{
GET_EDGE_FEATURE
,
&
GraphDataServiceImpl
::
GetEdgeFeature
}};
GraphDataServiceImpl
::
GraphDataServiceImpl
(
GraphDataServer
*
server
,
GraphDataImpl
*
graph_data_impl
)
:
server_
(
server
),
graph_data_impl_
(
graph_data_impl
)
{}
Status
GraphDataServiceImpl
::
FillDefaultFeature
(
GnnClientRegisterResponsePb
*
response
)
{
const
auto
default_node_features
=
graph_data_impl_
->
GetAllDefaultNodeFeatures
();
for
(
const
auto
feature
:
*
default_node_features
)
{
GnnFeatureInfoPb
*
feature_info
=
response
->
add_default_node_feature
();
feature_info
->
set_type
(
feature
.
first
);
RETURN_IF_NOT_OK
(
TensorToPb
(
feature
.
second
->
Value
(),
feature_info
->
mutable_feature
()));
}
const
auto
default_edge_features
=
graph_data_impl_
->
GetAllDefaultEdgeFeatures
();
for
(
const
auto
feature
:
*
default_edge_features
)
{
GnnFeatureInfoPb
*
feature_info
=
response
->
add_default_edge_feature
();
feature_info
->
set_type
(
feature
.
first
);
RETURN_IF_NOT_OK
(
TensorToPb
(
feature
.
second
->
Value
(),
feature_info
->
mutable_feature
()));
}
return
Status
::
OK
();
}
grpc
::
Status
GraphDataServiceImpl
::
ClientRegister
(
grpc
::
ServerContext
*
context
,
const
GnnClientRegisterRequestPb
*
request
,
GnnClientRegisterResponsePb
*
response
)
{
Status
s
=
server_
->
ClientRegister
(
request
->
pid
());
if
(
s
.
IsOk
())
{
switch
(
server_
->
state
())
{
case
GraphDataServer
::
kGdsUninit
:
case
GraphDataServer
::
kGdsInitializing
:
response
->
set_error_msg
(
"Initializing"
);
break
;
case
GraphDataServer
::
kGdsRunning
:
response
->
set_error_msg
(
"Success"
);
response
->
set_data_schema
(
graph_data_impl_
->
GetDataSchema
());
response
->
set_shared_memory_key
(
graph_data_impl_
->
GetSharedMemoryKey
());
response
->
set_shared_memory_size
(
graph_data_impl_
->
GetSharedMemorySize
());
s
=
FillDefaultFeature
(
response
);
if
(
!
s
.
IsOk
())
{
response
->
set_error_msg
(
s
.
ToString
());
}
break
;
case
GraphDataServer
::
kGdsStopped
:
response
->
set_error_msg
(
"Stoped"
);
break
;
}
}
else
{
response
->
set_error_msg
(
s
.
ToString
());
}
return
::
grpc
::
Status
::
OK
;
}
grpc
::
Status
GraphDataServiceImpl
::
ClientUnRegister
(
grpc
::
ServerContext
*
context
,
const
GnnClientUnRegisterRequestPb
*
request
,
GnnClientUnRegisterResponsePb
*
response
)
{
Status
s
=
server_
->
ClientUnRegister
(
request
->
pid
());
if
(
s
.
IsOk
())
{
response
->
set_error_msg
(
"Success"
);
}
else
{
response
->
set_error_msg
(
s
.
ToString
());
}
return
::
grpc
::
Status
::
OK
;
}
grpc
::
Status
GraphDataServiceImpl
::
GetGraphData
(
grpc
::
ServerContext
*
context
,
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
// MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name();
Status
s
;
auto
iter
=
g_get_graph_data_func_
.
find
(
request
->
op_name
());
if
(
iter
!=
g_get_graph_data_func_
.
end
())
{
pFunction
func
=
iter
->
second
;
s
=
(
this
->*
func
)(
request
,
response
);
if
(
s
.
IsOk
())
{
response
->
set_error_msg
(
"Success"
);
}
else
{
response
->
set_error_msg
(
s
.
ToString
());
}
}
else
{
response
->
set_error_msg
(
"Invalid op name."
);
}
// MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name();
return
::
grpc
::
Status
::
OK
;
}
grpc
::
Status
GraphDataServiceImpl
::
GetMetaInfo
(
grpc
::
ServerContext
*
context
,
const
GnnMetaInfoRequestPb
*
request
,
GnnMetaInfoResponsePb
*
response
)
{
MetaInfo
meta_info
;
Status
s
=
graph_data_impl_
->
GetMetaInfo
(
&
meta_info
);
if
(
s
.
IsOk
())
{
response
->
set_error_msg
(
"Success"
);
for
(
const
auto
&
type
:
meta_info
.
node_type
)
{
auto
node_info
=
response
->
add_node_info
();
node_info
->
set_type
(
static_cast
<
google
::
protobuf
::
int32
>
(
type
));
auto
itr
=
meta_info
.
node_num
.
find
(
type
);
if
(
itr
!=
meta_info
.
node_num
.
end
())
{
node_info
->
set_num
(
static_cast
<
google
::
protobuf
::
int32
>
(
itr
->
second
));
}
else
{
node_info
->
set_num
(
0
);
}
}
for
(
const
auto
&
type
:
meta_info
.
edge_type
)
{
auto
edge_info
=
response
->
add_edge_info
();
edge_info
->
set_type
(
static_cast
<
google
::
protobuf
::
int32
>
(
type
));
auto
itr
=
meta_info
.
edge_num
.
find
(
type
);
if
(
itr
!=
meta_info
.
edge_num
.
end
())
{
edge_info
->
set_num
(
static_cast
<
google
::
protobuf
::
int32
>
(
itr
->
second
));
}
else
{
edge_info
->
set_num
(
0
);
}
}
for
(
const
auto
&
type
:
meta_info
.
node_feature_type
)
{
response
->
add_node_feature_type
(
static_cast
<
google
::
protobuf
::
int32
>
(
type
));
}
for
(
const
auto
&
type
:
meta_info
.
edge_feature_type
)
{
response
->
add_edge_feature_type
(
static_cast
<
google
::
protobuf
::
int32
>
(
type
));
}
}
else
{
response
->
set_error_msg
(
s
.
ToString
());
}
return
::
grpc
::
Status
::
OK
;
}
Status
GraphDataServiceImpl
::
GetAllNodes
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
type_size
()
==
1
,
"The number of edge types is not 1"
);
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetAllNodes
(
static_cast
<
NodeType
>
(
request
->
type
()[
0
]),
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetAllEdges
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
type_size
()
==
1
,
"The number of edge types is not 1"
);
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetAllEdges
(
static_cast
<
EdgeType
>
(
request
->
type
()[
0
]),
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetNodesFromEdges
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
id_size
()
>
0
,
"The input edge id is empty"
);
std
::
vector
<
EdgeIdType
>
edge_list
;
edge_list
.
resize
(
request
->
id
().
size
());
std
::
transform
(
request
->
id
().
begin
(),
request
->
id
().
end
(),
edge_list
.
begin
(),
[](
const
google
::
protobuf
::
int32
id
)
{
return
static_cast
<
EdgeIdType
>
(
id
);
});
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetNodesFromEdges
(
edge_list
,
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetAllNeighbors
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
id_size
()
>
0
,
"The input node id is empty"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
type_size
()
==
1
,
"The number of edge types is not 1"
);
std
::
vector
<
NodeIdType
>
node_list
;
node_list
.
resize
(
request
->
id
().
size
());
std
::
transform
(
request
->
id
().
begin
(),
request
->
id
().
end
(),
node_list
.
begin
(),
[](
const
google
::
protobuf
::
int32
id
)
{
return
static_cast
<
NodeIdType
>
(
id
);
});
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetAllNeighbors
(
node_list
,
static_cast
<
NodeType
>
(
request
->
type
()[
0
]),
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetSampledNeighbors
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
id_size
()
>
0
,
"The input node id is empty"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
number_size
()
>
0
,
"The input neighbor number is empty"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
type_size
()
>
0
,
"The input neighbor type is empty"
);
std
::
vector
<
NodeIdType
>
node_list
;
node_list
.
resize
(
request
->
id
().
size
());
std
::
transform
(
request
->
id
().
begin
(),
request
->
id
().
end
(),
node_list
.
begin
(),
[](
const
google
::
protobuf
::
int32
id
)
{
return
static_cast
<
NodeIdType
>
(
id
);
});
std
::
vector
<
NodeIdType
>
neighbor_nums
;
neighbor_nums
.
resize
(
request
->
number
().
size
());
std
::
transform
(
request
->
number
().
begin
(),
request
->
number
().
end
(),
neighbor_nums
.
begin
(),
[](
const
google
::
protobuf
::
int32
num
)
{
return
static_cast
<
NodeIdType
>
(
num
);
});
std
::
vector
<
NodeType
>
neighbor_types
;
neighbor_types
.
resize
(
request
->
type
().
size
());
std
::
transform
(
request
->
type
().
begin
(),
request
->
type
().
end
(),
neighbor_types
.
begin
(),
[](
const
google
::
protobuf
::
int32
type
)
{
return
static_cast
<
NodeType
>
(
type
);
});
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetSampledNeighbors
(
node_list
,
neighbor_nums
,
neighbor_types
,
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetNegSampledNeighbors
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
id_size
()
>
0
,
"The input node id is empty"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
number_size
()
==
1
,
"The number of neighbor number is not 1"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
type_size
()
==
1
,
"The number of neighbor types is not 1"
);
std
::
vector
<
NodeIdType
>
node_list
;
node_list
.
resize
(
request
->
id
().
size
());
std
::
transform
(
request
->
id
().
begin
(),
request
->
id
().
end
(),
node_list
.
begin
(),
[](
const
google
::
protobuf
::
int32
id
)
{
return
static_cast
<
NodeIdType
>
(
id
);
});
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetNegSampledNeighbors
(
node_list
,
static_cast
<
NodeIdType
>
(
request
->
number
()[
0
]),
static_cast
<
NodeType
>
(
request
->
type
()[
0
]),
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
RandomWalk
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
id_size
()
>
0
,
"The input node id is empty"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
request
->
type_size
()
>
0
,
"The input meta path is empty"
);
std
::
vector
<
NodeIdType
>
node_list
;
node_list
.
resize
(
request
->
id
().
size
());
std
::
transform
(
request
->
id
().
begin
(),
request
->
id
().
end
(),
node_list
.
begin
(),
[](
const
google
::
protobuf
::
int32
id
)
{
return
static_cast
<
NodeIdType
>
(
id
);
});
std
::
vector
<
NodeType
>
meta_path
;
meta_path
.
resize
(
request
->
type
().
size
());
std
::
transform
(
request
->
type
().
begin
(),
request
->
type
().
end
(),
meta_path
.
begin
(),
[](
const
google
::
protobuf
::
int32
type
)
{
return
static_cast
<
NodeType
>
(
type
);
});
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
RandomWalk
(
node_list
,
meta_path
,
request
->
random_walk
().
p
(),
request
->
random_walk
().
q
(),
request
->
random_walk
().
default_id
(),
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetNodeFeature
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
std
::
shared_ptr
<
Tensor
>
nodes
;
RETURN_IF_NOT_OK
(
PbToTensor
(
&
request
->
id_tensor
(),
&
nodes
));
for
(
const
auto
&
type
:
request
->
type
())
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetNodeFeatureSharedMemory
(
nodes
,
type
,
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
}
return
Status
::
OK
();
}
Status
GraphDataServiceImpl
::
GetEdgeFeature
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
)
{
std
::
shared_ptr
<
Tensor
>
edges
;
RETURN_IF_NOT_OK
(
PbToTensor
(
&
request
->
id_tensor
(),
&
edges
));
for
(
const
auto
&
type
:
request
->
type
())
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_data_impl_
->
GetEdgeFeatureSharedMemory
(
edges
,
type
,
&
tensor
));
TensorPb
*
result
=
response
->
add_result_data
();
RETURN_IF_NOT_OK
(
TensorToPb
(
tensor
,
result
));
}
return
Status
::
OK
();
}
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#include <memory>
#include <string>
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
class
GraphDataServer
;
// class GraphDataServiceImpl : public GnnGraphData::Service {
class
GraphDataServiceImpl
{
public:
GraphDataServiceImpl
(
GraphDataServer
*
server
,
GraphDataImpl
*
graph_data_impl
);
~
GraphDataServiceImpl
()
=
default
;
grpc
::
Status
ClientRegister
(
grpc
::
ServerContext
*
context
,
const
GnnClientRegisterRequestPb
*
request
,
GnnClientRegisterResponsePb
*
response
);
grpc
::
Status
ClientUnRegister
(
grpc
::
ServerContext
*
context
,
const
GnnClientUnRegisterRequestPb
*
request
,
GnnClientUnRegisterResponsePb
*
response
);
grpc
::
Status
GetGraphData
(
grpc
::
ServerContext
*
context
,
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
grpc
::
Status
GetMetaInfo
(
grpc
::
ServerContext
*
context
,
const
GnnMetaInfoRequestPb
*
request
,
GnnMetaInfoResponsePb
*
response
);
Status
GetAllNodes
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetAllEdges
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetNodesFromEdges
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetAllNeighbors
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetSampledNeighbors
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetNegSampledNeighbors
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
RandomWalk
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetNodeFeature
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
Status
GetEdgeFeature
(
const
GnnGraphDataRequestPb
*
request
,
GnnGraphDataResponsePb
*
response
);
private:
Status
FillDefaultFeature
(
GnnClientRegisterResponsePb
*
response
);
GraphDataServer
*
server_
;
GraphDataImpl
*
graph_data_impl_
;
};
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#include <memory>
#include <utility>
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
using
mindrecord
::
MSRStatus
;
GraphFeatureParser
::
GraphFeatureParser
(
const
ShardColumn
&
shard_column
)
{
shard_column_
=
std
::
make_unique
<
ShardColumn
>
(
shard_column
);
}
Status
GraphFeatureParser
::
LoadFeatureTensor
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
col_blob
,
std
::
shared_ptr
<
Tensor
>
*
tensor
)
{
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
uint64_t
n_bytes
=
0
,
col_type_size
=
1
;
mindrecord
::
ColumnDataType
col_type
=
mindrecord
::
ColumnNoDataType
;
std
::
vector
<
int64_t
>
column_shape
;
MSRStatus
rs
=
shard_column_
->
GetColumnValueByName
(
key
,
col_blob
,
{},
&
data
,
&
data_ptr
,
&
n_bytes
,
&
col_type
,
&
col_type_size
,
&
column_shape
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rs
==
mindrecord
::
SUCCESS
,
"fail to load column"
+
key
);
if
(
data
==
nullptr
)
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
&
data_ptr
[
0
]);
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
std
::
move
(
TensorShape
({
static_cast
<
dsize_t
>
(
n_bytes
/
col_type_size
)})),
std
::
move
(
DataType
(
mindrecord
::
ColumnDataTypeNameNormalized
[
col_type
])),
data
,
tensor
));
return
Status
::
OK
();
}
#if !defined(_WIN32) && !defined(_WIN64)
Status
GraphFeatureParser
::
LoadFeatureToSharedMemory
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
col_blob
,
GraphSharedMemory
*
shared_memory
,
std
::
shared_ptr
<
Tensor
>
*
out_tensor
)
{
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
uint64_t
n_bytes
=
0
,
col_type_size
=
1
;
mindrecord
::
ColumnDataType
col_type
=
mindrecord
::
ColumnNoDataType
;
std
::
vector
<
int64_t
>
column_shape
;
MSRStatus
rs
=
shard_column_
->
GetColumnValueByName
(
key
,
col_blob
,
{},
&
data
,
&
data_ptr
,
&
n_bytes
,
&
col_type
,
&
col_type_size
,
&
column_shape
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rs
==
mindrecord
::
SUCCESS
,
"fail to load column"
+
key
);
if
(
data
==
nullptr
)
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
&
data_ptr
[
0
]);
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
std
::
move
(
TensorShape
({
2
})),
std
::
move
(
DataType
(
DataType
::
DE_INT64
)),
&
tensor
));
auto
fea_itr
=
tensor
->
begin
<
int64_t
>
();
int64_t
offset
=
0
;
RETURN_IF_NOT_OK
(
shared_memory
->
InsertData
(
data
,
n_bytes
,
&
offset
));
*
fea_itr
=
offset
;
++
fea_itr
;
*
fea_itr
=
n_bytes
;
*
out_tensor
=
std
::
move
(
tensor
);
return
Status
::
OK
();
}
#endif
Status
GraphFeatureParser
::
LoadFeatureIndex
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
col_blob
,
std
::
vector
<
int32_t
>
*
indices
)
{
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
uint64_t
n_bytes
=
0
,
col_type_size
=
1
;
mindrecord
::
ColumnDataType
col_type
=
mindrecord
::
ColumnNoDataType
;
std
::
vector
<
int64_t
>
column_shape
;
MSRStatus
rs
=
shard_column_
->
GetColumnValueByName
(
key
,
col_blob
,
{},
&
data
,
&
data_ptr
,
&
n_bytes
,
&
col_type
,
&
col_type_size
,
&
column_shape
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rs
==
mindrecord
::
SUCCESS
,
"fail to load column:"
+
key
);
if
(
data
==
nullptr
)
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
&
data_ptr
[
0
]);
for
(
int
i
=
0
;
i
<
n_bytes
;
i
+=
col_type_size
)
{
int32_t
feature_ind
=
-
1
;
if
(
col_type
==
mindrecord
::
ColumnInt32
)
{
feature_ind
=
*
(
reinterpret_cast
<
const
int32_t
*>
(
data
+
i
));
}
else
if
(
col_type
==
mindrecord
::
ColumnInt64
)
{
feature_ind
=
*
(
reinterpret_cast
<
const
int64_t
*>
(
data
+
i
));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Feature Index needs to be int32/int64 type!"
);
}
if
(
feature_ind
>=
0
)
indices
->
push_back
(
feature_ind
);
}
return
Status
::
OK
();
}
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
using
mindrecord
::
ShardColumn
;
class
GraphFeatureParser
{
public:
explicit
GraphFeatureParser
(
const
ShardColumn
&
shard_column
);
~
GraphFeatureParser
()
=
default
;
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status
LoadFeatureIndex
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
blob
,
std
::
vector
<
int32_t
>
*
ind
);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status
LoadFeatureTensor
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
blob
,
std
::
shared_ptr
<
Tensor
>
*
tensor
);
#if !defined(_WIN32) && !defined(_WIN64)
Status
LoadFeatureToSharedMemory
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
col_blob
,
GraphSharedMemory
*
shared_memory
,
std
::
shared_ptr
<
Tensor
>
*
out_tensor
);
#endif
private:
std
::
unique_ptr
<
ShardColumn
>
shard_column_
;
};
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc
浏览文件 @
256dccc6
...
...
@@ -13,41 +13,42 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include <future>
#include <tuple>
#include <utility>
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/local_edge.h"
#include "minddata/dataset/engine/gnn/local_node.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/mindrecord/include/shard_error.h"
using
ShardTuple
=
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
mindspore
::
mindrecord
::
json
>>
;
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
using
mindrecord
::
MSRStatus
;
GraphLoader
::
GraphLoader
(
std
::
string
mr_filepath
,
int32_t
num_workers
)
:
mr_path_
(
mr_filepath
),
GraphLoader
::
GraphLoader
(
GraphDataImpl
*
graph_impl
,
std
::
string
mr_filepath
,
int32_t
num_workers
,
bool
server_mode
)
:
graph_impl_
(
graph_impl
),
mr_path_
(
mr_filepath
),
num_workers_
(
num_workers
),
row_id_
(
0
),
shard_reader_
(
nullptr
),
graph_feature_parser_
(
nullptr
),
keys_
({
"first_id"
,
"second_id"
,
"third_id"
,
"attribute"
,
"type"
,
"node_feature_index"
,
"edge_feature_index"
})
{}
Status
GraphLoader
::
GetNodesAndEdges
(
NodeIdMap
*
n_id_map
,
EdgeIdMap
*
e_id_map
,
NodeTypeMap
*
n_type_map
,
EdgeTypeMap
*
e_type_map
,
NodeFeatureMap
*
n_feature_map
,
EdgeFeatureMap
*
e_feature_map
,
DefaultNodeFeatureMap
*
default_node_feature_map
,
DefaultEdgeFeatureMap
*
default_edge_feature_map
)
{
Status
GraphLoader
::
GetNodesAndEdges
()
{
NodeIdMap
*
n_id_map
=
&
graph_impl_
->
node_id_map_
;
EdgeIdMap
*
e_id_map
=
&
graph_impl_
->
edge_id_map_
;
for
(
std
::
deque
<
std
::
shared_ptr
<
Node
>>
&
dq
:
n_deques_
)
{
while
(
dq
.
empty
()
==
false
)
{
std
::
shared_ptr
<
Node
>
node_ptr
=
dq
.
front
();
n_id_map
->
insert
({
node_ptr
->
id
(),
node_ptr
});
(
*
n_type_map
)
[
node_ptr
->
type
()].
push_back
(
node_ptr
->
id
());
graph_impl_
->
node_type_map_
[
node_ptr
->
type
()].
push_back
(
node_ptr
->
id
());
dq
.
pop_front
();
}
}
...
...
@@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N
RETURN_IF_NOT_OK
(
edge_ptr
->
SetNode
({
src_itr
->
second
,
dst_itr
->
second
}));
RETURN_IF_NOT_OK
(
src_itr
->
second
->
AddNeighbor
(
dst_itr
->
second
));
e_id_map
->
insert
({
edge_ptr
->
id
(),
edge_ptr
});
// add edge to edge_id_map_
(
*
e_type_map
)
[
edge_ptr
->
type
()].
push_back
(
edge_ptr
->
id
());
graph_impl_
->
edge_type_map_
[
edge_ptr
->
type
()].
push_back
(
edge_ptr
->
id
());
dq
.
pop_front
();
}
}
for
(
auto
&
itr
:
*
n_type_map
)
itr
.
second
.
shrink_to_fit
();
for
(
auto
&
itr
:
*
e_type_map
)
itr
.
second
.
shrink_to_fit
();
for
(
auto
&
itr
:
graph_impl_
->
node_type_map_
)
itr
.
second
.
shrink_to_fit
();
for
(
auto
&
itr
:
graph_impl_
->
edge_type_map_
)
itr
.
second
.
shrink_to_fit
();
MergeFeatureMaps
(
n_feature_map
,
e_feature_map
,
default_node_feature_map
,
default_edge_feature_map
);
MergeFeatureMaps
();
return
Status
::
OK
();
}
...
...
@@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() {
CHECK_FAIL_RETURN_UNEXPECTED
(
shard_reader_
->
GetShardHeader
()
->
GetSchemaCount
()
>
0
,
"No schema found!"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
shard_reader_
->
Launch
(
true
)
==
MSRStatus
::
SUCCESS
,
"fail to launch mr"
);
mindrecord
::
json
schema
=
(
shard_reader_
->
GetShardHeader
()
->
GetSchemas
()[
0
]
->
GetSchema
())[
"schema"
];
graph_impl_
->
data_schema_
=
(
shard_reader_
->
GetShardHeader
()
->
GetSchemas
()[
0
]
->
GetSchema
());
mindrecord
::
json
schema
=
graph_impl_
->
data_schema_
[
"schema"
];
for
(
const
std
::
string
&
key
:
keys_
)
{
if
(
schema
.
find
(
key
)
==
schema
.
end
())
{
RETURN_STATUS_UNEXPECTED
(
key
+
":doesn't exist in schema:"
+
schema
.
dump
());
}
}
if
(
graph_impl_
->
server_mode_
)
{
#if !defined(_WIN32) && !defined(_WIN64)
int64_t
total_blob_size
=
0
;
CHECK_FAIL_RETURN_UNEXPECTED
(
shard_reader_
->
GetTotalBlobSize
(
&
total_blob_size
)
==
MSRStatus
::
SUCCESS
,
"failed to get total blob size"
);
graph_impl_
->
graph_shared_memory_
=
std
::
make_unique
<
GraphSharedMemory
>
(
total_blob_size
,
mr_path_
);
RETURN_IF_NOT_OK
(
graph_impl_
->
graph_shared_memory_
->
CreateSharedMemory
());
#endif
}
graph_feature_parser_
=
std
::
make_unique
<
GraphFeatureParser
>
(
*
shard_reader_
->
GetShardColumn
());
// launching worker threads
for
(
int
wkr_id
=
0
;
wkr_id
<
num_workers_
;
++
wkr_id
)
{
RETURN_IF_NOT_OK
(
vg
.
CreateAsyncTask
(
"GraphLoader"
,
std
::
bind
(
&
GraphLoader
::
WorkerEntry
,
this
,
wkr_id
)));
...
...
@@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
NodeType
node_type
=
static_cast
<
NodeType
>
(
col_jsn
[
"type"
]);
(
*
node
)
=
std
::
make_shared
<
LocalNode
>
(
node_id
,
node_type
);
std
::
vector
<
int32_t
>
indices
;
RETURN_IF_NOT_OK
(
LoadFeatureIndex
(
"node_feature_index"
,
col_blob
,
col_jsn
,
&
indices
));
for
(
int32_t
ind
:
indices
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
LoadFeatureTensor
(
"node_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
col_jsn
,
&
tensor
));
RETURN_IF_NOT_OK
((
*
node
)
->
UpdateFeature
(
std
::
make_shared
<
Feature
>
(
ind
,
tensor
)));
(
*
feature_map
)[
node_type
].
insert
(
ind
);
if
((
*
default_feature
)[
ind
]
==
nullptr
)
{
std
::
shared_ptr
<
Tensor
>
zero_tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
tensor
->
shape
(),
tensor
->
type
(),
&
zero_tensor
));
RETURN_IF_NOT_OK
(
zero_tensor
->
Zero
());
(
*
default_feature
)[
ind
]
=
std
::
make_shared
<
Feature
>
(
ind
,
zero_tensor
);
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureIndex
(
"node_feature_index"
,
col_blob
,
&
indices
));
if
(
graph_impl_
->
server_mode_
)
{
#if !defined(_WIN32) && !defined(_WIN64)
for
(
int32_t
ind
:
indices
)
{
std
::
shared_ptr
<
Tensor
>
tensor_sm
;
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureToSharedMemory
(
"node_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
graph_impl_
->
graph_shared_memory_
.
get
(),
&
tensor_sm
));
RETURN_IF_NOT_OK
((
*
node
)
->
UpdateFeature
(
std
::
make_shared
<
Feature
>
(
ind
,
tensor_sm
,
true
)));
(
*
feature_map
)[
node_type
].
insert
(
ind
);
if
((
*
default_feature
)[
ind
]
==
nullptr
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureTensor
(
"node_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
&
tensor
));
std
::
shared_ptr
<
Tensor
>
zero_tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
tensor
->
shape
(),
tensor
->
type
(),
&
zero_tensor
));
RETURN_IF_NOT_OK
(
zero_tensor
->
Zero
());
(
*
default_feature
)[
ind
]
=
std
::
make_shared
<
Feature
>
(
ind
,
zero_tensor
);
}
}
#endif
}
else
{
for
(
int32_t
ind
:
indices
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureTensor
(
"node_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
&
tensor
));
RETURN_IF_NOT_OK
((
*
node
)
->
UpdateFeature
(
std
::
make_shared
<
Feature
>
(
ind
,
tensor
)));
(
*
feature_map
)[
node_type
].
insert
(
ind
);
if
((
*
default_feature
)[
ind
]
==
nullptr
)
{
std
::
shared_ptr
<
Tensor
>
zero_tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
tensor
->
shape
(),
tensor
->
type
(),
&
zero_tensor
));
RETURN_IF_NOT_OK
(
zero_tensor
->
Zero
());
(
*
default_feature
)[
ind
]
=
std
::
make_shared
<
Feature
>
(
ind
,
zero_tensor
);
}
}
}
return
Status
::
OK
();
...
...
@@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
std
::
shared_ptr
<
Node
>
dst
=
std
::
make_shared
<
LocalNode
>
(
dst_id
,
-
1
);
(
*
edge
)
=
std
::
make_shared
<
LocalEdge
>
(
edge_id
,
edge_type
,
src
,
dst
);
std
::
vector
<
int32_t
>
indices
;
RETURN_IF_NOT_OK
(
LoadFeatureIndex
(
"edge_feature_index"
,
col_blob
,
col_jsn
,
&
indices
));
for
(
int32_t
ind
:
indices
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
LoadFeatureTensor
(
"edge_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
col_jsn
,
&
tensor
));
RETURN_IF_NOT_OK
((
*
edge
)
->
UpdateFeature
(
std
::
make_shared
<
Feature
>
(
ind
,
tensor
)));
(
*
feature_map
)[
edge_type
].
insert
(
ind
);
if
((
*
default_feature
)[
ind
]
==
nullptr
)
{
std
::
shared_ptr
<
Tensor
>
zero_tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
tensor
->
shape
(),
tensor
->
type
(),
&
zero_tensor
));
RETURN_IF_NOT_OK
(
zero_tensor
->
Zero
());
(
*
default_feature
)[
ind
]
=
std
::
make_shared
<
Feature
>
(
ind
,
zero_tensor
);
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureIndex
(
"edge_feature_index"
,
col_blob
,
&
indices
));
if
(
graph_impl_
->
server_mode_
)
{
#if !defined(_WIN32) && !defined(_WIN64)
for
(
int32_t
ind
:
indices
)
{
std
::
shared_ptr
<
Tensor
>
tensor_sm
;
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureToSharedMemory
(
"edge_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
graph_impl_
->
graph_shared_memory_
.
get
(),
&
tensor_sm
));
RETURN_IF_NOT_OK
((
*
edge
)
->
UpdateFeature
(
std
::
make_shared
<
Feature
>
(
ind
,
tensor_sm
,
true
)));
(
*
feature_map
)[
edge_type
].
insert
(
ind
);
if
((
*
default_feature
)[
ind
]
==
nullptr
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureTensor
(
"edge_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
&
tensor
));
std
::
shared_ptr
<
Tensor
>
zero_tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
tensor
->
shape
(),
tensor
->
type
(),
&
zero_tensor
));
RETURN_IF_NOT_OK
(
zero_tensor
->
Zero
());
(
*
default_feature
)[
ind
]
=
std
::
make_shared
<
Feature
>
(
ind
,
zero_tensor
);
}
}
}
return
Status
::
OK
();
}
Status
GraphLoader
::
LoadFeatureTensor
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
col_blob
,
const
mindrecord
::
json
&
col_jsn
,
std
::
shared_ptr
<
Tensor
>
*
tensor
)
{
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
uint64_t
n_bytes
=
0
,
col_type_size
=
1
;
mindrecord
::
ColumnDataType
col_type
=
mindrecord
::
ColumnNoDataType
;
std
::
vector
<
int64_t
>
column_shape
;
MSRStatus
rs
=
shard_reader_
->
GetShardColumn
()
->
GetColumnValueByName
(
key
,
col_blob
,
col_jsn
,
&
data
,
&
data_ptr
,
&
n_bytes
,
&
col_type
,
&
col_type_size
,
&
column_shape
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rs
==
mindrecord
::
SUCCESS
,
"fail to load column"
+
key
);
if
(
data
==
nullptr
)
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
&
data_ptr
[
0
]);
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
std
::
move
(
TensorShape
({
static_cast
<
dsize_t
>
(
n_bytes
/
col_type_size
)})),
std
::
move
(
DataType
(
mindrecord
::
ColumnDataTypeNameNormalized
[
col_type
])),
data
,
tensor
));
return
Status
::
OK
();
}
Status
GraphLoader
::
LoadFeatureIndex
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
col_blob
,
const
mindrecord
::
json
&
col_jsn
,
std
::
vector
<
int32_t
>
*
indices
)
{
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
uint64_t
n_bytes
=
0
,
col_type_size
=
1
;
mindrecord
::
ColumnDataType
col_type
=
mindrecord
::
ColumnNoDataType
;
std
::
vector
<
int64_t
>
column_shape
;
MSRStatus
rs
=
shard_reader_
->
GetShardColumn
()
->
GetColumnValueByName
(
key
,
col_blob
,
col_jsn
,
&
data
,
&
data_ptr
,
&
n_bytes
,
&
col_type
,
&
col_type_size
,
&
column_shape
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rs
==
mindrecord
::
SUCCESS
,
"fail to load column:"
+
key
);
if
(
data
==
nullptr
)
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
&
data_ptr
[
0
]);
for
(
int
i
=
0
;
i
<
n_bytes
;
i
+=
col_type_size
)
{
int32_t
feature_ind
=
-
1
;
if
(
col_type
==
mindrecord
::
ColumnInt32
)
{
feature_ind
=
*
(
reinterpret_cast
<
const
int32_t
*>
(
data
+
i
));
}
else
if
(
col_type
==
mindrecord
::
ColumnInt64
)
{
feature_ind
=
*
(
reinterpret_cast
<
const
int64_t
*>
(
data
+
i
));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Feature Index needs to be int32/int64 type!"
);
#endif
}
else
{
for
(
int32_t
ind
:
indices
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
graph_feature_parser_
->
LoadFeatureTensor
(
"edge_feature_"
+
std
::
to_string
(
ind
),
col_blob
,
&
tensor
));
RETURN_IF_NOT_OK
((
*
edge
)
->
UpdateFeature
(
std
::
make_shared
<
Feature
>
(
ind
,
tensor
)));
(
*
feature_map
)[
edge_type
].
insert
(
ind
);
if
((
*
default_feature
)[
ind
]
==
nullptr
)
{
std
::
shared_ptr
<
Tensor
>
zero_tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
tensor
->
shape
(),
tensor
->
type
(),
&
zero_tensor
));
RETURN_IF_NOT_OK
(
zero_tensor
->
Zero
());
(
*
default_feature
)[
ind
]
=
std
::
make_shared
<
Feature
>
(
ind
,
zero_tensor
);
}
}
if
(
feature_ind
>=
0
)
indices
->
push_back
(
feature_ind
);
}
return
Status
::
OK
();
}
...
...
@@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) {
return
Status
::
OK
();
}
void
GraphLoader
::
MergeFeatureMaps
(
NodeFeatureMap
*
n_feature_map
,
EdgeFeatureMap
*
e_feature_map
,
DefaultNodeFeatureMap
*
default_node_feature_map
,
DefaultEdgeFeatureMap
*
default_edge_feature_map
)
{
void
GraphLoader
::
MergeFeatureMaps
()
{
for
(
int
wkr_id
=
0
;
wkr_id
<
num_workers_
;
wkr_id
++
)
{
for
(
auto
&
m
:
n_feature_maps_
[
wkr_id
])
{
for
(
auto
&
n
:
m
.
second
)
(
*
n_feature_map
)
[
m
.
first
].
insert
(
n
);
for
(
auto
&
n
:
m
.
second
)
graph_impl_
->
node_feature_map_
[
m
.
first
].
insert
(
n
);
}
for
(
auto
&
m
:
e_feature_maps_
[
wkr_id
])
{
for
(
auto
&
n
:
m
.
second
)
(
*
e_feature_map
)
[
m
.
first
].
insert
(
n
);
for
(
auto
&
n
:
m
.
second
)
graph_impl_
->
edge_feature_map_
[
m
.
first
].
insert
(
n
);
}
for
(
auto
&
m
:
default_node_feature_maps_
[
wkr_id
])
{
(
*
default_node_feature_map
)
[
m
.
first
]
=
m
.
second
;
graph_impl_
->
default_node_feature_map_
[
m
.
first
]
=
m
.
second
;
}
for
(
auto
&
m
:
default_edge_feature_maps_
[
wkr_id
])
{
(
*
default_edge_feature_map
)
[
m
.
first
]
=
m
.
second
;
graph_impl_
->
default_edge_feature_map_
[
m
.
first
]
=
m
.
second
;
}
}
n_feature_maps_
.
clear
();
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h
浏览文件 @
256dccc6
...
...
@@ -26,10 +26,13 @@
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_reader.h"
namespace
mindspore
{
...
...
@@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy
using
DefaultNodeFeatureMap
=
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Feature
>>
;
using
DefaultEdgeFeatureMap
=
std
::
unordered_map
<
FeatureType
,
std
::
shared_ptr
<
Feature
>>
;
class
GraphDataImpl
;
// this class interfaces with the underlying storage format (mindrecord)
// it returns raw nodes and edges via GetNodesAndEdges
// it is then the responsibility of graph to construct itself based on the nodes and edges
// if needed, this class could become a base where each derived class handles a specific storage format
class
GraphLoader
{
public:
explicit
GraphLoader
(
std
::
string
mr_filepath
,
int32_t
num_workers
=
4
);
GraphLoader
(
GraphDataImpl
*
graph_impl
,
std
::
string
mr_filepath
,
int32_t
num_workers
=
4
,
bool
server_mode
=
false
);
~
GraphLoader
()
=
default
;
// Init mindrecord and load everything into memory multi-threaded
...
...
@@ -63,8 +68,7 @@ class GraphLoader {
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
// features attached to each node and edge are expected to be filled correctly
Status
GetNodesAndEdges
(
NodeIdMap
*
,
EdgeIdMap
*
,
NodeTypeMap
*
,
EdgeTypeMap
*
,
NodeFeatureMap
*
,
EdgeFeatureMap
*
,
DefaultNodeFeatureMap
*
,
DefaultEdgeFeatureMap
*
);
Status
GetNodesAndEdges
();
private:
//
...
...
@@ -92,29 +96,15 @@ class GraphLoader {
Status
LoadEdge
(
const
std
::
vector
<
uint8_t
>
&
blob
,
const
mindrecord
::
json
&
jsn
,
std
::
shared_ptr
<
Edge
>
*
edge
,
EdgeFeatureMap
*
feature_map
,
DefaultEdgeFeatureMap
*
default_feature
);
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status
LoadFeatureIndex
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
blob
,
const
mindrecord
::
json
&
jsn
,
std
::
vector
<
int32_t
>
*
ind
);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status
LoadFeatureTensor
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>
&
blob
,
const
mindrecord
::
json
&
jsn
,
std
::
shared_ptr
<
Tensor
>
*
tensor
);
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
void
MergeFeatureMaps
(
NodeFeatureMap
*
,
EdgeFeatureMap
*
,
DefaultNodeFeatureMap
*
,
DefaultEdgeFeatureMap
*
);
void
MergeFeatureMaps
();
GraphDataImpl
*
graph_impl_
;
std
::
string
mr_path_
;
const
int32_t
num_workers_
;
std
::
atomic_int
row_id_
;
std
::
string
mr_path_
;
std
::
unique_ptr
<
ShardReader
>
shard_reader_
;
std
::
unique_ptr
<
GraphFeatureParser
>
graph_feature_parser_
;
std
::
vector
<
std
::
deque
<
std
::
shared_ptr
<
Node
>>>
n_deques_
;
std
::
vector
<
std
::
deque
<
std
::
shared_ptr
<
Edge
>>>
e_deques_
;
std
::
vector
<
NodeFeatureMap
>
n_feature_maps_
;
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#include <string>
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
GraphSharedMemory
::
GraphSharedMemory
(
int64_t
memory_size
,
key_t
memory_key
)
:
memory_size_
(
memory_size
),
memory_key_
(
memory_key
),
memory_ptr_
(
nullptr
),
memory_offset_
(
0
),
is_new_create_
(
false
)
{
std
::
stringstream
stream
;
stream
<<
std
::
hex
<<
memory_key_
;
memory_key_str_
=
stream
.
str
();
}
GraphSharedMemory
::
GraphSharedMemory
(
int64_t
memory_size
,
const
std
::
string
&
mr_file
)
:
mr_file_
(
mr_file
),
memory_size_
(
memory_size
),
memory_key_
(
-
1
),
memory_ptr_
(
nullptr
),
memory_offset_
(
0
),
is_new_create_
(
false
)
{}
GraphSharedMemory
::~
GraphSharedMemory
()
{
if
(
is_new_create_
)
{
(
void
)
DeleteSharedMemory
();
}
}
Status
GraphSharedMemory
::
CreateSharedMemory
()
{
if
(
memory_key_
==
-
1
)
{
// ftok to generate unique key
memory_key_
=
ftok
(
mr_file_
.
data
(),
kGnnSharedMemoryId
);
CHECK_FAIL_RETURN_UNEXPECTED
(
memory_key_
!=
-
1
,
"Failed to get key of shared memory. file_name:"
+
mr_file_
);
std
::
stringstream
stream
;
stream
<<
std
::
hex
<<
memory_key_
;
memory_key_str_
=
stream
.
str
();
}
int
shmflg
=
(
0666
|
IPC_CREAT
|
IPC_EXCL
);
Status
s
=
SharedMemoryImpl
(
shmflg
);
if
(
s
.
IsOk
())
{
is_new_create_
=
true
;
MS_LOG
(
INFO
)
<<
"Create shared memory success, key=0x"
<<
memory_key_str_
;
}
else
{
MS_LOG
(
WARNING
)
<<
"Shared memory with the same key may already exist, key=0x"
<<
memory_key_str_
;
shmflg
=
(
0666
|
IPC_CREAT
);
s
=
SharedMemoryImpl
(
shmflg
);
if
(
!
s
.
IsOk
())
{
RETURN_STATUS_UNEXPECTED
(
"Create shared memory fao;ed, key=0x"
+
memory_key_str_
);
}
}
return
Status
::
OK
();
}
Status
GraphSharedMemory
::
GetSharedMemory
()
{
int
shmflg
=
0
;
RETURN_IF_NOT_OK
(
SharedMemoryImpl
(
shmflg
));
return
Status
::
OK
();
}
Status
GraphSharedMemory
::
DeleteSharedMemory
()
{
int
shmid
=
shmget
(
memory_key_
,
0
,
0
);
CHECK_FAIL_RETURN_UNEXPECTED
(
shmid
!=
-
1
,
"Failed to get shared memory. key=0x"
+
memory_key_str_
);
int
result
=
shmctl
(
shmid
,
IPC_RMID
,
0
);
CHECK_FAIL_RETURN_UNEXPECTED
(
result
!=
-
1
,
"Failed to delete shared memory. key=0x"
+
memory_key_str_
);
return
Status
::
OK
();
}
Status
GraphSharedMemory
::
SharedMemoryImpl
(
const
int
&
shmflg
)
{
// shmget returns an identifier in shmid
int
shmid
=
shmget
(
memory_key_
,
memory_size_
,
shmflg
);
CHECK_FAIL_RETURN_UNEXPECTED
(
shmid
!=
-
1
,
"Failed to get shared memory. key=0x"
+
memory_key_str_
);
// shmat to attach to shared memory
auto
data
=
shmat
(
shmid
,
reinterpret_cast
<
void
*>
(
0
),
0
);
CHECK_FAIL_RETURN_UNEXPECTED
(
data
!=
(
char
*
)(
-
1
),
"Failed to address shared memory. key=0x"
+
memory_key_str_
);
memory_ptr_
=
reinterpret_cast
<
uint8_t
*>
(
data
);
return
Status
::
OK
();
}
Status
GraphSharedMemory
::
InsertData
(
const
uint8_t
*
data
,
int64_t
len
,
int64_t
*
offset
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
data
,
"Input data is nullptr."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
len
>
0
,
"Input len is invalid."
);
std
::
lock_guard
<
std
::
mutex
>
lck
(
mutex_
);
CHECK_FAIL_RETURN_UNEXPECTED
((
memory_size_
-
memory_offset_
>=
len
),
"Insufficient shared memory space to insert data."
);
if
(
EOK
!=
memcpy_s
(
memory_ptr_
+
memory_offset_
,
memory_size_
-
memory_offset_
,
data
,
len
))
{
RETURN_STATUS_UNEXPECTED
(
"Failed to insert data into shared memory."
);
}
*
offset
=
memory_offset_
;
memory_offset_
+=
len
;
return
Status
::
OK
();
}
Status
GraphSharedMemory
::
GetData
(
uint8_t
*
data
,
int64_t
data_len
,
int64_t
offset
,
int64_t
get_data_len
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
data
,
"Input data is nullptr."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
get_data_len
>
0
,
"Input get_data_len is invalid."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
data_len
>=
get_data_len
,
"Insufficient target address space."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
memory_size_
>=
get_data_len
+
offset
,
"get_data_len is too large, beyond the space of shared memory."
);
if
(
EOK
!=
memcpy_s
(
data
,
data_len
,
memory_ptr_
+
offset
,
get_data_len
))
{
RETURN_STATUS_UNEXPECTED
(
"Failed to insert data into shared memory."
);
}
return
Status
::
OK
();
}
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#include <sys/ipc.h>
#include <sys/shm.h>
#include <mutex>
#include <string>
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
namespace
gnn
{
const
int
kGnnSharedMemoryId
=
65
;
class
GraphSharedMemory
{
public:
explicit
GraphSharedMemory
(
int64_t
memory_size
,
key_t
memory_key
);
explicit
GraphSharedMemory
(
int64_t
memory_size
,
const
std
::
string
&
mr_file
);
~
GraphSharedMemory
();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status
CreateSharedMemory
();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status
GetSharedMemory
();
Status
DeleteSharedMemory
();
Status
InsertData
(
const
uint8_t
*
data
,
int64_t
len
,
int64_t
*
offset
);
Status
GetData
(
uint8_t
*
data
,
int64_t
data_len
,
int64_t
offset
,
int64_t
get_data_len
);
key_t
memory_key
()
{
return
memory_key_
;
}
int64_t
memory_size
()
{
return
memory_size_
;
}
private:
Status
SharedMemoryImpl
(
const
int
&
shmflg
);
std
::
string
mr_file_
;
int64_t
memory_size_
;
key_t
memory_key_
;
std
::
string
memory_key_str_
;
uint8_t
*
memory_ptr_
;
int64_t
memory_offset_
;
std
::
mutex
mutex_
;
bool
is_new_create_
;
};
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#include <limits>
#include "minddata/dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
GrpcAsyncServer
::
GrpcAsyncServer
(
const
std
::
string
&
host
,
int32_t
port
)
:
host_
(
host
),
port_
(
port
)
{}
GrpcAsyncServer
::~
GrpcAsyncServer
()
{
Stop
();
}
Status
GrpcAsyncServer
::
Run
()
{
std
::
string
server_address
=
host_
+
":"
+
std
::
to_string
(
port_
);
grpc
::
ServerBuilder
builder
;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int32_t
>::
max
());
builder
.
AddChannelArgument
(
GRPC_ARG_ALLOW_REUSEPORT
,
0
);
int
port_tcpip
=
0
;
builder
.
AddListeningPort
(
server_address
,
grpc
::
InsecureServerCredentials
(),
&
port_tcpip
);
RETURN_IF_NOT_OK
(
RegisterService
(
&
builder
));
cq_
=
builder
.
AddCompletionQueue
();
server_
=
builder
.
BuildAndStart
();
if
(
server_
)
{
MS_LOG
(
INFO
)
<<
"Server listening on "
<<
server_address
;
}
else
{
std
::
string
errMsg
=
"Fail to start server. "
;
if
(
port_tcpip
!=
port_
)
{
errMsg
+=
"Unable to bind to address "
+
server_address
+
"."
;
}
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
return
Status
::
OK
();
}
Status
GrpcAsyncServer
::
HandleRequest
()
{
bool
success
;
void
*
tag
;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CallData
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
RETURN_IF_NOT_OK
(
EnqueueRequest
());
while
(
cq_
->
Next
(
&
tag
,
&
success
))
{
RETURN_IF_INTERRUPTED
();
if
(
success
)
{
RETURN_IF_NOT_OK
(
ProcessRequest
(
tag
));
}
else
{
MS_LOG
(
DEBUG
)
<<
"cq_->Next failed."
;
}
}
return
Status
::
OK
();
}
void
GrpcAsyncServer
::
Stop
()
{
if
(
server_
)
{
server_
->
Shutdown
();
}
// Always shutdown the completion queue after the server.
if
(
cq_
)
{
cq_
->
Shutdown
();
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "grpcpp/grpcpp.h"
#include "grpcpp/impl/codegen/async_unary_call.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief Async server base class
class
GrpcAsyncServer
{
public:
explicit
GrpcAsyncServer
(
const
std
::
string
&
host
,
int32_t
port
);
virtual
~
GrpcAsyncServer
();
/// \brief Brings up gRPC server
/// \return none
Status
Run
();
/// \brief Entry function to handle async server request
Status
HandleRequest
();
void
Stop
();
virtual
Status
RegisterService
(
grpc
::
ServerBuilder
*
builder
)
=
0
;
virtual
Status
EnqueueRequest
()
=
0
;
virtual
Status
ProcessRequest
(
void
*
tag
)
=
0
;
protected:
int32_t
port_
;
std
::
string
host_
;
std
::
unique_ptr
<
grpc
::
ServerCompletionQueue
>
cq_
;
std
::
unique_ptr
<
grpc
::
Server
>
server_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc
浏览文件 @
256dccc6
...
...
@@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
return
Status
::
OK
();
}
}
}
// namespace gnn
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h
浏览文件 @
256dccc6
...
...
@@ -20,10 +20,10 @@
#include <unordered_map>
#include <utility>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h
浏览文件 @
256dccc6
...
...
@@ -20,9 +20,9 @@
#include <unordered_map>
#include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/node.h
浏览文件 @
256dccc6
...
...
@@ -20,8 +20,8 @@
#include <unordered_map>
#include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
...
...
mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#include <algorithm>
#include <utility>
#include <unordered_map>
namespace
mindspore
{
namespace
dataset
{
const
std
::
unordered_map
<
DataTypePb
,
DataType
::
Type
>
g_pb2datatype_map
{
{
DataTypePb
::
DE_PB_UNKNOWN
,
DataType
::
DE_UNKNOWN
},
{
DataTypePb
::
DE_PB_BOOL
,
DataType
::
DE_BOOL
},
{
DataTypePb
::
DE_PB_INT8
,
DataType
::
DE_INT8
},
{
DataTypePb
::
DE_PB_UINT8
,
DataType
::
DE_UINT8
},
{
DataTypePb
::
DE_PB_INT16
,
DataType
::
DE_INT16
},
{
DataTypePb
::
DE_PB_UINT16
,
DataType
::
DE_UINT16
},
{
DataTypePb
::
DE_PB_INT32
,
DataType
::
DE_INT32
},
{
DataTypePb
::
DE_PB_UINT32
,
DataType
::
DE_UINT32
},
{
DataTypePb
::
DE_PB_INT64
,
DataType
::
DE_INT64
},
{
DataTypePb
::
DE_PB_UINT64
,
DataType
::
DE_UINT64
},
{
DataTypePb
::
DE_PB_FLOAT16
,
DataType
::
DE_FLOAT16
},
{
DataTypePb
::
DE_PB_FLOAT32
,
DataType
::
DE_FLOAT32
},
{
DataTypePb
::
DE_PB_FLOAT64
,
DataType
::
DE_FLOAT64
},
{
DataTypePb
::
DE_PB_STRING
,
DataType
::
DE_STRING
},
};
const
std
::
unordered_map
<
DataType
::
Type
,
DataTypePb
>
g_datatype2pb_map
{
{
DataType
::
DE_UNKNOWN
,
DataTypePb
::
DE_PB_UNKNOWN
},
{
DataType
::
DE_BOOL
,
DataTypePb
::
DE_PB_BOOL
},
{
DataType
::
DE_INT8
,
DataTypePb
::
DE_PB_INT8
},
{
DataType
::
DE_UINT8
,
DataTypePb
::
DE_PB_UINT8
},
{
DataType
::
DE_INT16
,
DataTypePb
::
DE_PB_INT16
},
{
DataType
::
DE_UINT16
,
DataTypePb
::
DE_PB_UINT16
},
{
DataType
::
DE_INT32
,
DataTypePb
::
DE_PB_INT32
},
{
DataType
::
DE_UINT32
,
DataTypePb
::
DE_PB_UINT32
},
{
DataType
::
DE_INT64
,
DataTypePb
::
DE_PB_INT64
},
{
DataType
::
DE_UINT64
,
DataTypePb
::
DE_PB_UINT64
},
{
DataType
::
DE_FLOAT16
,
DataTypePb
::
DE_PB_FLOAT16
},
{
DataType
::
DE_FLOAT32
,
DataTypePb
::
DE_PB_FLOAT32
},
{
DataType
::
DE_FLOAT64
,
DataTypePb
::
DE_PB_FLOAT64
},
{
DataType
::
DE_STRING
,
DataTypePb
::
DE_PB_STRING
},
};
Status
TensorToPb
(
const
std
::
shared_ptr
<
Tensor
>
tensor
,
TensorPb
*
tensor_pb
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
tensor
,
"Parameter tensor is a null pointer"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
tensor_pb
,
"Parameter tensor_pb is a null pointer"
);
std
::
vector
<
dsize_t
>
shape
=
tensor
->
shape
().
AsVector
();
for
(
auto
dim
:
shape
)
{
tensor_pb
->
add_dims
(
static_cast
<
google
::
protobuf
::
int64
>
(
dim
));
}
auto
iter
=
g_datatype2pb_map
.
find
(
tensor
->
type
().
value
());
if
(
iter
==
g_datatype2pb_map
.
end
())
{
RETURN_STATUS_UNEXPECTED
(
"Invalid tensor type: "
+
tensor
->
type
().
ToString
());
}
tensor_pb
->
set_tensor_type
(
iter
->
second
);
tensor_pb
->
set_data
(
tensor
->
GetBuffer
(),
tensor
->
SizeInBytes
());
return
Status
::
OK
();
}
Status
PbToTensor
(
const
TensorPb
*
tensor_pb
,
std
::
shared_ptr
<
Tensor
>
*
tensor
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
tensor_pb
,
"Parameter tensor_pb is a null pointer"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
tensor
,
"Parameter tensor is a null pointer"
);
std
::
vector
<
dsize_t
>
shape
;
shape
.
resize
(
tensor_pb
->
dims
().
size
());
std
::
transform
(
tensor_pb
->
dims
().
begin
(),
tensor_pb
->
dims
().
end
(),
shape
.
begin
(),
[](
const
google
::
protobuf
::
int64
dim
)
{
return
static_cast
<
dsize_t
>
(
dim
);
});
auto
iter
=
g_pb2datatype_map
.
find
(
tensor_pb
->
tensor_type
());
if
(
iter
==
g_pb2datatype_map
.
end
())
{
RETURN_STATUS_UNEXPECTED
(
"Invalid Tensor_pb type: "
+
std
::
to_string
(
tensor_pb
->
tensor_type
()));
}
DataType
::
Type
type
=
iter
->
second
;
std
::
shared_ptr
<
Tensor
>
tensor_out
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
TensorShape
(
shape
),
DataType
(
type
),
reinterpret_cast
<
const
unsigned
char
*>
(
tensor_pb
->
data
().
data
()),
tensor_pb
->
data
().
size
(),
&
tensor_out
));
*
tensor
=
std
::
move
(
tensor_out
);
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h
0 → 100644
浏览文件 @
256dccc6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
#include <deque>
#include <memory>
#include <vector>
#include "proto/gnn_tensor.pb.h"
#include "minddata/dataset/core/tensor.h"
namespace
mindspore
{
namespace
dataset
{
Status
TensorToPb
(
const
std
::
shared_ptr
<
Tensor
>
tensor
,
TensorPb
*
tensor_pb
);
Status
PbToTensor
(
const
TensorPb
*
tensor_pb
,
std
::
shared_ptr
<
Tensor
>
*
tensor
);
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
mindspore/ccsrc/minddata/mindrecord/include/shard_column.h
浏览文件 @
256dccc6
...
...
@@ -61,6 +61,7 @@ const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
class
ShardColumn
{
public:
explicit
ShardColumn
(
const
std
::
shared_ptr
<
ShardHeader
>
&
shard_header
,
bool
compress_integer
=
true
);
explicit
ShardColumn
(
const
json
&
schema_json
,
bool
compress_integer
=
true
);
~
ShardColumn
()
=
default
;
...
...
@@ -72,23 +73,29 @@ class ShardColumn {
std
::
vector
<
int64_t
>
*
column_shape
);
/// \brief compress blob
std
::
vector
<
uint8_t
>
CompressBlob
(
const
std
::
vector
<
uint8_t
>
&
blob
);
std
::
vector
<
uint8_t
>
CompressBlob
(
const
std
::
vector
<
uint8_t
>
&
blob
,
int64_t
*
compression_size
);
/// \brief check if blob compressed
bool
CheckCompressBlob
()
const
{
return
has_compress_blob_
;
}
/// \brief getter
uint64_t
GetNumBlobColumn
()
const
{
return
num_blob_column_
;
}
/// \brief getter
std
::
vector
<
std
::
string
>
GetColumnName
()
{
return
column_name_
;
}
/// \brief getter
std
::
vector
<
ColumnDataType
>
GeColumnDataType
()
{
return
column_data_type_
;
}
/// \brief getter
std
::
vector
<
std
::
vector
<
int64_t
>>
GetColumnShape
()
{
return
column_shape_
;
}
/// \brief get column value from blob
MSRStatus
GetColumnFromBlob
(
const
std
::
string
&
column_name
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
unsigned
char
**
data
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
const
n_bytes
);
/// \brief get column type
std
::
pair
<
MSRStatus
,
ColumnCategory
>
GetColumnTypeByName
(
const
std
::
string
&
column_name
,
ColumnDataType
*
column_data_type
,
uint64_t
*
column_data_type_size
,
...
...
@@ -99,6 +106,9 @@ class ShardColumn {
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
);
private:
/// \brief intialization
void
Init
(
const
json
&
schema_json
,
bool
compress_integer
=
true
);
/// \brief get float value from json
template
<
typename
T
>
MSRStatus
GetFloat
(
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
json
&
json_column_value
,
bool
use_double
);
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
浏览文件 @
256dccc6
...
...
@@ -65,6 +65,11 @@ class ShardHeader {
/// \return the Statistic
std
::
vector
<
std
::
shared_ptr
<
Statistics
>>
GetStatistics
();
/// \brief add the statistic and save it
/// \param[in] statistic info of slim size
/// \return null
int64_t
GetSlimSizeStatistic
(
const
json
&
slim_size_json
);
/// \brief get the fields of the index
/// \return the fields of the index
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
GetFields
();
...
...
@@ -114,10 +119,14 @@ class ShardHeader {
uint64_t
GetPageSize
()
const
{
return
page_size_
;
}
uint64_t
GetCompressionSize
()
const
{
return
compression_size_
;
}
void
SetHeaderSize
(
const
uint64_t
&
header_size
)
{
header_size_
=
header_size
;
}
void
SetPageSize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
void
SetCompressionSize
(
const
uint64_t
&
compression_size
)
{
compression_size_
=
compression_size
;
}
std
::
vector
<
std
::
string
>
SerializeHeader
();
MSRStatus
PagesToFile
(
const
std
::
string
dump_file_name
);
...
...
@@ -177,6 +186,7 @@ class ShardHeader {
uint32_t
shard_count_
;
uint64_t
header_size_
;
uint64_t
page_size_
;
uint64_t
compression_size_
;
std
::
shared_ptr
<
Index
>
index_
;
std
::
vector
<
std
::
string
>
shard_addresses_
;
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
浏览文件 @
256dccc6
...
...
@@ -209,6 +209,9 @@ class ShardReader {
/// \brief get all classes
MSRStatus
GetAllClasses
(
const
std
::
string
&
category_field
,
std
::
set
<
std
::
string
>
&
categories
);
/// \brief get the size of blob data
MSRStatus
GetTotalBlobSize
(
int64_t
*
total_blob_size
);
protected:
/// \brief sqlite call back function
static
int
SelectCallback
(
void
*
p_data
,
int
num_fields
,
char
**
p_fields
,
char
**
p_col_names
);
...
...
@@ -323,6 +326,7 @@ class ShardReader {
const
std
::
string
kThreadName
=
"THRD_ITER_"
;
// prefix of thread name
std
::
vector
<
std
::
thread
>
thread_set_
;
// thread list
int
num_rows_
;
// number of rows
int64_t
total_blob_size_
;
// total size of blob data
std
::
mutex
mtx_delivery_
;
// locker for delivery
std
::
condition_variable
cv_delivery_
;
// conditional variable for delivery
std
::
condition_variable
cv_iterator_
;
// conditional variable for iterator
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
浏览文件 @
256dccc6
...
...
@@ -257,6 +257,7 @@ class ShardWriter {
std
::
mutex
check_mutex_
;
// mutex for data check
std
::
atomic
<
bool
>
flag_
{
false
};
std
::
atomic
<
int64_t
>
compression_size_
;
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
浏览文件 @
256dccc6
...
...
@@ -43,6 +43,7 @@ ShardReader::ShardReader() {
page_size_
=
0
;
header_size_
=
0
;
num_rows_
=
0
;
total_blob_size_
=
0
;
num_padded_
=
0
;
}
...
...
@@ -55,9 +56,11 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
return
{
FAILED
,
{}};
}
auto
header
=
ret
.
second
;
meta_data
=
{{
"header_size"
,
header
[
"header_size"
]},
{
"page_size"
,
header
[
"page_size"
]},
{
"version"
,
header
[
"version"
]},
{
"index_fields"
,
header
[
"index_fields"
]},
{
"schema"
,
header
[
"schema"
]},
{
"blob_fields"
,
header
[
"blob_fields"
]}};
uint64_t
compression_size
=
header
.
contains
(
"compression_size"
)
?
header
[
"compression_size"
].
get
<
uint64_t
>
()
:
0
;
meta_data
=
{{
"header_size"
,
header
[
"header_size"
]},
{
"page_size"
,
header
[
"page_size"
]},
{
"compression_size"
,
compression_size
},
{
"version"
,
header
[
"version"
]},
{
"index_fields"
,
header
[
"index_fields"
]},
{
"schema"
,
header
[
"schema"
]},
{
"blob_fields"
,
header
[
"blob_fields"
]}};
return
{
SUCCESS
,
header
[
"shard_addresses"
]};
}
...
...
@@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
for
(
const
auto
&
rg
:
row_group_summary
)
{
num_rows_
+=
std
::
get
<
3
>
(
rg
);
}
auto
disk_size
=
page_size_
*
row_group_summary
.
size
();
auto
compression_size
=
shard_header_
->
GetCompressionSize
();
total_blob_size_
=
disk_size
+
compression_size
;
MS_LOG
(
INFO
)
<<
"Blob data size, on disk: "
<<
disk_size
<<
" , addtional uncompression: "
<<
compression_size
<<
" , Total: "
<<
total_blob_size_
;
MS_LOG
(
INFO
)
<<
"Get meta from mindrecord file & index file successfully."
;
...
...
@@ -272,6 +280,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
return
row_group_summary
;
}
MSRStatus
ShardReader
::
GetTotalBlobSize
(
int64_t
*
total_blob_size
)
{
*
total_blob_size
=
total_blob_size_
;
return
SUCCESS
;
}
MSRStatus
ShardReader
::
ConvertLabelToJson
(
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
labels
,
std
::
shared_ptr
<
std
::
fstream
>
fs
,
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
&
offsets
,
int
shard_id
,
...
...
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
浏览文件 @
256dccc6
...
...
@@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO;
namespace
mindspore
{
namespace
mindrecord
{
ShardWriter
::
ShardWriter
()
:
shard_count_
(
1
),
header_size_
(
kDefaultHeaderSize
),
page_size_
(
kDefaultPageSize
),
row_count_
(
0
),
schema_count_
(
1
)
{}
:
shard_count_
(
1
),
header_size_
(
kDefaultHeaderSize
),
page_size_
(
kDefaultPageSize
),
row_count_
(
0
),
schema_count_
(
1
)
{
compression_size_
=
0
;
}
ShardWriter
::~
ShardWriter
()
{
for
(
int
i
=
static_cast
<
int
>
(
file_streams_
.
size
())
-
1
;
i
>=
0
;
i
--
)
{
...
...
@@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if
(
ret
==
FAILED
)
{
return
FAILED
;
}
compression_size_
=
shard_header_
->
GetCompressionSize
();
ret
=
Open
(
real_addresses
,
true
);
if
(
ret
==
FAILED
)
{
MS_LOG
(
ERROR
)
<<
"Open file failed"
;
...
...
@@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
// compress blob
if
(
shard_column_
->
CheckCompressBlob
())
{
for
(
auto
&
blob
:
blob_data
)
{
blob
=
shard_column_
->
CompressBlob
(
blob
);
int64_t
compression_bytes
=
0
;
blob
=
shard_column_
->
CompressBlob
(
blob
,
&
compression_bytes
);
compression_size_
+=
compression_bytes
;
}
}
...
...
@@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() {
MS_LOG
(
ERROR
)
<<
"Shard header is null"
;
return
FAILED
;
}
int64_t
compression_temp
=
compression_size_
;
uint64_t
compression_size
=
compression_temp
>
0
?
compression_temp
:
0
;
shard_header_
->
SetCompressionSize
(
compression_size
);
auto
shard_header
=
shard_header_
->
SerializeHeader
();
// Write header data to multi files
if
(
shard_count_
>
static_cast
<
int
>
(
file_streams_
.
size
())
||
shard_count_
>
static_cast
<
int
>
(
shard_header
.
size
()))
{
...
...
mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc
浏览文件 @
256dccc6
...
...
@@ -24,7 +24,15 @@ namespace mindspore {
namespace
mindrecord
{
ShardColumn
::
ShardColumn
(
const
std
::
shared_ptr
<
ShardHeader
>
&
shard_header
,
bool
compress_integer
)
{
auto
first_schema
=
shard_header
->
GetSchemas
()[
0
];
auto
schema
=
first_schema
->
GetSchema
()[
"schema"
];
json
schema_json
=
first_schema
->
GetSchema
();
Init
(
schema_json
,
compress_integer
);
}
ShardColumn
::
ShardColumn
(
const
json
&
schema_json
,
bool
compress_integer
)
{
Init
(
schema_json
,
compress_integer
);
}
void
ShardColumn
::
Init
(
const
json
&
schema_json
,
bool
compress_integer
)
{
auto
schema
=
schema_json
[
"schema"
];
auto
blob_fields
=
schema_json
[
"blob_fields"
];
bool
has_integer_array
=
false
;
for
(
json
::
iterator
it
=
schema
.
begin
();
it
!=
schema
.
end
();
++
it
)
{
...
...
@@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool
column_name_id_
[
column_name_
[
i
]]
=
i
;
}
auto
blob_fields
=
first_schema
->
GetBlobFields
();
for
(
const
auto
&
field
:
blob_fields
)
{
blob_column_
.
push_back
(
field
);
}
...
...
@@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
return
it_blob
==
blob_column_id_
.
end
()
?
ColumnInRaw
:
ColumnInBlob
;
}
std
::
vector
<
uint8_t
>
ShardColumn
::
CompressBlob
(
const
std
::
vector
<
uint8_t
>
&
blob
)
{
std
::
vector
<
uint8_t
>
ShardColumn
::
CompressBlob
(
const
std
::
vector
<
uint8_t
>
&
blob
,
int64_t
*
compression_size
)
{
// Skip if no compress columns
*
compression_size
=
0
;
if
(
!
CheckCompressBlob
())
return
blob
;
std
::
vector
<
uint8_t
>
dst_blob
;
...
...
@@ -295,7 +302,9 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
// Compress and return is blob has 1 column only
if
(
num_blob_column_
==
1
)
{
return
CompressInt
(
blob
,
int_type
);
dst_blob
=
CompressInt
(
blob
,
int_type
);
*
compression_size
=
static_cast
<
int64_t
>
(
blob
.
size
())
-
static_cast
<
int64_t
>
(
dst_blob
.
size
());
return
dst_blob
;
}
// Just copy and continue if column dat type is not int32/int64
...
...
@@ -319,6 +328,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
i_src
+=
kInt64Len
+
num_bytes
;
}
MS_LOG
(
DEBUG
)
<<
"Compress all blob from "
<<
blob
.
size
()
<<
" to "
<<
dst_blob
.
size
()
<<
"."
;
*
compression_size
=
static_cast
<
int64_t
>
(
blob
.
size
())
-
static_cast
<
int64_t
>
(
dst_blob
.
size
());
return
dst_blob
;
}
...
...
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
浏览文件 @
256dccc6
...
...
@@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR;
namespace
mindspore
{
namespace
mindrecord
{
std
::
atomic
<
bool
>
thread_status
(
false
);
ShardHeader
::
ShardHeader
()
:
shard_count_
(
0
),
header_size_
(
0
),
page_size_
(
0
)
{
index_
=
std
::
make_shared
<
Index
>
();
}
ShardHeader
::
ShardHeader
()
:
shard_count_
(
0
),
header_size_
(
0
),
page_size_
(
0
),
compression_size_
(
0
)
{
index_
=
std
::
make_shared
<
Index
>
();
}
MSRStatus
ShardHeader
::
InitializeHeader
(
const
std
::
vector
<
json
>
&
headers
,
bool
load_dataset
)
{
shard_count_
=
headers
.
size
();
...
...
@@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l
ParseShardAddress
(
header
[
"shard_addresses"
]);
header_size_
=
header
[
"header_size"
].
get
<
uint64_t
>
();
page_size_
=
header
[
"page_size"
].
get
<
uint64_t
>
();
compression_size_
=
header
.
contains
(
"compression_size"
)
?
header
[
"compression_size"
].
get
<
uint64_t
>
()
:
0
;
}
if
(
SUCCESS
!=
ParsePage
(
header
[
"page"
],
shard_index
,
load_dataset
))
{
return
FAILED
;
...
...
@@ -146,9 +149,12 @@ std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &fil
return
{
FAILED
,
json
()};
}
json
raw_header
=
ret
.
second
;
uint64_t
compression_size
=
raw_header
.
contains
(
"compression_size"
)
?
raw_header
[
"compression_size"
].
get
<
uint64_t
>
()
:
0
;
json
header
=
{{
"shard_addresses"
,
raw_header
[
"shard_addresses"
]},
{
"header_size"
,
raw_header
[
"header_size"
]},
{
"page_size"
,
raw_header
[
"page_size"
]},
{
"compression_size"
,
compression_size
},
{
"index_fields"
,
raw_header
[
"index_fields"
]},
{
"blob_fields"
,
raw_header
[
"schema"
][
0
][
"blob_fields"
]},
{
"schema"
,
raw_header
[
"schema"
][
0
][
"schema"
]},
...
...
@@ -343,6 +349,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s
+=
"
\"
index_fields
\"
:"
+
index
+
","
;
s
+=
"
\"
page
\"
:"
+
pages
[
shardId
]
+
","
;
s
+=
"
\"
page_size
\"
:"
+
std
::
to_string
(
page_size_
)
+
","
;
s
+=
"
\"
compression_size
\"
:"
+
std
::
to_string
(
compression_size_
)
+
","
;
s
+=
"
\"
schema
\"
:"
+
schema
+
","
;
s
+=
"
\"
shard_addresses
\"
:"
+
address
+
","
;
s
+=
"
\"
shard_id
\"
:"
+
std
::
to_string
(
shardId
)
+
","
;
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
256dccc6
...
...
@@ -3085,20 +3085,22 @@ def _cpp_sampler_fn(sampler, dataset):
yield
tuple
([
np
.
array
(
x
,
copy
=
False
)
for
x
in
val
])
def
_cpp_sampler_fn_mp
(
sampler
,
dataset
,
num_worker
):
def
_cpp_sampler_fn_mp
(
sampler
,
dataset
,
num_worker
,
multi_process
):
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
"""
indices
=
sampler
.
get_indices
()
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
sample_fn
=
SamplerFn
(
dataset
,
num_worker
,
multi_process
)
return
sample_fn
.
process
(
indices
)
def
_py_sampler_fn_mp
(
sampler
,
num_samples
,
dataset
,
num_worker
):
def
_py_sampler_fn_mp
(
sampler
,
num_samples
,
dataset
,
num_worker
,
multi_process
):
"""
Multiprocessing generator function wrapper for mappable dataset with python sampler.
"""
indices
=
_fetch_py_sampler_indices
(
sampler
,
num_samples
)
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
sample_fn
=
SamplerFn
(
dataset
,
num_worker
,
multi_process
)
return
sample_fn
.
process
(
indices
)
def
_fetch_py_sampler_indices
(
sampler
,
num_samples
):
...
...
@@ -3132,63 +3134,92 @@ def _fill_worker_indices(workers, indices, idx):
return
idx
def
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
:
class
SamplerFn
:
"""
Multiprocessing generator function wrapper master process.
Multiprocessing
or multithread
generator function wrapper master process.
"""
workers
=
[]
# Event for end of epoch
eoe
=
multiprocessing
.
Event
()
# Create workers
for
_
in
range
(
num_worker
):
worker
=
_GeneratorWorker
(
dataset
,
eoe
)
worker
.
daemon
=
True
workers
.
append
(
worker
)
# Fill initial index queues
idx_cursor
=
0
idx_cursor
=
_fill_worker_indices
(
workers
,
indices
,
idx_cursor
)
# Start all workers
for
w
in
workers
:
w
.
start
()
# Fetch results
for
i
in
range
(
len
(
indices
)):
# Fetch result and put index
try
:
result
=
workers
[
i
%
num_worker
].
get
()
except
queue
.
Empty
:
raise
Exception
(
"Generator worker process timeout"
)
except
KeyboardInterrupt
:
for
w
in
workers
:
w
.
terminate
()
def
__init__
(
self
,
dataset
,
num_worker
,
multi_process
):
self
.
workers
=
[]
self
.
num_worker
=
num_worker
self
.
multi_process
=
multi_process
# Event for end of epoch
if
multi_process
is
True
:
self
.
eoe
=
multiprocessing
.
Event
()
self
.
eof
=
multiprocessing
.
Event
()
else
:
self
.
eoe
=
threading
.
Event
()
self
.
eof
=
threading
.
Event
()
# Create workers
for
_
in
range
(
num_worker
):
if
multi_process
is
True
:
worker
=
_GeneratorWorkerMp
(
dataset
,
self
.
eoe
,
self
.
eof
)
else
:
worker
=
_GeneratorWorkerMt
(
dataset
,
self
.
eoe
,
self
.
eof
)
worker
.
daemon
=
True
self
.
workers
.
append
(
worker
)
def
process
(
self
,
indices
):
"""
The main process, start the child process or child thread, and fill the index queue,
get the result from the result and return.
"""
# Fill initial index queues
idx_cursor
=
0
idx_cursor
=
_fill_worker_indices
(
self
.
workers
,
indices
,
idx_cursor
)
# Start all workers
for
w
in
self
.
workers
:
w
.
start
()
# Fetch results
for
i
in
range
(
len
(
indices
)):
# Fetch result and put index
try
:
result
=
self
.
workers
[
i
%
self
.
num_worker
].
get
()
except
queue
.
Empty
:
raise
Exception
(
"Generator worker process timeout"
)
except
KeyboardInterrupt
:
self
.
eof
.
set
()
for
w
in
self
.
workers
:
w
.
terminate
()
w
.
join
()
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
if
idx_cursor
<
len
(
indices
):
idx_cursor
=
_fill_worker_indices
(
self
.
workers
,
indices
,
idx_cursor
)
# Set eoe event once all indices are sent
if
idx_cursor
==
len
(
indices
)
and
not
self
.
eoe
.
is_set
():
self
.
eoe
.
set
()
yield
tuple
([
np
.
array
(
x
,
copy
=
False
)
for
x
in
result
])
def
__del__
(
self
):
self
.
eoe
.
set
()
self
.
eof
.
set
()
if
self
.
multi_process
is
False
:
for
w
in
self
.
workers
:
w
.
join
()
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
if
idx_cursor
<
len
(
indices
):
idx_cursor
=
_fill_worker_indices
(
workers
,
indices
,
idx_cursor
)
# Set eoe event once all indices are sent
if
idx_cursor
==
len
(
indices
)
and
not
eoe
.
is_set
():
eoe
.
set
()
yield
tuple
([
np
.
array
(
x
,
copy
=
False
)
for
x
in
result
])
def
_generator_worker_loop
(
dataset
,
idx_queue
,
result_queue
,
eoe
):
def
_generator_worker_loop
(
dataset
,
idx_queue
,
result_queue
,
eoe
,
eof
):
"""
Multiprocessing generator worker process loop.
Multiprocessing
or multithread
generator worker process loop.
"""
while
True
:
# Fetch index, block
try
:
idx
=
idx_queue
.
get
()
idx
=
idx_queue
.
get
(
timeout
=
10
)
except
KeyboardInterrupt
:
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
except
queue
.
Empty
:
if
eof
.
is_set
()
or
eoe
.
is_set
():
raise
Exception
(
"Generator worker receives queue.Empty"
)
continue
if
idx
is
None
:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert
eoe
.
is_set
(),
""
return
if
eof
.
is_set
():
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result
=
dataset
[
idx
]
# Send data, block
...
...
@@ -3197,17 +3228,42 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
except
KeyboardInterrupt
:
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
del
result
,
idx
if
eoe
.
is_set
()
and
idx_queue
.
empty
():
return
class
_GeneratorWorker
(
multiprocessing
.
Process
):
class
_GeneratorWorkerMt
(
threading
.
Thread
):
"""
Worker process for multithread Generator.
"""
def
__init__
(
self
,
dataset
,
eoe
,
eof
):
self
.
idx_queue
=
queue
.
Queue
(
16
)
self
.
res_queue
=
queue
.
Queue
(
16
)
super
().
__init__
(
target
=
_generator_worker_loop
,
args
=
(
dataset
,
self
.
idx_queue
,
self
.
res_queue
,
eoe
,
eof
))
def
put
(
self
,
item
):
"""
Put function for worker index queue. Never block. Raise queue.Full on failure.
"""
self
.
idx_queue
.
put_nowait
(
item
)
def
get
(
self
):
"""
Get function for worker result queue. Block with timeout.
"""
return
self
.
res_queue
.
get
(
timeout
=
10
)
class
_GeneratorWorkerMp
(
multiprocessing
.
Process
):
"""
Worker process for multiprocess Generator.
"""
def
__init__
(
self
,
dataset
,
eoe
):
def
__init__
(
self
,
dataset
,
eoe
,
eof
):
self
.
idx_queue
=
multiprocessing
.
Queue
(
16
)
self
.
res_queue
=
multiprocessing
.
Queue
(
16
)
super
().
__init__
(
target
=
_generator_worker_loop
,
args
=
(
dataset
,
self
.
idx_queue
,
self
.
res_queue
,
eoe
))
super
().
__init__
(
target
=
_generator_worker_loop
,
args
=
(
dataset
,
self
.
idx_queue
,
self
.
res_queue
,
eoe
,
eof
))
def
put
(
self
,
item
):
"""
...
...
@@ -3219,7 +3275,7 @@ class _GeneratorWorker(multiprocessing.Process):
"""
Get function for worker result queue. Block with timeout.
"""
return
self
.
res_queue
.
get
()
return
self
.
res_queue
.
get
(
timeout
=
10
)
def
__del__
(
self
):
self
.
terminate
()
...
...
@@ -3282,6 +3338,8 @@ class GeneratorDataset(MappableDataset):
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=True).
Examples:
>>> import mindspore.dataset as ds
...
...
@@ -3318,12 +3376,14 @@ class GeneratorDataset(MappableDataset):
@
check_generatordataset
def
__init__
(
self
,
source
,
column_names
=
None
,
column_types
=
None
,
schema
=
None
,
num_samples
=
None
,
num_parallel_workers
=
1
,
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
num_parallel_workers
=
1
,
shuffle
=
None
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
,
python_multiprocessing
=
True
):
super
().
__init__
(
num_parallel_workers
)
self
.
source
=
source
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
self
.
num_shards
=
num_shards
self
.
python_multiprocessing
=
python_multiprocessing
if
column_names
is
not
None
and
not
isinstance
(
column_names
,
list
):
column_names
=
[
column_names
]
...
...
@@ -3405,12 +3465,16 @@ class GeneratorDataset(MappableDataset):
sampler_instance
.
set_num_rows
(
len
(
self
.
source
))
sampler_instance
.
initialize
()
if
new_op
.
num_parallel_workers
>
1
:
new_op
.
source
=
(
lambda
:
_cpp_sampler_fn_mp
(
sampler_instance
,
self
.
source
,
new_op
.
num_parallel_workers
))
new_op
.
source
=
(
lambda
:
_cpp_sampler_fn_mp
(
sampler_instance
,
self
.
source
,
new_op
.
num_parallel_workers
,
self
.
python_multiprocessing
))
else
:
new_op
.
source
=
(
lambda
:
_cpp_sampler_fn
(
sampler_instance
,
self
.
source
))
else
:
if
new_op
.
num_parallel_workers
>
1
:
new_op
.
source
=
(
lambda
:
_py_sampler_fn_mp
(
new_op
.
sampler
,
new_op
.
num_samples
,
self
.
source
,
new_op
.
num_parallel_workers
))
new_op
.
source
=
(
lambda
:
_py_sampler_fn_mp
(
new_op
.
sampler
,
new_op
.
num_samples
,
self
.
source
,
new_op
.
num_parallel_workers
,
self
.
python_multiprocessing
))
else
:
new_op
.
source
=
(
lambda
:
_py_sampler_fn
(
new_op
.
sampler
,
new_op
.
num_samples
,
self
.
source
))
else
:
...
...
mindspore/dataset/engine/graphdata.py
浏览文件 @
256dccc6
...
...
@@ -16,8 +16,11 @@
graphdata.py supports loading graph dataset for GNN network training,
and provides operations related to graph data.
"""
import
atexit
import
time
import
numpy
as
np
from
mindspore._c_dataengine
import
Graph
from
mindspore._c_dataengine
import
GraphDataClient
from
mindspore._c_dataengine
import
GraphDataServer
from
mindspore._c_dataengine
import
Tensor
from
.validators
import
check_gnn_graphdata
,
check_gnn_get_all_nodes
,
check_gnn_get_all_edges
,
\
...
...
@@ -34,14 +37,52 @@ class GraphData:
dataset_file (str): One of file names in dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
(default=None).
working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local').
- 'local', used in non-distributed training scenarios.
- 'client', used in distributed training scenarios, the client does not load data,
but obtains data from the server.
- 'server', used in distributed training scenarios, the server loads the data
and is available to the client.
hostname (str, optional): Valid when working_mode is set to 'client' or 'server',
set the hostname of the graph data server (default='127.0.0.1').
port (int, optional): Valid when working_mode is set to 'client' or 'server',
set the port of the graph data server, the range is 1024-65535 (default=50051).
num_client (int, optional): Valid when working_mode is set to 'server',
set the number of clients expected to connect, and the server will allocate corresponding
resources according to this parameter (default=1).
auto_shutdown (bool, optional): Valid when working_mode is set to 'server',
Control when all clients have connected and no client connected to the server,
automatically exit the server (default=True).
"""
@
check_gnn_graphdata
def
__init__
(
self
,
dataset_file
,
num_parallel_workers
=
None
):
def
__init__
(
self
,
dataset_file
,
num_parallel_workers
=
None
,
working_mode
=
'local'
,
hostname
=
'127.0.0.1'
,
port
=
50051
,
num_client
=
1
,
auto_shutdown
=
True
):
self
.
_dataset_file
=
dataset_file
self
.
_working_mode
=
working_mode
if
num_parallel_workers
is
None
:
num_parallel_workers
=
1
self
.
_graph
=
Graph
(
dataset_file
,
num_parallel_workers
)
def
stop
():
self
.
_graph_data
.
stop
()
atexit
.
register
(
stop
)
if
working_mode
in
[
'local'
,
'client'
]:
self
.
_graph_data
=
GraphDataClient
(
dataset_file
,
num_parallel_workers
,
working_mode
,
hostname
,
port
)
if
working_mode
==
'server'
:
self
.
_graph_data
=
GraphDataServer
(
dataset_file
,
num_parallel_workers
,
hostname
,
port
,
num_client
,
auto_shutdown
)
try
:
while
self
.
_graph_data
.
is_stoped
()
is
not
True
:
time
.
sleep
(
1
)
except
KeyboardInterrupt
:
# self._graph_data.stop()
raise
Exception
(
"Graph data server receives KeyboardInterrupt"
)
@
check_gnn_get_all_nodes
def
get_all_nodes
(
self
,
node_type
):
...
...
@@ -62,7 +103,9 @@ class GraphData:
Raises:
TypeError: If `node_type` is not integer.
"""
return
self
.
_graph
.
get_all_nodes
(
node_type
).
as_array
()
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
get_all_nodes
(
node_type
).
as_array
()
@
check_gnn_get_all_edges
def
get_all_edges
(
self
,
edge_type
):
...
...
@@ -83,7 +126,9 @@ class GraphData:
Raises:
TypeError: If `edge_type` is not integer.
"""
return
self
.
_graph
.
get_all_edges
(
edge_type
).
as_array
()
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
get_all_edges
(
edge_type
).
as_array
()
@
check_gnn_get_nodes_from_edges
def
get_nodes_from_edges
(
self
,
edge_list
):
...
...
@@ -99,7 +144,9 @@ class GraphData:
Raises:
TypeError: If `edge_list` is not list or ndarray.
"""
return
self
.
_graph
.
get_nodes_from_edges
(
edge_list
).
as_array
()
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
get_nodes_from_edges
(
edge_list
).
as_array
()
@
check_gnn_get_all_neighbors
def
get_all_neighbors
(
self
,
node_list
,
neighbor_type
):
...
...
@@ -123,7 +170,9 @@ class GraphData:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `neighbor_type` is not integer.
"""
return
self
.
_graph
.
get_all_neighbors
(
node_list
,
neighbor_type
).
as_array
()
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
get_all_neighbors
(
node_list
,
neighbor_type
).
as_array
()
@
check_gnn_get_sampled_neighbors
def
get_sampled_neighbors
(
self
,
node_list
,
neighbor_nums
,
neighbor_types
):
...
...
@@ -155,7 +204,9 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray.
"""
return
self
.
_graph
.
get_sampled_neighbors
(
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
get_sampled_neighbors
(
node_list
,
neighbor_nums
,
neighbor_types
).
as_array
()
@
check_gnn_get_neg_sampled_neighbors
...
...
@@ -182,7 +233,9 @@ class GraphData:
TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer.
"""
return
self
.
_graph
.
get_neg_sampled_neighbors
(
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
get_neg_sampled_neighbors
(
node_list
,
neg_neighbor_num
,
neg_neighbor_type
).
as_array
()
@
check_gnn_get_node_feature
...
...
@@ -207,10 +260,12 @@ class GraphData:
TypeError: If `node_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray.
"""
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
if
isinstance
(
node_list
,
list
):
node_list
=
np
.
array
(
node_list
,
dtype
=
np
.
int32
)
return
[
t
.
as_array
()
for
t
in
self
.
_graph
.
get_node_feature
(
t
.
as_array
()
for
t
in
self
.
_graph
_data
.
get_node_feature
(
Tensor
(
node_list
),
feature_types
)]
...
...
@@ -236,10 +291,12 @@ class GraphData:
TypeError: If `edge_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray.
"""
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
if
isinstance
(
edge_list
,
list
):
edge_list
=
np
.
array
(
edge_list
,
dtype
=
np
.
int32
)
return
[
t
.
as_array
()
for
t
in
self
.
_graph
.
get_edge_feature
(
t
.
as_array
()
for
t
in
self
.
_graph
_data
.
get_edge_feature
(
Tensor
(
edge_list
),
feature_types
)]
...
...
@@ -252,7 +309,9 @@ class GraphData:
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type.
"""
return
self
.
_graph
.
graph_info
()
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
graph_info
()
@
check_gnn_random_walk
def
random_walk
(
...
...
@@ -285,5 +344,7 @@ class GraphData:
TypeError: If `target_nodes` is not list or ndarray.
TypeError: If `meta_path` is not list or ndarray.
"""
return
self
.
_graph
.
random_walk
(
target_nodes
,
meta_path
,
step_home_param
,
step_away_param
,
default_node
).
as_array
()
if
self
.
_working_mode
==
'server'
:
raise
Exception
(
"This method is not supported when working mode is server"
)
return
self
.
_graph_data
.
random_walk
(
target_nodes
,
meta_path
,
step_home_param
,
step_away_param
,
default_node
).
as_array
()
mindspore/dataset/engine/validators.py
浏览文件 @
256dccc6
...
...
@@ -18,6 +18,7 @@ Built-in validators.
"""
import
inspect
as
ins
import
os
import
re
from
functools
import
wraps
import
numpy
as
np
...
...
@@ -912,16 +913,36 @@ def check_split(method):
return
new_method
def
check_hostname
(
hostname
):
if
len
(
hostname
)
>
255
:
return
False
if
hostname
[
-
1
]
==
"."
:
hostname
=
hostname
[:
-
1
]
# strip exactly one dot from the right, if present
allowed
=
re
.
compile
(
"(?!-)[A-Z
\\
d-]{1,63}(?<!-)$"
,
re
.
IGNORECASE
)
return
all
(
allowed
.
match
(
x
)
for
x
in
hostname
.
split
(
"."
))
def
check_gnn_graphdata
(
method
):
"""check the input arguments of graphdata."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
dataset_file
,
num_parallel_workers
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
[
dataset_file
,
num_parallel_workers
,
working_mode
,
hostname
,
port
,
num_client
,
auto_shutdown
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
check_file
(
dataset_file
)
if
num_parallel_workers
is
not
None
:
check_num_parallel_workers
(
num_parallel_workers
)
type_check
(
hostname
,
(
str
,),
"hostname"
)
if
check_hostname
(
hostname
)
is
False
:
raise
ValueError
(
"The hostname is illegal"
)
type_check
(
working_mode
,
(
str
,),
"working_mode"
)
if
working_mode
not
in
{
'local'
,
'client'
,
'server'
}:
raise
ValueError
(
"Invalid working mode"
)
type_check
(
port
,
(
int
,),
"port"
)
check_value
(
port
,
(
1024
,
65535
),
"port"
)
type_check
(
num_client
,
(
int
,),
"num_client"
)
check_value
(
num_client
,
(
1
,
255
),
"num_client"
)
type_check
(
auto_shutdown
,
(
bool
,),
"auto_shutdown"
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
new_method
...
...
model_zoo/utils/graph_to_mindrecord/sns/mr_api.py
浏览文件 @
256dccc6
...
...
@@ -15,6 +15,7 @@
"""
User-defined API for MindRecord GNN writer.
"""
import
numpy
as
np
social_data
=
[[
348
,
350
],
[
348
,
327
],
[
348
,
329
],
[
348
,
331
],
[
348
,
335
],
[
348
,
336
],
[
348
,
337
],
[
348
,
338
],
[
348
,
340
],
[
348
,
341
],
[
348
,
342
],
[
348
,
343
],
[
348
,
344
],
[
348
,
345
],
[
348
,
346
],
...
...
@@ -29,7 +30,7 @@ social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[
355
,
352
],
[
353
,
350
],
[
352
,
349
],
[
351
,
349
],
[
350
,
349
]]
# profile: (num_features, feature_data_types, feature_shapes)
node_profile
=
(
0
,
[],
[
])
node_profile
=
(
2
,
[
"int64"
,
"int32"
],
[[
-
1
],
[
-
1
]
])
edge_profile
=
(
0
,
[],
[])
...
...
@@ -51,7 +52,9 @@ def yield_nodes(task_id=0):
node_list
.
sort
()
print
(
node_list
)
for
node_id
in
node_list
:
node
=
{
'id'
:
node_id
,
'type'
:
1
}
node
=
{
'id'
:
node_id
,
'type'
:
1
,
'feature_1'
:
np
.
ones
((
5
,),
dtype
=
np
.
int64
),
'feature_2'
:
np
.
ones
((
10
,),
dtype
=
np
.
int32
)}
yield
node
...
...
tests/ut/cpp/dataset/gnn_graph_test.cc
浏览文件 @
256dccc6
...
...
@@ -22,6 +22,7 @@
#include "gtest/gtest.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_loader.h"
using
namespace
mindspore
::
dataset
;
...
...
@@ -39,30 +40,9 @@ class MindDataTestGNNGraph : public UT::Common {
MindDataTestGNNGraph
()
=
default
;
};
TEST_F
(
MindDataTestGNNGraph
,
TestGraphLoader
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/testdata"
;
GraphLoader
gl
(
path
,
4
);
EXPECT_TRUE
(
gl
.
InitAndLoad
().
IsOk
());
NodeIdMap
n_id_map
;
EdgeIdMap
e_id_map
;
NodeTypeMap
n_type_map
;
EdgeTypeMap
e_type_map
;
NodeFeatureMap
n_feature_map
;
EdgeFeatureMap
e_feature_map
;
DefaultNodeFeatureMap
default_node_feature_map
;
DefaultEdgeFeatureMap
default_edge_feature_map
;
EXPECT_TRUE
(
gl
.
GetNodesAndEdges
(
&
n_id_map
,
&
e_id_map
,
&
n_type_map
,
&
e_type_map
,
&
n_feature_map
,
&
e_feature_map
,
&
default_node_feature_map
,
&
default_edge_feature_map
)
.
IsOk
());
EXPECT_EQ
(
n_id_map
.
size
(),
20
);
EXPECT_EQ
(
e_id_map
.
size
(),
40
);
EXPECT_EQ
(
n_type_map
[
2
].
size
(),
10
);
EXPECT_EQ
(
n_type_map
[
1
].
size
(),
10
);
}
TEST_F
(
MindDataTestGNNGraph
,
TestGetAllNeighbors
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/testdata"
;
Graph
graph
(
path
,
1
);
Graph
DataImpl
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
...
...
@@ -103,7 +83,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
TEST_F
(
MindDataTestGNNGraph
,
TestGetSampledNeighbors
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/testdata"
;
Graph
graph
(
path
,
1
);
Graph
DataImpl
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
...
...
@@ -194,7 +174,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
TEST_F
(
MindDataTestGNNGraph
,
TestGetNegSampledNeighbors
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/testdata"
;
Graph
graph
(
path
,
1
);
Graph
DataImpl
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
...
...
@@ -237,7 +217,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
TEST_F
(
MindDataTestGNNGraph
,
TestRandomWalk
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/sns"
;
Graph
graph
(
path
,
1
);
Graph
DataImpl
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
...
...
@@ -263,7 +243,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
TEST_F
(
MindDataTestGNNGraph
,
TestRandomWalkDefaults
)
{
std
::
string
path
=
"data/mindrecord/testGraphData/sns"
;
Graph
graph
(
path
,
1
);
Graph
DataImpl
graph
(
path
,
1
);
Status
s
=
graph
.
Init
();
EXPECT_TRUE
(
s
.
IsOk
());
...
...
tests/ut/data/mindrecord/testGraphData/sns
浏览文件 @
256dccc6
无法预览此类型文件
tests/ut/data/mindrecord/testGraphData/sns.db
浏览文件 @
256dccc6
无法预览此类型文件
tests/ut/data/mindrecord/testGraphData/testdata
浏览文件 @
256dccc6
无法预览此类型文件
tests/ut/python/dataset/test_graphdata_distributed.py
0 → 100644
浏览文件 @
256dccc6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
random
import
time
from
multiprocessing
import
Process
import
numpy
as
np
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
DATASET_FILE
=
"../data/mindrecord/testGraphData/testdata"
def
graphdata_startserver
():
"""
start graphdata server
"""
logger
.
info
(
'test start server.
\n
'
)
ds
.
GraphData
(
DATASET_FILE
,
1
,
'server'
)
class
RandomBatchedSampler
(
ds
.
Sampler
):
# RandomBatchedSampler generate random sequence without replacement in a batched manner
def
__init__
(
self
,
index_range
,
num_edges_per_sample
):
super
().
__init__
()
self
.
index_range
=
index_range
self
.
num_edges_per_sample
=
num_edges_per_sample
def
__iter__
(
self
):
indices
=
[
i
+
1
for
i
in
range
(
self
.
index_range
)]
# Reset random seed here if necessary
# random.seed(0)
random
.
shuffle
(
indices
)
for
i
in
range
(
0
,
self
.
index_range
,
self
.
num_edges_per_sample
):
# Drop reminder
if
i
+
self
.
num_edges_per_sample
<=
self
.
index_range
:
yield
indices
[
i
:
i
+
self
.
num_edges_per_sample
]
class
GNNGraphDataset
():
def
__init__
(
self
,
g
,
batch_num
):
self
.
g
=
g
self
.
batch_num
=
batch_num
def
__len__
(
self
):
# Total sample size of GNN dataset
# In this case, the size should be total_num_edges/num_edges_per_sample
return
self
.
g
.
graph_info
()[
'edge_num'
][
0
]
//
self
.
batch_num
def
__getitem__
(
self
,
index
):
# index will be a list of indices yielded from RandomBatchedSampler
# Fetch edges/nodes/samples/features based on indices
nodes
=
self
.
g
.
get_nodes_from_edges
(
index
.
astype
(
np
.
int32
))
nodes
=
nodes
[:,
0
]
neg_nodes
=
self
.
g
.
get_neg_sampled_neighbors
(
node_list
=
nodes
,
neg_neighbor_num
=
3
,
neg_neighbor_type
=
1
)
nodes_neighbors
=
self
.
g
.
get_sampled_neighbors
(
node_list
=
nodes
,
neighbor_nums
=
[
2
,
2
],
neighbor_types
=
[
2
,
1
])
neg_nodes_neighbors
=
self
.
g
.
get_sampled_neighbors
(
node_list
=
neg_nodes
[:,
1
:].
reshape
(
-
1
),
neighbor_nums
=
[
2
,
2
],
neighbor_types
=
[
2
,
2
])
nodes_neighbors_features
=
self
.
g
.
get_node_feature
(
node_list
=
nodes_neighbors
,
feature_types
=
[
2
,
3
])
neg_neighbors_features
=
self
.
g
.
get_node_feature
(
node_list
=
neg_nodes_neighbors
,
feature_types
=
[
2
,
3
])
return
nodes_neighbors
,
neg_nodes_neighbors
,
nodes_neighbors_features
[
0
],
neg_neighbors_features
[
1
]
def
test_graphdata_distributed
():
"""
Test distributed
"""
logger
.
info
(
'test distributed.
\n
'
)
p1
=
Process
(
target
=
graphdata_startserver
)
p1
.
start
()
time
.
sleep
(
2
)
g
=
ds
.
GraphData
(
DATASET_FILE
,
1
,
'client'
)
nodes
=
g
.
get_all_nodes
(
1
)
assert
nodes
.
tolist
()
==
[
101
,
102
,
103
,
104
,
105
,
106
,
107
,
108
,
109
,
110
]
row_tensor
=
g
.
get_node_feature
(
nodes
.
tolist
(),
[
1
,
2
,
3
])
assert
row_tensor
[
0
].
tolist
()
==
[[
0
,
1
,
0
,
0
,
0
],
[
1
,
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
0
,
1
,
0
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
],
[
0
,
1
,
1
,
0
,
0
],
[
0
,
1
,
0
,
1
,
0
]]
assert
row_tensor
[
2
].
tolist
()
==
[
1
,
2
,
3
,
1
,
4
,
3
,
5
,
3
,
5
,
4
]
edges
=
g
.
get_all_edges
(
0
)
assert
edges
.
tolist
()
==
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
]
features
=
g
.
get_edge_feature
(
edges
,
[
1
,
2
])
assert
features
[
0
].
tolist
()
==
[
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
]
batch_num
=
2
edge_num
=
g
.
graph_info
()[
'edge_num'
][
0
]
out_column_names
=
[
"neighbors"
,
"neg_neighbors"
,
"neighbors_features"
,
"neg_neighbors_features"
]
dataset
=
ds
.
GeneratorDataset
(
source
=
GNNGraphDataset
(
g
,
batch_num
),
column_names
=
out_column_names
,
sampler
=
RandomBatchedSampler
(
edge_num
,
batch_num
),
num_parallel_workers
=
4
,
python_multiprocessing
=
False
)
dataset
=
dataset
.
repeat
(
2
)
itr
=
dataset
.
create_dict_iterator
()
i
=
0
for
data
in
itr
:
assert
data
[
'neighbors'
].
shape
==
(
2
,
7
)
assert
data
[
'neg_neighbors'
].
shape
==
(
6
,
7
)
assert
data
[
'neighbors_features'
].
shape
==
(
2
,
7
)
assert
data
[
'neg_neighbors_features'
].
shape
==
(
6
,
7
)
i
+=
1
assert
i
==
40
if
__name__
==
'__main__'
:
test_graphdata_distributed
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录