Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9b3b53ba
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9b3b53ba
编写于
1月 30, 2022
作者:
Z
zhaocaibei123
提交者:
GitHub
1月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
geo memory sparse table (#39250)
* geo depends * add memory geo table * fix
上级
bafea65c
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
604 addition
and
33 deletion
+604
-33
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+132
-14
paddle/fluid/distributed/ps/service/brpc_ps_client.h
paddle/fluid/distributed/ps/service/brpc_ps_client.h
+4
-0
paddle/fluid/distributed/ps/service/communicator/communicator.cc
...fluid/distributed/ps/service/communicator/communicator.cc
+22
-10
paddle/fluid/distributed/ps/service/communicator/communicator.h
.../fluid/distributed/ps/service/communicator/communicator.h
+3
-3
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+11
-0
paddle/fluid/distributed/ps/table/CMakeLists.txt
paddle/fluid/distributed/ps/table/CMakeLists.txt
+4
-1
paddle/fluid/distributed/ps/table/depends/geo_recorder.h
paddle/fluid/distributed/ps/table/depends/geo_recorder.h
+0
-4
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc
+220
-0
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
+78
-0
paddle/fluid/distributed/ps/table/table.cc
paddle/fluid/distributed/ps/table/table.cc
+2
-0
paddle/fluid/distributed/test/CMakeLists.txt
paddle/fluid/distributed/test/CMakeLists.txt
+3
-0
paddle/fluid/distributed/test/memory_geo_table_test.cc
paddle/fluid/distributed/test/memory_geo_table_test.cc
+123
-0
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+2
-1
未找到文件。
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
浏览文件 @
9b3b53ba
...
@@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() {
...
@@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() {
auto
&
profiler
=
CostProfiler
::
instance
();
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_client_pull_dense"
);
profiler
.
register_profiler
(
"pserver_client_pull_dense"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_param"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_local"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_local"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_parse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_parse"
);
...
@@ -543,6 +544,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
...
@@ -543,6 +544,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
return
fut
;
return
fut
;
}
}
// for GEO
std
::
future
<
int32_t
>
BrpcPsClient
::
push_sparse_param
(
std
::
future
<
int32_t
>
BrpcPsClient
::
push_sparse_param
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
{
size_t
num
,
void
*
done
)
{
...
@@ -558,18 +560,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
...
@@ -558,18 +560,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
ids
.
resize
(
request_call_num
);
ids
.
resize
(
request_call_num
);
value_ptrs
.
resize
(
request_call_num
);
value_ptrs
.
resize
(
request_call_num
);
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
pserver_idx
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
])
;
size_t
pserver_idx
=
keys
[
i
]
%
request_call_num
;
ids
[
pserver_idx
].
push_back
(
keys
[
i
]);
ids
[
pserver_idx
].
push_back
(
keys
[
i
]);
value_ptrs
[
pserver_idx
].
push_back
(
update_values
[
i
]);
value_ptrs
[
pserver_idx
].
push_back
(
update_values
[
i
]);
}
}
...
@@ -1003,6 +995,120 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
...
@@ -1003,6 +995,120 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
return
fut
;
return
fut
;
}
}
// for GEO
std
::
future
<
int32_t
>
BrpcPsClient
::
pull_sparse_param
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse_param"
);
size_t
request_call_num
=
_server_channels
.
size
();
auto
shard_sorted_kvs
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
*>>>>
();
shard_sorted_kvs
->
resize
(
request_call_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
keys
[
i
]
%
request_call_num
;
shard_sorted_kvs
->
at
(
shard_id
).
push_back
({
keys
[
i
],
select_values
[
i
]});
}
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
select_size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
shard_sorted_kvs
->
size
();
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
request_kvs
=
shard_sorted_kvs
->
at
(
i
);
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
uint64_t
last_key
=
UINT64_MAX
;
float
*
last_value_data
=
NULL
;
// can remove sort&unique
for
(
size_t
kv_idx
=
0
;
kv_idx
<
request_kvs
.
size
();
++
kv_idx
)
{
auto
*
kv_pair
=
&
(
request_kvs
[
kv_idx
]);
if
(
kv_pair
->
first
==
last_key
)
{
memcpy
(
reinterpret_cast
<
void
*>
(
kv_pair
->
second
),
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
);
}
else
{
last_key
=
kv_pair
->
first
;
last_value_data
=
kv_pair
->
second
;
if
(
value_size
!=
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
))
{
LOG
(
WARNING
)
<<
"res data is lack or not in format"
;
ret
=
-
1
;
break
;
}
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kvs
=
shard_sorted_kvs
->
at
(
i
);
std
::
sort
(
sorted_kvs
.
begin
(),
sorted_kvs
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
uint64_t
last_key
=
UINT64_MAX
;
uint32_t
kv_request_count
=
0
;
size_t
sorted_kv_size
=
sorted_kvs
.
size
();
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
is_training
),
sizeof
(
bool
));
std
::
vector
<
uint32_t
>
keys_counter
;
keys_counter
.
reserve
(
sorted_kv_size
);
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
++
kv_request_count
;
uint32_t
keys
=
1
;
last_key
=
sorted_kvs
[
kv_idx
].
first
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
last_key
),
sizeof
(
uint64_t
));
while
(
kv_idx
<
sorted_kv_size
-
1
&&
last_key
==
sorted_kvs
[
kv_idx
+
1
].
first
)
{
++
kv_idx
;
++
keys
;
}
keys_counter
.
push_back
(
keys
);
}
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
keys_counter
.
data
()),
sizeof
(
uint32_t
)
*
keys_counter
.
size
());
if
(
kv_request_count
==
0
)
{
closure
->
Run
();
}
else
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_SPARSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
kv_request_count
,
// NOLINT
sizeof
(
uint32_t
));
PsService_Stub
rpc_stub
(
get_cmd_channel
(
i
));
closure
->
cntl
(
i
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
send_client2client_msg
(
std
::
future
<
int32_t
>
BrpcPsClient
::
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
...
@@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
...
@@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
std
::
string
var_name
=
""
;
std
::
string
var_name
=
""
;
int64_t
var_num
=
0
;
int64_t
var_num
=
0
;
int64_t
var_shape
=
0
;
int64_t
var_shape
=
0
;
std
::
string
table_class
;
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
size_t
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
if
(
worker_param
.
downpour_table_param
(
i
).
table_id
()
==
table_id
)
{
if
(
worker_param
.
downpour_table_param
(
i
).
table_id
()
==
table_id
)
{
var_name
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_name
();
var_name
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_name
();
var_num
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_num
();
var_num
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_num
();
var_shape
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_dim
();
var_shape
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_dim
();
table_class
=
worker_param
.
downpour_table_param
(
i
).
table_class
();
break
;
break
;
}
}
}
}
...
@@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
...
@@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec
.
push_back
(
save_huge_vec
.
data
()
+
i
*
var_shape
);
save_vec
.
push_back
(
save_huge_vec
.
data
()
+
i
*
var_shape
);
}
}
VLOG
(
2
)
<<
"recv_and_save_table: table_class: "
<<
table_class
;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// recv_and_save_table
if
(
table_class
==
"MemorySparseGeoTable"
)
{
auto
status
=
pull_sparse_param
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
status
.
wait
();
}
else
{
auto
status
=
pull_sparse
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
auto
status
=
pull_sparse
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
status
.
wait
();
status
.
wait
();
}
// create lod tensor
// create lod tensor
std
::
shared_ptr
<
framework
::
Scope
>
scope
;
std
::
shared_ptr
<
framework
::
Scope
>
scope
;
...
...
paddle/fluid/distributed/ps/service/brpc_ps_client.h
浏览文件 @
9b3b53ba
...
@@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient {
...
@@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient {
size_t
table_id
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
bool
is_training
);
virtual
std
::
future
<
int32_t
>
pull_sparse_param
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
print_table_stat
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
print_table_stat
(
uint32_t
table_id
);
...
...
paddle/fluid/distributed/ps/service/communicator/communicator.cc
浏览文件 @
9b3b53ba
...
@@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
...
@@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool
training
=
true
;
bool
training
=
true
;
auto
status
=
_worker_ptr
->
pull_sparse
(
auto
status
=
_worker_ptr
->
pull_sparse
_param
(
(
float
**
)
push_g_vec
.
data
(),
table_id
,
// NOLINT
(
float
**
)
push_g_vec
.
data
(),
table_id
,
// NOLINT
sparse_push_keys
.
data
(),
sparse_push_keys
.
size
(),
training
);
sparse_push_keys
.
data
(),
sparse_push_keys
.
size
(),
training
);
status
.
wait
();
status
.
wait
();
...
@@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
...
@@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto
&
sparse_ids_set
=
iter
.
second
;
auto
&
sparse_ids_set
=
iter
.
second
;
auto
sparse_ids_vec
=
std
::
make_shared
<
std
::
vector
<
int64_t
>>
();
auto
sparse_ids_vec
=
std
::
make_shared
<
std
::
vector
<
int64_t
>>
();
sparse_ids_vec
->
assign
(
sparse_ids_set
.
begin
(),
sparse_ids_set
.
end
());
sparse_ids_vec
->
assign
(
sparse_ids_set
.
begin
(),
sparse_ids_set
.
end
());
sparse_id_queues_
.
at
(
key
)
->
Pu
sh
(
sparse_ids_vec
);
sparse_id_queues_
.
at
(
key
)
->
Pu
t
(
sparse_ids_vec
);
VLOG
(
3
)
<<
"push "
<<
sparse_ids_vec
->
size
()
<<
" ids to "
<<
key
VLOG
(
3
)
<<
"push "
<<
sparse_ids_vec
->
size
()
<<
" ids to "
<<
key
<<
"'s queue"
;
<<
"'s queue"
;
}
}
...
@@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
...
@@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
auto
&
ctx
=
iter
.
second
;
auto
&
ctx
=
iter
.
second
;
if
(
!
ctx
.
is_sparse
)
continue
;
if
(
!
ctx
.
is_sparse
)
{
parallel_task_nums_
+=
1
;
continue
;
}
auto
&
varnames
=
ctx
.
origin_varnames
;
auto
&
varnames
=
ctx
.
origin_varnames
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
varnames
.
size
(),
1
,
varnames
.
size
(),
1
,
...
@@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
...
@@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for
(
auto
&
splited_var
:
ctx
.
splited_varnames
)
{
for
(
auto
&
splited_var
:
ctx
.
splited_varnames
)
{
parallel_task_nums_
+=
1
;
parallel_task_nums_
+=
1
;
sparse_id_queues_
.
insert
(
sparse_id_queues_
.
insert
(
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
pair
<
std
::
string
,
paddle
::
framework
::
Channel
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
>
(
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
(
splited_var
,
splited_var
,
std
::
make_shared
<
paddle
::
framework
::
MakeChannel
<
BlockingQueue
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
(
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>
(
send_queue_size_
)));
send_queue_size_
)));
}
}
}
}
...
@@ -1242,8 +1244,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
...
@@ -1242,8 +1244,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
VLOG
(
3
)
<<
"Merge Number of "
<<
send_varname
<<
" = "
<<
merge_num
;
VLOG
(
3
)
<<
"Merge Number of "
<<
send_varname
<<
" = "
<<
merge_num
;
if
(
sparse_id_queues_
.
at
(
send_varname
)
->
Size
()
>
0
)
{
if
(
sparse_id_queues_
.
at
(
send_varname
)
->
Size
()
>
0
)
{
wait_times
=
0
;
wait_times
=
0
;
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>
pop_ids
=
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>
pop_ids
=
nullptr
;
sparse_id_queues_
.
at
(
send_varname
)
->
Pop
(
);
sparse_id_queues_
.
at
(
send_varname
)
->
Get
(
pop_ids
);
for
(
size_t
j
=
0
;
j
<
pop_ids
->
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
pop_ids
->
size
();
j
++
)
{
sparse_ids
.
insert
(
pop_ids
->
at
(
j
));
sparse_ids
.
insert
(
pop_ids
->
at
(
j
));
}
}
...
@@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname,
...
@@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname,
std
::
vector
<
int64_t
>
&
sparse_ids
,
int
table_id
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
int
table_id
,
int
ep_idx
)
{
int
ep_idx
)
{
platform
::
RecordEvent
record_event
(
"GeoCommunicator->SendSparse"
);
platform
::
RecordEvent
record_event
(
"GeoCommunicator->SendSparse"
);
if
(
sparse_ids
.
size
()
==
0
)
{
return
;
}
std
::
string
param_name
=
SplitedGradToParam
(
varname
);
std
::
string
param_name
=
SplitedGradToParam
(
varname
);
VLOG
(
1
)
<<
"In GeoCommunicator::SendSparse("
<<
varname
<<
" "
<<
param_name
VLOG
(
1
)
<<
"In GeoCommunicator::SendSparse("
<<
varname
<<
" "
<<
param_name
<<
", ids.size = "
<<
sparse_ids
.
size
()
<<
", table_id: "
<<
table_id
<<
", ids.size = "
<<
sparse_ids
.
size
()
<<
", table_id: "
<<
table_id
...
@@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname,
...
@@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname,
t_value
+
j
*
dims1
,
t_value
+
j
*
dims1
,
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
);
t_old
->
data
<
float
>
()
+
sparse_ids
[
j
]
*
dims1
);
push_g_vec
.
push_back
(
t_value
+
j
*
dims1
);
push_g_vec
.
push_back
(
t_value
+
j
*
dims1
);
VLOG
(
5
)
<<
"DEBUG GeoCommunicator::SendSparse send sparse key "
<<
sparse_ids
[
j
]
<<
" value[0] "
<<
push_g_vec
[
j
][
0
]
<<
" value[-1] "
<<
push_g_vec
[
j
][
dims1
-
1
];
}
}
++
_async_call_num
;
++
_async_call_num
;
...
@@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
...
@@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
cpu_ctx
);
cpu_ctx
);
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
keys
.
size
());
++
j
)
{
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
keys
.
size
());
++
j
)
{
VLOG
(
5
)
<<
"DEBUG GeoCommunicator::RecvSparse recv sparse key"
<<
keys
[
j
]
<<
"value[0] "
<<
values
[
j
*
dims1
]
<<
" value[-1] "
<<
values
[
j
*
dims1
+
dims1
-
1
];
float
*
latest_data
=
t_latest
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
float
*
latest_data
=
t_latest
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
float
*
old_data
=
t_old
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
float
*
old_data
=
t_old
->
data
<
float
>
()
+
keys
[
j
]
*
dims1
;
// pserver - old => delta
// pserver - old => delta
...
...
paddle/fluid/distributed/ps/service/communicator/communicator.h
浏览文件 @
9b3b53ba
...
@@ -30,6 +30,7 @@ limitations under the License. */
...
@@ -30,6 +30,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
...
@@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator {
...
@@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver
// parameter on pserver
std
::
shared_ptr
<
Scope
>
pserver_scope_
;
std
::
shared_ptr
<
Scope
>
pserver_scope_
;
std
::
unordered_map
<
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
Channel
<
std
::
string
,
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
std
::
vector
<
int64_t
>>>>>
sparse_id_queues_
;
sparse_id_queues_
;
};
};
...
...
paddle/fluid/distributed/ps/service/ps_client.h
浏览文件 @
9b3b53ba
...
@@ -128,6 +128,17 @@ class PSClient {
...
@@ -128,6 +128,17 @@ class PSClient {
const
uint64_t
*
keys
,
size_t
num
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
=
0
;
bool
is_training
)
=
0
;
virtual
std
::
future
<
int32_t
>
pull_sparse_param
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
pull_sparse_ptr
(
char
**
select_values
,
virtual
::
std
::
future
<
int32_t
>
pull_sparse_ptr
(
char
**
select_values
,
size_t
table_id
,
size_t
table_id
,
const
uint64_t
*
keys
,
const
uint64_t
*
keys
,
...
...
paddle/fluid/distributed/ps/table/CMakeLists.txt
浏览文件 @
9b3b53ba
...
@@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo
...
@@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo
cc_library
(
ctr_accessor SRCS ctr_accessor.cc DEPS
${
TABLE_DEPS
}
ps_framework_proto sparse_sgd_rule
)
cc_library
(
ctr_accessor SRCS ctr_accessor.cc DEPS
${
TABLE_DEPS
}
ps_framework_proto sparse_sgd_rule
)
cc_library
(
memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto
${
TABLE_DEPS
}
fs afs_wrapper ctr_accessor common_table
)
cc_library
(
memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto
${
TABLE_DEPS
}
fs afs_wrapper ctr_accessor common_table
)
cc_library
(
table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost
)
set_source_files_properties
(
memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto
${
TABLE_DEPS
}
common_table
)
cc_library
(
table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost
)
target_link_libraries
(
table -fopenmp
)
target_link_libraries
(
table -fopenmp
)
paddle/fluid/distributed/ps/table/depends/geo_recorder.h
浏览文件 @
9b3b53ba
...
@@ -15,13 +15,9 @@
...
@@ -15,13 +15,9 @@
#pragma once
#pragma once
#include <ThreadPool.h>
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <future> // NOLINT
#include <memory>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <vector>
#include <vector>
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc
0 → 100644
浏览文件 @
9b3b53ba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
namespace
paddle
{
namespace
distributed
{
int32_t
MemorySparseGeoTable
::
push_sparse_param
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::push_sparse_param begin "
"push_sparse_param "
<<
num
;
auto
shard_num
=
_task_pool_size
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
offset_bucket
;
offset_bucket
.
resize
(
shard_num
);
for
(
int
x
=
0
;
x
<
num
;
++
x
)
{
auto
y
=
keys
[
x
]
%
shard_num
;
offset_bucket
[
y
].
push_back
(
x
);
if
(
x
<
10
)
{
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::push_sparse_param key: "
<<
keys
[
x
]
<<
" shard: "
<<
y
;
}
}
std
::
vector
<
std
::
future
<
int
>>
tasks
(
shard_num
);
for
(
int
shard_id
=
0
;
shard_id
<
shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
&
keys
,
&
offset_bucket
,
&
values
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
offsets
=
offset_bucket
[
shard_id
];
for
(
int
i
=
0
;
i
<
offsets
.
size
();
++
i
)
{
auto
offset
=
offsets
[
i
];
auto
id
=
keys
[
offset
];
auto
&
feature_value
=
local_shard
[
id
];
feature_value
.
resize
(
_dim
);
std
::
copy_n
(
values
+
_dim
*
offset
,
_dim
,
feature_value
.
data
());
if
(
i
<
10
)
{
VLOG
(
5
)
<<
"MemorySparseGeoTable::push_sparse_param "
"push_sparse_param key "
<<
id
<<
" value[0]: "
<<
(
values
+
_dim
*
offset
)[
0
]
<<
" data: "
<<
feature_value
.
data
()[
0
]
<<
" value[-1]: "
<<
(
values
+
_dim
*
offset
)[
_dim
-
1
]
<<
" data: "
<<
feature_value
.
data
()[
_dim
-
1
];
}
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseGeoTable
::
pull_geo_param
(
const
uint32_t
trainer_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
ids
)
{
_geo_recorder
->
GetAndClear
(
trainer_id
,
ids
);
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id "
<<
trainer_id
<<
" id_num: "
<<
ids
->
size
();
std
::
vector
<
uint32_t
>
frequencies
;
frequencies
.
resize
(
ids
->
size
(),
1
);
auto
pull_value
=
PullSparseValue
(
ids
->
size
(),
_dim
);
pull_value
.
is_training_
=
true
;
pull_value
.
feasigns_
=
ids
->
data
();
pull_value
.
frequencies_
=
frequencies
.
data
();
values
->
resize
(
ids
->
size
()
*
_dim
);
pull_sparse
(
values
->
data
(),
pull_value
);
return
0
;
}
int32_t
MemorySparseGeoTable
::
push_sparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::push_sparse keys[0]"
<<
keys
[
0
]
<<
" key_num: "
<<
num
;
std
::
vector
<
uint64_t
>
ids
;
ids
.
resize
(
num
);
std
::
copy_n
(
keys
,
num
,
ids
.
begin
());
_geo_recorder
->
Update
(
ids
);
_push_sparse
(
keys
,
values
,
num
);
return
0
;
}
int32_t
MemorySparseGeoTable
::
initialize
()
{
if
(
!
_geo_recorder
)
{
auto
trainers
=
_config
.
common
().
trainer_num
();
_geo_recorder
=
std
::
make_shared
<
GeoRecorder
>
(
trainers
);
}
_dim
=
_config
.
common
().
dims
()[
0
];
_shards_task_pool
.
resize
(
_task_pool_size
);
for
(
int
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
_local_shards
.
reset
(
new
shard_type
[
_task_pool_size
]);
return
0
;
}
int32_t
MemorySparseGeoTable
::
pull_sparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
auto
shard_num
=
_task_pool_size
;
std
::
vector
<
std
::
future
<
int
>>
tasks
(
shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
shard_num
);
size_t
num
=
pull_value
.
numel_
;
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
pull_value
.
feasigns_
[
i
]
%
shard_num
;
task_keys
[
shard_id
].
push_back
({
pull_value
.
feasigns_
[
i
],
i
});
}
for
(
int
shard_id
=
0
;
shard_id
<
shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
&
task_keys
,
pull_values
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
keys
=
task_keys
[
shard_id
];
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
i
++
)
{
uint64_t
key
=
keys
[
i
].
first
;
auto
offset
=
keys
[
i
].
second
;
float
*
select_data
=
pull_values
+
_dim
*
offset
;
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
// ++missed_keys;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
_dim
);
memset
(
feature_value
.
data
(),
0
,
sizeof
(
float
)
*
_dim
);
VLOG
(
0
)
<<
"MemorySparseGeoTable pull_sparse key not found!!! "
<<
key
;
itr
=
local_shard
.
find
(
key
);
}
memcpy
(
select_data
,
itr
.
value
().
data
(),
_dim
*
sizeof
(
float
));
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::pull_sparse key: "
<<
key
<<
" select_data[0] "
<<
select_data
[
0
]
<<
" value[0]: "
<<
itr
.
value
().
data
()[
0
];
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
int32_t
MemorySparseGeoTable
::
_push_sparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
{
auto
shard_num
=
_task_pool_size
;
std
::
vector
<
std
::
future
<
int
>>
tasks
(
shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
shard_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
shard_id
=
keys
[
i
]
%
shard_num
;
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
]
->
enqueue
(
[
this
,
shard_id
,
values
,
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
blas
=
GetBlas
<
float
>
();
for
(
int
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
+
push_data_idx
*
_dim
;
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
VLOG
(
0
)
<<
"sparse geo table push not found key!!! "
<<
key
;
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
_dim
);
memset
(
feature_value
.
data
(),
0
,
sizeof
(
float
)
*
_dim
);
itr
=
local_shard
.
find
(
key
);
}
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
feature_value
.
data
();
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::_push_sparse before key: "
<<
key
<<
" update_data[0] "
<<
update_data
[
0
]
<<
" value[0]: "
<<
value_data
[
0
];
blas
.
VADD
(
_dim
,
update_data
,
value_data
,
value_data
);
VLOG
(
5
)
<<
"DEBUG MemorySparseGeoTable::_push_sparse after key: "
<<
key
<<
" value[0]: "
<<
value_data
[
0
];
}
return
0
;
});
}
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h
0 → 100644
浏览文件 @
9b3b53ba
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
// #include <pthread.h>
#include <stdint.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/distributed/ps/table/depends/geo_recorder.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
class
GeoRecorder
;
class
MemorySparseGeoTable
:
public
SparseTable
{
public:
typedef
SparseTableShard
<
uint64_t
,
FixedFeatureValue
>
shard_type
;
MemorySparseGeoTable
()
{
_geo_recorder
=
nullptr
;
}
virtual
~
MemorySparseGeoTable
()
{}
virtual
int32_t
initialize
();
virtual
int32_t
initialize_shard
()
{
return
0
;
}
virtual
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
0
;
}
virtual
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
0
;
}
virtual
int32_t
flush
()
{
return
0
;
}
virtual
int32_t
shrink
(
const
std
::
string
&
param
)
{
return
0
;
}
virtual
void
clear
()
{
return
;
}
virtual
int32_t
pull_sparse
(
float
*
values
,
const
PullSparseValue
&
pull_value
);
int32_t
push_sparse_param
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t
pull_geo_param
(
const
uint32_t
trainer_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
keys
);
int32_t
push_sparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
)
override
;
int32_t
_push_sparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
size_t
num
);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value);
private:
std
::
shared_ptr
<
GeoRecorder
>
_geo_recorder
;
const
int
_task_pool_size
=
10
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
unique_ptr
<
shard_type
[]
>
_local_shards
;
int
_dim
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/table.cc
浏览文件 @
9b3b53ba
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/distributed/ps/table/common_dense_table.h"
#include "paddle/fluid/distributed/ps/table/common_dense_table.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/common_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/common_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_geo_table.h"
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
...
@@ -43,6 +44,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable);
...
@@ -43,6 +44,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS
(
Table
,
DenseTensorTable
);
REGISTER_PSCORE_CLASS
(
Table
,
DenseTensorTable
);
REGISTER_PSCORE_CLASS
(
Table
,
GlobalStepTable
);
REGISTER_PSCORE_CLASS
(
Table
,
GlobalStepTable
);
REGISTER_PSCORE_CLASS
(
Table
,
MemorySparseTable
);
REGISTER_PSCORE_CLASS
(
Table
,
MemorySparseTable
);
REGISTER_PSCORE_CLASS
(
Table
,
MemorySparseGeoTable
);
REGISTER_PSCORE_CLASS
(
ValueAccessor
,
CommMergeAccessor
);
REGISTER_PSCORE_CLASS
(
ValueAccessor
,
CommMergeAccessor
);
REGISTER_PSCORE_CLASS
(
ValueAccessor
,
CtrCommonAccessor
);
REGISTER_PSCORE_CLASS
(
ValueAccessor
,
CtrCommonAccessor
);
REGISTER_PSCORE_CLASS
(
SparseValueSGDRule
,
StdAdaGradSGDRule
);
REGISTER_PSCORE_CLASS
(
SparseValueSGDRule
,
StdAdaGradSGDRule
);
...
...
paddle/fluid/distributed/test/CMakeLists.txt
浏览文件 @
9b3b53ba
...
@@ -35,3 +35,6 @@ cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost ta
...
@@ -35,3 +35,6 @@ cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost ta
set_source_files_properties
(
memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
cc_test
(
memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
set_source_files_properties
(
memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
memory_sparse_geo_table_test SRCS memory_geo_table_test.cc DEPS
${
COMMON_DEPS
}
boost table
)
paddle/fluid/distributed/test/memory_geo_table_test.cc
0 → 100644
浏览文件 @
9b3b53ba
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
// MemorySparseGeoTable
TEST
(
MemorySparseGeoTable
,
SSUM
)
{
int
emb_dim
=
10
;
int
trainers
=
2
;
TableParameter
table_config
;
table_config
.
set_table_class
(
"MemorySparseGeoTable"
);
FsClientParameter
fs_config
;
Table
*
table
=
new
MemorySparseGeoTable
();
TableAccessorParameter
*
accessor_config
=
table_config
.
mutable_accessor
();
accessor_config
->
set_accessor_class
(
"CommMergeAccessor"
);
accessor_config
->
set_fea_dim
(
10
);
CommonAccessorParameter
*
common_config
=
table_config
.
mutable_common
();
common_config
->
set_name
(
"sum"
);
common_config
->
set_table_name
(
"ssum_test_table"
);
common_config
->
set_trainer_num
(
trainers
);
common_config
->
add_params
(
"Param"
);
common_config
->
add_dims
(
emb_dim
);
common_config
->
add_initializers
(
"fill_constant&1.0"
);
auto
ret
=
table
->
initialize
(
table_config
,
fs_config
);
ASSERT_EQ
(
ret
,
0
);
// test push_sparse_param, and create params
std
::
vector
<
uint64_t
>
init_keys
=
{
0
,
1
,
2
,
3
,
4
};
std
::
vector
<
uint32_t
>
init_fres
=
{
1
,
1
,
1
,
1
,
1
};
std
::
vector
<
float
>
init_values
;
for
(
size_t
i
=
0
;
i
<
init_keys
.
size
()
*
emb_dim
;
i
++
)
{
init_values
.
push_back
(
0.0
);
}
table
->
push_sparse_param
(
init_keys
.
data
(),
init_values
.
data
(),
init_keys
.
size
());
std
::
vector
<
float
>
pull_values
(
init_values
.
size
());
auto
value
=
PullSparseValue
(
init_keys
,
init_fres
,
emb_dim
);
table
->
pull_sparse
(
pull_values
.
data
(),
value
);
for
(
size_t
i
=
0
;
i
<
init_keys
.
size
()
*
emb_dim
;
i
++
)
{
ASSERT_TRUE
(
abs
(
pull_values
[
i
]
-
init_values
[
i
])
<
1e-5
);
}
std
::
vector
<
std
::
vector
<
uint64_t
>>
trainer_keys
;
std
::
vector
<
std
::
vector
<
float
>>
trainer_values
;
trainer_keys
.
resize
(
trainers
);
trainer_values
.
resize
(
trainers
);
float
start
=
0.0
;
for
(
int
i
=
0
;
i
<
trainers
;
i
++
)
{
trainer_keys
[
i
]
=
init_keys
;
for
(
size_t
j
=
0
;
j
<
trainer_keys
[
i
].
size
();
j
++
)
{
auto
id
=
trainer_keys
[
i
][
j
];
for
(
int
k
=
0
;
k
<
emb_dim
;
k
++
)
{
trainer_values
[
i
].
push_back
(
start
);
pull_values
[
id
*
emb_dim
+
k
]
+=
start
;
start
+=
0.1
;
}
}
}
std
::
shared_ptr
<::
ThreadPool
>
pool_
=
std
::
make_shared
<::
ThreadPool
>
(
trainers
);
std
::
vector
<
std
::
future
<
void
>>
task_status
;
for
(
int
i
=
0
;
i
<
trainers
;
i
++
)
{
auto
&
push_keys
=
trainer_keys
[
i
];
auto
&
push_values
=
trainer_values
[
i
];
auto
task
=
[
table
,
&
push_keys
,
&
push_values
]
{
table
->
push_sparse
(
push_keys
.
data
(),
push_values
.
data
(),
push_keys
.
size
());
};
task_status
.
push_back
(
pool_
->
enqueue
(
std
::
move
(
task
)));
}
for
(
auto
&
status
:
task_status
)
{
status
.
wait
();
}
std
::
vector
<
std
::
vector
<
uint64_t
>>
geo_pull_ids
;
std
::
vector
<
std
::
vector
<
float
>>
geo_pull_values
;
geo_pull_ids
.
resize
(
trainers
);
geo_pull_values
.
resize
(
trainers
);
for
(
int
i
=
0
;
i
<
trainers
;
i
++
)
{
table
->
pull_geo_param
(
i
,
&
geo_pull_values
[
i
],
&
geo_pull_ids
[
i
]);
ASSERT_EQ
(
geo_pull_values
[
i
].
size
(),
geo_pull_ids
[
i
].
size
()
*
emb_dim
);
for
(
size_t
j
=
0
;
j
<
geo_pull_ids
[
i
].
size
();
++
j
)
{
auto
id
=
geo_pull_ids
[
i
][
j
];
for
(
int
k
=
0
;
k
<
emb_dim
;
k
++
)
{
ASSERT_TRUE
(
abs
(
geo_pull_values
[
i
][
j
*
emb_dim
+
k
]
-
pull_values
[
id
*
emb_dim
+
k
])
<
1e-5
);
}
}
}
}
}
// namespace distributed
}
// namespace paddle
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
9b3b53ba
...
@@ -943,7 +943,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -943,7 +943,7 @@ class TheOnePSRuntime(RuntimeBase):
ctx
.
origin_varnames
()[
0
]]
ctx
.
origin_varnames
()[
0
]]
if
self
.
compiled_strategy
.
is_geo_mode
():
if
self
.
compiled_strategy
.
is_geo_mode
():
table
.
table_class
=
"SparseGeoTable"
table
.
table_class
=
"
Memory
SparseGeoTable"
else
:
else
:
all_table_proto
=
self
.
context
[
all_table_proto
=
self
.
context
[
"user_defined_strategy"
].
sparse_table_configs
"user_defined_strategy"
].
sparse_table_configs
...
@@ -1306,6 +1306,7 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -1306,6 +1306,7 @@ class TheOnePSRuntime(RuntimeBase):
is_dense
=
True
,
is_dense
=
True
,
split_dense_table
=
self
.
role_maker
.
_is_heter_parameter_server_mode
,
split_dense_table
=
self
.
role_maker
.
_is_heter_parameter_server_mode
,
use_origin_program
=
True
)
use_origin_program
=
True
)
# TODO(zhaocaibei123): for GEO: should call GeoCommunicator::RecvDense
self
.
_communicator
.
pull_dense
(
denses
)
self
.
_communicator
.
pull_dense
(
denses
)
generate_vars
=
self
.
context
[
generate_vars
=
self
.
context
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录