Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5eb640c6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5eb640c6
编写于
10月 21, 2021
作者:
S
seemingwang
提交者:
GitHub
10月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Graph engine4 (#36587)
上级
d64f7b3b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
292 addition
and
16 deletion
+292
-16
paddle/fluid/distributed/service/graph_brpc_client.cc
paddle/fluid/distributed/service/graph_brpc_client.cc
+57
-1
paddle/fluid/distributed/service/graph_brpc_client.h
paddle/fluid/distributed/service/graph_brpc_client.h
+2
-1
paddle/fluid/distributed/service/graph_brpc_server.cc
paddle/fluid/distributed/service/graph_brpc_server.cc
+199
-5
paddle/fluid/distributed/service/graph_brpc_server.h
paddle/fluid/distributed/service/graph_brpc_server.h
+9
-0
paddle/fluid/distributed/service/graph_py_service.cc
paddle/fluid/distributed/service/graph_py_service.cc
+1
-0
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-0
paddle/fluid/distributed/service/server.h
paddle/fluid/distributed/service/server.h
+2
-1
paddle/fluid/distributed/table/common_graph_table.cc
paddle/fluid/distributed/table/common_graph_table.cc
+11
-7
paddle/fluid/distributed/table/common_graph_table.h
paddle/fluid/distributed/table/common_graph_table.h
+4
-1
paddle/fluid/distributed/test/graph_node_test.cc
paddle/fluid/distributed/test/graph_node_test.cc
+6
-0
未找到文件。
paddle/fluid/distributed/service/graph_brpc_client.cc
浏览文件 @
5eb640c6
...
...
@@ -304,7 +304,63 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
// char* &buffer,int &actual_size
std
::
future
<
int32_t
>
GraphBrpcClient
::
batch_sample_neighboors
(
uint32_t
table_id
,
std
::
vector
<
uint64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>
&
res
)
{
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>
&
res
,
int
server_index
)
{
if
(
server_index
!=
-
1
)
{
res
.
resize
(
node_ids
.
size
());
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
node_num
=
*
(
size_t
*
)
buffer
;
int
*
actual_sizes
=
(
int
*
)(
buffer
+
sizeof
(
size_t
));
char
*
node_buffer
=
buffer
+
sizeof
(
size_t
)
+
sizeof
(
int
)
*
node_num
;
int
offset
=
0
;
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
int
actual_size
=
actual_sizes
[
node_idx
];
int
start
=
0
;
while
(
start
<
actual_size
)
{
res
[
node_idx
].
push_back
(
{
*
(
uint64_t
*
)(
node_buffer
+
offset
+
start
),
*
(
float
*
)(
node_buffer
+
offset
+
start
+
GraphNode
::
id_size
)});
start
+=
GraphNode
::
id_size
+
GraphNode
::
weight_size
;
}
offset
+=
actual_size
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
;
closure
->
request
(
0
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
node_ids
.
data
(),
sizeof
(
uint64_t
)
*
node_ids
.
size
());
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
get_cmd_channel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
res
.
clear
();
...
...
paddle/fluid/distributed/service/graph_brpc_client.h
浏览文件 @
5eb640c6
...
...
@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
// given a batch of nodes, sample graph_neighboors for each of them
virtual
std
::
future
<
int32_t
>
batch_sample_neighboors
(
uint32_t
table_id
,
std
::
vector
<
uint64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>&
res
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>&
res
,
int
server_index
=
-
1
);
virtual
std
::
future
<
int32_t
>
pull_graph_list
(
uint32_t
table_id
,
int
server_index
,
int
start
,
...
...
paddle/fluid/distributed/service/graph_brpc_server.cc
浏览文件 @
5eb640c6
...
...
@@ -61,6 +61,10 @@ int32_t GraphBrpcServer::initialize() {
return
0
;
}
brpc
::
Channel
*
GraphBrpcServer
::
get_cmd_channel
(
size_t
server_index
)
{
return
_pserver_channels
[
server_index
].
get
();
}
uint64_t
GraphBrpcServer
::
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
...
...
@@ -80,6 +84,42 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
return
0
;
}
int32_t
GraphBrpcServer
::
build_peer2peer_connection
(
int
rank
)
{
this
->
rank
=
rank
;
auto
_env
=
environment
();
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
500000
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
10000
;
options
.
max_retry
=
3
;
std
::
vector
<
PSHost
>
server_list
=
_env
->
get_ps_servers
();
_pserver_channels
.
resize
(
server_list
.
size
());
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
server_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
server_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
server_list
[
i
].
port
));
_pserver_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"GraphServer connect to Server:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
server_list
[
i
].
ip
,
server_list
[
i
].
port
);
if
(
_pserver_channels
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"GraphServer connect to Server:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"servers peer2peer connection success:"
<<
os
.
str
();
return
0
;
}
int32_t
GraphBrpcService
::
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
...
...
@@ -160,6 +200,9 @@ int32_t GraphBrpcService::initialize() {
&
GraphBrpcService
::
remove_graph_node
;
_service_handler_map
[
PS_GRAPH_SET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_set_node_feat
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
]
=
&
GraphBrpcService
::
sample_neighboors_across_multi_servers
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
...
...
@@ -172,10 +215,10 @@ int32_t GraphBrpcService::initialize_shard_info() {
if
(
_is_initialize_shard_info
)
{
return
0
;
}
s
ize_t
shard_num
=
_server
->
environment
()
->
get_ps_servers
().
size
();
s
erver_size
=
_server
->
environment
()
->
get_ps_servers
().
size
();
auto
&
table_map
=
*
(
_server
->
table
());
for
(
auto
itr
:
table_map
)
{
itr
.
second
->
set_shard
(
_rank
,
s
hard_num
);
itr
.
second
->
set_shard
(
_rank
,
s
erver_size
);
}
_is_initialize_shard_info
=
true
;
}
...
...
@@ -209,7 +252,9 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
int
service_ret
=
(
this
->*
handler_func
)(
table
,
*
request
,
*
response
,
cntl
);
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
if
(
!
response
->
has_err_msg
())
{
response
->
set_err_msg
(
"server internal error"
);
}
}
}
...
...
@@ -403,7 +448,156 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return
0
;
}
int32_t
GraphBrpcService
::
sample_neighboors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
// sleep(5);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"graph_random_sample request requires at least 2 arguments"
);
return
0
;
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
uint64_t
),
size_of_size_t
=
sizeof
(
size_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
0
).
c_str
());
int
sample_size
=
*
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
// std::vector<uint64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
std
::
vector
<
uint64_t
>
local_id
;
std
::
vector
<
int
>
local_query_idx
;
size_t
rank
=
get_rank
();
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
if
(
server2request
[
rank
]
!=
-
1
)
{
auto
pos
=
server2request
[
rank
];
std
::
swap
(
request2server
[
pos
],
request2server
[(
int
)
request2server
.
size
()
-
1
]);
server2request
[
request2server
[
pos
]]
=
pos
;
server2request
[
request2server
[(
int
)
request2server
.
size
()
-
1
]]
=
request2server
.
size
()
-
1
;
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
unique_ptr
<
char
[]
>>
local_buffers
;
std
::
vector
<
int
>
local_actual_sizes
;
std
::
vector
<
size_t
>
seq
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_data
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
seq
.
push_back
(
request_idx
);
}
size_t
remote_call_num
=
request_call_num
;
if
(
request2server
.
size
()
!=
0
&&
request2server
.
back
()
==
rank
)
{
remote_call_num
--
;
local_buffers
.
resize
(
node_id_buckets
.
back
().
size
());
local_actual_sizes
.
resize
(
node_id_buckets
.
back
().
size
());
}
cntl
->
response_attachment
().
append
(
&
node_num
,
sizeof
(
size_t
));
auto
local_promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
local_fut
=
local_promise
->
get_future
();
std
::
vector
<
bool
>
failed
(
server_size
,
false
);
std
::
function
<
void
(
void
*
)
>
func
=
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
local_fut
.
get
();
std
::
vector
<
int
>
actual_size
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
std
::
vector
<
std
::
unique_ptr
<
butil
::
IOBufBytesIterator
>>
res
(
remote_call_num
);
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SAMPLE_NEIGHBOORS
)
!=
0
)
{
++
fail_num
;
failed
[
request2server
[
request_idx
]]
=
true
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
size_t
node_size
;
res
[
request_idx
].
reset
(
new
butil
::
IOBufBytesIterator
(
res_io_buffer
));
size_t
num
;
res
[
request_idx
]
->
copy_and_forward
(
&
num
,
sizeof
(
size_t
));
}
}
int
size
;
int
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
size
=
0
;
}
else
if
(
request2server
[
seq
[
i
]]
!=
rank
)
{
res
[
seq
[
i
]]
->
copy_and_forward
(
&
size
,
sizeof
(
int
));
}
else
{
size
=
local_actual_sizes
[
local_index
++
];
}
actual_size
.
push_back
(
size
);
}
cntl
->
response_attachment
().
append
(
actual_size
.
data
(),
actual_size
.
size
()
*
sizeof
(
int
));
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
continue
;
}
else
if
(
request2server
[
seq
[
i
]]
!=
rank
)
{
char
temp
[
actual_size
[
i
]
+
1
];
res
[
seq
[
i
]]
->
copy_and_forward
(
temp
,
actual_size
[
i
]);
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
else
{
char
*
temp
=
local_buffers
[
local_index
++
].
get
();
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
}
closure
->
set_promise_value
(
0
);
};
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
remote_call_num
,
func
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
int
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NEIGHBOORS
);
closure
->
request
(
request_idx
)
->
set_table_id
(
request
.
table_id
());
closure
->
request
(
request_idx
)
->
set_client_id
(
rank
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
uint64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
PsService_Stub
rpc_stub
(
((
GraphBrpcServer
*
)
get_server
())
->
get_cmd_channel
(
server_index
));
// GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
if
(
server2request
[
rank
]
!=
-
1
)
{
((
GraphTable
*
)
table
)
->
random_sample_neighboors
(
node_id_buckets
.
back
().
data
(),
sample_size
,
local_buffers
,
local_actual_sizes
);
}
local_promise
.
get
()
->
set_value
(
0
);
if
(
remote_call_num
==
0
)
func
(
closure
);
fut
.
get
();
return
0
;
}
int32_t
GraphBrpcService
::
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
...
...
@@ -412,7 +606,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"graph_set_node_feat request requires at least
2
arguments"
);
"graph_set_node_feat request requires at least
3
arguments"
);
return
0
;
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
uint64_t
);
...
...
paddle/fluid/distributed/service/graph_brpc_server.h
浏览文件 @
5eb640c6
...
...
@@ -32,6 +32,8 @@ class GraphBrpcServer : public PSServer {
virtual
~
GraphBrpcServer
()
{}
PsBaseService
*
get_service
()
{
return
_service
.
get
();
}
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int32_t
build_peer2peer_connection
(
int
rank
);
virtual
brpc
::
Channel
*
get_cmd_channel
(
size_t
server_index
);
virtual
int32_t
stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
)
return
0
;
...
...
@@ -50,6 +52,7 @@ class GraphBrpcServer : public PSServer {
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
bool
stoped_
=
false
;
int
rank
;
brpc
::
Server
_server
;
std
::
shared_ptr
<
PsBaseService
>
_service
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
...
...
@@ -113,12 +116,18 @@ class GraphBrpcService : public PsBaseService {
int32_t
print_table_stat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
sample_neighboors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
private:
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_msg_handler_map
;
std
::
vector
<
float
>
_ori_values
;
const
int
sample_nodes_ranges
=
23
;
size_t
server_size
;
std
::
shared_ptr
<::
ThreadPool
>
task_pool
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/service/graph_py_service.cc
浏览文件 @
5eb640c6
...
...
@@ -107,6 +107,7 @@ void GraphPyServer::start_server(bool block) {
empty_vec
.
push_back
(
empty_prog
);
pserver_ptr
->
configure
(
server_proto
,
_ps_env
,
rank
,
empty_vec
);
pserver_ptr
->
start
(
ip
,
port
);
pserver_ptr
->
build_peer2peer_connection
(
rank
);
std
::
condition_variable
*
cv_
=
pserver_ptr
->
export_cv
();
if
(
block
)
{
std
::
mutex
mutex_
;
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
5eb640c6
...
...
@@ -56,6 +56,7 @@ enum PsCmdID {
PS_GRAPH_ADD_GRAPH_NODE
=
35
;
PS_GRAPH_REMOVE_GRAPH_NODE
=
36
;
PS_GRAPH_SET_NODE_FEAT
=
37
;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
=
38
;
}
message
PsRequestMessage
{
...
...
paddle/fluid/distributed/service/server.h
浏览文件 @
5eb640c6
...
...
@@ -147,7 +147,7 @@ class PsBaseService : public PsService {
public:
PsBaseService
()
:
_rank
(
0
),
_server
(
NULL
),
_config
(
NULL
)
{}
virtual
~
PsBaseService
()
{}
virtual
size_t
get_rank
()
{
return
_rank
;
}
virtual
int32_t
configure
(
PSServer
*
server
)
{
_server
=
server
;
_rank
=
_server
->
rank
();
...
...
@@ -167,6 +167,7 @@ class PsBaseService : public PsService {
}
virtual
int32_t
initialize
()
=
0
;
PSServer
*
get_server
()
{
return
_server
;
}
protected:
size_t
_rank
;
...
...
paddle/fluid/distributed/table/common_graph_table.cc
浏览文件 @
5eb640c6
...
...
@@ -305,12 +305,12 @@ Node *GraphTable::find_node(uint64_t id) {
return
node
;
}
uint32_t
GraphTable
::
get_thread_pool_index
(
uint64_t
node_id
)
{
return
node_id
%
shard_num
%
shard_num_per_
table
%
task_pool_size_
;
return
node_id
%
shard_num
%
shard_num_per_
server
%
task_pool_size_
;
}
uint32_t
GraphTable
::
get_thread_pool_index_by_shard_index
(
uint64_t
shard_index
)
{
return
shard_index
%
shard_num_per_
table
%
task_pool_size_
;
return
shard_index
%
shard_num_per_
server
%
task_pool_size_
;
}
int32_t
GraphTable
::
clear_nodes
()
{
...
...
@@ -575,6 +575,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
actual_size
=
size
;
return
0
;
}
int32_t
GraphTable
::
get_server_index_by_id
(
uint64_t
id
)
{
return
id
%
shard_num
/
shard_num_per_server
;
}
int32_t
GraphTable
::
initialize
()
{
_shards_task_pool
.
resize
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
...
...
@@ -611,13 +616,12 @@ int32_t GraphTable::initialize() {
shard_num
=
_config
.
shard_num
();
VLOG
(
0
)
<<
"in init graph table shard num = "
<<
shard_num
<<
" shard_idx"
<<
_shard_idx
;
shard_num_per_
table
=
sparse_local_shard_num
(
shard_num
,
server_num
);
shard_start
=
_shard_idx
*
shard_num_per_
table
;
shard_end
=
shard_start
+
shard_num_per_
table
;
shard_num_per_
server
=
sparse_local_shard_num
(
shard_num
,
server_num
);
shard_start
=
_shard_idx
*
shard_num_per_
server
;
shard_end
=
shard_start
+
shard_num_per_
server
;
VLOG
(
0
)
<<
"in init graph table shard idx = "
<<
_shard_idx
<<
" shard_start "
<<
shard_start
<<
" shard_end "
<<
shard_end
;
// shards.resize(shard_num_per_table);
shards
=
std
::
vector
<
GraphShard
>
(
shard_num_per_table
,
GraphShard
(
shard_num
));
shards
=
std
::
vector
<
GraphShard
>
(
shard_num_per_server
,
GraphShard
(
shard_num
));
return
0
;
}
}
// namespace distributed
...
...
paddle/fluid/distributed/table/common_graph_table.h
浏览文件 @
5eb640c6
...
...
@@ -94,6 +94,7 @@ class GraphTable : public SparseTable {
int32_t
remove_graph_node
(
std
::
vector
<
uint64_t
>
&
id_list
);
int32_t
get_server_index_by_id
(
uint64_t
id
);
Node
*
find_node
(
uint64_t
id
);
virtual
int32_t
pull_sparse
(
float
*
values
,
...
...
@@ -128,9 +129,11 @@ class GraphTable : public SparseTable {
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
size_t
get_server_num
()
{
return
server_num
;
}
protected:
std
::
vector
<
GraphShard
>
shards
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_
table
,
shard_num
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_
server
,
shard_num
;
const
int
task_pool_size_
=
24
;
const
int
random_sample_nodes_ranges
=
3
;
...
...
paddle/fluid/distributed/test/graph_node_test.cc
浏览文件 @
5eb640c6
...
...
@@ -138,6 +138,10 @@ void testSingleSampleNeighboor(
for
(
auto
g
:
s
)
{
ASSERT_EQ
(
true
,
s1
.
find
(
g
)
!=
s1
.
end
());
}
vs
.
clear
();
pull_status
=
worker_ptr_
->
batch_sample_neighboors
(
0
,
{
96
,
37
},
4
,
vs
,
0
);
pull_status
.
wait
();
ASSERT_EQ
(
vs
.
size
(),
2
);
}
void
testAddNode
(
...
...
@@ -356,6 +360,7 @@ void RunServer() {
pserver_ptr_
->
configure
(
server_proto
,
_ps_env
,
0
,
empty_vec
);
LOG
(
INFO
)
<<
"first server, run start(ip,port)"
;
pserver_ptr_
->
start
(
ip_
,
port_
);
pserver_ptr_
->
build_peer2peer_connection
(
0
);
LOG
(
INFO
)
<<
"init first server Done"
;
}
...
...
@@ -373,6 +378,7 @@ void RunServer2() {
empty_vec2
.
push_back
(
empty_prog2
);
pserver_ptr2
->
configure
(
server_proto2
,
_ps_env2
,
1
,
empty_vec2
);
pserver_ptr2
->
start
(
ip2
,
port2
);
pserver_ptr2
->
build_peer2peer_connection
(
1
);
}
void
RunClient
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录