Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5eb640c6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5eb640c6
编写于
10月 21, 2021
作者:
S
seemingwang
提交者:
GitHub
10月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Graph engine4 (#36587)
上级
d64f7b3b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
292 addition
and
16 deletion
+292
-16
paddle/fluid/distributed/service/graph_brpc_client.cc
paddle/fluid/distributed/service/graph_brpc_client.cc
+57
-1
paddle/fluid/distributed/service/graph_brpc_client.h
paddle/fluid/distributed/service/graph_brpc_client.h
+2
-1
paddle/fluid/distributed/service/graph_brpc_server.cc
paddle/fluid/distributed/service/graph_brpc_server.cc
+199
-5
paddle/fluid/distributed/service/graph_brpc_server.h
paddle/fluid/distributed/service/graph_brpc_server.h
+9
-0
paddle/fluid/distributed/service/graph_py_service.cc
paddle/fluid/distributed/service/graph_py_service.cc
+1
-0
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-0
paddle/fluid/distributed/service/server.h
paddle/fluid/distributed/service/server.h
+2
-1
paddle/fluid/distributed/table/common_graph_table.cc
paddle/fluid/distributed/table/common_graph_table.cc
+11
-7
paddle/fluid/distributed/table/common_graph_table.h
paddle/fluid/distributed/table/common_graph_table.h
+4
-1
paddle/fluid/distributed/test/graph_node_test.cc
paddle/fluid/distributed/test/graph_node_test.cc
+6
-0
未找到文件。
paddle/fluid/distributed/service/graph_brpc_client.cc
浏览文件 @
5eb640c6
...
...
@@ -304,7 +304,63 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
// char* &buffer,int &actual_size
std
::
future
<
int32_t
>
GraphBrpcClient
::
batch_sample_neighboors
(
uint32_t
table_id
,
std
::
vector
<
uint64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>
&
res
)
{
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>
&
res
,
int
server_index
)
{
if
(
server_index
!=
-
1
)
{
res
.
resize
(
node_ids
.
size
());
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
&
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
if
(
closure
->
check_response
(
0
,
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
)
!=
0
)
{
ret
=
-
1
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
bytes_size
=
io_buffer_itr
.
bytes_left
();
std
::
unique_ptr
<
char
[]
>
buffer_wrapper
(
new
char
[
bytes_size
]);
char
*
buffer
=
buffer_wrapper
.
get
();
io_buffer_itr
.
copy_and_forward
((
void
*
)(
buffer
),
bytes_size
);
size_t
node_num
=
*
(
size_t
*
)
buffer
;
int
*
actual_sizes
=
(
int
*
)(
buffer
+
sizeof
(
size_t
));
char
*
node_buffer
=
buffer
+
sizeof
(
size_t
)
+
sizeof
(
int
)
*
node_num
;
int
offset
=
0
;
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
int
actual_size
=
actual_sizes
[
node_idx
];
int
start
=
0
;
while
(
start
<
actual_size
)
{
res
[
node_idx
].
push_back
(
{
*
(
uint64_t
*
)(
node_buffer
+
offset
+
start
),
*
(
float
*
)(
node_buffer
+
offset
+
start
+
GraphNode
::
id_size
)});
start
+=
GraphNode
::
id_size
+
GraphNode
::
weight_size
;
}
offset
+=
actual_size
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
;
closure
->
request
(
0
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
add_params
((
char
*
)
node_ids
.
data
(),
sizeof
(
uint64_t
)
*
node_ids
.
size
());
closure
->
request
(
0
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
get_cmd_channel
(
server_index
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
res
.
clear
();
...
...
paddle/fluid/distributed/service/graph_brpc_client.h
浏览文件 @
5eb640c6
...
...
@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
// given a batch of nodes, sample graph_neighboors for each of them
virtual
std
::
future
<
int32_t
>
batch_sample_neighboors
(
uint32_t
table_id
,
std
::
vector
<
uint64_t
>
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>&
res
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
>>>&
res
,
int
server_index
=
-
1
);
virtual
std
::
future
<
int32_t
>
pull_graph_list
(
uint32_t
table_id
,
int
server_index
,
int
start
,
...
...
paddle/fluid/distributed/service/graph_brpc_server.cc
浏览文件 @
5eb640c6
...
...
@@ -61,6 +61,10 @@ int32_t GraphBrpcServer::initialize() {
return
0
;
}
brpc
::
Channel
*
GraphBrpcServer
::
get_cmd_channel
(
size_t
server_index
)
{
return
_pserver_channels
[
server_index
].
get
();
}
uint64_t
GraphBrpcServer
::
start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
...
...
@@ -80,6 +84,42 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
return
0
;
}
int32_t
GraphBrpcServer
::
build_peer2peer_connection
(
int
rank
)
{
this
->
rank
=
rank
;
auto
_env
=
environment
();
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
500000
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
10000
;
options
.
max_retry
=
3
;
std
::
vector
<
PSHost
>
server_list
=
_env
->
get_ps_servers
();
_pserver_channels
.
resize
(
server_list
.
size
());
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
server_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
server_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
server_list
[
i
].
port
));
_pserver_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"GraphServer connect to Server:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
server_list
[
i
].
ip
,
server_list
[
i
].
port
);
if
(
_pserver_channels
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"GraphServer connect to Server:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"servers peer2peer connection success:"
<<
os
.
str
();
return
0
;
}
int32_t
GraphBrpcService
::
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
...
...
@@ -160,6 +200,9 @@ int32_t GraphBrpcService::initialize() {
&
GraphBrpcService
::
remove_graph_node
;
_service_handler_map
[
PS_GRAPH_SET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_set_node_feat
;
_service_handler_map
[
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
]
=
&
GraphBrpcService
::
sample_neighboors_across_multi_servers
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
...
...
@@ -172,10 +215,10 @@ int32_t GraphBrpcService::initialize_shard_info() {
if
(
_is_initialize_shard_info
)
{
return
0
;
}
s
ize_t
shard_num
=
_server
->
environment
()
->
get_ps_servers
().
size
();
s
erver_size
=
_server
->
environment
()
->
get_ps_servers
().
size
();
auto
&
table_map
=
*
(
_server
->
table
());
for
(
auto
itr
:
table_map
)
{
itr
.
second
->
set_shard
(
_rank
,
s
hard_num
);
itr
.
second
->
set_shard
(
_rank
,
s
erver_size
);
}
_is_initialize_shard_info
=
true
;
}
...
...
@@ -209,7 +252,9 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
int
service_ret
=
(
this
->*
handler_func
)(
table
,
*
request
,
*
response
,
cntl
);
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
if
(
!
response
->
has_err_msg
())
{
response
->
set_err_msg
(
"server internal error"
);
}
}
}
...
...
@@ -403,7 +448,156 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return
0
;
}
int32_t
GraphBrpcService
::
sample_neighboors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
// sleep(5);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"graph_random_sample request requires at least 2 arguments"
);
return
0
;
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
uint64_t
),
size_of_size_t
=
sizeof
(
size_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
0
).
c_str
());
int
sample_size
=
*
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
// std::vector<uint64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
std
::
vector
<
uint64_t
>
local_id
;
std
::
vector
<
int
>
local_query_idx
;
size_t
rank
=
get_rank
();
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
if
(
server2request
[
rank
]
!=
-
1
)
{
auto
pos
=
server2request
[
rank
];
std
::
swap
(
request2server
[
pos
],
request2server
[(
int
)
request2server
.
size
()
-
1
]);
server2request
[
request2server
[
pos
]]
=
pos
;
server2request
[
request2server
[(
int
)
request2server
.
size
()
-
1
]]
=
request2server
.
size
()
-
1
;
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
unique_ptr
<
char
[]
>>
local_buffers
;
std
::
vector
<
int
>
local_actual_sizes
;
std
::
vector
<
size_t
>
seq
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
int
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_data
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
seq
.
push_back
(
request_idx
);
}
size_t
remote_call_num
=
request_call_num
;
if
(
request2server
.
size
()
!=
0
&&
request2server
.
back
()
==
rank
)
{
remote_call_num
--
;
local_buffers
.
resize
(
node_id_buckets
.
back
().
size
());
local_actual_sizes
.
resize
(
node_id_buckets
.
back
().
size
());
}
cntl
->
response_attachment
().
append
(
&
node_num
,
sizeof
(
size_t
));
auto
local_promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
local_fut
=
local_promise
->
get_future
();
std
::
vector
<
bool
>
failed
(
server_size
,
false
);
std
::
function
<
void
(
void
*
)
>
func
=
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
local_fut
.
get
();
std
::
vector
<
int
>
actual_size
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
std
::
vector
<
std
::
unique_ptr
<
butil
::
IOBufBytesIterator
>>
res
(
remote_call_num
);
size_t
fail_num
=
0
;
for
(
size_t
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SAMPLE_NEIGHBOORS
)
!=
0
)
{
++
fail_num
;
failed
[
request2server
[
request_idx
]]
=
true
;
}
else
{
auto
&
res_io_buffer
=
closure
->
cntl
(
request_idx
)
->
response_attachment
();
size_t
node_size
;
res
[
request_idx
].
reset
(
new
butil
::
IOBufBytesIterator
(
res_io_buffer
));
size_t
num
;
res
[
request_idx
]
->
copy_and_forward
(
&
num
,
sizeof
(
size_t
));
}
}
int
size
;
int
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
size
=
0
;
}
else
if
(
request2server
[
seq
[
i
]]
!=
rank
)
{
res
[
seq
[
i
]]
->
copy_and_forward
(
&
size
,
sizeof
(
int
));
}
else
{
size
=
local_actual_sizes
[
local_index
++
];
}
actual_size
.
push_back
(
size
);
}
cntl
->
response_attachment
().
append
(
actual_size
.
data
(),
actual_size
.
size
()
*
sizeof
(
int
));
local_index
=
0
;
for
(
size_t
i
=
0
;
i
<
node_num
;
i
++
)
{
if
(
fail_num
>
0
&&
failed
[
seq
[
i
]])
{
continue
;
}
else
if
(
request2server
[
seq
[
i
]]
!=
rank
)
{
char
temp
[
actual_size
[
i
]
+
1
];
res
[
seq
[
i
]]
->
copy_and_forward
(
temp
,
actual_size
[
i
]);
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
else
{
char
*
temp
=
local_buffers
[
local_index
++
].
get
();
cntl
->
response_attachment
().
append
(
temp
,
actual_size
[
i
]);
}
}
closure
->
set_promise_value
(
0
);
};
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
remote_call_num
,
func
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
int
request_idx
=
0
;
request_idx
<
remote_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SAMPLE_NEIGHBOORS
);
closure
->
request
(
request_idx
)
->
set_table_id
(
request
.
table_id
());
closure
->
request
(
request_idx
)
->
set_client_id
(
rank
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
uint64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
sample_size
,
sizeof
(
int
));
PsService_Stub
rpc_stub
(
((
GraphBrpcServer
*
)
get_server
())
->
get_cmd_channel
(
server_index
));
// GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
if
(
server2request
[
rank
]
!=
-
1
)
{
((
GraphTable
*
)
table
)
->
random_sample_neighboors
(
node_id_buckets
.
back
().
data
(),
sample_size
,
local_buffers
,
local_actual_sizes
);
}
local_promise
.
get
()
->
set_value
(
0
);
if
(
remote_call_num
==
0
)
func
(
closure
);
fut
.
get
();
return
0
;
}
int32_t
GraphBrpcService
::
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
...
...
@@ -412,7 +606,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"graph_set_node_feat request requires at least
2
arguments"
);
"graph_set_node_feat request requires at least
3
arguments"
);
return
0
;
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
uint64_t
);
...
...
paddle/fluid/distributed/service/graph_brpc_server.h
浏览文件 @
5eb640c6
...
...
@@ -32,6 +32,8 @@ class GraphBrpcServer : public PSServer {
virtual
~
GraphBrpcServer
()
{}
PsBaseService
*
get_service
()
{
return
_service
.
get
();
}
virtual
uint64_t
start
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int32_t
build_peer2peer_connection
(
int
rank
);
virtual
brpc
::
Channel
*
get_cmd_channel
(
size_t
server_index
);
virtual
int32_t
stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
)
return
0
;
...
...
@@ -50,6 +52,7 @@ class GraphBrpcServer : public PSServer {
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
bool
stoped_
=
false
;
int
rank
;
brpc
::
Server
_server
;
std
::
shared_ptr
<
PsBaseService
>
_service
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
...
...
@@ -113,12 +116,18 @@ class GraphBrpcService : public PsBaseService {
int32_t
print_table_stat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
sample_neighboors_across_multi_servers
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
private:
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_msg_handler_map
;
std
::
vector
<
float
>
_ori_values
;
const
int
sample_nodes_ranges
=
23
;
size_t
server_size
;
std
::
shared_ptr
<::
ThreadPool
>
task_pool
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/service/graph_py_service.cc
浏览文件 @
5eb640c6
...
...
@@ -107,6 +107,7 @@ void GraphPyServer::start_server(bool block) {
empty_vec
.
push_back
(
empty_prog
);
pserver_ptr
->
configure
(
server_proto
,
_ps_env
,
rank
,
empty_vec
);
pserver_ptr
->
start
(
ip
,
port
);
pserver_ptr
->
build_peer2peer_connection
(
rank
);
std
::
condition_variable
*
cv_
=
pserver_ptr
->
export_cv
();
if
(
block
)
{
std
::
mutex
mutex_
;
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
5eb640c6
...
...
@@ -56,6 +56,7 @@ enum PsCmdID {
PS_GRAPH_ADD_GRAPH_NODE
=
35
;
PS_GRAPH_REMOVE_GRAPH_NODE
=
36
;
PS_GRAPH_SET_NODE_FEAT
=
37
;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
=
38
;
}
message
PsRequestMessage
{
...
...
paddle/fluid/distributed/service/server.h
浏览文件 @
5eb640c6
...
...
@@ -147,7 +147,7 @@ class PsBaseService : public PsService {
public:
PsBaseService
()
:
_rank
(
0
),
_server
(
NULL
),
_config
(
NULL
)
{}
virtual
~
PsBaseService
()
{}
virtual
size_t
get_rank
()
{
return
_rank
;
}
virtual
int32_t
configure
(
PSServer
*
server
)
{
_server
=
server
;
_rank
=
_server
->
rank
();
...
...
@@ -167,6 +167,7 @@ class PsBaseService : public PsService {
}
virtual
int32_t
initialize
()
=
0
;
PSServer
*
get_server
()
{
return
_server
;
}
protected:
size_t
_rank
;
...
...
paddle/fluid/distributed/table/common_graph_table.cc
浏览文件 @
5eb640c6
...
...
@@ -305,12 +305,12 @@ Node *GraphTable::find_node(uint64_t id) {
return
node
;
}
uint32_t
GraphTable
::
get_thread_pool_index
(
uint64_t
node_id
)
{
return
node_id
%
shard_num
%
shard_num_per_
table
%
task_pool_size_
;
return
node_id
%
shard_num
%
shard_num_per_
server
%
task_pool_size_
;
}
uint32_t
GraphTable
::
get_thread_pool_index_by_shard_index
(
uint64_t
shard_index
)
{
return
shard_index
%
shard_num_per_
table
%
task_pool_size_
;
return
shard_index
%
shard_num_per_
server
%
task_pool_size_
;
}
int32_t
GraphTable
::
clear_nodes
()
{
...
...
@@ -575,6 +575,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
actual_size
=
size
;
return
0
;
}
int32_t
GraphTable
::
get_server_index_by_id
(
uint64_t
id
)
{
return
id
%
shard_num
/
shard_num_per_server
;
}
int32_t
GraphTable
::
initialize
()
{
_shards_task_pool
.
resize
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
...
...
@@ -611,13 +616,12 @@ int32_t GraphTable::initialize() {
shard_num
=
_config
.
shard_num
();
VLOG
(
0
)
<<
"in init graph table shard num = "
<<
shard_num
<<
" shard_idx"
<<
_shard_idx
;
shard_num_per_
table
=
sparse_local_shard_num
(
shard_num
,
server_num
);
shard_start
=
_shard_idx
*
shard_num_per_
table
;
shard_end
=
shard_start
+
shard_num_per_
table
;
shard_num_per_
server
=
sparse_local_shard_num
(
shard_num
,
server_num
);
shard_start
=
_shard_idx
*
shard_num_per_
server
;
shard_end
=
shard_start
+
shard_num_per_
server
;
VLOG
(
0
)
<<
"in init graph table shard idx = "
<<
_shard_idx
<<
" shard_start "
<<
shard_start
<<
" shard_end "
<<
shard_end
;
// shards.resize(shard_num_per_table);
shards
=
std
::
vector
<
GraphShard
>
(
shard_num_per_table
,
GraphShard
(
shard_num
));
shards
=
std
::
vector
<
GraphShard
>
(
shard_num_per_server
,
GraphShard
(
shard_num
));
return
0
;
}
}
// namespace distributed
...
...
paddle/fluid/distributed/table/common_graph_table.h
浏览文件 @
5eb640c6
...
...
@@ -94,6 +94,7 @@ class GraphTable : public SparseTable {
int32_t
remove_graph_node
(
std
::
vector
<
uint64_t
>
&
id_list
);
int32_t
get_server_index_by_id
(
uint64_t
id
);
Node
*
find_node
(
uint64_t
id
);
virtual
int32_t
pull_sparse
(
float
*
values
,
...
...
@@ -128,9 +129,11 @@ class GraphTable : public SparseTable {
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
size_t
get_server_num
()
{
return
server_num
;
}
protected:
std
::
vector
<
GraphShard
>
shards
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_
table
,
shard_num
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_
server
,
shard_num
;
const
int
task_pool_size_
=
24
;
const
int
random_sample_nodes_ranges
=
3
;
...
...
paddle/fluid/distributed/test/graph_node_test.cc
浏览文件 @
5eb640c6
...
...
@@ -138,6 +138,10 @@ void testSingleSampleNeighboor(
for
(
auto
g
:
s
)
{
ASSERT_EQ
(
true
,
s1
.
find
(
g
)
!=
s1
.
end
());
}
vs
.
clear
();
pull_status
=
worker_ptr_
->
batch_sample_neighboors
(
0
,
{
96
,
37
},
4
,
vs
,
0
);
pull_status
.
wait
();
ASSERT_EQ
(
vs
.
size
(),
2
);
}
void
testAddNode
(
...
...
@@ -356,6 +360,7 @@ void RunServer() {
pserver_ptr_
->
configure
(
server_proto
,
_ps_env
,
0
,
empty_vec
);
LOG
(
INFO
)
<<
"first server, run start(ip,port)"
;
pserver_ptr_
->
start
(
ip_
,
port_
);
pserver_ptr_
->
build_peer2peer_connection
(
0
);
LOG
(
INFO
)
<<
"init first server Done"
;
}
...
...
@@ -373,6 +378,7 @@ void RunServer2() {
empty_vec2
.
push_back
(
empty_prog2
);
pserver_ptr2
->
configure
(
server_proto2
,
_ps_env2
,
1
,
empty_vec2
);
pserver_ptr2
->
start
(
ip2
,
port2
);
pserver_ptr2
->
build_peer2peer_connection
(
1
);
}
void
RunClient
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录