Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
876aa717
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
876aa717
编写于
12月 02, 2021
作者:
S
seemingwang
提交者:
GitHub
12月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support distributed graph_split load and query. (#37740)
上级
a710abee
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
534 addition
and
33 deletion
+534
-33
paddle/fluid/distributed/service/graph_brpc_client.cc
paddle/fluid/distributed/service/graph_brpc_client.cc
+36
-0
paddle/fluid/distributed/service/graph_brpc_client.h
paddle/fluid/distributed/service/graph_brpc_client.h
+2
-0
paddle/fluid/distributed/service/graph_brpc_server.cc
paddle/fluid/distributed/service/graph_brpc_server.cc
+17
-0
paddle/fluid/distributed/service/graph_brpc_server.h
paddle/fluid/distributed/service/graph_brpc_server.h
+4
-0
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-0
paddle/fluid/distributed/table/common_graph_table.cc
paddle/fluid/distributed/table/common_graph_table.cc
+185
-27
paddle/fluid/distributed/table/common_graph_table.h
paddle/fluid/distributed/table/common_graph_table.h
+8
-6
paddle/fluid/distributed/table/graph/graph_node.cc
paddle/fluid/distributed/table/graph/graph_node.cc
+3
-0
paddle/fluid/distributed/test/CMakeLists.txt
paddle/fluid/distributed/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/test/graph_node_split_test.cc
paddle/fluid/distributed/test/graph_node_split_test.cc
+275
-0
未找到文件。
paddle/fluid/distributed/service/graph_brpc_client.cc
浏览文件 @
876aa717
...
@@ -514,6 +514,42 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
...
@@ -514,6 +514,42 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
return
fut
;
return
fut
;
}
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
load_graph_split_config
(
uint32_t
table_id
,
std
::
string
path
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
server_size
,
[
&
,
server_size
=
this
->
server_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
server_size
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG
)
!=
0
)
{
++
fail_num
;
break
;
}
}
ret
=
fail_num
==
0
?
0
:
-
1
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
server_size
;
i
++
)
{
int
server_index
=
i
;
closure
->
request
(
server_index
)
->
set_cmd_id
(
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG
);
closure
->
request
(
server_index
)
->
set_table_id
(
table_id
);
closure
->
request
(
server_index
)
->
set_client_id
(
_client_id
);
closure
->
request
(
server_index
)
->
add_params
(
path
);
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
get_cmd_channel
(
server_index
));
closure
->
cntl
(
server_index
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
server_index
),
closure
->
request
(
server_index
),
closure
->
response
(
server_index
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
use_neighbors_sample_cache
(
std
::
future
<
int32_t
>
GraphBrpcClient
::
use_neighbors_sample_cache
(
uint32_t
table_id
,
size_t
total_size_limit
,
size_t
ttl
)
{
uint32_t
table_id
,
size_t
total_size_limit
,
size_t
ttl
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
...
...
paddle/fluid/distributed/service/graph_brpc_client.h
浏览文件 @
876aa717
...
@@ -93,6 +93,8 @@ class GraphBrpcClient : public BrpcPsClient {
...
@@ -93,6 +93,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual
std
::
future
<
int32_t
>
use_neighbors_sample_cache
(
uint32_t
table_id
,
virtual
std
::
future
<
int32_t
>
use_neighbors_sample_cache
(
uint32_t
table_id
,
size_t
size_limit
,
size_t
size_limit
,
size_t
ttl
);
size_t
ttl
);
virtual
std
::
future
<
int32_t
>
load_graph_split_config
(
uint32_t
table_id
,
std
::
string
path
);
virtual
std
::
future
<
int32_t
>
remove_graph_node
(
virtual
std
::
future
<
int32_t
>
remove_graph_node
(
uint32_t
table_id
,
std
::
vector
<
uint64_t
>&
node_id_list
);
uint32_t
table_id
,
std
::
vector
<
uint64_t
>&
node_id_list
);
virtual
int32_t
initialize
();
virtual
int32_t
initialize
();
...
...
paddle/fluid/distributed/service/graph_brpc_server.cc
浏览文件 @
876aa717
...
@@ -204,6 +204,8 @@ int32_t GraphBrpcService::initialize() {
...
@@ -204,6 +204,8 @@ int32_t GraphBrpcService::initialize() {
&
GraphBrpcService
::
sample_neighbors_across_multi_servers
;
&
GraphBrpcService
::
sample_neighbors_across_multi_servers
;
_service_handler_map
[
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE
]
=
_service_handler_map
[
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE
]
=
&
GraphBrpcService
::
use_neighbors_sample_cache
;
&
GraphBrpcService
::
use_neighbors_sample_cache
;
_service_handler_map
[
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG
]
=
&
GraphBrpcService
::
load_graph_split_config
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
initialize_shard_info
();
...
@@ -658,5 +660,20 @@ int32_t GraphBrpcService::use_neighbors_sample_cache(
...
@@ -658,5 +660,20 @@ int32_t GraphBrpcService::use_neighbors_sample_cache(
((
GraphTable
*
)
table
)
->
make_neighbor_sample_cache
(
size_limit
,
ttl
);
((
GraphTable
*
)
table
)
->
make_neighbor_sample_cache
(
size_limit
,
ttl
);
return
0
;
return
0
;
}
}
int32_t
GraphBrpcService
::
load_graph_split_config
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"load_graph_split_configrequest requires at least 1 "
"argument1[file_path]"
);
return
0
;
}
((
GraphTable
*
)
table
)
->
load_graph_split_config
(
request
.
params
(
0
));
return
0
;
}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/service/graph_brpc_server.h
浏览文件 @
876aa717
...
@@ -126,6 +126,10 @@ class GraphBrpcService : public PsBaseService {
...
@@ -126,6 +126,10 @@ class GraphBrpcService : public PsBaseService {
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
brpc
::
Controller
*
cntl
);
int32_t
load_graph_split_config
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
private:
private:
bool
_is_initialize_shard_info
;
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
mutex
_initialize_shard_mutex
;
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
876aa717
...
@@ -58,6 +58,7 @@ enum PsCmdID {
...
@@ -58,6 +58,7 @@ enum PsCmdID {
PS_GRAPH_SET_NODE_FEAT
=
37
;
PS_GRAPH_SET_NODE_FEAT
=
37
;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
=
38
;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
=
38
;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE
=
39
;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE
=
39
;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG
=
40
;
}
}
message
PsRequestMessage
{
message
PsRequestMessage
{
...
...
paddle/fluid/distributed/table/common_graph_table.cc
浏览文件 @
876aa717
...
@@ -56,7 +56,7 @@ int32_t GraphTable::add_graph_node(std::vector<uint64_t> &id_list,
...
@@ -56,7 +56,7 @@ int32_t GraphTable::add_graph_node(std::vector<uint64_t> &id_list,
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
batch
,
i
,
this
]()
->
int
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
batch
,
i
,
this
]()
->
int
{
for
(
auto
&
p
:
batch
[
i
])
{
for
(
auto
&
p
:
batch
[
i
])
{
size_t
index
=
p
.
first
%
this
->
shard_num
-
this
->
shard_start
;
size_t
index
=
p
.
first
%
this
->
shard_num
-
this
->
shard_start
;
this
->
shards
[
index
]
.
add_graph_node
(
p
.
first
)
->
build_edges
(
p
.
second
);
this
->
shards
[
index
]
->
add_graph_node
(
p
.
first
)
->
build_edges
(
p
.
second
);
}
}
return
0
;
return
0
;
}));
}));
...
@@ -79,7 +79,7 @@ int32_t GraphTable::remove_graph_node(std::vector<uint64_t> &id_list) {
...
@@ -79,7 +79,7 @@ int32_t GraphTable::remove_graph_node(std::vector<uint64_t> &id_list) {
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
batch
,
i
,
this
]()
->
int
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
batch
,
i
,
this
]()
->
int
{
for
(
auto
&
p
:
batch
[
i
])
{
for
(
auto
&
p
:
batch
[
i
])
{
size_t
index
=
p
%
this
->
shard_num
-
this
->
shard_start
;
size_t
index
=
p
%
this
->
shard_num
-
this
->
shard_start
;
this
->
shards
[
index
]
.
delete_node
(
p
);
this
->
shards
[
index
]
->
delete_node
(
p
);
}
}
return
0
;
return
0
;
}));
}));
...
@@ -97,6 +97,7 @@ void GraphShard::clear() {
...
@@ -97,6 +97,7 @@ void GraphShard::clear() {
}
}
GraphShard
::~
GraphShard
()
{
clear
();
}
GraphShard
::~
GraphShard
()
{
clear
();
}
void
GraphShard
::
delete_node
(
uint64_t
id
)
{
void
GraphShard
::
delete_node
(
uint64_t
id
)
{
auto
iter
=
node_location
.
find
(
id
);
auto
iter
=
node_location
.
find
(
id
);
if
(
iter
==
node_location
.
end
())
return
;
if
(
iter
==
node_location
.
end
())
return
;
...
@@ -117,6 +118,14 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) {
...
@@ -117,6 +118,14 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) {
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
}
}
GraphNode
*
GraphShard
::
add_graph_node
(
Node
*
node
)
{
auto
id
=
node
->
get_id
();
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
node
);
}
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
}
FeatureNode
*
GraphShard
::
add_feature_node
(
uint64_t
id
)
{
FeatureNode
*
GraphShard
::
add_feature_node
(
uint64_t
id
)
{
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
node_location
[
id
]
=
bucket
.
size
();
node_location
[
id
]
=
bucket
.
size
();
...
@@ -134,6 +143,33 @@ Node *GraphShard::find_node(uint64_t id) {
...
@@ -134,6 +143,33 @@ Node *GraphShard::find_node(uint64_t id) {
return
iter
==
node_location
.
end
()
?
nullptr
:
bucket
[
iter
->
second
];
return
iter
==
node_location
.
end
()
?
nullptr
:
bucket
[
iter
->
second
];
}
}
GraphTable
::~
GraphTable
()
{
for
(
auto
p
:
shards
)
{
delete
p
;
}
for
(
auto
p
:
extra_shards
)
{
delete
p
;
}
shards
.
clear
();
extra_shards
.
clear
();
}
int32_t
GraphTable
::
load_graph_split_config
(
const
std
::
string
&
path
)
{
VLOG
(
4
)
<<
"in server side load graph split config
\n
"
;
std
::
ifstream
file
(
path
);
std
::
string
line
;
while
(
std
::
getline
(
file
,
line
))
{
auto
values
=
paddle
::
string
::
split_string
<
std
::
string
>
(
line
,
"
\t
"
);
if
(
values
.
size
()
<
2
)
continue
;
size_t
index
=
(
size_t
)
std
::
stoi
(
values
[
0
]);
if
(
index
!=
_shard_idx
)
continue
;
auto
dst_id
=
std
::
stoull
(
values
[
1
]);
extra_nodes
.
insert
(
dst_id
);
}
if
(
extra_nodes
.
size
()
!=
0
)
use_duplicate_nodes
=
true
;
return
0
;
}
int32_t
GraphTable
::
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
int32_t
GraphTable
::
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
bool
load_edge
=
(
param
[
0
]
==
'e'
);
bool
load_edge
=
(
param
[
0
]
==
'e'
);
bool
load_node
=
(
param
[
0
]
==
'n'
);
bool
load_node
=
(
param
[
0
]
==
'n'
);
...
@@ -154,7 +190,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
...
@@ -154,7 +190,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res
.
clear
();
res
.
clear
();
std
::
vector
<
std
::
future
<
std
::
vector
<
uint64_t
>>>
tasks
;
std
::
vector
<
std
::
future
<
std
::
vector
<
uint64_t
>>>
tasks
;
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
index
<
(
int
)
ranges
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
index
<
(
int
)
ranges
.
size
();
i
++
)
{
end
=
total_size
+
shards
[
i
]
.
get_size
();
end
=
total_size
+
shards
[
i
]
->
get_size
();
start
=
total_size
;
start
=
total_size
;
while
(
start
<
end
&&
index
<
ranges
.
size
())
{
while
(
start
<
end
&&
index
<
ranges
.
size
())
{
if
(
ranges
[
index
].
second
<=
start
)
if
(
ranges
[
index
].
second
<=
start
)
...
@@ -169,11 +205,11 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
...
@@ -169,11 +205,11 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
second
-=
total_size
;
second
-=
total_size
;
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
this
,
first
,
second
,
i
]()
->
std
::
vector
<
uint64_t
>
{
[
this
,
first
,
second
,
i
]()
->
std
::
vector
<
uint64_t
>
{
return
shards
[
i
]
.
get_ids_by_range
(
first
,
second
);
return
shards
[
i
]
->
get_ids_by_range
(
first
,
second
);
}));
}));
}
}
}
}
total_size
+=
shards
[
i
]
.
get_size
();
total_size
+=
shards
[
i
]
->
get_size
();
}
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
{
auto
vec
=
tasks
[
i
].
get
();
auto
vec
=
tasks
[
i
].
get
();
...
@@ -217,7 +253,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
...
@@ -217,7 +253,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
size_t
index
=
shard_id
-
shard_start
;
size_t
index
=
shard_id
-
shard_start
;
auto
node
=
shards
[
index
]
.
add_feature_node
(
id
);
auto
node
=
shards
[
index
]
->
add_feature_node
(
id
);
node
->
set_feature_size
(
feat_name
.
size
());
node
->
set_feature_size
(
feat_name
.
size
());
...
@@ -245,7 +281,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
...
@@ -245,7 +281,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
std
::
string
sample_type
=
"random"
;
std
::
string
sample_type
=
"random"
;
bool
is_weighted
=
false
;
bool
is_weighted
=
false
;
int
valid_count
=
0
;
int
valid_count
=
0
;
int
extra_alloc_index
=
0
;
for
(
auto
path
:
paths
)
{
for
(
auto
path
:
paths
)
{
std
::
ifstream
file
(
path
);
std
::
ifstream
file
(
path
);
std
::
string
line
;
std
::
string
line
;
...
@@ -268,8 +304,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
...
@@ -268,8 +304,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
size_t
src_shard_id
=
src_id
%
shard_num
;
size_t
src_shard_id
=
src_id
%
shard_num
;
if
(
src_shard_id
>=
shard_end
||
src_shard_id
<
shard_start
)
{
if
(
src_shard_id
>=
shard_end
||
src_shard_id
<
shard_start
)
{
VLOG
(
4
)
<<
"will not load "
<<
src_id
<<
" from "
<<
path
if
(
use_duplicate_nodes
==
false
||
<<
", please check id distribution"
;
extra_nodes
.
find
(
src_id
)
==
extra_nodes
.
end
())
{
VLOG
(
4
)
<<
"will not load "
<<
src_id
<<
" from "
<<
path
<<
", please check id distribution"
;
continue
;
}
int
index
;
if
(
extra_nodes_to_thread_index
.
find
(
src_id
)
!=
extra_nodes_to_thread_index
.
end
())
{
index
=
extra_nodes_to_thread_index
[
src_id
];
}
else
{
index
=
extra_alloc_index
++
;
extra_alloc_index
%=
task_pool_size_
;
extra_nodes_to_thread_index
[
src_id
]
=
index
;
}
extra_shards
[
index
]
->
add_graph_node
(
src_id
)
->
build_edges
(
is_weighted
);
extra_shards
[
index
]
->
add_neighbor
(
src_id
,
dst_id
,
weight
);
valid_count
++
;
continue
;
continue
;
}
}
if
(
count
%
1000000
==
0
)
{
if
(
count
%
1000000
==
0
)
{
...
@@ -278,36 +330,130 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
...
@@ -278,36 +330,130 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
}
}
size_t
index
=
src_shard_id
-
shard_start
;
size_t
index
=
src_shard_id
-
shard_start
;
shards
[
index
]
.
add_graph_node
(
src_id
)
->
build_edges
(
is_weighted
);
shards
[
index
]
->
add_graph_node
(
src_id
)
->
build_edges
(
is_weighted
);
shards
[
index
]
.
add_neighbor
(
src_id
,
dst_id
,
weight
);
shards
[
index
]
->
add_neighbor
(
src_id
,
dst_id
,
weight
);
valid_count
++
;
valid_count
++
;
}
}
}
}
VLOG
(
0
)
<<
valid_count
<<
"/"
<<
count
<<
" edges are loaded successfully in "
VLOG
(
0
)
<<
valid_count
<<
"/"
<<
count
<<
" edges are loaded successfully in "
<<
path
;
<<
path
;
std
::
vector
<
int
>
used
(
task_pool_size_
,
0
);
// Build Sampler j
// Build Sampler j
for
(
auto
&
shard
:
shards
)
{
for
(
auto
&
shard
:
shards
)
{
auto
bucket
=
shard
.
get_bucket
();
auto
bucket
=
shard
->
get_bucket
();
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
bucket
[
i
]
->
build_sampler
(
sample_type
);
bucket
[
i
]
->
build_sampler
(
sample_type
);
used
[
get_thread_pool_index
(
bucket
[
i
]
->
get_id
())]
++
;
}
}
}
}
/*-----------------------
relocate the duplicate nodes to make them distributed evenly among threads.
*/
for
(
auto
&
shard
:
extra_shards
)
{
auto
bucket
=
shard
->
get_bucket
();
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
bucket
[
i
]
->
build_sampler
(
sample_type
);
}
}
int
size
=
extra_nodes_to_thread_index
.
size
();
if
(
size
==
0
)
return
0
;
std
::
vector
<
int
>
index
;
for
(
int
i
=
0
;
i
<
used
.
size
();
i
++
)
index
.
push_back
(
i
);
sort
(
index
.
begin
(),
index
.
end
(),
[
&
](
int
&
a
,
int
&
b
)
{
return
used
[
a
]
<
used
[
b
];
});
std
::
vector
<
int
>
alloc
(
index
.
size
(),
0
),
has_alloc
(
index
.
size
(),
0
);
int
t
=
1
,
aim
=
0
,
mod
=
0
;
for
(;
t
<
used
.
size
();
t
++
)
{
if
((
used
[
index
[
t
]]
-
used
[
index
[
t
-
1
]])
*
t
>=
size
)
{
break
;
}
else
{
size
-=
(
used
[
index
[
t
]]
-
used
[
index
[
t
-
1
]])
*
t
;
}
}
aim
=
used
[
index
[
t
-
1
]]
+
size
/
t
;
mod
=
size
%
t
;
for
(
int
x
=
t
-
1
;
x
>=
0
;
x
--
)
{
alloc
[
index
[
x
]]
=
aim
;
if
(
t
-
x
<=
mod
)
alloc
[
index
[
x
]]
++
;
alloc
[
index
[
x
]]
-=
used
[
index
[
x
]];
}
std
::
vector
<
uint64_t
>
vec
[
index
.
size
()];
for
(
auto
p
:
extra_nodes_to_thread_index
)
{
has_alloc
[
p
.
second
]
++
;
vec
[
p
.
second
].
push_back
(
p
.
first
);
}
sort
(
index
.
begin
(),
index
.
end
(),
[
&
](
int
&
a
,
int
&
b
)
{
return
has_alloc
[
a
]
-
alloc
[
a
]
<
has_alloc
[
b
]
-
alloc
[
b
];
});
int
left
=
0
,
right
=
index
.
size
()
-
1
;
while
(
left
<
right
)
{
if
(
has_alloc
[
index
[
right
]]
-
alloc
[
index
[
right
]]
==
0
)
break
;
int
x
=
std
::
min
(
alloc
[
index
[
left
]]
-
has_alloc
[
index
[
left
]],
has_alloc
[
index
[
right
]]
-
alloc
[
index
[
right
]]);
has_alloc
[
index
[
left
]]
+=
x
;
has_alloc
[
index
[
right
]]
-=
x
;
uint64_t
id
;
while
(
x
--
)
{
id
=
vec
[
index
[
right
]].
back
();
vec
[
index
[
right
]].
pop_back
();
extra_nodes_to_thread_index
[
id
]
=
index
[
left
];
vec
[
index
[
left
]].
push_back
(
id
);
}
if
(
has_alloc
[
index
[
right
]]
-
alloc
[
index
[
right
]]
==
0
)
right
--
;
if
(
alloc
[
index
[
left
]]
-
has_alloc
[
index
[
left
]]
==
0
)
left
++
;
}
std
::
vector
<
GraphShard
*>
extra_shards_copy
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
++
i
)
{
extra_shards_copy
.
push_back
(
new
GraphShard
());
}
for
(
auto
&
shard
:
extra_shards
)
{
auto
&
bucket
=
shard
->
get_bucket
();
auto
&
node_location
=
shard
->
get_node_location
();
while
(
bucket
.
size
())
{
Node
*
temp
=
bucket
.
back
();
bucket
.
pop_back
();
node_location
.
erase
(
temp
->
get_id
());
extra_shards_copy
[
extra_nodes_to_thread_index
[
temp
->
get_id
()]]
->
add_graph_node
(
temp
);
}
}
for
(
int
i
=
0
;
i
<
task_pool_size_
;
++
i
)
{
delete
extra_shards
[
i
];
extra_shards
[
i
]
=
extra_shards_copy
[
i
];
}
return
0
;
return
0
;
}
}
Node
*
GraphTable
::
find_node
(
uint64_t
id
)
{
Node
*
GraphTable
::
find_node
(
uint64_t
id
)
{
size_t
shard_id
=
id
%
shard_num
;
size_t
shard_id
=
id
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
return
nullptr
;
if
(
use_duplicate_nodes
==
false
||
extra_nodes_to_thread_index
.
size
()
==
0
)
return
nullptr
;
auto
iter
=
extra_nodes_to_thread_index
.
find
(
id
);
if
(
iter
==
extra_nodes_to_thread_index
.
end
())
return
nullptr
;
else
{
return
extra_shards
[
iter
->
second
]
->
find_node
(
id
);
}
}
}
size_t
index
=
shard_id
-
shard_start
;
size_t
index
=
shard_id
-
shard_start
;
Node
*
node
=
shards
[
index
]
.
find_node
(
id
);
Node
*
node
=
shards
[
index
]
->
find_node
(
id
);
return
node
;
return
node
;
}
}
uint32_t
GraphTable
::
get_thread_pool_index
(
uint64_t
node_id
)
{
uint32_t
GraphTable
::
get_thread_pool_index
(
uint64_t
node_id
)
{
return
node_id
%
shard_num
%
shard_num_per_server
%
task_pool_size_
;
if
(
use_duplicate_nodes
==
false
||
extra_nodes_to_thread_index
.
size
()
==
0
)
return
node_id
%
shard_num
%
shard_num_per_server
%
task_pool_size_
;
size_t
src_shard_id
=
node_id
%
shard_num
;
if
(
src_shard_id
>=
shard_end
||
src_shard_id
<
shard_start
)
{
auto
iter
=
extra_nodes_to_thread_index
.
find
(
node_id
);
if
(
iter
!=
extra_nodes_to_thread_index
.
end
())
{
return
iter
->
second
;
}
}
return
src_shard_id
%
shard_num_per_server
%
task_pool_size_
;
}
}
uint32_t
GraphTable
::
get_thread_pool_index_by_shard_index
(
uint32_t
GraphTable
::
get_thread_pool_index_by_shard_index
(
...
@@ -319,11 +465,16 @@ int32_t GraphTable::clear_nodes() {
...
@@ -319,11 +465,16 @@ int32_t GraphTable::clear_nodes() {
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
i
=
0
;
i
<
shards
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shards
.
size
();
i
++
)
{
tasks
.
push_back
(
tasks
.
push_back
(
_shards_task_pool
[
get_thread_pool_index_by_shard_index
(
i
)]
->
enqueue
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
([
this
,
i
]()
->
int
{
[
this
,
i
]()
->
int
{
this
->
shards
[
i
]
->
clear
();
this
->
shards
[
i
].
clear
();
return
0
;
return
0
;
}));
}));
}
for
(
size_t
i
=
0
;
i
<
extra_shards
.
size
();
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
this
,
i
]()
->
int
{
this
->
extra_shards
[
i
]
->
clear
();
return
0
;
}));
}
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
return
0
;
return
0
;
...
@@ -334,7 +485,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
...
@@ -334,7 +485,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
int
&
actual_size
)
{
int
&
actual_size
)
{
int
total_size
=
0
;
int
total_size
=
0
;
for
(
int
i
=
0
;
i
<
shards
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
shards
.
size
();
i
++
)
{
total_size
+=
shards
[
i
]
.
get_size
();
total_size
+=
shards
[
i
]
->
get_size
();
}
}
if
(
sample_size
>
total_size
)
sample_size
=
total_size
;
if
(
sample_size
>
total_size
)
sample_size
=
total_size
;
int
range_num
=
random_sample_nodes_ranges
;
int
range_num
=
random_sample_nodes_ranges
;
...
@@ -401,8 +552,8 @@ int32_t GraphTable::random_sample_neighbors(
...
@@ -401,8 +552,8 @@ int32_t GraphTable::random_sample_neighbors(
size_t
node_num
=
buffers
.
size
();
size_t
node_num
=
buffers
.
size
();
std
::
function
<
void
(
char
*
)
>
char_del
=
[](
char
*
c
)
{
delete
[]
c
;
};
std
::
function
<
void
(
char
*
)
>
char_del
=
[](
char
*
c
)
{
delete
[]
c
;
};
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
std
::
vector
<
uint32_t
>>
seq_id
(
shard_end
-
shard_start
);
std
::
vector
<
std
::
vector
<
uint32_t
>>
seq_id
(
task_pool_size_
);
std
::
vector
<
std
::
vector
<
SampleKey
>>
id_list
(
shard_end
-
shard_start
);
std
::
vector
<
std
::
vector
<
SampleKey
>>
id_list
(
task_pool_size_
);
size_t
index
;
size_t
index
;
for
(
size_t
idx
=
0
;
idx
<
node_num
;
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
node_num
;
++
idx
)
{
index
=
get_thread_pool_index
(
node_ids
[
idx
]);
index
=
get_thread_pool_index
(
node_ids
[
idx
]);
...
@@ -524,7 +675,7 @@ int32_t GraphTable::set_node_feat(
...
@@ -524,7 +675,7 @@ int32_t GraphTable::set_node_feat(
tasks
.
push_back
(
_shards_task_pool
[
get_thread_pool_index
(
node_id
)]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
get_thread_pool_index
(
node_id
)]
->
enqueue
(
[
&
,
idx
,
node_id
]()
->
int
{
[
&
,
idx
,
node_id
]()
->
int
{
size_t
index
=
node_id
%
this
->
shard_num
-
this
->
shard_start
;
size_t
index
=
node_id
%
this
->
shard_num
-
this
->
shard_start
;
auto
node
=
shards
[
index
]
.
add_feature_node
(
node_id
);
auto
node
=
shards
[
index
]
->
add_feature_node
(
node_id
);
node
->
set_feature_size
(
this
->
feat_name
.
size
());
node
->
set_feature_size
(
this
->
feat_name
.
size
());
for
(
int
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
int
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
...
@@ -581,7 +732,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
...
@@ -581,7 +732,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int
size
=
0
,
cur_size
;
int
size
=
0
,
cur_size
;
std
::
vector
<
std
::
future
<
std
::
vector
<
Node
*>>>
tasks
;
std
::
vector
<
std
::
future
<
std
::
vector
<
Node
*>>>
tasks
;
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
total_size
>
0
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
total_size
>
0
;
i
++
)
{
cur_size
=
shards
[
i
]
.
get_size
();
cur_size
=
shards
[
i
]
->
get_size
();
if
(
size
+
cur_size
<=
start
)
{
if
(
size
+
cur_size
<=
start
)
{
size
+=
cur_size
;
size
+=
cur_size
;
continue
;
continue
;
...
@@ -590,7 +741,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
...
@@ -590,7 +741,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int
end
=
start
+
(
count
-
1
)
*
step
+
1
;
int
end
=
start
+
(
count
-
1
)
*
step
+
1
;
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
this
,
i
,
start
,
end
,
step
,
size
]()
->
std
::
vector
<
Node
*>
{
[
this
,
i
,
start
,
end
,
step
,
size
]()
->
std
::
vector
<
Node
*>
{
return
this
->
shards
[
i
]
.
get_batch
(
start
-
size
,
end
-
size
,
step
);
return
this
->
shards
[
i
]
->
get_batch
(
start
-
size
,
end
-
size
,
step
);
}));
}));
start
+=
count
*
step
;
start
+=
count
*
step
;
total_size
-=
count
;
total_size
-=
count
;
...
@@ -665,7 +816,14 @@ int32_t GraphTable::initialize() {
...
@@ -665,7 +816,14 @@ int32_t GraphTable::initialize() {
shard_end
=
shard_start
+
shard_num_per_server
;
shard_end
=
shard_start
+
shard_num_per_server
;
VLOG
(
0
)
<<
"in init graph table shard idx = "
<<
_shard_idx
<<
" shard_start "
VLOG
(
0
)
<<
"in init graph table shard idx = "
<<
_shard_idx
<<
" shard_start "
<<
shard_start
<<
" shard_end "
<<
shard_end
;
<<
shard_start
<<
" shard_end "
<<
shard_end
;
shards
=
std
::
vector
<
GraphShard
>
(
shard_num_per_server
,
GraphShard
(
shard_num
));
for
(
int
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
shards
.
push_back
(
new
GraphShard
());
}
use_duplicate_nodes
=
false
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
extra_shards
.
push_back
(
new
GraphShard
());
}
return
0
;
return
0
;
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/table/common_graph_table.h
浏览文件 @
876aa717
...
@@ -47,7 +47,6 @@ class GraphShard {
...
@@ -47,7 +47,6 @@ class GraphShard {
public:
public:
size_t
get_size
();
size_t
get_size
();
GraphShard
()
{}
GraphShard
()
{}
GraphShard
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
~
GraphShard
();
~
GraphShard
();
std
::
vector
<
Node
*>
&
get_bucket
()
{
return
bucket
;
}
std
::
vector
<
Node
*>
&
get_bucket
()
{
return
bucket
;
}
std
::
vector
<
Node
*>
get_batch
(
int
start
,
int
end
,
int
step
);
std
::
vector
<
Node
*>
get_batch
(
int
start
,
int
end
,
int
step
);
...
@@ -60,18 +59,18 @@ class GraphShard {
...
@@ -60,18 +59,18 @@ class GraphShard {
}
}
GraphNode
*
add_graph_node
(
uint64_t
id
);
GraphNode
*
add_graph_node
(
uint64_t
id
);
GraphNode
*
add_graph_node
(
Node
*
node
);
FeatureNode
*
add_feature_node
(
uint64_t
id
);
FeatureNode
*
add_feature_node
(
uint64_t
id
);
Node
*
find_node
(
uint64_t
id
);
Node
*
find_node
(
uint64_t
id
);
void
delete_node
(
uint64_t
id
);
void
delete_node
(
uint64_t
id
);
void
clear
();
void
clear
();
void
add_neighbor
(
uint64_t
id
,
uint64_t
dst_id
,
float
weight
);
void
add_neighbor
(
uint64_t
id
,
uint64_t
dst_id
,
float
weight
);
std
::
unordered_map
<
uint64_t
,
int
>
get_node_location
()
{
std
::
unordered_map
<
uint64_t
,
int
>
&
get_node_location
()
{
return
node_location
;
return
node_location
;
}
}
private:
private:
std
::
unordered_map
<
uint64_t
,
int
>
node_location
;
std
::
unordered_map
<
uint64_t
,
int
>
node_location
;
int
shard_num
;
std
::
vector
<
Node
*>
bucket
;
std
::
vector
<
Node
*>
bucket
;
};
};
...
@@ -355,7 +354,7 @@ class ScaledLRU {
...
@@ -355,7 +354,7 @@ class ScaledLRU {
class
GraphTable
:
public
SparseTable
{
class
GraphTable
:
public
SparseTable
{
public:
public:
GraphTable
()
{
use_cache
=
false
;
}
GraphTable
()
{
use_cache
=
false
;
}
virtual
~
GraphTable
()
{}
virtual
~
GraphTable
()
;
virtual
int32_t
pull_graph_list
(
int
start
,
int
size
,
virtual
int32_t
pull_graph_list
(
int
start
,
int
size
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
int
&
actual_size
,
bool
need_feature
,
int
&
actual_size
,
bool
need_feature
,
...
@@ -374,6 +373,7 @@ class GraphTable : public SparseTable {
...
@@ -374,6 +373,7 @@ class GraphTable : public SparseTable {
virtual
int32_t
initialize
();
virtual
int32_t
initialize
();
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int32_t
load_graph_split_config
(
const
std
::
string
&
path
);
int32_t
load_edges
(
const
std
::
string
&
path
,
bool
reverse
);
int32_t
load_edges
(
const
std
::
string
&
path
,
bool
reverse
);
...
@@ -434,7 +434,7 @@ class GraphTable : public SparseTable {
...
@@ -434,7 +434,7 @@ class GraphTable : public SparseTable {
}
}
protected:
protected:
std
::
vector
<
GraphShard
>
shards
;
std
::
vector
<
GraphShard
*>
shards
,
extra_
shards
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_server
,
shard_num
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_server
,
shard_num
;
const
int
task_pool_size_
=
24
;
const
int
task_pool_size_
=
24
;
const
int
random_sample_nodes_ranges
=
3
;
const
int
random_sample_nodes_ranges
=
3
;
...
@@ -449,7 +449,9 @@ class GraphTable : public SparseTable {
...
@@ -449,7 +449,9 @@ class GraphTable : public SparseTable {
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
vector
<
std
::
shared_ptr
<
std
::
mt19937_64
>>
_shards_task_rng_pool
;
std
::
vector
<
std
::
shared_ptr
<
std
::
mt19937_64
>>
_shards_task_rng_pool
;
std
::
shared_ptr
<
ScaledLRU
<
SampleKey
,
SampleResult
>>
scaled_lru
;
std
::
shared_ptr
<
ScaledLRU
<
SampleKey
,
SampleResult
>>
scaled_lru
;
bool
use_cache
;
std
::
unordered_set
<
uint64_t
>
extra_nodes
;
std
::
unordered_map
<
uint64_t
,
size_t
>
extra_nodes_to_thread_index
;
bool
use_cache
,
use_duplicate_nodes
;
mutable
std
::
mutex
mutex_
;
mutable
std
::
mutex
mutex_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/table/graph/graph_node.cc
浏览文件 @
876aa717
...
@@ -65,6 +65,9 @@ void GraphNode::build_edges(bool is_weighted) {
...
@@ -65,6 +65,9 @@ void GraphNode::build_edges(bool is_weighted) {
}
}
}
}
void
GraphNode
::
build_sampler
(
std
::
string
sample_type
)
{
void
GraphNode
::
build_sampler
(
std
::
string
sample_type
)
{
if
(
sampler
!=
nullptr
)
{
return
;
}
if
(
sample_type
==
"random"
)
{
if
(
sample_type
==
"random"
)
{
sampler
=
new
RandomSampler
();
sampler
=
new
RandomSampler
();
}
else
if
(
sample_type
==
"weighted"
)
{
}
else
if
(
sample_type
==
"weighted"
)
{
...
...
paddle/fluid/distributed/test/CMakeLists.txt
浏览文件 @
876aa717
...
@@ -21,6 +21,9 @@ cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_funct
...
@@ -21,6 +21,9 @@ cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_funct
set_source_files_properties
(
graph_node_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
graph_node_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto
${
COMMON_DEPS
}
)
cc_test
(
graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto
${
COMMON_DEPS
}
)
set_source_files_properties
(
graph_node_split_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto
${
COMMON_DEPS
}
)
set_source_files_properties
(
feature_value_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
feature_value_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
feature_value_test SRCS feature_value_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
cc_test
(
feature_value_test SRCS feature_value_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
...
...
paddle/fluid/distributed/test/graph_node_split_test.cc
0 → 100644
浏览文件 @
876aa717
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/service/graph_py_service.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
operators
=
paddle
::
operators
;
namespace
math
=
paddle
::
operators
::
math
;
namespace
memory
=
paddle
::
memory
;
namespace
distributed
=
paddle
::
distributed
;
std
::
vector
<
std
::
string
>
edges
=
{
std
::
string
(
"37
\t
45
\t
0.34"
),
std
::
string
(
"37
\t
145
\t
0.31"
),
std
::
string
(
"37
\t
112
\t
0.21"
),
std
::
string
(
"96
\t
48
\t
1.4"
),
std
::
string
(
"96
\t
247
\t
0.31"
),
std
::
string
(
"96
\t
111
\t
1.21"
),
std
::
string
(
"59
\t
45
\t
0.34"
),
std
::
string
(
"59
\t
145
\t
0.31"
),
std
::
string
(
"59
\t
122
\t
0.21"
),
std
::
string
(
"97
\t
48
\t
0.34"
),
std
::
string
(
"97
\t
247
\t
0.31"
),
std
::
string
(
"97
\t
111
\t
0.21"
)};
char
edge_file_name
[]
=
"edges.txt"
;
std
::
vector
<
std
::
string
>
nodes
=
{
std
::
string
(
"user
\t
37
\t
a 0.34
\t
b 13 14
\t
c hello
\t
d abc"
),
std
::
string
(
"user
\t
96
\t
a 0.31
\t
b 15 10
\t
c 96hello
\t
d abcd"
),
std
::
string
(
"user
\t
59
\t
a 0.11
\t
b 11 14"
),
std
::
string
(
"user
\t
97
\t
a 0.11
\t
b 12 11"
),
std
::
string
(
"item
\t
45
\t
a 0.21"
),
std
::
string
(
"item
\t
145
\t
a 0.21"
),
std
::
string
(
"item
\t
112
\t
a 0.21"
),
std
::
string
(
"item
\t
48
\t
a 0.21"
),
std
::
string
(
"item
\t
247
\t
a 0.21"
),
std
::
string
(
"item
\t
111
\t
a 0.21"
),
std
::
string
(
"item
\t
46
\t
a 0.21"
),
std
::
string
(
"item
\t
146
\t
a 0.21"
),
std
::
string
(
"item
\t
122
\t
a 0.21"
),
std
::
string
(
"item
\t
49
\t
a 0.21"
),
std
::
string
(
"item
\t
248
\t
a 0.21"
),
std
::
string
(
"item
\t
113
\t
a 0.21"
)};
char
node_file_name
[]
=
"nodes.txt"
;
std
::
vector
<
std
::
string
>
graph_split
=
{
std
::
string
(
"0
\t
97"
)};
char
graph_split_file_name
[]
=
"graph_split.txt"
;
void
prepare_file
(
char
file_name
[],
std
::
vector
<
std
::
string
>
data
)
{
std
::
ofstream
ofile
;
ofile
.
open
(
file_name
);
for
(
auto
x
:
data
)
{
ofile
<<
x
<<
std
::
endl
;
}
ofile
.
close
();
}
void
GetDownpourSparseTableProto
(
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
)
{
sparse_table_proto
->
set_table_id
(
0
);
sparse_table_proto
->
set_table_class
(
"GraphTable"
);
sparse_table_proto
->
set_shard_num
(
127
);
sparse_table_proto
->
set_type
(
::
paddle
::
distributed
::
PS_SPARSE_TABLE
);
::
paddle
::
distributed
::
TableAccessorParameter
*
accessor_proto
=
sparse_table_proto
->
mutable_accessor
();
accessor_proto
->
set_accessor_class
(
"CommMergeAccessor"
);
}
::
paddle
::
distributed
::
PSParameter
GetServerProto
()
{
// Generate server proto desc
::
paddle
::
distributed
::
PSParameter
server_fleet_desc
;
::
paddle
::
distributed
::
ServerParameter
*
server_proto
=
server_fleet_desc
.
mutable_server_param
();
::
paddle
::
distributed
::
DownpourServerParameter
*
downpour_server_proto
=
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"GraphBrpcService"
);
server_service_proto
->
set_server_class
(
"GraphBrpcServer"
);
server_service_proto
->
set_client_class
(
"GraphBrpcClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_server_thread_num
(
12
);
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
=
downpour_server_proto
->
add_downpour_table_param
();
GetDownpourSparseTableProto
(
sparse_table_proto
);
return
server_fleet_desc
;
}
::
paddle
::
distributed
::
PSParameter
GetWorkerProto
()
{
::
paddle
::
distributed
::
PSParameter
worker_fleet_desc
;
::
paddle
::
distributed
::
WorkerParameter
*
worker_proto
=
worker_fleet_desc
.
mutable_worker_param
();
::
paddle
::
distributed
::
DownpourWorkerParameter
*
downpour_worker_proto
=
worker_proto
->
mutable_downpour_worker_param
();
::
paddle
::
distributed
::
TableParameter
*
worker_sparse_table_proto
=
downpour_worker_proto
->
add_downpour_table_param
();
GetDownpourSparseTableProto
(
worker_sparse_table_proto
);
::
paddle
::
distributed
::
ServerParameter
*
server_proto
=
worker_fleet_desc
.
mutable_server_param
();
::
paddle
::
distributed
::
DownpourServerParameter
*
downpour_server_proto
=
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"GraphBrpcService"
);
server_service_proto
->
set_server_class
(
"GraphBrpcServer"
);
server_service_proto
->
set_client_class
(
"GraphBrpcClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_server_thread_num
(
12
);
::
paddle
::
distributed
::
TableParameter
*
server_sparse_table_proto
=
downpour_server_proto
->
add_downpour_table_param
();
GetDownpourSparseTableProto
(
server_sparse_table_proto
);
return
worker_fleet_desc
;
}
/*-------------------------------------------------------------------------*/
std
::
string
ip_
=
"127.0.0.1"
,
ip2
=
"127.0.0.1"
;
uint32_t
port_
=
5209
,
port2
=
5210
;
std
::
vector
<
std
::
string
>
host_sign_list_
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
pserver_ptr_
,
pserver_ptr2
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
worker_ptr_
;
void
RunServer
()
{
LOG
(
INFO
)
<<
"init first server"
;
::
paddle
::
distributed
::
PSParameter
server_proto
=
GetServerProto
();
auto
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
set_ps_servers
(
&
host_sign_list_
,
2
);
// test
pserver_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
(
(
paddle
::
distributed
::
GraphBrpcServer
*
)
paddle
::
distributed
::
PSServerFactory
::
create
(
server_proto
));
std
::
vector
<
framework
::
ProgramDesc
>
empty_vec
;
framework
::
ProgramDesc
empty_prog
;
empty_vec
.
push_back
(
empty_prog
);
pserver_ptr_
->
configure
(
server_proto
,
_ps_env
,
0
,
empty_vec
);
LOG
(
INFO
)
<<
"first server, run start(ip,port)"
;
pserver_ptr_
->
start
(
ip_
,
port_
);
pserver_ptr_
->
build_peer2peer_connection
(
0
);
LOG
(
INFO
)
<<
"init first server Done"
;
}
void
RunServer2
()
{
LOG
(
INFO
)
<<
"init second server"
;
::
paddle
::
distributed
::
PSParameter
server_proto2
=
GetServerProto
();
auto
_ps_env2
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env2
.
set_ps_servers
(
&
host_sign_list_
,
2
);
// test
pserver_ptr2
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
(
(
paddle
::
distributed
::
GraphBrpcServer
*
)
paddle
::
distributed
::
PSServerFactory
::
create
(
server_proto2
));
std
::
vector
<
framework
::
ProgramDesc
>
empty_vec2
;
framework
::
ProgramDesc
empty_prog2
;
empty_vec2
.
push_back
(
empty_prog2
);
pserver_ptr2
->
configure
(
server_proto2
,
_ps_env2
,
1
,
empty_vec2
);
pserver_ptr2
->
start
(
ip2
,
port2
);
pserver_ptr2
->
build_peer2peer_connection
(
1
);
}
void
RunClient
(
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>&
dense_regions
,
int
index
,
paddle
::
distributed
::
PsBaseService
*
service
)
{
::
paddle
::
distributed
::
PSParameter
worker_proto
=
GetWorkerProto
();
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
auto
servers_
=
host_sign_list_
.
size
();
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
set_ps_servers
(
&
host_sign_list_
,
servers_
);
worker_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
(
(
paddle
::
distributed
::
GraphBrpcClient
*
)
paddle
::
distributed
::
PSClientFactory
::
create
(
worker_proto
));
worker_ptr_
->
configure
(
worker_proto
,
dense_regions
,
_ps_env
,
0
);
worker_ptr_
->
set_shard_num
(
127
);
worker_ptr_
->
set_local_channel
(
index
);
worker_ptr_
->
set_local_graph_service
(
(
paddle
::
distributed
::
GraphBrpcService
*
)
service
);
}
void
RunGraphSplit
()
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
prepare_file
(
edge_file_name
,
edges
);
prepare_file
(
node_file_name
,
nodes
);
prepare_file
(
graph_split_file_name
,
graph_split
);
auto
ph_host
=
paddle
::
distributed
::
PSHost
(
ip_
,
port_
,
0
);
host_sign_list_
.
push_back
(
ph_host
.
serialize_to_string
());
// test-start
auto
ph_host2
=
paddle
::
distributed
::
PSHost
(
ip2
,
port2
,
1
);
host_sign_list_
.
push_back
(
ph_host2
.
serialize_to_string
());
// test-end
// Srart Server
std
::
thread
*
server_thread
=
new
std
::
thread
(
RunServer
);
std
::
thread
*
server_thread2
=
new
std
::
thread
(
RunServer2
);
sleep
(
2
);
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
dense_regions
.
insert
(
std
::
pair
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
auto
regions
=
dense_regions
[
0
];
RunClient
(
dense_regions
,
0
,
pserver_ptr_
->
get_service
());
/*-----------------------Test Server Init----------------------------------*/
auto
pull_status
=
worker_ptr_
->
load_graph_split_config
(
0
,
std
::
string
(
graph_split_file_name
));
pull_status
.
wait
();
pull_status
=
worker_ptr_
->
load
(
0
,
std
::
string
(
edge_file_name
),
std
::
string
(
"e>"
));
srand
(
time
(
0
));
pull_status
.
wait
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
_vs
;
std
::
vector
<
std
::
vector
<
float
>>
vs
;
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
uint64_t
>
(
1
,
10240001024
),
4
,
_vs
,
vs
,
true
);
pull_status
.
wait
();
ASSERT_EQ
(
0
,
_vs
[
0
].
size
());
_vs
.
clear
();
vs
.
clear
();
pull_status
=
worker_ptr_
->
batch_sample_neighbors
(
0
,
std
::
vector
<
uint64_t
>
(
1
,
97
),
4
,
_vs
,
vs
,
true
);
pull_status
.
wait
();
ASSERT_EQ
(
3
,
_vs
[
0
].
size
());
std
::
remove
(
edge_file_name
);
std
::
remove
(
node_file_name
);
std
::
remove
(
graph_split_file_name
);
LOG
(
INFO
)
<<
"Run stop_server"
;
worker_ptr_
->
stop_server
();
LOG
(
INFO
)
<<
"Run finalize_worker"
;
worker_ptr_
->
finalize_worker
();
}
TEST
(
RunGraphSplit
,
Run
)
{
RunGraphSplit
();
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录