Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
31776199
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
31776199
编写于
3月 17, 2022
作者:
S
seemingwang
提交者:
GitHub
3月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge cpu and gpu graph engines (#40597)
* extract sub-graph * graph-engine merging * fix * fix * fix heter-ps config
上级
313bff6b
变更
25
展开全部
隐藏空白更改
内联
并排
Showing
25 changed file
with
1199 addition
and
318 deletion
+1199
-318
paddle/fluid/distributed/ps.proto
paddle/fluid/distributed/ps.proto
+23
-0
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
+26
-26
paddle/fluid/distributed/ps/service/graph_brpc_client.h
paddle/fluid/distributed/ps/service/graph_brpc_client.h
+8
-8
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
+24
-24
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc
...uid/distributed/ps/service/ps_service/graph_py_service.cc
+14
-15
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h
...luid/distributed/ps/service/ps_service/graph_py_service.h
+32
-16
paddle/fluid/distributed/ps/table/CMakeLists.txt
paddle/fluid/distributed/ps/table/CMakeLists.txt
+0
-1
paddle/fluid/distributed/ps/table/common_graph_table.cc
paddle/fluid/distributed/ps/table/common_graph_table.cc
+383
-55
paddle/fluid/distributed/ps/table/common_graph_table.h
paddle/fluid/distributed/ps/table/common_graph_table.h
+172
-31
paddle/fluid/distributed/ps/table/graph/class_macro.h
paddle/fluid/distributed/ps/table/graph/class_macro.h
+39
-0
paddle/fluid/distributed/ps/table/graph/graph_edge.cc
paddle/fluid/distributed/ps/table/graph/graph_edge.cc
+2
-2
paddle/fluid/distributed/ps/table/graph/graph_edge.h
paddle/fluid/distributed/ps/table/graph/graph_edge.h
+5
-4
paddle/fluid/distributed/ps/table/graph/graph_node.h
paddle/fluid/distributed/ps/table/graph/graph_node.h
+2
-0
paddle/fluid/distributed/ps/table/table.cc
paddle/fluid/distributed/ps/table/table.cc
+2
-0
paddle/fluid/distributed/test/CMakeLists.txt
paddle/fluid/distributed/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/test/graph_node_split_test.cc
paddle/fluid/distributed/test/graph_node_split_test.cc
+4
-4
paddle/fluid/distributed/test/graph_node_test.cc
paddle/fluid/distributed/test/graph_node_test.cc
+30
-30
paddle/fluid/distributed/test/graph_table_sample_test.cc
paddle/fluid/distributed/test/graph_table_sample_test.cc
+148
-0
paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt
paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt
+2
-1
paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
+120
-0
paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
+19
-98
paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h
...e/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h
+30
-1
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+1
-0
paddle/fluid/framework/fleet/heter_ps/test_cpu_graph_sample.cu
...e/fluid/framework/fleet/heter_ps/test_cpu_graph_sample.cu
+108
-0
paddle/fluid/pybind/fleet_py.cc
paddle/fluid/pybind/fleet_py.cc
+2
-2
未找到文件。
paddle/fluid/distributed/ps.proto
浏览文件 @
31776199
...
@@ -115,6 +115,7 @@ message TableParameter {
...
@@ -115,6 +115,7 @@ message TableParameter {
optional
CommonAccessorParameter
common
=
6
;
optional
CommonAccessorParameter
common
=
6
;
optional
TableType
type
=
7
;
optional
TableType
type
=
7
;
optional
bool
compress_in_save
=
8
[
default
=
false
];
optional
bool
compress_in_save
=
8
[
default
=
false
];
optional
GraphParameter
graph_parameter
=
9
;
}
}
message
TableAccessorParameter
{
message
TableAccessorParameter
{
...
@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
...
@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
optional
double
ada_epsilon
=
5
[
default
=
1e-08
];
optional
double
ada_epsilon
=
5
[
default
=
1e-08
];
repeated
float
weight_bounds
=
6
;
repeated
float
weight_bounds
=
6
;
}
}
message
GraphParameter
{
optional
int32
task_pool_size
=
1
[
default
=
24
];
optional
bool
gpups_mode
=
2
[
default
=
false
];
optional
string
gpups_graph_sample_class
=
3
[
default
=
"CompleteGraphSampler"
];
optional
string
gpups_graph_sample_args
=
4
[
default
=
""
];
optional
bool
use_cache
=
5
[
default
=
true
];
optional
float
cache_ratio
=
6
[
default
=
0.3
];
optional
int32
cache_ttl
=
7
[
default
=
5
];
optional
GraphFeature
graph_feature
=
8
;
optional
string
table_name
=
9
[
default
=
""
];
optional
string
table_type
=
10
[
default
=
""
];
optional
int32
gpups_mode_shard_num
=
11
[
default
=
127
];
optional
int32
gpu_num
=
12
[
default
=
1
];
}
message
GraphFeature
{
repeated
string
name
=
1
;
repeated
string
dtype
=
2
;
repeated
int32
shape
=
3
;
}
\ No newline at end of file
paddle/fluid/distributed/ps/service/graph_brpc_client.cc
浏览文件 @
31776199
...
@@ -44,7 +44,7 @@ void GraphPsService_Stub::service(
...
@@ -44,7 +44,7 @@ void GraphPsService_Stub::service(
}
}
}
}
int
GraphBrpcClient
::
get_server_index_by_id
(
u
int64_t
id
)
{
int
GraphBrpcClient
::
get_server_index_by_id
(
int64_t
id
)
{
int
shard_num
=
get_shard_num
();
int
shard_num
=
get_shard_num
();
int
shard_per_server
=
shard_num
%
server_size
==
0
int
shard_per_server
=
shard_num
%
server_size
==
0
?
shard_num
/
server_size
?
shard_num
/
server_size
...
@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
...
@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
}
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
get_node_feat
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
get_node_feat
(
const
uint32_t
&
table_id
,
const
std
::
vector
<
u
int64_t
>
&
node_ids
,
const
uint32_t
&
table_id
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
request2server
;
...
@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
...
@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
}
}
}
}
size_t
request_call_num
=
request2server
.
size
();
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
u
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
int
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
for
(
int
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
...
@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
...
@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
u
int64_t
)
*
node_num
);
sizeof
(
int64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
...
@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
...
@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return
fut
;
return
fut
;
}
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
add_graph_node
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
add_graph_node
(
uint32_t
table_id
,
std
::
vector
<
u
int64_t
>
&
node_id_list
,
uint32_t
table_id
,
std
::
vector
<
int64_t
>
&
node_id_list
,
std
::
vector
<
bool
>
&
is_weighted_list
)
{
std
::
vector
<
bool
>
&
is_weighted_list
)
{
std
::
vector
<
std
::
vector
<
u
int64_t
>>
request_bucket
;
std
::
vector
<
std
::
vector
<
int64_t
>>
request_bucket
;
std
::
vector
<
std
::
vector
<
bool
>>
is_weighted_bucket
;
std
::
vector
<
std
::
vector
<
bool
>>
is_weighted_bucket
;
bool
add_weight
=
is_weighted_list
.
size
()
>
0
;
bool
add_weight
=
is_weighted_list
.
size
()
>
0
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
server_index_arr
;
...
@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
...
@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
if
(
index_mapping
[
server_index
]
==
-
1
)
{
if
(
index_mapping
[
server_index
]
==
-
1
)
{
index_mapping
[
server_index
]
=
request_bucket
.
size
();
index_mapping
[
server_index
]
=
request_bucket
.
size
();
server_index_arr
.
push_back
(
server_index
);
server_index_arr
.
push_back
(
server_index
);
request_bucket
.
push_back
(
std
::
vector
<
u
int64_t
>
());
request_bucket
.
push_back
(
std
::
vector
<
int64_t
>
());
if
(
add_weight
)
is_weighted_bucket
.
push_back
(
std
::
vector
<
bool
>
());
if
(
add_weight
)
is_weighted_bucket
.
push_back
(
std
::
vector
<
bool
>
());
}
}
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
...
@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
...
@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
size_t
node_num
=
request_bucket
[
request_idx
].
size
();
size_t
node_num
=
request_bucket
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
sizeof
(
u
int64_t
)
*
node_num
);
sizeof
(
int64_t
)
*
node_num
);
if
(
add_weight
)
{
if
(
add_weight
)
{
bool
weighted
[
is_weighted_bucket
[
request_idx
].
size
()
+
1
];
bool
weighted
[
is_weighted_bucket
[
request_idx
].
size
()
+
1
];
for
(
size_t
j
=
0
;
j
<
is_weighted_bucket
[
request_idx
].
size
();
j
++
)
for
(
size_t
j
=
0
;
j
<
is_weighted_bucket
[
request_idx
].
size
();
j
++
)
...
@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
...
@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return
fut
;
return
fut
;
}
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
remove_graph_node
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
remove_graph_node
(
uint32_t
table_id
,
std
::
vector
<
u
int64_t
>
&
node_id_list
)
{
uint32_t
table_id
,
std
::
vector
<
int64_t
>
&
node_id_list
)
{
std
::
vector
<
std
::
vector
<
u
int64_t
>>
request_bucket
;
std
::
vector
<
std
::
vector
<
int64_t
>>
request_bucket
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
server_index_arr
;
std
::
vector
<
int
>
index_mapping
(
server_size
,
-
1
);
std
::
vector
<
int
>
index_mapping
(
server_size
,
-
1
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_id_list
.
size
();
++
query_idx
)
{
for
(
size_t
query_idx
=
0
;
query_idx
<
node_id_list
.
size
();
++
query_idx
)
{
...
@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
...
@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
if
(
index_mapping
[
server_index
]
==
-
1
)
{
if
(
index_mapping
[
server_index
]
==
-
1
)
{
index_mapping
[
server_index
]
=
request_bucket
.
size
();
index_mapping
[
server_index
]
=
request_bucket
.
size
();
server_index_arr
.
push_back
(
server_index
);
server_index_arr
.
push_back
(
server_index
);
request_bucket
.
push_back
(
std
::
vector
<
u
int64_t
>
());
request_bucket
.
push_back
(
std
::
vector
<
int64_t
>
());
}
}
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
request_bucket
[
index_mapping
[
server_index
]].
push_back
(
node_id_list
[
query_idx
]);
node_id_list
[
query_idx
]);
...
@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
...
@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
->
add_params
((
char
*
)
request_bucket
[
request_idx
].
data
(),
sizeof
(
u
int64_t
)
*
node_num
);
sizeof
(
int64_t
)
*
node_num
);
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub
rpc_stub
=
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
get_cmd_channel
(
server_index
));
getServiceStub
(
get_cmd_channel
(
server_index
));
...
@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
...
@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
}
}
// char* &buffer,int &actual_size
// char* &buffer,int &actual_size
std
::
future
<
int32_t
>
GraphBrpcClient
::
batch_sample_neighbors
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
batch_sample_neighbors
(
uint32_t
table_id
,
std
::
vector
<
u
int64_t
>
node_ids
,
int
sample_size
,
uint32_t
table_id
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
// std::vector<std::vector<std::pair<
u
int64_t, float>>> &res,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std
::
vector
<
std
::
vector
<
u
int64_t
>>
&
res
,
std
::
vector
<
std
::
vector
<
int64_t
>>
&
res
,
std
::
vector
<
std
::
vector
<
float
>>
&
res_weight
,
bool
need_weight
,
std
::
vector
<
std
::
vector
<
float
>>
&
res_weight
,
bool
need_weight
,
int
server_index
)
{
int
server_index
)
{
if
(
server_index
!=
-
1
)
{
if
(
server_index
!=
-
1
)
{
...
@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
...
@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int
start
=
0
;
int
start
=
0
;
while
(
start
<
actual_size
)
{
while
(
start
<
actual_size
)
{
res
[
node_idx
].
emplace_back
(
res
[
node_idx
].
emplace_back
(
*
(
u
int64_t
*
)(
node_buffer
+
offset
+
start
));
*
(
int64_t
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
id_size
;
start
+=
GraphNode
::
id_size
;
if
(
need_weight
)
{
if
(
need_weight
)
{
res_weight
[
node_idx
].
emplace_back
(
res_weight
[
node_idx
].
emplace_back
(
...
@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
...
@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
node_ids
.
data
(),
closure
->
request
(
0
)
->
add_params
((
char
*
)
node_ids
.
data
(),
sizeof
(
u
int64_t
)
*
node_ids
.
size
());
sizeof
(
int64_t
)
*
node_ids
.
size
());
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
need_weight
,
sizeof
(
bool
));
;
;
...
@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
...
@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
server2request
[
server_index
]
=
request2server
.
size
();
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
request2server
.
push_back
(
server_index
);
}
}
// res.push_back(std::vector<std::pair<
u
int64_t, float>>());
// res.push_back(std::vector<std::pair<int64_t, float>>());
res
.
push_back
({});
res
.
push_back
({});
if
(
need_weight
)
{
if
(
need_weight
)
{
res_weight
.
push_back
({});
res_weight
.
push_back
({});
}
}
}
}
size_t
request_call_num
=
request2server
.
size
();
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
u
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
int
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
for
(
int
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
...
@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
...
@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int
start
=
0
;
int
start
=
0
;
while
(
start
<
actual_size
)
{
while
(
start
<
actual_size
)
{
res
[
query_idx
].
emplace_back
(
res
[
query_idx
].
emplace_back
(
*
(
u
int64_t
*
)(
node_buffer
+
offset
+
start
));
*
(
int64_t
*
)(
node_buffer
+
offset
+
start
));
start
+=
GraphNode
::
id_size
;
start
+=
GraphNode
::
id_size
;
if
(
need_weight
)
{
if
(
need_weight
)
{
res_weight
[
query_idx
].
emplace_back
(
res_weight
[
query_idx
].
emplace_back
(
...
@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
...
@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
u
int64_t
)
*
node_num
);
sizeof
(
int64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
...
@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
...
@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
}
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
random_sample_nodes
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
random_sample_nodes
(
uint32_t
table_id
,
int
server_index
,
int
sample_size
,
uint32_t
table_id
,
int
server_index
,
int
sample_size
,
std
::
vector
<
u
int64_t
>
&
ids
)
{
std
::
vector
<
int64_t
>
&
ids
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
...
@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
...
@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto
size
=
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
auto
size
=
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
int
index
=
0
;
int
index
=
0
;
while
(
index
<
bytes_size
)
{
while
(
index
<
bytes_size
)
{
ids
.
push_back
(
*
(
u
int64_t
*
)(
buffer
+
index
));
ids
.
push_back
(
*
(
int64_t
*
)(
buffer
+
index
));
index
+=
GraphNode
::
id_size
;
index
+=
GraphNode
::
id_size
;
}
}
delete
[]
buffer
;
delete
[]
buffer
;
...
@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
...
@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
}
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
set_node_feat
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
set_node_feat
(
const
uint32_t
&
table_id
,
const
std
::
vector
<
u
int64_t
>
&
node_ids
,
const
uint32_t
&
table_id
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
features
)
{
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
features
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
request2server
;
...
@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
...
@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
}
}
}
}
size_t
request_call_num
=
request2server
.
size
();
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
u
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
features_idx_buckets
(
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
features_idx_buckets
(
request_call_num
);
request_call_num
);
...
@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
...
@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
u
int64_t
)
*
node_num
);
sizeof
(
int64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
...
...
paddle/fluid/distributed/ps/service/graph_brpc_client.h
浏览文件 @
31776199
...
@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient {
...
@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual
~
GraphBrpcClient
()
{}
virtual
~
GraphBrpcClient
()
{}
// given a batch of nodes, sample graph_neighbors for each of them
// given a batch of nodes, sample graph_neighbors for each of them
virtual
std
::
future
<
int32_t
>
batch_sample_neighbors
(
virtual
std
::
future
<
int32_t
>
batch_sample_neighbors
(
uint32_t
table_id
,
std
::
vector
<
u
int64_t
>
node_ids
,
int
sample_size
,
uint32_t
table_id
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
u
int64_t
>>&
res
,
std
::
vector
<
std
::
vector
<
int64_t
>>&
res
,
std
::
vector
<
std
::
vector
<
float
>>&
res_weight
,
bool
need_weight
,
std
::
vector
<
std
::
vector
<
float
>>&
res_weight
,
bool
need_weight
,
int
server_index
=
-
1
);
int
server_index
=
-
1
);
...
@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient {
...
@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient {
virtual
std
::
future
<
int32_t
>
random_sample_nodes
(
uint32_t
table_id
,
virtual
std
::
future
<
int32_t
>
random_sample_nodes
(
uint32_t
table_id
,
int
server_index
,
int
server_index
,
int
sample_size
,
int
sample_size
,
std
::
vector
<
u
int64_t
>&
ids
);
std
::
vector
<
int64_t
>&
ids
);
virtual
std
::
future
<
int32_t
>
get_node_feat
(
virtual
std
::
future
<
int32_t
>
get_node_feat
(
const
uint32_t
&
table_id
,
const
std
::
vector
<
u
int64_t
>&
node_ids
,
const
uint32_t
&
table_id
,
const
std
::
vector
<
int64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>&
res
);
std
::
vector
<
std
::
vector
<
std
::
string
>>&
res
);
virtual
std
::
future
<
int32_t
>
set_node_feat
(
virtual
std
::
future
<
int32_t
>
set_node_feat
(
const
uint32_t
&
table_id
,
const
std
::
vector
<
u
int64_t
>&
node_ids
,
const
uint32_t
&
table_id
,
const
std
::
vector
<
int64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
features
);
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
features
);
virtual
std
::
future
<
int32_t
>
clear_nodes
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
clear_nodes
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
add_graph_node
(
virtual
std
::
future
<
int32_t
>
add_graph_node
(
uint32_t
table_id
,
std
::
vector
<
u
int64_t
>&
node_id_list
,
uint32_t
table_id
,
std
::
vector
<
int64_t
>&
node_id_list
,
std
::
vector
<
bool
>&
is_weighted_list
);
std
::
vector
<
bool
>&
is_weighted_list
);
virtual
std
::
future
<
int32_t
>
use_neighbors_sample_cache
(
uint32_t
table_id
,
virtual
std
::
future
<
int32_t
>
use_neighbors_sample_cache
(
uint32_t
table_id
,
size_t
size_limit
,
size_t
size_limit
,
...
@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient {
...
@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient {
virtual
std
::
future
<
int32_t
>
load_graph_split_config
(
uint32_t
table_id
,
virtual
std
::
future
<
int32_t
>
load_graph_split_config
(
uint32_t
table_id
,
std
::
string
path
);
std
::
string
path
);
virtual
std
::
future
<
int32_t
>
remove_graph_node
(
virtual
std
::
future
<
int32_t
>
remove_graph_node
(
uint32_t
table_id
,
std
::
vector
<
u
int64_t
>&
node_id_list
);
uint32_t
table_id
,
std
::
vector
<
int64_t
>&
node_id_list
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize
();
int
get_shard_num
()
{
return
shard_num
;
}
int
get_shard_num
()
{
return
shard_num
;
}
void
set_shard_num
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
void
set_shard_num
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
int
get_server_index_by_id
(
u
int64_t
id
);
int
get_server_index_by_id
(
int64_t
id
);
void
set_local_channel
(
int
index
)
{
void
set_local_channel
(
int
index
)
{
this
->
local_channel
=
get_cmd_channel
(
index
);
this
->
local_channel
=
get_cmd_channel
(
index
);
}
}
...
...
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
浏览文件 @
31776199
...
@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
...
@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return
0
;
return
0
;
}
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
u
int64_t
);
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
std
::
vector
<
u
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
bool
>
is_weighted_list
;
std
::
vector
<
bool
>
is_weighted_list
;
if
(
request
.
params_size
()
==
2
)
{
if
(
request
.
params_size
()
==
2
)
{
size_t
weight_list_size
=
request
.
params
(
1
).
size
()
/
sizeof
(
bool
);
size_t
weight_list_size
=
request
.
params
(
1
).
size
()
/
sizeof
(
bool
);
...
@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
...
@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"graph_get_node_feat request requires at least 1 argument"
);
"graph_get_node_feat request requires at least 1 argument"
);
return
0
;
return
0
;
}
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
u
int64_t
);
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
std
::
vector
<
u
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
((
GraphTable
*
)
table
)
->
remove_graph_node
(
node_ids
);
((
GraphTable
*
)
table
)
->
remove_graph_node
(
node_ids
);
return
0
;
return
0
;
...
@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
...
@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments"
);
"graph_random_sample_neighbors request requires at least 3 arguments"
);
return
0
;
return
0
;
}
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
u
int64_t
);
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int
sample_size
=
*
(
u
int64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
2
).
c_str
());
std
::
vector
<
std
::
shared_ptr
<
char
>>
buffers
(
node_num
);
std
::
vector
<
std
::
shared_ptr
<
char
>>
buffers
(
node_num
);
std
::
vector
<
int
>
actual_sizes
(
node_num
,
0
);
std
::
vector
<
int
>
actual_sizes
(
node_num
,
0
);
...
@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
...
@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
int32_t
GraphBrpcService
::
graph_random_sample_nodes
(
int32_t
GraphBrpcService
::
graph_random_sample_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
size_t
size
=
*
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
size_t
size
=
*
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
std
::
unique_ptr
<
char
[]
>
buffer
;
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
int
actual_size
;
if
(((
GraphTable
*
)
table
)
->
random_sample_nodes
(
size
,
buffer
,
actual_size
)
==
if
(((
GraphTable
*
)
table
)
->
random_sample_nodes
(
size
,
buffer
,
actual_size
)
==
...
@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
...
@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 2 arguments"
);
"graph_get_node_feat request requires at least 2 arguments"
);
return
0
;
return
0
;
}
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
u
int64_t
);
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
std
::
vector
<
u
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
std
::
string
>
feature_names
=
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
1
),
"
\t
"
);
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
1
),
"
\t
"
);
...
@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
...
@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
"at least 3 arguments"
);
"at least 3 arguments"
);
return
0
;
return
0
;
}
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
u
int64_t
),
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
int64_t
),
size_of_size_t
=
sizeof
(
size_t
);
size_of_size_t
=
sizeof
(
size_t
);
uint64_t
*
node_data
=
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int
sample_size
=
*
(
u
int64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int64_t
*
)(
request
.
params
(
1
).
c_str
());
bool
need_weight
=
*
(
u
int64_t
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
int64_t
*
)(
request
.
params
(
2
).
c_str
());
// std::vector<
u
int64_t> res = ((GraphTable
// std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
std
::
vector
<
u
int64_t
>
local_id
;
std
::
vector
<
int64_t
>
local_id
;
std
::
vector
<
int
>
local_query_idx
;
std
::
vector
<
int
>
local_query_idx
;
size_t
rank
=
get_rank
();
size_t
rank
=
get_rank
();
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
...
@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
...
@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std
::
vector
<
std
::
shared_ptr
<
char
>>
local_buffers
;
std
::
vector
<
std
::
shared_ptr
<
char
>>
local_buffers
;
std
::
vector
<
int
>
local_actual_sizes
;
std
::
vector
<
int
>
local_actual_sizes
;
std
::
vector
<
size_t
>
seq
;
std
::
vector
<
size_t
>
seq
;
std
::
vector
<
std
::
vector
<
u
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
int
server_index
=
...
@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
...
@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
u
int64_t
)
*
node_num
);
sizeof
(
int64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
closure
->
request
(
request_idx
)
...
@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
...
@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments"
);
"graph_set_node_feat request requires at least 3 arguments"
);
return
0
;
return
0
;
}
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
u
int64_t
);
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
u
int64_t
*
)(
request
.
params
(
0
).
c_str
());
int64_t
*
node_data
=
(
int64_t
*
)(
request
.
params
(
0
).
c_str
());
std
::
vector
<
u
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
int64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
std
::
string
>
feature_names
=
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
1
),
"
\t
"
);
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
1
),
"
\t
"
);
...
...
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc
浏览文件 @
31776199
...
@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
...
@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
}
}
}
}
void
add_graph_node
(
std
::
vector
<
u
int64_t
>
node_ids
,
void
add_graph_node
(
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
bool
>
weight_list
)
{}
std
::
vector
<
bool
>
weight_list
)
{}
void
remove_graph_node
(
std
::
vector
<
u
int64_t
>
node_ids
)
{}
void
remove_graph_node
(
std
::
vector
<
int64_t
>
node_ids
)
{}
void
GraphPyService
::
set_up
(
std
::
string
ips_str
,
int
shard_num
,
void
GraphPyService
::
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
)
{
std
::
vector
<
std
::
string
>
edge_types
)
{
...
@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) {
...
@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) {
}
}
void
GraphPyClient
::
add_graph_node
(
std
::
string
name
,
void
GraphPyClient
::
add_graph_node
(
std
::
string
name
,
std
::
vector
<
u
int64_t
>&
node_ids
,
std
::
vector
<
int64_t
>&
node_ids
,
std
::
vector
<
bool
>&
weight_list
)
{
std
::
vector
<
bool
>&
weight_list
)
{
if
(
this
->
table_id_map
.
count
(
name
))
{
if
(
this
->
table_id_map
.
count
(
name
))
{
uint32_t
table_id
=
this
->
table_id_map
[
name
];
uint32_t
table_id
=
this
->
table_id_map
[
name
];
...
@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name,
...
@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name,
}
}
void
GraphPyClient
::
remove_graph_node
(
std
::
string
name
,
void
GraphPyClient
::
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
u
int64_t
>&
node_ids
)
{
std
::
vector
<
int64_t
>&
node_ids
)
{
if
(
this
->
table_id_map
.
count
(
name
))
{
if
(
this
->
table_id_map
.
count
(
name
))
{
uint32_t
table_id
=
this
->
table_id_map
[
name
];
uint32_t
table_id
=
this
->
table_id_map
[
name
];
auto
status
=
get_ps_client
()
->
remove_graph_node
(
table_id
,
node_ids
);
auto
status
=
get_ps_client
()
->
remove_graph_node
(
table_id
,
node_ids
);
...
@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
...
@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
}
}
}
}
std
::
pair
<
std
::
vector
<
std
::
vector
<
u
int64_t
>>
,
std
::
vector
<
float
>>
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
GraphPyClient
::
batch_sample_neighbors
(
std
::
string
name
,
GraphPyClient
::
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
u
int64_t
>
node_ids
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
bool
return_weight
,
int
sample_size
,
bool
return_weight
,
bool
return_edges
)
{
bool
return_edges
)
{
// std::vector<std::vector<std::pair<uint64_t, float>>> v;
std
::
vector
<
std
::
vector
<
int64_t
>>
v
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
v
;
std
::
vector
<
std
::
vector
<
float
>>
v1
;
std
::
vector
<
std
::
vector
<
float
>>
v1
;
if
(
this
->
table_id_map
.
count
(
name
))
{
if
(
this
->
table_id_map
.
count
(
name
))
{
uint32_t
table_id
=
this
->
table_id_map
[
name
];
uint32_t
table_id
=
this
->
table_id_map
[
name
];
...
@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name,
...
@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name,
// res.first[1]: slice index
// res.first[1]: slice index
// res.first[2]: src nodes
// res.first[2]: src nodes
// res.second: edges weight
// res.second: edges weight
std
::
pair
<
std
::
vector
<
std
::
vector
<
u
int64_t
>>
,
std
::
vector
<
float
>>
res
;
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
res
;
res
.
first
.
push_back
({});
res
.
first
.
push_back
({});
res
.
first
.
push_back
({});
res
.
first
.
push_back
({});
if
(
return_edges
)
res
.
first
.
push_back
({});
if
(
return_edges
)
res
.
first
.
push_back
({});
...
@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name,
...
@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name,
status
.
wait
();
status
.
wait
();
}
}
}
}
std
::
vector
<
u
int64_t
>
GraphPyClient
::
random_sample_nodes
(
std
::
string
name
,
std
::
vector
<
int64_t
>
GraphPyClient
::
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
int
server_index
,
int
sample_size
)
{
int
sample_size
)
{
std
::
vector
<
u
int64_t
>
v
;
std
::
vector
<
int64_t
>
v
;
if
(
this
->
table_id_map
.
count
(
name
))
{
if
(
this
->
table_id_map
.
count
(
name
))
{
uint32_t
table_id
=
this
->
table_id_map
[
name
];
uint32_t
table_id
=
this
->
table_id_map
[
name
];
auto
status
=
auto
status
=
...
@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
...
@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
// (name, dtype, ndarray)
// (name, dtype, ndarray)
std
::
vector
<
std
::
vector
<
std
::
string
>>
GraphPyClient
::
get_node_feat
(
std
::
vector
<
std
::
vector
<
std
::
string
>>
GraphPyClient
::
get_node_feat
(
std
::
string
node_type
,
std
::
vector
<
u
int64_t
>
node_ids
,
std
::
string
node_type
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
)
{
std
::
vector
<
std
::
string
>
feature_names
)
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
v
(
std
::
vector
<
std
::
vector
<
std
::
string
>>
v
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_ids
.
size
()));
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_ids
.
size
()));
...
@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
...
@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
}
}
void
GraphPyClient
::
set_node_feat
(
void
GraphPyClient
::
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
u
int64_t
>
node_ids
,
std
::
string
node_type
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
)
{
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
)
{
if
(
this
->
table_id_map
.
count
(
node_type
))
{
if
(
this
->
table_id_map
.
count
(
node_type
))
{
...
...
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h
浏览文件 @
31776199
...
@@ -70,18 +70,34 @@ class GraphPyService {
...
@@ -70,18 +70,34 @@ class GraphPyService {
::
paddle
::
distributed
::
TableAccessorParameter
*
accessor_proto
=
::
paddle
::
distributed
::
TableAccessorParameter
*
accessor_proto
=
sparse_table_proto
->
mutable_accessor
();
sparse_table_proto
->
mutable_accessor
();
::
paddle
::
distributed
::
CommonAccessorParameter
*
common_proto
=
//
::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto
->
mutable_common
();
//
sparse_table_proto->mutable_common();
::
paddle
::
distributed
::
GraphParameter
*
graph_proto
=
sparse_table_proto
->
mutable_graph_parameter
();
::
paddle
::
distributed
::
GraphFeature
*
graph_feature
=
graph_proto
->
mutable_graph_feature
();
graph_proto
->
set_task_pool_size
(
24
);
graph_proto
->
set_table_name
(
table_name
);
graph_proto
->
set_table_type
(
table_type
);
graph_proto
->
set_use_cache
(
false
);
// Set GraphTable Parameter
// Set GraphTable Parameter
common_proto
->
set_table_name
(
table_name
);
// common_proto->set_table_name(table_name);
common_proto
->
set_name
(
table_type
);
// common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
for
(
size_t
i
=
0
;
i
<
feat_name
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
feat_name
.
size
();
i
++
)
{
common_proto
->
add_params
(
feat_dtype
[
i
]);
graph_feature
->
add_dtype
(
feat_dtype
[
i
]);
common_proto
->
add_dims
(
feat_shape
[
i
]);
graph_feature
->
add_shape
(
feat_shape
[
i
]);
common_proto
->
add_attributes
(
feat_name
[
i
]);
graph_feature
->
add_name
(
feat_name
[
i
]);
}
}
accessor_proto
->
set_accessor_class
(
"CommMergeAccessor"
);
accessor_proto
->
set_accessor_class
(
"CommMergeAccessor"
);
}
}
...
@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService {
...
@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService {
void
load_edge_file
(
std
::
string
name
,
std
::
string
filepath
,
bool
reverse
);
void
load_edge_file
(
std
::
string
name
,
std
::
string
filepath
,
bool
reverse
);
void
load_node_file
(
std
::
string
name
,
std
::
string
filepath
);
void
load_node_file
(
std
::
string
name
,
std
::
string
filepath
);
void
clear_nodes
(
std
::
string
name
);
void
clear_nodes
(
std
::
string
name
);
void
add_graph_node
(
std
::
string
name
,
std
::
vector
<
u
int64_t
>&
node_ids
,
void
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
,
std
::
vector
<
bool
>&
weight_list
);
std
::
vector
<
bool
>&
weight_list
);
void
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
u
int64_t
>&
node_ids
);
void
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
);
int
get_client_id
()
{
return
client_id
;
}
int
get_client_id
()
{
return
client_id
;
}
void
set_client_id
(
int
client_id
)
{
this
->
client_id
=
client_id
;
}
void
set_client_id
(
int
client_id
)
{
this
->
client_id
=
client_id
;
}
void
start_client
();
void
start_client
();
std
::
pair
<
std
::
vector
<
std
::
vector
<
u
int64_t
>>
,
std
::
vector
<
float
>>
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
u
int64_t
>
node_ids
,
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
bool
return_weight
,
int
sample_size
,
bool
return_weight
,
bool
return_edges
);
bool
return_edges
);
std
::
vector
<
u
int64_t
>
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
std
::
vector
<
int64_t
>
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
int
sample_size
);
int
sample_size
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
get_node_feat
(
std
::
vector
<
std
::
vector
<
std
::
string
>>
get_node_feat
(
std
::
string
node_type
,
std
::
vector
<
u
int64_t
>
node_ids
,
std
::
string
node_type
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
);
std
::
vector
<
std
::
string
>
feature_names
);
void
use_neighbors_sample_cache
(
std
::
string
name
,
size_t
total_size_limit
,
void
use_neighbors_sample_cache
(
std
::
string
name
,
size_t
total_size_limit
,
size_t
ttl
);
size_t
ttl
);
void
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
u
int64_t
>
node_ids
,
void
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
);
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
);
std
::
vector
<
FeatureNode
>
pull_graph_list
(
std
::
string
name
,
int
server_index
,
std
::
vector
<
FeatureNode
>
pull_graph_list
(
std
::
string
name
,
int
server_index
,
...
...
paddle/fluid/distributed/ps/table/CMakeLists.txt
浏览文件 @
31776199
...
@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro
...
@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro
set_source_files_properties
(
memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto
${
TABLE_DEPS
}
common_table
)
cc_library
(
memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto
${
TABLE_DEPS
}
common_table
)
cc_library
(
table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost
)
cc_library
(
table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost
)
target_link_libraries
(
table -fopenmp
)
target_link_libraries
(
table -fopenmp
)
paddle/fluid/distributed/ps/table/common_graph_table.cc
浏览文件 @
31776199
此差异已折叠。
点击以展开。
paddle/fluid/distributed/ps/table/common_graph_table.h
浏览文件 @
31776199
...
@@ -38,10 +38,14 @@
...
@@ -38,10 +38,14 @@
#include <vector>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
class
GraphShard
{
class
GraphShard
{
...
@@ -51,37 +55,37 @@ class GraphShard {
...
@@ -51,37 +55,37 @@ class GraphShard {
~
GraphShard
();
~
GraphShard
();
std
::
vector
<
Node
*>
&
get_bucket
()
{
return
bucket
;
}
std
::
vector
<
Node
*>
&
get_bucket
()
{
return
bucket
;
}
std
::
vector
<
Node
*>
get_batch
(
int
start
,
int
end
,
int
step
);
std
::
vector
<
Node
*>
get_batch
(
int
start
,
int
end
,
int
step
);
std
::
vector
<
u
int64_t
>
get_ids_by_range
(
int
start
,
int
end
)
{
std
::
vector
<
int64_t
>
get_ids_by_range
(
int
start
,
int
end
)
{
std
::
vector
<
u
int64_t
>
res
;
std
::
vector
<
int64_t
>
res
;
for
(
int
i
=
start
;
i
<
end
&&
i
<
(
int
)
bucket
.
size
();
i
++
)
{
for
(
int
i
=
start
;
i
<
end
&&
i
<
(
int
)
bucket
.
size
();
i
++
)
{
res
.
push_back
(
bucket
[
i
]
->
get_id
());
res
.
push_back
(
bucket
[
i
]
->
get_id
());
}
}
return
res
;
return
res
;
}
}
GraphNode
*
add_graph_node
(
u
int64_t
id
);
GraphNode
*
add_graph_node
(
int64_t
id
);
GraphNode
*
add_graph_node
(
Node
*
node
);
GraphNode
*
add_graph_node
(
Node
*
node
);
FeatureNode
*
add_feature_node
(
u
int64_t
id
);
FeatureNode
*
add_feature_node
(
int64_t
id
);
Node
*
find_node
(
u
int64_t
id
);
Node
*
find_node
(
int64_t
id
);
void
delete_node
(
u
int64_t
id
);
void
delete_node
(
int64_t
id
);
void
clear
();
void
clear
();
void
add_neighbor
(
uint64_t
id
,
u
int64_t
dst_id
,
float
weight
);
void
add_neighbor
(
int64_t
id
,
int64_t
dst_id
,
float
weight
);
std
::
unordered_map
<
u
int64_t
,
int
>
&
get_node_location
()
{
std
::
unordered_map
<
int64_t
,
int
>
&
get_node_location
()
{
return
node_location
;
return
node_location
;
}
}
private:
private:
std
::
unordered_map
<
u
int64_t
,
int
>
node_location
;
std
::
unordered_map
<
int64_t
,
int
>
node_location
;
std
::
vector
<
Node
*>
bucket
;
std
::
vector
<
Node
*>
bucket
;
};
};
enum
LRUResponse
{
ok
=
0
,
blocked
=
1
,
err
=
2
};
enum
LRUResponse
{
ok
=
0
,
blocked
=
1
,
err
=
2
};
struct
SampleKey
{
struct
SampleKey
{
u
int64_t
node_key
;
int64_t
node_key
;
size_t
sample_size
;
size_t
sample_size
;
bool
is_weighted
;
bool
is_weighted
;
SampleKey
(
u
int64_t
_node_key
,
size_t
_sample_size
,
bool
_is_weighted
)
SampleKey
(
int64_t
_node_key
,
size_t
_sample_size
,
bool
_is_weighted
)
:
node_key
(
_node_key
),
:
node_key
(
_node_key
),
sample_size
(
_sample_size
),
sample_size
(
_sample_size
),
is_weighted
(
_is_weighted
)
{}
is_weighted
(
_is_weighted
)
{}
...
@@ -300,7 +304,7 @@ class ScaledLRU {
...
@@ -300,7 +304,7 @@ class ScaledLRU {
node_size
+=
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
;
node_size
+=
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
;
}
}
if
(
node_size
<=
size_t
(
1.1
*
size_limit
)
+
1
)
return
0
;
if
(
(
size_t
)
node_size
<=
size_t
(
1.1
*
size_limit
)
+
1
)
return
0
;
if
(
pthread_rwlock_wrlock
(
&
rwlock
)
==
0
)
{
if
(
pthread_rwlock_wrlock
(
&
rwlock
)
==
0
)
{
// VLOG(0)<"in shrink\n";
// VLOG(0)<"in shrink\n";
global_count
=
0
;
global_count
=
0
;
...
@@ -308,9 +312,9 @@ class ScaledLRU {
...
@@ -308,9 +312,9 @@ class ScaledLRU {
global_count
+=
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
;
global_count
+=
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
;
}
}
// VLOG(0)<<"global_count "<<global_count<<"\n";
// VLOG(0)<<"global_count "<<global_count<<"\n";
if
(
global_count
>
size_limit
)
{
if
(
(
size_t
)
global_count
>
size_limit
)
{
size_t
remove
=
global_count
-
size_limit
;
size_t
remove
=
global_count
-
size_limit
;
for
(
in
t
i
=
0
;
i
<
lru_pool
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
lru_pool
.
size
();
i
++
)
{
lru_pool
[
i
].
total_diff
=
0
;
lru_pool
[
i
].
total_diff
=
0
;
lru_pool
[
i
].
remove_count
+=
lru_pool
[
i
].
remove_count
+=
1.0
*
(
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
)
/
1.0
*
(
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
)
/
...
@@ -352,9 +356,69 @@ class ScaledLRU {
...
@@ -352,9 +356,69 @@ class ScaledLRU {
friend
class
RandomSampleLRU
<
K
,
V
>
;
friend
class
RandomSampleLRU
<
K
,
V
>
;
};
};
#ifdef PADDLE_WITH_HETERPS
enum
GraphSamplerStatus
{
waiting
=
0
,
running
=
1
,
terminating
=
2
};
class
GraphTable
;
class
GraphSampler
{
public:
GraphSampler
()
{
status
=
GraphSamplerStatus
::
waiting
;
thread_pool
.
reset
(
new
::
ThreadPool
(
1
));
callback
=
[](
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
&
res
)
{
return
;
};
}
virtual
int
run_graph_sampling
()
=
0
;
virtual
int
start_graph_sampling
()
{
if
(
status
!=
GraphSamplerStatus
::
waiting
)
{
return
-
1
;
}
std
::
promise
<
int
>
prom
;
std
::
future
<
int
>
fut
=
prom
.
get_future
();
graph_sample_task_over
=
thread_pool
->
enqueue
([
&
prom
,
this
]()
{
prom
.
set_value
(
0
);
status
=
GraphSamplerStatus
::
running
;
return
run_graph_sampling
();
});
return
fut
.
get
();
}
virtual
void
init
(
size_t
gpu_num
,
GraphTable
*
graph_table
,
std
::
vector
<
std
::
string
>
args
)
=
0
;
virtual
void
set_graph_sample_callback
(
std
::
function
<
void
(
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
&
)
>
callback
)
{
this
->
callback
=
callback
;
}
virtual
int
end_graph_sampling
()
{
if
(
status
==
GraphSamplerStatus
::
running
)
{
status
=
GraphSamplerStatus
::
terminating
;
return
graph_sample_task_over
.
get
();
}
return
-
1
;
}
virtual
GraphSamplerStatus
get_graph_sampler_status
()
{
return
status
;
}
protected:
std
::
function
<
void
(
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
&
)
>
callback
;
std
::
shared_ptr
<::
ThreadPool
>
thread_pool
;
GraphSamplerStatus
status
;
std
::
future
<
int
>
graph_sample_task_over
;
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
sample_res
;
};
#endif
class
GraphTable
:
public
SparseTable
{
class
GraphTable
:
public
SparseTable
{
public:
public:
GraphTable
()
{
use_cache
=
false
;
}
GraphTable
()
{
use_cache
=
false
;
shard_num
=
0
;
#ifdef PADDLE_WITH_HETERPS
gpups_mode
=
false
;
#endif
rw_lock
.
reset
(
new
pthread_rwlock_t
());
}
virtual
~
GraphTable
();
virtual
~
GraphTable
();
virtual
int32_t
pull_graph_list
(
int
start
,
int
size
,
virtual
int32_t
pull_graph_list
(
int
start
,
int
size
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
...
@@ -362,7 +426,7 @@ class GraphTable : public SparseTable {
...
@@ -362,7 +426,7 @@ class GraphTable : public SparseTable {
int
step
);
int
step
);
virtual
int32_t
random_sample_neighbors
(
virtual
int32_t
random_sample_neighbors
(
u
int64_t
*
node_ids
,
int
sample_size
,
int64_t
*
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
shared_ptr
<
char
>>
&
buffers
,
std
::
vector
<
std
::
shared_ptr
<
char
>>
&
buffers
,
std
::
vector
<
int
>
&
actual_sizes
,
bool
need_weight
);
std
::
vector
<
int
>
&
actual_sizes
,
bool
need_weight
);
...
@@ -370,9 +434,11 @@ class GraphTable : public SparseTable {
...
@@ -370,9 +434,11 @@ class GraphTable : public SparseTable {
int
&
actual_sizes
);
int
&
actual_sizes
);
virtual
int32_t
get_nodes_ids_by_ranges
(
virtual
int32_t
get_nodes_ids_by_ranges
(
std
::
vector
<
std
::
pair
<
int
,
int
>>
ranges
,
std
::
vector
<
uint64_t
>
&
res
);
std
::
vector
<
std
::
pair
<
int
,
int
>>
ranges
,
std
::
vector
<
int64_t
>
&
res
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize
()
{
return
0
;
}
virtual
int32_t
initialize
(
const
TableParameter
&
config
,
const
FsClientParameter
&
fs_config
);
virtual
int32_t
initialize
(
const
GraphParameter
&
config
);
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int32_t
load_graph_split_config
(
const
std
::
string
&
path
);
int32_t
load_graph_split_config
(
const
std
::
string
&
path
);
...
@@ -380,13 +446,13 @@ class GraphTable : public SparseTable {
...
@@ -380,13 +446,13 @@ class GraphTable : public SparseTable {
int32_t
load_nodes
(
const
std
::
string
&
path
,
std
::
string
node_type
);
int32_t
load_nodes
(
const
std
::
string
&
path
,
std
::
string
node_type
);
int32_t
add_graph_node
(
std
::
vector
<
u
int64_t
>
&
id_list
,
int32_t
add_graph_node
(
std
::
vector
<
int64_t
>
&
id_list
,
std
::
vector
<
bool
>
&
is_weight_list
);
std
::
vector
<
bool
>
&
is_weight_list
);
int32_t
remove_graph_node
(
std
::
vector
<
u
int64_t
>
&
id_list
);
int32_t
remove_graph_node
(
std
::
vector
<
int64_t
>
&
id_list
);
int32_t
get_server_index_by_id
(
u
int64_t
id
);
int32_t
get_server_index_by_id
(
int64_t
id
);
Node
*
find_node
(
u
int64_t
id
);
Node
*
find_node
(
int64_t
id
);
virtual
int32_t
pull_sparse
(
float
*
values
,
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
)
{
const
PullSparseValue
&
pull_value
)
{
...
@@ -407,16 +473,27 @@ class GraphTable : public SparseTable {
...
@@ -407,16 +473,27 @@ class GraphTable : public SparseTable {
return
0
;
return
0
;
}
}
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
uint32_t
get_thread_pool_index_by_shard_index
(
uint64_t
shard_index
);
virtual
int32_t
set_shard
(
size_t
shard_idx
,
size_t
server_num
)
{
virtual
uint32_t
get_thread_pool_index
(
uint64_t
node_id
);
_shard_idx
=
shard_idx
;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num
=
server_num
;
this
->
server_num
=
server_num
;
return
0
;
}
virtual
uint32_t
get_thread_pool_index_by_shard_index
(
int64_t
shard_index
);
virtual
uint32_t
get_thread_pool_index
(
int64_t
node_id
);
virtual
std
::
pair
<
int32_t
,
std
::
string
>
parse_feature
(
std
::
string
feat_str
);
virtual
std
::
pair
<
int32_t
,
std
::
string
>
parse_feature
(
std
::
string
feat_str
);
virtual
int32_t
get_node_feat
(
const
std
::
vector
<
u
int64_t
>
&
node_ids
,
virtual
int32_t
get_node_feat
(
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
virtual
int32_t
set_node_feat
(
virtual
int32_t
set_node_feat
(
const
std
::
vector
<
u
int64_t
>
&
node_ids
,
const
std
::
vector
<
int64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
...
@@ -433,11 +510,25 @@ class GraphTable : public SparseTable {
...
@@ -433,11 +510,25 @@ class GraphTable : public SparseTable {
}
}
return
0
;
return
0
;
}
}
#ifdef PADDLE_WITH_HETERPS
virtual
int32_t
start_graph_sampling
()
{
return
this
->
graph_sampler
->
start_graph_sampling
();
}
virtual
int32_t
end_graph_sampling
()
{
return
this
->
graph_sampler
->
end_graph_sampling
();
}
virtual
int32_t
set_graph_sample_callback
(
std
::
function
<
void
(
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
&
)
>
callback
)
{
graph_sampler
->
set_graph_sample_callback
(
callback
);
return
0
;
}
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
#endif
protected:
protected:
std
::
vector
<
GraphShard
*>
shards
,
extra_shards
;
std
::
vector
<
GraphShard
*>
shards
,
extra_shards
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_server
,
shard_num
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_server
,
shard_num
;
const
int
task_pool_size_
=
24
;
int
task_pool_size_
=
24
;
const
int
random_sample_nodes_ranges
=
3
;
const
int
random_sample_nodes_ranges
=
3
;
std
::
vector
<
std
::
string
>
feat_name
;
std
::
vector
<
std
::
string
>
feat_name
;
...
@@ -450,11 +541,61 @@ class GraphTable : public SparseTable {
...
@@ -450,11 +541,61 @@ class GraphTable : public SparseTable {
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
vector
<
std
::
shared_ptr
<
std
::
mt19937_64
>>
_shards_task_rng_pool
;
std
::
vector
<
std
::
shared_ptr
<
std
::
mt19937_64
>>
_shards_task_rng_pool
;
std
::
shared_ptr
<
ScaledLRU
<
SampleKey
,
SampleResult
>>
scaled_lru
;
std
::
shared_ptr
<
ScaledLRU
<
SampleKey
,
SampleResult
>>
scaled_lru
;
std
::
unordered_set
<
u
int64_t
>
extra_nodes
;
std
::
unordered_set
<
int64_t
>
extra_nodes
;
std
::
unordered_map
<
u
int64_t
,
size_t
>
extra_nodes_to_thread_index
;
std
::
unordered_map
<
int64_t
,
size_t
>
extra_nodes_to_thread_index
;
bool
use_cache
,
use_duplicate_nodes
;
bool
use_cache
,
use_duplicate_nodes
;
mutable
std
::
mutex
mutex_
;
mutable
std
::
mutex
mutex_
;
std
::
shared_ptr
<
pthread_rwlock_t
>
rw_lock
;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
bool
gpups_mode
;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
std
::
shared_ptr
<
GraphSampler
>
graph_sampler
;
REGISTER_GRAPH_FRIEND_CLASS
(
2
,
CompleteGraphSampler
,
BasicBfsGraphSampler
)
#endif
};
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER
(
GraphSampler
);
class
CompleteGraphSampler
:
public
GraphSampler
{
public:
CompleteGraphSampler
()
{}
~
CompleteGraphSampler
()
{}
// virtual pthread_rwlock_t *export_rw_lock();
virtual
int
run_graph_sampling
();
virtual
void
init
(
size_t
gpu_num
,
GraphTable
*
graph_table
,
std
::
vector
<
std
::
string
>
args_
);
protected:
GraphTable
*
graph_table
;
std
::
vector
<
std
::
vector
<
paddle
::
framework
::
GpuPsGraphNode
>>
sample_nodes
;
std
::
vector
<
std
::
vector
<
int64_t
>>
sample_neighbors
;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int
gpu_num
;
};
class
BasicBfsGraphSampler
:
public
GraphSampler
{
public:
BasicBfsGraphSampler
()
{}
~
BasicBfsGraphSampler
()
{}
// virtual pthread_rwlock_t *export_rw_lock();
virtual
int
run_graph_sampling
();
virtual
void
init
(
size_t
gpu_num
,
GraphTable
*
graph_table
,
std
::
vector
<
std
::
string
>
args_
);
protected:
GraphTable
*
graph_table
;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std
::
vector
<
std
::
vector
<
paddle
::
framework
::
GpuPsGraphNode
>>
sample_nodes
;
std
::
vector
<
std
::
vector
<
int64_t
>>
sample_neighbors
;
size_t
gpu_num
;
int
node_num_for_each_shard
,
edge_num_for_each_node
;
int
rounds
,
interval
;
std
::
vector
<
std
::
unordered_map
<
int64_t
,
std
::
vector
<
int64_t
>>>
sample_neighbors_map
;
};
};
#endif
}
// namespace distributed
}
// namespace distributed
};
// namespace paddle
};
// namespace paddle
...
...
paddle/fluid/distributed/ps/table/graph/class_macro.h
0 → 100644
浏览文件 @
31776199
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define DECLARE_GRAPH_FRIEND_CLASS(a) friend class a;
#define DECLARE_1_FRIEND_CLASS(a, ...) DECLARE_GRAPH_FRIEND_CLASS(a)
#define DECLARE_2_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_1_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_3_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_2_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_4_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_3_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_5_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_4_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_6_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_5_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_7_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_6_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_8_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_7_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_9_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_8_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_10_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_9_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
paddle/fluid/distributed/ps/table/graph/graph_edge.cc
浏览文件 @
31776199
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
void
GraphEdgeBlob
::
add_edge
(
u
int64_t
id
,
float
weight
=
1
)
{
void
GraphEdgeBlob
::
add_edge
(
int64_t
id
,
float
weight
=
1
)
{
id_arr
.
push_back
(
id
);
id_arr
.
push_back
(
id
);
}
}
void
WeightedGraphEdgeBlob
::
add_edge
(
u
int64_t
id
,
float
weight
=
1
)
{
void
WeightedGraphEdgeBlob
::
add_edge
(
int64_t
id
,
float
weight
=
1
)
{
id_arr
.
push_back
(
id
);
id_arr
.
push_back
(
id
);
weight_arr
.
push_back
(
weight
);
weight_arr
.
push_back
(
weight
);
}
}
...
...
paddle/fluid/distributed/ps/table/graph/graph_edge.h
浏览文件 @
31776199
...
@@ -24,19 +24,20 @@ class GraphEdgeBlob {
...
@@ -24,19 +24,20 @@ class GraphEdgeBlob {
GraphEdgeBlob
()
{}
GraphEdgeBlob
()
{}
virtual
~
GraphEdgeBlob
()
{}
virtual
~
GraphEdgeBlob
()
{}
size_t
size
()
{
return
id_arr
.
size
();
}
size_t
size
()
{
return
id_arr
.
size
();
}
virtual
void
add_edge
(
u
int64_t
id
,
float
weight
);
virtual
void
add_edge
(
int64_t
id
,
float
weight
);
u
int64_t
get_id
(
int
idx
)
{
return
id_arr
[
idx
];
}
int64_t
get_id
(
int
idx
)
{
return
id_arr
[
idx
];
}
virtual
float
get_weight
(
int
idx
)
{
return
1
;
}
virtual
float
get_weight
(
int
idx
)
{
return
1
;
}
std
::
vector
<
int64_t
>&
export_id_array
()
{
return
id_arr
;
}
protected:
protected:
std
::
vector
<
u
int64_t
>
id_arr
;
std
::
vector
<
int64_t
>
id_arr
;
};
};
class
WeightedGraphEdgeBlob
:
public
GraphEdgeBlob
{
class
WeightedGraphEdgeBlob
:
public
GraphEdgeBlob
{
public:
public:
WeightedGraphEdgeBlob
()
{}
WeightedGraphEdgeBlob
()
{}
virtual
~
WeightedGraphEdgeBlob
()
{}
virtual
~
WeightedGraphEdgeBlob
()
{}
virtual
void
add_edge
(
u
int64_t
id
,
float
weight
);
virtual
void
add_edge
(
int64_t
id
,
float
weight
);
virtual
float
get_weight
(
int
idx
)
{
return
weight_arr
[
idx
];
}
virtual
float
get_weight
(
int
idx
)
{
return
weight_arr
[
idx
];
}
protected:
protected:
...
...
paddle/fluid/distributed/ps/table/graph/graph_node.h
浏览文件 @
31776199
...
@@ -48,6 +48,7 @@ class Node {
...
@@ -48,6 +48,7 @@ class Node {
virtual
void
set_feature
(
int
idx
,
std
::
string
str
)
{}
virtual
void
set_feature
(
int
idx
,
std
::
string
str
)
{}
virtual
void
set_feature_size
(
int
size
)
{}
virtual
void
set_feature_size
(
int
size
)
{}
virtual
int
get_feature_size
()
{
return
0
;
}
virtual
int
get_feature_size
()
{
return
0
;
}
virtual
size_t
get_neighbor_size
()
{
return
0
;
}
protected:
protected:
uint64_t
id
;
uint64_t
id
;
...
@@ -70,6 +71,7 @@ class GraphNode : public Node {
...
@@ -70,6 +71,7 @@ class GraphNode : public Node {
}
}
virtual
uint64_t
get_neighbor_id
(
int
idx
)
{
return
edges
->
get_id
(
idx
);
}
virtual
uint64_t
get_neighbor_id
(
int
idx
)
{
return
edges
->
get_id
(
idx
);
}
virtual
float
get_neighbor_weight
(
int
idx
)
{
return
edges
->
get_weight
(
idx
);
}
virtual
float
get_neighbor_weight
(
int
idx
)
{
return
edges
->
get_weight
(
idx
);
}
virtual
size_t
get_neighbor_size
()
{
return
edges
->
size
();
}
protected:
protected:
Sampler
*
sampler
;
Sampler
*
sampler
;
...
...
paddle/fluid/distributed/ps/table/table.cc
浏览文件 @
31776199
...
@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
...
@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS
(
Table
,
CommonSparseTable
);
REGISTER_PSCORE_CLASS
(
Table
,
CommonSparseTable
);
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS
(
Table
,
SSDSparseTable
);
REGISTER_PSCORE_CLASS
(
Table
,
SSDSparseTable
);
REGISTER_PSCORE_CLASS
(
GraphSampler
,
CompleteGraphSampler
);
REGISTER_PSCORE_CLASS
(
GraphSampler
,
BasicBfsGraphSampler
);
#endif
#endif
REGISTER_PSCORE_CLASS
(
Table
,
SparseGeoTable
);
REGISTER_PSCORE_CLASS
(
Table
,
SparseGeoTable
);
REGISTER_PSCORE_CLASS
(
Table
,
BarrierTable
);
REGISTER_PSCORE_CLASS
(
Table
,
BarrierTable
);
...
...
paddle/fluid/distributed/test/CMakeLists.txt
浏览文件 @
31776199
...
@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv
...
@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv
set_source_files_properties
(
graph_node_split_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
graph_node_split_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto
${
COMMON_DEPS
}
)
cc_test
(
graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto
${
COMMON_DEPS
}
)
set_source_files_properties
(
graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto
${
COMMON_DEPS
}
)
set_source_files_properties
(
feature_value_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
feature_value_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
feature_value_test SRCS feature_value_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
cc_test
(
feature_value_test SRCS feature_value_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
...
...
paddle/fluid/distributed/test/graph_node_split_test.cc
浏览文件 @
31776199
...
@@ -236,7 +236,7 @@ void RunGraphSplit() {
...
@@ -236,7 +236,7 @@ void RunGraphSplit() {
sleep
(
2
);
sleep
(
2
);
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
dense_regions
.
insert
(
dense_regions
.
insert
(
std
::
pair
<
u
int64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
std
::
pair
<
int64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
auto
regions
=
dense_regions
[
0
];
auto
regions
=
dense_regions
[
0
];
RunClient
(
dense_regions
,
0
,
pserver_ptr_
->
get_service
());
RunClient
(
dense_regions
,
0
,
pserver_ptr_
->
get_service
());
...
@@ -250,16 +250,16 @@ void RunGraphSplit() {
...
@@ -250,16 +250,16 @@ void RunGraphSplit() {
worker_ptr_
->
load
(
0
,
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
worker_ptr_
->
load
(
0
,
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
srand
(
time
(
0
));
srand
(
time
(
0
));
pull_status
.
wait
();
pull_status
.
wait
();
std
::
vector
<
std
::
vector
<
u
int64_t
>>
_vs
;
std
::
vector
<
std
::
vector
<
int64_t
>>
_vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs
;
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
10240001024
),
4
,
_vs
,
vs
,
true
);
0
,
std
::
vector
<
int64_t
>
(
1
,
10240001024
),
4
,
_vs
,
vs
,
true
);
pull_status
.
wait
();
pull_status
.
wait
();
ASSERT_EQ
(
0
,
_vs
[
0
].
size
());
ASSERT_EQ
(
0
,
_vs
[
0
].
size
());
_vs
.
clear
();
_vs
.
clear
();
vs
.
clear
();
vs
.
clear
();
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
97
),
4
,
_vs
,
vs
,
true
);
0
,
std
::
vector
<
int64_t
>
(
1
,
97
),
4
,
_vs
,
vs
,
true
);
pull_status
.
wait
();
pull_status
.
wait
();
ASSERT_EQ
(
3
,
_vs
[
0
].
size
());
ASSERT_EQ
(
3
,
_vs
[
0
].
size
());
std
::
remove
(
edge_file_name
);
std
::
remove
(
edge_file_name
);
...
...
paddle/fluid/distributed/test/graph_node_test.cc
浏览文件 @
31776199
...
@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed;
...
@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed;
void
testSampleNodes
(
void
testSampleNodes
(
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
vector
<
u
int64_t
>
ids
;
std
::
vector
<
int64_t
>
ids
;
auto
pull_status
=
worker_ptr_
->
random_sample_nodes
(
0
,
0
,
6
,
ids
);
auto
pull_status
=
worker_ptr_
->
random_sample_nodes
(
0
,
0
,
6
,
ids
);
std
::
unordered_set
<
u
int64_t
>
s
;
std
::
unordered_set
<
int64_t
>
s
;
std
::
unordered_set
<
u
int64_t
>
s1
=
{
37
,
59
};
std
::
unordered_set
<
int64_t
>
s1
=
{
37
,
59
};
pull_status
.
wait
();
pull_status
.
wait
();
for
(
auto
id
:
ids
)
s
.
insert
(
id
);
for
(
auto
id
:
ids
)
s
.
insert
(
id
);
ASSERT_EQ
(
true
,
s
.
size
()
==
s1
.
size
());
ASSERT_EQ
(
true
,
s
.
size
()
==
s1
.
size
());
...
@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() {
...
@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() {
void
testSingleSampleNeighboor
(
void
testSingleSampleNeighboor
(
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
vector
<
std
::
vector
<
u
int64_t
>>
vs
;
std
::
vector
<
std
::
vector
<
int64_t
>>
vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs1
;
std
::
vector
<
std
::
vector
<
float
>>
vs1
;
auto
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
auto
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
37
),
4
,
vs
,
vs1
,
true
);
0
,
std
::
vector
<
int64_t
>
(
1
,
37
),
4
,
vs
,
vs1
,
true
);
pull_status
.
wait
();
pull_status
.
wait
();
std
::
unordered_set
<
u
int64_t
>
s
;
std
::
unordered_set
<
int64_t
>
s
;
std
::
unordered_set
<
u
int64_t
>
s1
=
{
112
,
45
,
145
};
std
::
unordered_set
<
int64_t
>
s1
=
{
112
,
45
,
145
};
for
(
auto
g
:
vs
[
0
])
{
for
(
auto
g
:
vs
[
0
])
{
s
.
insert
(
g
);
s
.
insert
(
g
);
}
}
...
@@ -126,7 +126,7 @@ void testSingleSampleNeighboor(
...
@@ -126,7 +126,7 @@ void testSingleSampleNeighboor(
vs
.
clear
();
vs
.
clear
();
vs1
.
clear
();
vs1
.
clear
();
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
96
),
4
,
vs
,
vs1
,
true
);
0
,
std
::
vector
<
int64_t
>
(
1
,
96
),
4
,
vs
,
vs1
,
true
);
pull_status
.
wait
();
pull_status
.
wait
();
s1
=
{
111
,
48
,
247
};
s1
=
{
111
,
48
,
247
};
for
(
auto
g
:
vs
[
0
])
{
for
(
auto
g
:
vs
[
0
])
{
...
@@ -147,30 +147,30 @@ void testAddNode(
...
@@ -147,30 +147,30 @@ void testAddNode(
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
worker_ptr_
->
clear_nodes
(
0
);
worker_ptr_
->
clear_nodes
(
0
);
int
total_num
=
270000
;
int
total_num
=
270000
;
u
int64_t
id
;
int64_t
id
;
std
::
unordered_set
<
u
int64_t
>
id_set
;
std
::
unordered_set
<
int64_t
>
id_set
;
for
(
int
i
=
0
;
i
<
total_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
total_num
;
i
++
)
{
while
(
id_set
.
find
(
id
=
rand
())
!=
id_set
.
end
())
while
(
id_set
.
find
(
id
=
rand
())
!=
id_set
.
end
())
;
;
id_set
.
insert
(
id
);
id_set
.
insert
(
id
);
}
}
std
::
vector
<
u
int64_t
>
id_list
(
id_set
.
begin
(),
id_set
.
end
());
std
::
vector
<
int64_t
>
id_list
(
id_set
.
begin
(),
id_set
.
end
());
std
::
vector
<
bool
>
weight_list
;
std
::
vector
<
bool
>
weight_list
;
auto
status
=
worker_ptr_
->
add_graph_node
(
0
,
id_list
,
weight_list
);
auto
status
=
worker_ptr_
->
add_graph_node
(
0
,
id_list
,
weight_list
);
status
.
wait
();
status
.
wait
();
std
::
vector
<
u
int64_t
>
ids
[
2
];
std
::
vector
<
int64_t
>
ids
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
auto
sample_status
=
auto
sample_status
=
worker_ptr_
->
random_sample_nodes
(
0
,
i
,
total_num
,
ids
[
i
]);
worker_ptr_
->
random_sample_nodes
(
0
,
i
,
total_num
,
ids
[
i
]);
sample_status
.
wait
();
sample_status
.
wait
();
}
}
std
::
unordered_set
<
u
int64_t
>
id_set_check
(
ids
[
0
].
begin
(),
ids
[
0
].
end
());
std
::
unordered_set
<
int64_t
>
id_set_check
(
ids
[
0
].
begin
(),
ids
[
0
].
end
());
for
(
auto
x
:
ids
[
1
])
id_set_check
.
insert
(
x
);
for
(
auto
x
:
ids
[
1
])
id_set_check
.
insert
(
x
);
ASSERT_EQ
(
id_set
.
size
(),
id_set_check
.
size
());
ASSERT_EQ
(
id_set
.
size
(),
id_set_check
.
size
());
for
(
auto
x
:
id_set
)
{
for
(
auto
x
:
id_set
)
{
ASSERT_EQ
(
id_set_check
.
find
(
x
)
!=
id_set_check
.
end
(),
true
);
ASSERT_EQ
(
id_set_check
.
find
(
x
)
!=
id_set_check
.
end
(),
true
);
}
}
std
::
vector
<
u
int64_t
>
remove_ids
;
std
::
vector
<
int64_t
>
remove_ids
;
for
(
auto
p
:
id_set_check
)
{
for
(
auto
p
:
id_set_check
)
{
if
(
remove_ids
.
size
()
==
0
)
if
(
remove_ids
.
size
()
==
0
)
remove_ids
.
push_back
(
p
);
remove_ids
.
push_back
(
p
);
...
@@ -187,7 +187,7 @@ void testAddNode(
...
@@ -187,7 +187,7 @@ void testAddNode(
worker_ptr_
->
random_sample_nodes
(
0
,
i
,
total_num
,
ids
[
i
]);
worker_ptr_
->
random_sample_nodes
(
0
,
i
,
total_num
,
ids
[
i
]);
sample_status
.
wait
();
sample_status
.
wait
();
}
}
std
::
unordered_set
<
u
int64_t
>
id_set_check1
(
ids
[
0
].
begin
(),
ids
[
0
].
end
());
std
::
unordered_set
<
int64_t
>
id_set_check1
(
ids
[
0
].
begin
(),
ids
[
0
].
end
());
for
(
auto
x
:
ids
[
1
])
id_set_check1
.
insert
(
x
);
for
(
auto
x
:
ids
[
1
])
id_set_check1
.
insert
(
x
);
ASSERT_EQ
(
id_set_check1
.
size
(),
id_set_check
.
size
());
ASSERT_EQ
(
id_set_check1
.
size
(),
id_set_check
.
size
());
for
(
auto
x
:
id_set_check1
)
{
for
(
auto
x
:
id_set_check1
)
{
...
@@ -196,14 +196,14 @@ void testAddNode(
...
@@ -196,14 +196,14 @@ void testAddNode(
}
}
void
testBatchSampleNeighboor
(
void
testBatchSampleNeighboor
(
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>&
worker_ptr_
)
{
std
::
vector
<
std
::
vector
<
u
int64_t
>>
vs
;
std
::
vector
<
std
::
vector
<
int64_t
>>
vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs1
;
std
::
vector
<
std
::
vector
<
float
>>
vs1
;
std
::
vector
<
std
::
u
int64_t
>
v
=
{
37
,
96
};
std
::
vector
<
std
::
int64_t
>
v
=
{
37
,
96
};
auto
pull_status
=
auto
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
v
,
4
,
vs
,
vs1
,
false
);
worker_ptr_
->
batch_sample_neighbors
(
0
,
v
,
4
,
vs
,
vs1
,
false
);
pull_status
.
wait
();
pull_status
.
wait
();
std
::
unordered_set
<
u
int64_t
>
s
;
std
::
unordered_set
<
int64_t
>
s
;
std
::
unordered_set
<
u
int64_t
>
s1
=
{
112
,
45
,
145
};
std
::
unordered_set
<
int64_t
>
s1
=
{
112
,
45
,
145
};
for
(
auto
g
:
vs
[
0
])
{
for
(
auto
g
:
vs
[
0
])
{
s
.
insert
(
g
);
s
.
insert
(
g
);
}
}
...
@@ -417,7 +417,7 @@ void RunBrpcPushSparse() {
...
@@ -417,7 +417,7 @@ void RunBrpcPushSparse() {
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
dense_regions
.
insert
(
dense_regions
.
insert
(
std
::
pair
<
u
int64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
std
::
pair
<
int64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
auto
regions
=
dense_regions
[
0
];
auto
regions
=
dense_regions
[
0
];
RunClient
(
dense_regions
,
0
,
pserver_ptr_
->
get_service
());
RunClient
(
dense_regions
,
0
,
pserver_ptr_
->
get_service
());
...
@@ -427,14 +427,14 @@ void RunBrpcPushSparse() {
...
@@ -427,14 +427,14 @@ void RunBrpcPushSparse() {
worker_ptr_
->
load
(
0
,
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
worker_ptr_
->
load
(
0
,
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
srand
(
time
(
0
));
srand
(
time
(
0
));
pull_status
.
wait
();
pull_status
.
wait
();
std
::
vector
<
std
::
vector
<
u
int64_t
>>
_vs
;
std
::
vector
<
std
::
vector
<
int64_t
>>
_vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs
;
testSampleNodes
(
worker_ptr_
);
testSampleNodes
(
worker_ptr_
);
sleep
(
5
);
sleep
(
5
);
testSingleSampleNeighboor
(
worker_ptr_
);
testSingleSampleNeighboor
(
worker_ptr_
);
testBatchSampleNeighboor
(
worker_ptr_
);
testBatchSampleNeighboor
(
worker_ptr_
);
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
10240001024
),
4
,
_vs
,
vs
,
true
);
0
,
std
::
vector
<
int64_t
>
(
1
,
10240001024
),
4
,
_vs
,
vs
,
true
);
pull_status
.
wait
();
pull_status
.
wait
();
ASSERT_EQ
(
0
,
_vs
[
0
].
size
());
ASSERT_EQ
(
0
,
_vs
[
0
].
size
());
paddle
::
distributed
::
GraphTable
*
g
=
paddle
::
distributed
::
GraphTable
*
g
=
...
@@ -445,14 +445,14 @@ void RunBrpcPushSparse() {
...
@@ -445,14 +445,14 @@ void RunBrpcPushSparse() {
while
(
round
--
)
{
while
(
round
--
)
{
vs
.
clear
();
vs
.
clear
();
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
37
),
1
,
_vs
,
vs
,
false
);
0
,
std
::
vector
<
int64_t
>
(
1
,
37
),
1
,
_vs
,
vs
,
false
);
pull_status
.
wait
();
pull_status
.
wait
();
for
(
int
i
=
0
;
i
<
ttl
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ttl
;
i
++
)
{
std
::
vector
<
std
::
vector
<
u
int64_t
>>
vs1
;
std
::
vector
<
std
::
vector
<
int64_t
>>
vs1
;
std
::
vector
<
std
::
vector
<
float
>>
vs2
;
std
::
vector
<
std
::
vector
<
float
>>
vs2
;
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
u
int64_t
>
(
1
,
37
),
1
,
vs1
,
vs2
,
false
);
0
,
std
::
vector
<
int64_t
>
(
1
,
37
),
1
,
vs1
,
vs2
,
false
);
pull_status
.
wait
();
pull_status
.
wait
();
ASSERT_EQ
(
_vs
[
0
].
size
(),
vs1
[
0
].
size
());
ASSERT_EQ
(
_vs
[
0
].
size
(),
vs1
[
0
].
size
());
...
@@ -540,7 +540,7 @@ void RunBrpcPushSparse() {
...
@@ -540,7 +540,7 @@ void RunBrpcPushSparse() {
// Test Pull by step
// Test Pull by step
std
::
unordered_set
<
u
int64_t
>
count_item_nodes
;
std
::
unordered_set
<
int64_t
>
count_item_nodes
;
// pull by step 2
// pull by step 2
for
(
int
test_step
=
1
;
test_step
<
4
;
test_step
++
)
{
for
(
int
test_step
=
1
;
test_step
<
4
;
test_step
++
)
{
count_item_nodes
.
clear
();
count_item_nodes
.
clear
();
...
@@ -558,18 +558,18 @@ void RunBrpcPushSparse() {
...
@@ -558,18 +558,18 @@ void RunBrpcPushSparse() {
ASSERT_EQ
(
count_item_nodes
.
size
(),
12
);
ASSERT_EQ
(
count_item_nodes
.
size
(),
12
);
}
}
std
::
pair
<
std
::
vector
<
std
::
vector
<
u
int64_t
>>
,
std
::
vector
<
float
>>
res
;
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
res
;
res
=
client1
.
batch_sample_neighbors
(
res
=
client1
.
batch_sample_neighbors
(
std
::
string
(
"user2item"
),
std
::
vector
<
u
int64_t
>
(
1
,
96
),
4
,
true
,
false
);
std
::
string
(
"user2item"
),
std
::
vector
<
int64_t
>
(
1
,
96
),
4
,
true
,
false
);
ASSERT_EQ
(
res
.
first
[
0
].
size
(),
3
);
ASSERT_EQ
(
res
.
first
[
0
].
size
(),
3
);
std
::
vector
<
u
int64_t
>
node_ids
;
std
::
vector
<
int64_t
>
node_ids
;
node_ids
.
push_back
(
96
);
node_ids
.
push_back
(
96
);
node_ids
.
push_back
(
37
);
node_ids
.
push_back
(
37
);
res
=
client1
.
batch_sample_neighbors
(
std
::
string
(
"user2item"
),
node_ids
,
4
,
res
=
client1
.
batch_sample_neighbors
(
std
::
string
(
"user2item"
),
node_ids
,
4
,
true
,
false
);
true
,
false
);
ASSERT_EQ
(
res
.
first
[
1
].
size
(),
1
);
ASSERT_EQ
(
res
.
first
[
1
].
size
(),
1
);
std
::
vector
<
u
int64_t
>
nodes_ids
=
client2
.
random_sample_nodes
(
"user"
,
0
,
6
);
std
::
vector
<
int64_t
>
nodes_ids
=
client2
.
random_sample_nodes
(
"user"
,
0
,
6
);
ASSERT_EQ
(
nodes_ids
.
size
(),
2
);
ASSERT_EQ
(
nodes_ids
.
size
(),
2
);
ASSERT_EQ
(
true
,
(
nodes_ids
[
0
]
==
59
&&
nodes_ids
[
1
]
==
37
)
||
ASSERT_EQ
(
true
,
(
nodes_ids
[
0
]
==
59
&&
nodes_ids
[
1
]
==
37
)
||
(
nodes_ids
[
0
]
==
37
&&
nodes_ids
[
1
]
==
59
));
(
nodes_ids
[
0
]
==
37
&&
nodes_ids
[
1
]
==
59
));
...
...
paddle/fluid/distributed/test/graph_table_sample_test.cc
0 → 100644
浏览文件 @
31776199
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
operators
=
paddle
::
operators
;
namespace
memory
=
paddle
::
memory
;
namespace
distributed
=
paddle
::
distributed
;
std
::
vector
<
std
::
string
>
edges
=
{
std
::
string
(
"37
\t
45
\t
0.34"
),
std
::
string
(
"37
\t
145
\t
0.31"
),
std
::
string
(
"37
\t
112
\t
0.21"
),
std
::
string
(
"96
\t
48
\t
1.4"
),
std
::
string
(
"96
\t
247
\t
0.31"
),
std
::
string
(
"96
\t
111
\t
1.21"
),
std
::
string
(
"59
\t
45
\t
0.34"
),
std
::
string
(
"59
\t
145
\t
0.31"
),
std
::
string
(
"59
\t
122
\t
0.21"
),
std
::
string
(
"97
\t
48
\t
0.34"
),
std
::
string
(
"97
\t
247
\t
0.31"
),
std
::
string
(
"97
\t
111
\t
0.21"
)};
// odd id:96 48 122 112
char
edge_file_name
[]
=
"edges.txt"
;
std
::
vector
<
std
::
string
>
nodes
=
{
std
::
string
(
"user
\t
37
\t
a 0.34
\t
b 13 14
\t
c hello
\t
d abc"
),
std
::
string
(
"user
\t
96
\t
a 0.31
\t
b 15 10
\t
c 96hello
\t
d abcd"
),
std
::
string
(
"user
\t
59
\t
a 0.11
\t
b 11 14"
),
std
::
string
(
"user
\t
97
\t
a 0.11
\t
b 12 11"
),
std
::
string
(
"item
\t
45
\t
a 0.21"
),
std
::
string
(
"item
\t
145
\t
a 0.21"
),
std
::
string
(
"item
\t
112
\t
a 0.21"
),
std
::
string
(
"item
\t
48
\t
a 0.21"
),
std
::
string
(
"item
\t
247
\t
a 0.21"
),
std
::
string
(
"item
\t
111
\t
a 0.21"
),
std
::
string
(
"item
\t
46
\t
a 0.21"
),
std
::
string
(
"item
\t
146
\t
a 0.21"
),
std
::
string
(
"item
\t
122
\t
a 0.21"
),
std
::
string
(
"item
\t
49
\t
a 0.21"
),
std
::
string
(
"item
\t
248
\t
a 0.21"
),
std
::
string
(
"item
\t
113
\t
a 0.21"
)};
char
node_file_name
[]
=
"nodes.txt"
;
void
prepare_file
(
char
file_name
[],
std
::
vector
<
std
::
string
>
data
)
{
std
::
ofstream
ofile
;
ofile
.
open
(
file_name
);
for
(
auto
x
:
data
)
{
ofile
<<
x
<<
std
::
endl
;
}
ofile
.
close
();
}
void
testGraphSample
()
{
#ifdef PADDLE_WITH_HETERPS
::
paddle
::
distributed
::
GraphParameter
table_proto
;
table_proto
.
set_gpups_mode
(
true
);
table_proto
.
set_gpups_mode_shard_num
(
127
);
table_proto
.
set_gpu_num
(
2
);
distributed
::
GraphTable
graph_table
,
graph_table1
;
graph_table
.
initialize
(
table_proto
);
prepare_file
(
edge_file_name
,
edges
);
graph_table
.
load
(
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
res
;
std
::
promise
<
int
>
prom
;
std
::
future
<
int
>
fut
=
prom
.
get_future
();
graph_table
.
set_graph_sample_callback
(
[
&
res
,
&
prom
](
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
&
res0
)
{
res
=
res0
;
prom
.
set_value
(
0
);
});
graph_table
.
start_graph_sampling
();
fut
.
get
();
graph_table
.
end_graph_sampling
();
ASSERT_EQ
(
2
,
res
.
size
());
// 37 59 97
for
(
int
i
=
0
;
i
<
(
int
)
res
[
1
].
node_size
;
i
++
)
{
std
::
cout
<<
res
[
1
].
node_list
[
i
].
node_id
<<
std
::
endl
;
}
ASSERT_EQ
(
3
,
res
[
1
].
node_size
);
::
paddle
::
distributed
::
GraphParameter
table_proto1
;
table_proto1
.
set_gpups_mode
(
true
);
table_proto1
.
set_gpups_mode_shard_num
(
127
);
table_proto1
.
set_gpu_num
(
2
);
table_proto1
.
set_gpups_graph_sample_class
(
"BasicBfsGraphSampler"
);
table_proto1
.
set_gpups_graph_sample_args
(
"5,5,1,1"
);
graph_table1
.
initialize
(
table_proto1
);
graph_table1
.
load
(
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
res1
;
std
::
promise
<
int
>
prom1
;
std
::
future
<
int
>
fut1
=
prom1
.
get_future
();
graph_table1
.
set_graph_sample_callback
(
[
&
res1
,
&
prom1
](
std
::
vector
<
paddle
::
framework
::
GpuPsCommGraph
>
&
res0
)
{
res1
=
res0
;
prom1
.
set_value
(
0
);
});
graph_table1
.
start_graph_sampling
();
fut1
.
get
();
graph_table1
.
end_graph_sampling
();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ
(
2
,
res1
.
size
());
// odd id:96 48 122 112
for
(
int
i
=
0
;
i
<
(
int
)
res1
[
0
].
node_size
;
i
++
)
{
std
::
cout
<<
res1
[
0
].
node_list
[
i
].
node_id
<<
std
::
endl
;
}
ASSERT_EQ
(
4
,
res1
[
0
].
node_size
);
#endif
}
TEST
(
testGraphSample
,
Run
)
{
testGraphSample
();
}
paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt
浏览文件 @
31776199
...
@@ -10,8 +10,9 @@ IF(WITH_GPU)
...
@@ -10,8 +10,9 @@ IF(WITH_GPU)
nv_library
(
heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS
${
HETERPS_DEPS
}
)
nv_library
(
heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS
${
HETERPS_DEPS
}
)
nv_test
(
test_heter_comm SRCS feature_value.h DEPS heter_comm
)
nv_test
(
test_heter_comm SRCS feature_value.h DEPS heter_comm
)
nv_library
(
heter_ps SRCS heter_ps.cu DEPS heter_comm
)
nv_library
(
heter_ps SRCS heter_ps.cu DEPS heter_comm
)
nv_library
(
graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm
)
nv_library
(
graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm
table
)
nv_test
(
test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps
)
nv_test
(
test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps
)
nv_test
(
test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps
)
ENDIF
()
ENDIF
()
IF
(
WITH_ROCM
)
IF
(
WITH_ROCM
)
hip_library
(
heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context
)
hip_library
(
heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context
)
...
...
paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
0 → 100644
浏览文件 @
31776199
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
namespace
paddle
{
namespace
framework
{
struct
GpuPsGraphNode
{
int64_t
node_id
;
int
neighbor_size
,
neighbor_offset
;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct
GpuPsCommGraph
{
int64_t
*
neighbor_list
;
GpuPsGraphNode
*
node_list
;
int
neighbor_size
,
node_size
;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph
()
:
neighbor_list
(
NULL
),
node_list
(
NULL
),
neighbor_size
(
0
),
node_size
(
0
)
{}
GpuPsCommGraph
(
int64_t
*
neighbor_list_
,
GpuPsGraphNode
*
node_list_
,
int
neighbor_size_
,
int
node_size_
)
:
neighbor_list
(
neighbor_list_
),
node_list
(
node_list_
),
neighbor_size
(
neighbor_size_
),
node_size
(
node_size_
)
{}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct
NeighborSampleResult
{
int64_t
*
val
;
int
*
actual_sample_size
,
sample_size
,
key_size
;
NeighborSampleResult
(
int
_sample_size
,
int
_key_size
)
:
sample_size
(
_sample_size
),
key_size
(
_key_size
)
{
actual_sample_size
=
NULL
;
val
=
NULL
;
};
~
NeighborSampleResult
()
{
if
(
val
!=
NULL
)
cudaFree
(
val
);
if
(
actual_sample_size
!=
NULL
)
cudaFree
(
actual_sample_size
);
}
};
struct
NodeQueryResult
{
int64_t
*
val
;
int
actual_sample_size
;
NodeQueryResult
()
{
val
=
NULL
;
actual_sample_size
=
0
;
};
~
NodeQueryResult
()
{
if
(
val
!=
NULL
)
cudaFree
(
val
);
}
};
}
};
#endif
paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
浏览文件 @
31776199
...
@@ -14,114 +14,25 @@
...
@@ -14,114 +14,25 @@
#pragma once
#pragma once
#include "heter_comm.h"
#include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
struct
GpuPsGraphNode
{
int64_t
node_id
;
int
neighbor_size
,
neighbor_offset
;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct
GpuPsCommGraph
{
int64_t
*
neighbor_list
;
GpuPsGraphNode
*
node_list
;
int
neighbor_size
,
node_size
;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph
()
:
neighbor_list
(
NULL
),
node_list
(
NULL
),
neighbor_size
(
0
),
node_size
(
0
)
{}
GpuPsCommGraph
(
int64_t
*
neighbor_list_
,
GpuPsGraphNode
*
node_list_
,
int
neighbor_size_
,
int
node_size_
)
:
neighbor_list
(
neighbor_list_
),
node_list
(
node_list_
),
neighbor_size
(
neighbor_size_
),
node_size
(
node_size_
)
{}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct
NeighborSampleResult
{
int64_t
*
val
;
int
*
actual_sample_size
,
sample_size
,
key_size
;
NeighborSampleResult
(
int
_sample_size
,
int
_key_size
)
:
sample_size
(
_sample_size
),
key_size
(
_key_size
)
{
actual_sample_size
=
NULL
;
val
=
NULL
;
};
~
NeighborSampleResult
()
{
if
(
val
!=
NULL
)
cudaFree
(
val
);
if
(
actual_sample_size
!=
NULL
)
cudaFree
(
actual_sample_size
);
}
};
struct
NodeQueryResult
{
int64_t
*
val
;
int
actual_sample_size
;
NodeQueryResult
()
{
val
=
NULL
;
actual_sample_size
=
0
;
};
~
NodeQueryResult
()
{
if
(
val
!=
NULL
)
cudaFree
(
val
);
}
};
class
GpuPsGraphTable
:
public
HeterComm
<
int64_t
,
int
,
int
>
{
class
GpuPsGraphTable
:
public
HeterComm
<
int64_t
,
int
,
int
>
{
public:
public:
GpuPsGraphTable
(
std
::
shared_ptr
<
HeterPsResource
>
resource
)
GpuPsGraphTable
(
std
::
shared_ptr
<
HeterPsResource
>
resource
)
:
HeterComm
<
int64_t
,
int
,
int
>
(
1
,
resource
)
{
:
HeterComm
<
int64_t
,
int
,
int
>
(
1
,
resource
)
{
load_factor_
=
0.25
;
load_factor_
=
0.25
;
rw_lock
.
reset
(
new
pthread_rwlock_t
());
cpu_table_status
=
-
1
;
}
~
GpuPsGraphTable
()
{
if
(
cpu_table_status
!=
-
1
)
{
end_graph_sampling
();
}
}
}
void
build_graph_from_cpu
(
std
::
vector
<
GpuPsCommGraph
>
&
cpu_node_list
);
void
build_graph_from_cpu
(
std
::
vector
<
GpuPsCommGraph
>
&
cpu_node_list
);
NodeQueryResult
*
graph_node_sample
(
int
gpu_id
,
int
sample_size
);
NodeQueryResult
*
graph_node_sample
(
int
gpu_id
,
int
sample_size
);
...
@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
...
@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int
*
h_right
,
int
*
h_right
,
int64_t
*
src_sample_res
,
int64_t
*
src_sample_res
,
int
*
actual_sample_size
);
int
*
actual_sample_size
);
int
init_cpu_table
(
const
paddle
::
distributed
::
GraphParameter
&
graph
);
int
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
virtual
int32_t
end_graph_sampling
()
{
return
cpu_graph_table
->
end_graph_sampling
();
}
private:
private:
std
::
vector
<
GpuPsCommGraph
>
gpu_graph_list
;
std
::
vector
<
GpuPsCommGraph
>
gpu_graph_list
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphTable
>
cpu_graph_table
;
std
::
shared_ptr
<
pthread_rwlock_t
>
rw_lock
;
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
int
cpu_table_status
;
};
};
}
}
};
};
...
...
paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h
浏览文件 @
31776199
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
/*
/*
...
@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
...
@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
}
}
}
}
int
GpuPsGraphTable
::
init_cpu_table
(
const
paddle
::
distributed
::
GraphParameter
&
graph
)
{
cpu_graph_table
.
reset
(
new
paddle
::
distributed
::
GraphTable
);
cpu_table_status
=
cpu_graph_table
->
initialize
(
graph
);
if
(
cpu_table_status
!=
0
)
return
cpu_table_status
;
std
::
function
<
void
(
std
::
vector
<
GpuPsCommGraph
>&
)
>
callback
=
[
this
](
std
::
vector
<
GpuPsCommGraph
>&
res
)
{
pthread_rwlock_wrlock
(
this
->
rw_lock
.
get
());
this
->
clear_graph_info
();
this
->
build_graph_from_cpu
(
res
);
pthread_rwlock_unlock
(
this
->
rw_lock
.
get
());
cv_
.
notify_one
();
};
cpu_graph_table
->
set_graph_sample_callback
(
callback
);
return
cpu_table_status
;
}
int
GpuPsGraphTable
::
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
int
status
=
cpu_graph_table
->
load
(
path
,
param
);
if
(
status
!=
0
)
{
return
status
;
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cpu_graph_table
->
start_graph_sampling
();
cv_
.
wait
(
lock
);
return
0
;
}
/*
/*
comment 1
comment 1
...
@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
...
@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
that's what fill_dvals does.
that's what fill_dvals does.
*/
*/
void
GpuPsGraphTable
::
move_neighbor_sample_result_to_source_gpu
(
void
GpuPsGraphTable
::
move_neighbor_sample_result_to_source_gpu
(
int
gpu_id
,
int
gpu_num
,
int
sample_size
,
int
*
h_left
,
int
*
h_right
,
int
gpu_id
,
int
gpu_num
,
int
sample_size
,
int
*
h_left
,
int
*
h_right
,
int64_t
*
src_sample_res
,
int
*
actual_sample_size
)
{
int64_t
*
src_sample_res
,
int
*
actual_sample_size
)
{
...
@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
...
@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int64_t
));
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int64_t
));
int64_t
*
d_shard_keys_ptr
=
reinterpret_cast
<
int64_t
*>
(
d_shard_keys
->
ptr
());
int64_t
*
d_shard_keys_ptr
=
reinterpret_cast
<
int64_t
*>
(
d_shard_keys
->
ptr
());
auto
d_shard_vals
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int64_t
));
auto
d_shard_vals
=
memory
::
Alloc
(
place
,
sample_size
*
len
*
sizeof
(
int64_t
));
int64_t
*
d_shard_vals_ptr
=
reinterpret_cast
<
int64_t
*>
(
d_shard_vals
->
ptr
());
int64_t
*
d_shard_vals_ptr
=
reinterpret_cast
<
int64_t
*>
(
d_shard_vals
->
ptr
());
auto
d_shard_actual_sample_size
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
auto
d_shard_actual_sample_size
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_shard_actual_sample_size_ptr
=
int
*
d_shard_actual_sample_size_ptr
=
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
浏览文件 @
31776199
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include <queue>
#include <queue>
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/framework/fleet/heter_ps/test_cpu_graph_sample.cu
0 → 100644
浏览文件 @
31776199
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using
namespace
paddle
::
framework
;
void
prepare_file
(
char
file_name
[],
std
::
vector
<
std
::
string
>
data
)
{
std
::
ofstream
ofile
;
ofile
.
open
(
file_name
);
for
(
auto
x
:
data
)
{
ofile
<<
x
<<
std
::
endl
;
}
ofile
.
close
();
}
char
edge_file_name
[]
=
"edges.txt"
;
TEST
(
TEST_FLEET
,
graph_sample
)
{
std
::
vector
<
std
::
string
>
edges
;
int
gpu_count
=
3
;
std
::
vector
<
int
>
dev_ids
;
dev_ids
.
push_back
(
0
);
dev_ids
.
push_back
(
1
);
dev_ids
.
push_back
(
2
);
std
::
shared_ptr
<
HeterPsResource
>
resource
=
std
::
make_shared
<
HeterPsResource
>
(
dev_ids
);
resource
->
enable_p2p
();
GpuPsGraphTable
g
(
resource
);
int
node_count
=
10
;
std
::
vector
<
std
::
vector
<
int64_t
>>
neighbors
(
node_count
);
int
ind
=
0
;
int64_t
node_id
=
0
;
// std::vector<GpuPsCommGraph> graph_list(gpu_count);
while
(
ind
<
node_count
)
{
int
neighbor_size
=
ind
+
1
;
while
(
neighbor_size
--
)
{
edges
.
push_back
(
std
::
to_string
(
ind
)
+
"
\t
"
+
std
::
to_string
(
node_id
)
+
"
\t
1.0"
);
node_id
++
;
}
ind
++
;
}
/*
gpu 0:
0,3,6,9
gpu 1:
1,4,7
gpu 2:
2,5,8
query(2,6) returns nodes [6,9,1,4,7,2]
*/
::
paddle
::
distributed
::
GraphParameter
table_proto
;
table_proto
.
set_gpups_mode
(
true
);
table_proto
.
set_gpups_mode_shard_num
(
127
);
table_proto
.
set_gpu_num
(
3
);
table_proto
.
set_gpups_graph_sample_class
(
"BasicBfsGraphSampler"
);
table_proto
.
set_gpups_graph_sample_args
(
"5,5,1,1"
);
prepare_file
(
edge_file_name
,
edges
);
g
.
init_cpu_table
(
table_proto
);
g
.
load
(
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
/*
node x's neighbor list = [(1+x)*x/2,(1+x)*x/2 + 1,.....,(1+x)*x/2 + x]
so node 6's neighbors are [21,22...,27]
node 7's neighbors are [28,29,..35]
node 0's neighbors are [0]
query([7,0,6],sample_size=3) should return [28,29,30,0,x,x,21,22,23]
6 --index-->2
0 --index--->0
7 --index-->2
*/
int64_t
cpu_key
[
3
]
=
{
7
,
0
,
6
};
void
*
key
;
cudaMalloc
((
void
**
)
&
key
,
3
*
sizeof
(
int64_t
));
cudaMemcpy
(
key
,
cpu_key
,
3
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
auto
neighbor_sample_res
=
g
.
graph_neighbor_sample
(
0
,
(
int64_t
*
)
key
,
3
,
3
);
int64_t
*
res
=
new
int64_t
[
9
];
cudaMemcpy
(
res
,
neighbor_sample_res
->
val
,
72
,
cudaMemcpyDeviceToHost
);
std
::
sort
(
res
,
res
+
3
);
std
::
sort
(
res
+
6
,
res
+
9
);
int64_t
expected_sample_val
[]
=
{
28
,
29
,
30
,
0
,
-
1
,
-
1
,
21
,
22
,
23
};
for
(
int
i
=
0
;
i
<
9
;
i
++
)
{
if
(
expected_sample_val
[
i
]
!=
-
1
)
{
ASSERT_EQ
(
res
[
i
],
expected_sample_val
[
i
]);
}
}
delete
[]
res
;
delete
neighbor_sample_res
;
}
paddle/fluid/pybind/fleet_py.cc
浏览文件 @
31776199
...
@@ -225,7 +225,7 @@ void BindGraphPyClient(py::module* m) {
...
@@ -225,7 +225,7 @@ void BindGraphPyClient(py::module* m) {
.
def
(
"stop_server"
,
&
GraphPyClient
::
stop_server
)
.
def
(
"stop_server"
,
&
GraphPyClient
::
stop_server
)
.
def
(
"get_node_feat"
,
.
def
(
"get_node_feat"
,
[](
GraphPyClient
&
self
,
std
::
string
node_type
,
[](
GraphPyClient
&
self
,
std
::
string
node_type
,
std
::
vector
<
u
int64_t
>
node_ids
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
)
{
std
::
vector
<
std
::
string
>
feature_names
)
{
auto
feats
=
auto
feats
=
self
.
get_node_feat
(
node_type
,
node_ids
,
feature_names
);
self
.
get_node_feat
(
node_type
,
node_ids
,
feature_names
);
...
@@ -239,7 +239,7 @@ void BindGraphPyClient(py::module* m) {
...
@@ -239,7 +239,7 @@ void BindGraphPyClient(py::module* m) {
})
})
.
def
(
"set_node_feat"
,
.
def
(
"set_node_feat"
,
[](
GraphPyClient
&
self
,
std
::
string
node_type
,
[](
GraphPyClient
&
self
,
std
::
string
node_type
,
std
::
vector
<
u
int64_t
>
node_ids
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
std
::
vector
<
std
::
string
>
feature_names
,
std
::
vector
<
std
::
vector
<
py
::
bytes
>>
bytes_feats
)
{
std
::
vector
<
std
::
vector
<
py
::
bytes
>>
bytes_feats
)
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
feats
(
bytes_feats
.
size
());
std
::
vector
<
std
::
vector
<
std
::
string
>>
feats
(
bytes_feats
.
size
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录