Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c3efabeb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c3efabeb
编写于
8月 23, 2021
作者:
S
seemingwang
提交者:
GitHub
8月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
set node feature (#34994)
上级
77a8a394
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
224 addition
and
0 deletion
+224
-0
paddle/fluid/distributed/service/graph_brpc_client.cc
paddle/fluid/distributed/service/graph_brpc_client.cc
+96
-0
paddle/fluid/distributed/service/graph_brpc_client.h
paddle/fluid/distributed/service/graph_brpc_client.h
+5
-0
paddle/fluid/distributed/service/graph_brpc_server.cc
paddle/fluid/distributed/service/graph_brpc_server.cc
+42
-0
paddle/fluid/distributed/service/graph_brpc_server.h
paddle/fluid/distributed/service/graph_brpc_server.h
+4
-0
paddle/fluid/distributed/service/graph_py_service.cc
paddle/fluid/distributed/service/graph_py_service.cc
+13
-0
paddle/fluid/distributed/service/graph_py_service.h
paddle/fluid/distributed/service/graph_py_service.h
+3
-0
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-0
paddle/fluid/distributed/table/common_graph_table.cc
paddle/fluid/distributed/table/common_graph_table.cc
+28
-0
paddle/fluid/distributed/table/common_graph_table.h
paddle/fluid/distributed/table/common_graph_table.h
+6
-0
paddle/fluid/distributed/test/graph_node_test.cc
paddle/fluid/distributed/test/graph_node_test.cc
+11
-0
paddle/fluid/pybind/fleet_py.cc
paddle/fluid/pybind/fleet_py.cc
+15
-0
未找到文件。
paddle/fluid/distributed/service/graph_brpc_client.cc
浏览文件 @
c3efabeb
...
...
@@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
GraphBrpcClient
::
set_node_feat
(
const
uint32_t
&
table_id
,
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
features
)
{
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
for
(
int
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
}
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
features_idx_buckets
(
request_call_num
);
for
(
int
query_idx
=
0
;
query_idx
<
node_ids
.
size
();
++
query_idx
)
{
int
server_index
=
get_server_index_by_id
(
node_ids
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_ids
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
if
(
features_idx_buckets
[
request_idx
].
size
()
==
0
)
{
features_idx_buckets
[
request_idx
].
resize
(
feature_names
.
size
());
}
for
(
int
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
features_idx_buckets
[
request_idx
][
feat_idx
].
push_back
(
features
[
feat_idx
][
query_idx
]);
}
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
&
,
node_id_buckets
,
query_idx_buckets
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
size_t
fail_num
=
0
;
for
(
int
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
if
(
closure
->
check_response
(
request_idx
,
PS_GRAPH_SET_NODE_FEAT
)
!=
0
)
{
++
fail_num
;
}
if
(
fail_num
==
request_call_num
)
{
ret
=
-
1
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
int
request_idx
=
0
;
request_idx
<
request_call_num
;
++
request_idx
)
{
int
server_index
=
request2server
[
request_idx
];
closure
->
request
(
request_idx
)
->
set_cmd_id
(
PS_GRAPH_SET_NODE_FEAT
);
closure
->
request
(
request_idx
)
->
set_table_id
(
table_id
);
closure
->
request
(
request_idx
)
->
set_client_id
(
_client_id
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
uint64_t
)
*
node_num
);
std
::
string
joint_feature_name
=
paddle
::
string
::
join_strings
(
feature_names
,
'\t'
);
closure
->
request
(
request_idx
)
->
add_params
(
joint_feature_name
.
c_str
(),
joint_feature_name
.
size
());
// set features
std
::
string
set_feature
=
""
;
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
features_idx_buckets
[
request_idx
][
feat_idx
][
node_idx
].
size
();
set_feature
.
append
((
char
*
)
&
feat_len
,
sizeof
(
size_t
));
set_feature
.
append
(
features_idx_buckets
[
request_idx
][
feat_idx
][
node_idx
].
data
(),
feat_len
);
}
}
closure
->
request
(
request_idx
)
->
add_params
(
set_feature
.
c_str
(),
set_feature
.
size
());
GraphPsService_Stub
rpc_stub
=
getServiceStub
(
get_cmd_channel
(
server_index
));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
request_idx
),
closure
->
request
(
request_idx
),
closure
->
response
(
request_idx
),
closure
);
}
return
fut
;
}
int32_t
GraphBrpcClient
::
initialize
()
{
// set_shard_num(_config.shard_num());
BrpcPsClient
::
initialize
();
...
...
paddle/fluid/distributed/service/graph_brpc_client.h
浏览文件 @
c3efabeb
...
...
@@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient {
const
std
::
vector
<
std
::
string
>&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>&
res
);
virtual
std
::
future
<
int32_t
>
set_node_feat
(
const
uint32_t
&
table_id
,
const
std
::
vector
<
uint64_t
>&
node_ids
,
const
std
::
vector
<
std
::
string
>&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
features
);
virtual
std
::
future
<
int32_t
>
clear_nodes
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
add_graph_node
(
uint32_t
table_id
,
std
::
vector
<
uint64_t
>&
node_id_list
,
...
...
paddle/fluid/distributed/service/graph_brpc_server.cc
浏览文件 @
c3efabeb
...
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
...
...
@@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() {
&
GraphBrpcService
::
add_graph_node
;
_service_handler_map
[
PS_GRAPH_REMOVE_GRAPH_NODE
]
=
&
GraphBrpcService
::
remove_graph_node
;
_service_handler_map
[
PS_GRAPH_SET_NODE_FEAT
]
=
&
GraphBrpcService
::
graph_set_node_feat
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
...
...
@@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return
0
;
}
int32_t
GraphBrpcService
::
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"graph_set_node_feat request requires at least 2 arguments"
);
return
0
;
}
size_t
node_num
=
request
.
params
(
0
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
0
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
std
::
string
>
feature_names
=
paddle
::
string
::
split_string
<
std
::
string
>
(
request
.
params
(
1
),
"
\t
"
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_num
));
const
char
*
buffer
=
request
.
params
(
2
).
c_str
();
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
*
(
size_t
*
)(
buffer
);
buffer
+=
sizeof
(
size_t
);
auto
feat
=
std
::
string
(
buffer
,
feat_len
);
features
[
feat_idx
][
node_idx
]
=
feat
;
buffer
+=
feat_len
;
}
}
((
GraphTable
*
)
table
)
->
set_node_feat
(
node_ids
,
feature_names
,
features
);
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/service/graph_brpc_server.h
浏览文件 @
c3efabeb
...
...
@@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService {
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_get_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
graph_set_node_feat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
clear_nodes
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
int32_t
add_graph_node
(
Table
*
table
,
const
PsRequestMessage
&
request
,
...
...
paddle/fluid/distributed/service/graph_py_service.cc
浏览文件 @
c3efabeb
...
...
@@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
return
v
;
}
void
GraphPyClient
::
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
uint64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
)
{
if
(
this
->
table_id_map
.
count
(
node_type
))
{
uint32_t
table_id
=
this
->
table_id_map
[
node_type
];
auto
status
=
worker_ptr
->
set_node_feat
(
table_id
,
node_ids
,
feature_names
,
features
);
status
.
wait
();
}
return
;
}
std
::
vector
<
FeatureNode
>
GraphPyClient
::
pull_graph_list
(
std
::
string
name
,
int
server_index
,
int
start
,
int
size
,
...
...
paddle/fluid/distributed/service/graph_py_service.h
浏览文件 @
c3efabeb
...
...
@@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService {
std
::
vector
<
std
::
vector
<
std
::
string
>>
get_node_feat
(
std
::
string
node_type
,
std
::
vector
<
uint64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
);
void
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
uint64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
);
std
::
vector
<
FeatureNode
>
pull_graph_list
(
std
::
string
name
,
int
server_index
,
int
start
,
int
size
,
int
step
=
1
);
::
paddle
::
distributed
::
PSParameter
GetWorkerProto
();
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
c3efabeb
...
...
@@ -55,6 +55,7 @@ enum PsCmdID {
PS_GRAPH_CLEAR
=
34
;
PS_GRAPH_ADD_GRAPH_NODE
=
35
;
PS_GRAPH_REMOVE_GRAPH_NODE
=
36
;
PS_GRAPH_SET_NODE_FEAT
=
37
;
}
message
PsRequestMessage
{
...
...
paddle/fluid/distributed/table/common_graph_table.cc
浏览文件 @
c3efabeb
...
...
@@ -469,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return
0
;
}
int32_t
GraphTable
::
set_node_feat
(
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
size_t
node_num
=
node_ids
.
size
();
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
idx
=
0
;
idx
<
node_num
;
++
idx
)
{
uint64_t
node_id
=
node_ids
[
idx
];
tasks
.
push_back
(
_shards_task_pool
[
get_thread_pool_index
(
node_id
)]
->
enqueue
(
[
&
,
idx
,
node_id
]()
->
int
{
size_t
index
=
node_id
%
this
->
shard_num
-
this
->
shard_start
;
auto
node
=
shards
[
index
].
add_feature_node
(
node_id
);
node
->
set_feature_size
(
this
->
feat_name
.
size
());
for
(
int
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
if
(
feat_id_map
.
find
(
feature_name
)
!=
feat_id_map
.
end
())
{
node
->
set_feature
(
feat_id_map
[
feature_name
],
res
[
feat_idx
][
idx
]);
}
}
return
0
;
}));
}
for
(
size_t
idx
=
0
;
idx
<
node_num
;
++
idx
)
{
tasks
[
idx
].
get
();
}
return
0
;
}
std
::
pair
<
int32_t
,
std
::
string
>
GraphTable
::
parse_feature
(
std
::
string
feat_str
)
{
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
...
...
paddle/fluid/distributed/table/common_graph_table.h
浏览文件 @
c3efabeb
...
...
@@ -46,6 +46,7 @@ class GraphShard {
}
return
res
;
}
GraphNode
*
add_graph_node
(
uint64_t
id
);
FeatureNode
*
add_feature_node
(
uint64_t
id
);
Node
*
find_node
(
uint64_t
id
);
...
...
@@ -122,6 +123,11 @@ class GraphTable : public SparseTable {
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
virtual
int32_t
set_node_feat
(
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
protected:
std
::
vector
<
GraphShard
>
shards
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_table
,
shard_num
;
...
...
paddle/fluid/distributed/test/graph_node_test.cc
浏览文件 @
c3efabeb
...
...
@@ -558,6 +558,17 @@ void RunBrpcPushSparse() {
VLOG
(
0
)
<<
"get_node_feat: "
<<
node_feat
[
1
][
0
];
VLOG
(
0
)
<<
"get_node_feat: "
<<
node_feat
[
1
][
1
];
node_feat
[
1
][
0
]
=
"helloworld"
;
client1
.
set_node_feat
(
std
::
string
(
"user"
),
node_ids
,
feature_names
,
node_feat
);
// sleep(5);
node_feat
=
client1
.
get_node_feat
(
std
::
string
(
"user"
),
node_ids
,
feature_names
);
VLOG
(
0
)
<<
"get_node_feat: "
<<
node_feat
[
1
][
0
];
ASSERT_TRUE
(
node_feat
[
1
][
0
]
==
"helloworld"
);
// Test string
node_ids
.
clear
();
node_ids
.
push_back
(
37
);
...
...
paddle/fluid/pybind/fleet_py.cc
浏览文件 @
c3efabeb
...
...
@@ -205,6 +205,7 @@ void BindGraphPyClient(py::module* m) {
.
def
(
"pull_graph_list"
,
&
GraphPyClient
::
pull_graph_list
)
.
def
(
"start_client"
,
&
GraphPyClient
::
start_client
)
.
def
(
"batch_sample_neighboors"
,
&
GraphPyClient
::
batch_sample_neighboors
)
.
def
(
"remove_graph_node"
,
&
GraphPyClient
::
remove_graph_node
)
.
def
(
"random_sample_nodes"
,
&
GraphPyClient
::
random_sample_nodes
)
.
def
(
"stop_server"
,
&
GraphPyClient
::
stop_server
)
.
def
(
"get_node_feat"
,
...
...
@@ -221,6 +222,20 @@ void BindGraphPyClient(py::module* m) {
}
return
bytes_feats
;
})
.
def
(
"set_node_feat"
,
[](
GraphPyClient
&
self
,
std
::
string
node_type
,
std
::
vector
<
uint64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
std
::
vector
<
std
::
vector
<
py
::
bytes
>>
bytes_feats
)
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
feats
(
bytes_feats
.
size
());
for
(
int
i
=
0
;
i
<
bytes_feats
.
size
();
++
i
)
{
for
(
int
j
=
0
;
j
<
bytes_feats
[
i
].
size
();
++
j
)
{
feats
[
i
].
push_back
(
std
::
string
(
bytes_feats
[
i
][
j
]));
}
}
self
.
set_node_feat
(
node_type
,
node_ids
,
feature_names
,
feats
);
return
;
})
.
def
(
"bind_local_server"
,
&
GraphPyClient
::
bind_local_server
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录