Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2089b485
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2089b485
编写于
3月 30, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
3月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change to new api in ssync mode (#41022)
* change to new api in ssync mode * fix * fix * fix * fix
上级
60c4c9cd
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
296 addition
and
88 deletion
+296
-88
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+45
-37
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
+30
-7
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+3
-3
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+5
-3
paddle/fluid/distributed/ps/table/accessor.h
paddle/fluid/distributed/ps/table/accessor.h
+13
-1
paddle/fluid/distributed/ps/table/common_dense_table.cc
paddle/fluid/distributed/ps/table/common_dense_table.cc
+2
-2
paddle/fluid/distributed/ps/table/common_sparse_table.cc
paddle/fluid/distributed/ps/table/common_sparse_table.cc
+1
-1
paddle/fluid/distributed/ps/table/ctr_accessor.cc
paddle/fluid/distributed/ps/table/ctr_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/ctr_accessor.h
paddle/fluid/distributed/ps/table/ctr_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
+1
-0
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
+17
-13
paddle/fluid/distributed/ps/table/sparse_accessor.cc
paddle/fluid/distributed/ps/table/sparse_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/sparse_accessor.h
paddle/fluid/distributed/ps/table/sparse_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/table.cc
paddle/fluid/distributed/ps/table/table.cc
+1
-0
paddle/fluid/distributed/ps/table/table.h
paddle/fluid/distributed/ps/table/table.h
+3
-2
paddle/fluid/distributed/ps/table/tensor_accessor.cc
paddle/fluid/distributed/ps/table/tensor_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/tensor_accessor.h
paddle/fluid/distributed/ps/table/tensor_accessor.h
+2
-1
paddle/fluid/distributed/ps/wrapper/fleet.cc
paddle/fluid/distributed/ps/wrapper/fleet.cc
+45
-9
未找到文件。
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
浏览文件 @
2089b485
...
@@ -532,18 +532,17 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
...
@@ -532,18 +532,17 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
Region
*
dense_region
=
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
return
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
float
**
select_values
=
reinterpret_cast
<
float
**>
(
pull_context
.
sparse_values
);
size_t
table_id
=
pull_context
.
table
;
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
size_t
num
=
pull_context
.
num
;
bool
is_training
=
pull_context
.
is_training
;
bool
is_training
=
pull_context
.
is_training
;
if
(
pull_context
.
training_mode
==
Geo
)
{
// for geo
if
(
pull_context
.
training_mode
==
Geo
)
{
// for geo
pull_sparse_param
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
return
pull_sparse_param
(
pull_context
.
sparse_values
,
table_id
,
pull_context
.
keys
,
num
,
is_training
);
}
else
if
(
pull_context
.
training_mode
==
Async
)
{
// for async
}
else
if
(
pull_context
.
training_mode
==
Async
)
{
// for async
pull_sparse
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
return
pull_sparse
(
pull_context
.
sparse_values
,
table_id
,
pull_context
.
keys
,
num
,
is_training
);
}
}
}
}
}
}
...
@@ -551,7 +550,7 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
...
@@ -551,7 +550,7 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
std
::
future
<
int32_t
>
BrpcPsClient
::
Push
(
RequestContext
&
push_context
)
{
std
::
future
<
int32_t
>
BrpcPsClient
::
Push
(
RequestContext
&
push_context
)
{
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
const
Region
*
dense_region
=
push_context
.
push_context
.
push_dense_values
;
const
Region
*
dense_region
=
push_context
.
push_context
.
push_dense_values
;
push_dense
(
dense_region
,
push_context
.
num
,
push_context
.
table
);
return
push_dense
(
dense_region
,
push_context
.
num
,
push_context
.
table
);
}
else
{
// push sparse
}
else
{
// push sparse
size_t
table_id
=
push_context
.
table
;
size_t
table_id
=
push_context
.
table
;
size_t
num
=
push_context
.
num
;
size_t
num
=
push_context
.
num
;
...
@@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
...
@@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
}
else
if
(
push_context
.
training_mode
==
Async
)
{
// for async
}
else
if
(
push_context
.
training_mode
==
Async
)
{
// for async
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
return
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
}
}
}
}
}
}
...
@@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
...
@@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
&
shard_nums
),
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
&
shard_nums
),
sizeof
(
uint32_t
));
sizeof
(
uint32_t
));
keys
->
resize
(
shard_nums
);
keys
->
resize
(
shard_nums
);
values
->
resize
(
shard_nums
*
accessor
->
update_dim
(
));
values
->
resize
(
shard_nums
*
accessor
->
GetTableInfo
(
UPDATE_DIM
));
io_buffer_itr
.
copy_and_forward
((
void
*
)(
keys
->
data
()),
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
keys
->
data
()),
// NOLINT
sizeof
(
uint64_t
)
*
shard_nums
);
sizeof
(
uint64_t
)
*
shard_nums
);
io_buffer_itr
.
copy_and_forward
((
void
*
)(
values
->
data
()),
// NOLINT
io_buffer_itr
.
copy_and_forward
(
shard_nums
*
accessor
->
update_size
());
(
void
*
)(
values
->
data
()),
// NOLINT
shard_nums
*
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
closure
->
set_promise_value
(
ret
);
closure
->
set_promise_value
(
ret
);
});
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
...
@@ -630,7 +630,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
...
@@ -630,7 +630,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
auto
kvs
=
ids
[
shard_idx
];
auto
kvs
=
ids
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
update_size
(
);
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
// 发送RPC请求
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_PARAM
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_PARAM
);
...
@@ -638,13 +638,14 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
...
@@ -638,13 +638,14 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
push_request
->
set_client_id
(
_client_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
update_size
()));
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
GetTableInfo
(
UPDATE_SIZE
)));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
int
i
=
0
;
i
<
kv_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
update_size
(
));
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
push_data_ptr
+=
accessor
->
update_size
(
);
push_data_ptr
+=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
}
}
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
...
@@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
...
@@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t
table_id
)
{
size_t
table_id
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_dense"
);
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_dense"
);
auto
*
accessor
=
table_accessor
(
table_id
);
auto
*
accessor
=
table_accessor
(
table_id
);
auto
fea_dim
=
accessor
->
GetTableInfo
(
FEA_DIM
);
auto
select_size
=
accessor
->
GetTableInfo
(
SELECT_SIZE
);
size_t
request_call_num
=
_server_channels
.
size
();
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
// callback 将各shard结果,顺序填入region
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
num_per_shard
,
regions
,
region_num
,
request_call_num
,
[
request_call_num
,
num_per_shard
,
regions
,
region_num
,
...
@@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
...
@@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t
region_idx
=
0
;
// 当前填充的region偏移
size_t
region_idx
=
0
;
// 当前填充的region偏移
size_t
region_data_idx
=
0
;
// 当前填充的region内data偏移
size_t
region_data_idx
=
0
;
// 当前填充的region内data偏移
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
select_size
();
size_t
shard_data_size
=
num_per_shard
*
accessor
->
GetTableInfo
(
SELECT_SIZE
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_DENSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
ret
=
-
1
;
...
@@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
...
@@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std
::
vector
<
std
::
vector
<
Region
>>
regions_partition
(
request_call_num
);
std
::
vector
<
std
::
vector
<
Region
>>
regions_partition
(
request_call_num
);
uint32_t
num_per_shard
=
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
update_size
(
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
size_t
current_region_idx
=
0
;
size_t
current_region_idx
=
0
;
size_t
current_region_data_idx
=
0
;
size_t
current_region_data_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
...
@@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
...
@@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
auto
value_ptr
=
value_ptrs
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
update_size
(
);
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
// 发送RPC请求
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
auto
*
push_request
=
closure
->
request
(
shard_idx
);
...
@@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
...
@@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
push_request
->
set_client_id
(
_client_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
update_size
()));
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
GetTableInfo
(
UPDATE_SIZE
)));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
int
i
=
0
;
i
<
kv_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
update_size
(
));
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
push_data_ptr
+=
accessor
->
update_size
(
);
push_data_ptr
+=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
}
}
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
...
@@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
...
@@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
std
::
future
<
int
>
fut
=
promise
->
get_future
();
std
::
future
<
int
>
fut
=
promise
->
get_future
();
auto
*
accessor
=
table_accessor
(
table_id
);
auto
*
accessor
=
table_accessor
(
table_id
);
uint32_t
num_per_shard
=
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
...
@@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
...
@@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
}
}
auto
*
accessor
=
table_accessor
(
table_id
);
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
select_size
();
size_t
value_size
=
accessor
->
GetTableInfo
(
SELECT_SIZE
);
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
...
@@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
...
@@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
}
}
auto
*
accessor
=
table_accessor
(
table_id
);
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
select_size
(
);
size_t
value_size
=
accessor
->
GetTableInfo
(
SELECT_SIZE
);
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
...
@@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
...
@@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
{
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
{
auto
*
accessor
=
table_accessor
(
table_id
);
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
update_size
(
);
size_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
closure
->
add_promise
(
promise
);
...
@@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
...
@@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
shard_kv_data
.
kv_num
=
0
;
shard_kv_data
.
kv_num
=
0
;
continue
;
continue
;
}
}
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
uint32_t
value_size
=
accessor
->
update_size
();
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
shard_kv_data
.
key_list
[
kv_idx
]
=
sorted_kv_list
[
kv_idx
].
first
;
shard_kv_data
.
key_list
[
kv_idx
]
=
sorted_kv_list
[
kv_idx
].
first
;
shard_kv_data
.
value_list
[
kv_idx
].
assign
(
shard_kv_data
.
value_list
[
kv_idx
].
assign
(
...
@@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() {
...
@@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() {
void
sparse_local_merge
(
ValueAccessor
*
accessor
,
float
*
merge_data
,
void
sparse_local_merge
(
ValueAccessor
*
accessor
,
float
*
merge_data
,
const
float
*
another_data
)
{
const
float
*
another_data
)
{
size_t
col_num
=
accessor
->
update_size
(
)
/
sizeof
(
float
);
size_t
col_num
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
)
/
sizeof
(
float
);
float
*
merge_data_shell
[
col_num
];
float
*
merge_data_shell
[
col_num
];
const
float
*
another_data_shell
[
col_num
];
const
float
*
another_data_shell
[
col_num
];
for
(
int
i
=
0
;
i
<
col_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
col_num
;
++
i
)
{
...
@@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge(
...
@@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge(
ValueAccessor
*
accessor
)
{
ValueAccessor
*
accessor
)
{
size_t
merged_kv_count
=
0
;
size_t
merged_kv_count
=
0
;
uint64_t
min_key
=
UINT64_MAX
;
uint64_t
min_key
=
UINT64_MAX
;
uint32_t
value_size
=
accessor
->
update_size
(
);
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
thread_local
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>
sorted_kv_list
;
thread_local
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>
sorted_kv_list
;
sorted_kv_list
.
clear
();
sorted_kv_list
.
clear
();
...
@@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push(
...
@@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push(
push_request
->
add_params
(
reinterpret_cast
<
char
*>
(
&
merged_kv_count
),
push_request
->
add_params
(
reinterpret_cast
<
char
*>
(
&
merged_kv_count
),
sizeof
(
uint32_t
));
// NOLINT
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
auto
*
push_data
=
push_request
->
mutable_data
();
int
update_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
push_data
->
resize
(
merged_kv_count
*
push_data
->
resize
(
merged_kv_count
*
(
sizeof
(
uint64_t
)
+
accessor
->
update_size
(
)));
(
sizeof
(
uint64_t
)
+
accessor
->
GetTableInfo
(
UPDATE_SIZE
)));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
merged_key_list
.
data
(),
memcpy
(
push_data_ptr
,
merged_key_list
.
data
(),
merged_kv_count
*
sizeof
(
uint64_t
));
merged_kv_count
*
sizeof
(
uint64_t
));
...
@@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push(
...
@@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push(
const
char
*
task_data_ptr
=
merged_value_list
[
i
].
data
();
const
char
*
task_data_ptr
=
merged_value_list
[
i
].
data
();
memcpy
(
push_data_ptr
,
(
float
*
)(
task_data_ptr
),
// NOLINT
memcpy
(
push_data_ptr
,
(
float
*
)(
task_data_ptr
),
// NOLINT
accessor
->
update_size
(
));
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
push_data_ptr
+=
accessor
->
update_size
(
);
push_data_ptr
+=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
}
}
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
...
@@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
...
@@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t
region_num
,
size_t
region_num
,
size_t
table_id
)
{
size_t
table_id
)
{
auto
*
accessor
=
table_accessor
(
table_id
);
auto
*
accessor
=
table_accessor
(
table_id
);
int
fea_dim
=
accessor
->
GetTableInfo
(
FEA_DIM
);
int
update_dim
=
accessor
->
GetTableInfo
(
UPDATE_DIM
);
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense"
);
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense"
);
auto
parse_timer
=
auto
parse_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_parse"
);
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_parse"
);
...
@@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
...
@@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t
request_call_num
=
_server_channels
.
size
();
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
// 将region数据拷贝到转置矩阵中
// 将region数据拷贝到转置矩阵中
async_task
->
data
()
->
resize
(
num_per_shard
*
request_call_num
*
async_task
->
data
()
->
resize
(
num_per_shard
*
request_call_num
*
accessor
->
update_dim
(
));
accessor
->
GetTableInfo
(
UPDATE_DIM
));
float
*
data
=
async_task
->
data
()
->
data
();
float
*
data
=
async_task
->
data
()
->
data
();
size_t
data_size
=
async_task
->
data
()
->
size
();
size_t
data_size
=
async_task
->
data
()
->
size
();
uint32_t
pos
=
0
;
uint32_t
pos
=
0
;
...
@@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient(
...
@@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient(
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_rpc"
);
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_rpc"
);
closure
->
add_timer
(
timer
);
closure
->
add_timer
(
timer
);
uint32_t
num_per_shard
=
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
auto
send_timer
=
auto
send_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_send"
);
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_send"
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
...
...
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
浏览文件 @
2089b485
...
@@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
...
@@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
table
->
value_accesor
()
->
select_size
()
/
sizeof
(
float
));
res_data
->
resize
(
num
*
table
->
value_accesor
()
->
select_size
()
/
sizeof
(
float
));
table
->
pull_dense
(
res_data
->
data
(),
num
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table_context
.
num
=
num
;
table
->
Pull
(
table_context
);
// table->pull_dense(res_data->data(), num);
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
res_data
->
size
()
*
sizeof
(
float
));
...
@@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
...
@@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
|--4B---|----------------|
|--4B---|----------------|
*/
*/
uint32_t
num
=
*
(
const
uint32_t
*
)(
request
.
data
().
data
());
uint32_t
num
=
*
(
const
uint32_t
*
)(
request
.
data
().
data
());
const
float
*
values
=
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
(
const
float
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
(
const
float
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
if
(
table
->
push_dense
(
values
,
num
)
!=
0
)
{
table_context
.
num
=
num
;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->push_dense(values, num) != 0) {
set_response_code
(
response
,
-
1
,
"push_dense failed"
);
set_response_code
(
response
,
-
1
,
"push_dense failed"
);
}
}
...
@@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table,
...
@@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table,
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
dim
);
res_data
->
resize
(
num
*
dim
);
table
->
pull_sparse
(
res_data
->
data
(),
value
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
pull_value
=
value
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table
->
Pull
(
table_context
);
// table->pull_sparse(res_data->data(), value);
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
res_data
->
size
()
*
sizeof
(
float
));
...
@@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table,
...
@@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table,
|---keysData---|---valuesData---|
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
|---8*{num}B---|----------------|
*/
*/
const
uint64_t
*
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
TableContext
table_context
;
const
float
*
values
=
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
table_context
.
push_context
.
values
=
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
if
(
table
->
push_sparse
(
keys
,
values
,
num
)
!=
0
)
{
table_context
.
num
=
num
;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->push_sparse(keys, values, num) != 0) {
set_response_code
(
response
,
-
1
,
"push_sparse error"
);
set_response_code
(
response
,
-
1
,
"push_sparse error"
);
}
}
return
0
;
return
0
;
...
...
paddle/fluid/distributed/ps/service/ps_client.h
浏览文件 @
2089b485
...
@@ -86,9 +86,9 @@ struct RequestContext {
...
@@ -86,9 +86,9 @@ struct RequestContext {
TrainingMode
training_mode
;
// 1 for async, 2 for geo, 3 for sync
TrainingMode
training_mode
;
// 1 for async, 2 for geo, 3 for sync
TrainingPhase
training_phase
;
// 1 for init, 2 for train
TrainingPhase
training_phase
;
// 1 for init, 2 for train
ValueType
value_type
;
// 1 for sparse, 2 for dense
ValueType
value_type
;
// 1 for sparse, 2 for dense
void
*
keys
;
uint64_t
*
keys
;
void
**
sparse_values
;
// for sparse values
float
**
sparse_values
;
// for sparse values
Region
*
dense_values
;
// for dense values
Region
*
dense_values
;
// for dense values
PushContext
push_context
;
PushContext
push_context
;
size_t
num
;
size_t
num
;
bool
is_training
;
bool
is_training
;
...
...
paddle/fluid/distributed/ps/service/ps_local_client.cc
浏览文件 @
2089b485
...
@@ -126,11 +126,13 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
...
@@ -126,11 +126,13 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char
**
select_values
=
reinterpret_cast
<
char
**>
(
pull_context
.
sparse_values
);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t
table_id
=
pull_context
.
table
;
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
size_t
num
=
pull_context
.
num
;
pull_sparse_ptr
(
select_values
,
table_id
,
keys
,
num
);
pull_sparse_ptr
(
reinterpret_cast
<
char
**>
(
pull_context
.
sparse_values
),
table_id
,
pull_context
.
keys
,
num
);
}
}
}
}
...
...
paddle/fluid/distributed/ps/table/accessor.h
浏览文件 @
2089b485
...
@@ -56,6 +56,17 @@ struct AccessorInfo {
...
@@ -56,6 +56,17 @@ struct AccessorInfo {
size_t
fea_dim
;
size_t
fea_dim
;
};
};
enum
InfoKey
{
DIM
=
0
,
SIZE
=
1
,
SELECT_SIZE
=
2
,
SELECT_DIM
=
3
,
UPDATE_SIZE
=
4
,
UPDATE_DIM
=
5
,
MF_SIZE
=
6
,
FEA_DIM
=
7
};
class
ValueAccessor
{
class
ValueAccessor
{
public:
public:
ValueAccessor
()
{}
ValueAccessor
()
{}
...
@@ -79,7 +90,8 @@ class ValueAccessor {
...
@@ -79,7 +90,8 @@ class ValueAccessor {
}
}
virtual
int
initialize
()
=
0
;
virtual
int
initialize
()
=
0
;
virtual
void
GetTableInfo
(
AccessorInfo
&
info
)
=
0
;
virtual
void
SetTableInfo
(
AccessorInfo
&
info
)
=
0
;
virtual
size_t
GetTableInfo
(
InfoKey
key
)
=
0
;
// value维度
// value维度
virtual
size_t
dim
()
=
0
;
virtual
size_t
dim
()
=
0
;
...
...
paddle/fluid/distributed/ps/table/common_dense_table.cc
浏览文件 @
2089b485
...
@@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
...
@@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
int32_t
CommonDenseTable
::
Push
(
TableContext
&
context
)
{
int32_t
CommonDenseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
CHECK
(
context
.
value_type
==
Dense
);
if
(
context
.
pu
ll
_context
.
values
!=
nullptr
)
{
if
(
context
.
pu
sh
_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
const
float
*
values
=
context
.
push_context
.
values
;
return
push_dense
(
values
,
context
.
num
);
return
push_dense
(
values
,
context
.
num
);
}
}
...
@@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path,
...
@@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path,
}
}
size_t
dim_num_per_file
=
_config
.
accessor
().
fea_dim
()
/
file_list
.
size
()
+
1
;
size_t
dim_num_per_file
=
_config
.
accessor
().
fea_dim
()
/
file_list
.
size
()
+
1
;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t
dim_num_per_shard
=
_
value_accesor
->
fea_dim
()
/
_shard_num
+
1
;
size_t
dim_num_per_shard
=
_
table_info
.
fea_dim
/
_shard_num
+
1
;
size_t
start_dim_idx
=
dim_num_per_shard
*
_shard_idx
;
size_t
start_dim_idx
=
dim_num_per_shard
*
_shard_idx
;
size_t
start_file_idx
=
start_dim_idx
/
dim_num_per_file
;
size_t
start_file_idx
=
start_dim_idx
/
dim_num_per_file
;
size_t
end_file_idx
=
(
start_dim_idx
+
param_dim_
)
/
dim_num_per_file
;
size_t
end_file_idx
=
(
start_dim_idx
+
param_dim_
)
/
dim_num_per_file
;
...
...
paddle/fluid/distributed/ps/table/common_sparse_table.cc
浏览文件 @
2089b485
...
@@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
...
@@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
int32_t
CommonSparseTable
::
Push
(
TableContext
&
context
)
{
int32_t
CommonSparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
pu
ll
_context
.
values
!=
nullptr
)
{
if
(
context
.
pu
sh
_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
const
float
*
values
=
context
.
push_context
.
values
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
values
,
context
.
num
);
return
push_sparse
(
keys
,
values
,
context
.
num
);
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.cc
浏览文件 @
2089b485
...
@@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() {
...
@@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() {
return
0
;
return
0
;
}
}
void
CtrCommonAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
CtrCommonAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
info
.
fea_dim
=
fea_dim
();
}
}
size_t
CtrCommonAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
CtrCommonAccessor
::
dim
()
{
return
common_feature_value
.
dim
();
}
size_t
CtrCommonAccessor
::
dim
()
{
return
common_feature_value
.
dim
();
}
size_t
CtrCommonAccessor
::
dim_size
(
size_t
dim
)
{
size_t
CtrCommonAccessor
::
dim_size
(
size_t
dim
)
{
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.h
浏览文件 @
2089b485
...
@@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor {
...
@@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor {
virtual
int
initialize
();
virtual
int
initialize
();
virtual
~
CtrCommonAccessor
()
{}
virtual
~
CtrCommonAccessor
()
{}
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
浏览文件 @
2089b485
...
@@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() {
...
@@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() {
return
0
;
return
0
;
}
}
void
DownpourCtrDoubleAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
DownpourCtrDoubleAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
info
.
fea_dim
=
fea_dim
();
}
}
size_t
DownpourCtrDoubleAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
DownpourCtrDoubleAccessor
::
dim
()
{
size_t
DownpourCtrDoubleAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrDoubleFeatureValue
::
dim
(
embedx_dim
);
return
DownpourCtrDoubleFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
浏览文件 @
2089b485
...
@@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
...
@@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor
()
{}
DownpourCtrDoubleAccessor
()
{}
virtual
~
DownpourCtrDoubleAccessor
()
{}
virtual
~
DownpourCtrDoubleAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
浏览文件 @
2089b485
...
@@ -24,6 +24,7 @@ namespace paddle {
...
@@ -24,6 +24,7 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
struct
PullSparseValue
{
struct
PullSparseValue
{
PullSparseValue
()
{}
explicit
PullSparseValue
(
int
numel
,
int
dim
)
explicit
PullSparseValue
(
int
numel
,
int
dim
)
:
numel_
(
numel
),
:
numel_
(
numel
),
dim_
(
dim
),
dim_
(
dim
),
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
浏览文件 @
2089b485
...
@@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() {
...
@@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() {
return
0
;
return
0
;
}
}
void
DownpourCtrAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
DownpourCtrAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
info
.
fea_dim
=
fea_dim
();
}
}
size_t
DownpourCtrAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
DownpourCtrAccessor
::
dim
()
{
size_t
DownpourCtrAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrFeatureValue
::
dim
(
embedx_dim
);
return
DownpourCtrFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
浏览文件 @
2089b485
...
@@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor {
...
@@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual
~
DownpourCtrAccessor
()
{}
virtual
~
DownpourCtrAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
浏览文件 @
2089b485
...
@@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path,
...
@@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path,
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
feature_value_size
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
feature_value_size
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
omp_set_num_threads
(
thread_num
);
omp_set_num_threads
(
thread_num
);
...
@@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
...
@@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
feature_value_size
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
feature_value_size
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
omp_set_num_threads
(
thread_num
);
omp_set_num_threads
(
thread_num
);
...
@@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) {
...
@@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) {
CHECK
(
context
.
value_type
==
Sparse
);
CHECK
(
context
.
value_type
==
Sparse
);
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
context
.
push_context
.
ptr_
values
,
context
.
num
);
return
push_sparse
(
keys
,
context
.
push_context
.
values
,
context
.
num
);
}
}
int32_t
MemorySparseTable
::
pull_sparse
(
float
*
pull_values
,
int32_t
MemorySparseTable
::
pull_sparse
(
float
*
pull_values
,
...
@@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
...
@@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
CostTimer
timer
(
"pserver_sparse_select_all"
);
CostTimer
timer
(
"pserver_sparse_select_all"
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
const
size_t
value_size
=
_value_accesor
->
size
()
/
sizeof
(
float
);
const
size_t
value_size
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
mf_size
()
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetTableInfo
(
MF_SIZE
)
/
sizeof
(
float
);
size_t
select_value_size
=
_value_accesor
->
select_size
()
/
sizeof
(
float
);
size_t
select_value_size
=
_value_accesor
->
GetTableInfo
(
SELECT_SIZE
)
/
sizeof
(
float
);
// std::atomic<uint32_t> missed_keys{0};
// std::atomic<uint32_t> missed_keys{0};
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
...
@@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
...
@@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
tasks
[
shard_id
].
wait
();
}
}
return
0
;
return
0
;
}
}
...
@@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
...
@@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
}
const
size_t
value_col
=
_value_accesor
->
size
()
/
sizeof
(
float
);
const
size_t
value_col
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
mf_size
()
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetTableInfo
(
MF_SIZE
)
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
update_size
()
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetTableInfo
(
UPDATE_SIZE
)
/
sizeof
(
float
);
for
(
size_t
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
for
(
size_t
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
...
@@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
...
@@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
}
size_t
value_col
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
value_col
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
mf_size
()
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetTableInfo
(
MF_SIZE
)
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
update_size
()
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetTableInfo
(
UPDATE_SIZE
)
/
sizeof
(
float
);
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
...
...
paddle/fluid/distributed/ps/table/sparse_accessor.cc
浏览文件 @
2089b485
...
@@ -38,16 +38,39 @@ int SparseAccessor::initialize() {
...
@@ -38,16 +38,39 @@ int SparseAccessor::initialize() {
return
0
;
return
0
;
}
}
void
SparseAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
SparseAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
info
.
fea_dim
=
fea_dim
();
}
}
size_t
SparseAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
SparseAccessor
::
dim
()
{
return
sparse_feature_value
.
dim
();
}
size_t
SparseAccessor
::
dim
()
{
return
sparse_feature_value
.
dim
();
}
size_t
SparseAccessor
::
dim_size
(
size_t
dim
)
{
size_t
SparseAccessor
::
dim_size
(
size_t
dim
)
{
...
...
paddle/fluid/distributed/ps/table/sparse_accessor.h
浏览文件 @
2089b485
...
@@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor {
...
@@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor {
};
};
SparseAccessor
()
{}
SparseAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
virtual
~
SparseAccessor
()
{}
virtual
~
SparseAccessor
()
{}
// value维度
// value维度
...
...
paddle/fluid/distributed/ps/table/table.cc
浏览文件 @
2089b485
...
@@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() {
...
@@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() {
return
-
1
;
return
-
1
;
}
}
_value_accesor
.
reset
(
accessor
);
_value_accesor
.
reset
(
accessor
);
// _value_accesor->SetTableInfo(_table_info);
return
0
;
return
0
;
}
}
...
...
paddle/fluid/distributed/ps/table/table.h
浏览文件 @
2089b485
...
@@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 };
...
@@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 };
struct
PullContext
{
struct
PullContext
{
const
uint64_t
*
keys
;
const
uint64_t
*
keys
;
const
PullSparseValue
pull_value
;
PullSparseValue
pull_value
;
float
*
values
;
float
*
values
;
char
**
ptr_values
;
char
**
ptr_values
;
};
};
...
@@ -53,7 +53,7 @@ struct TableContext {
...
@@ -53,7 +53,7 @@ struct TableContext {
PullContext
pull_context
;
PullContext
pull_context
;
TablePushContext
push_context
;
TablePushContext
push_context
;
size_t
num
;
size_t
num
;
bool
use_ptr
;
bool
use_ptr
=
false
;
};
};
class
Table
{
class
Table
{
...
@@ -164,6 +164,7 @@ class Table {
...
@@ -164,6 +164,7 @@ class Table {
TableParameter
_config
;
TableParameter
_config
;
float
*
_global_lr
=
nullptr
;
float
*
_global_lr
=
nullptr
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
AccessorInfo
_table_info
;
AfsClient
_afs_client
;
AfsClient
_afs_client
;
};
};
REGISTER_PSCORE_REGISTERER
(
Table
);
REGISTER_PSCORE_REGISTERER
(
Table
);
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.cc
浏览文件 @
2089b485
...
@@ -20,16 +20,39 @@ namespace distributed {
...
@@ -20,16 +20,39 @@ namespace distributed {
int
CommMergeAccessor
::
initialize
()
{
return
0
;
}
int
CommMergeAccessor
::
initialize
()
{
return
0
;
}
void
CommMergeAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
CommMergeAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
info
.
fea_dim
=
fea_dim
();
}
}
size_t
CommMergeAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
// value 维度
// value 维度
size_t
CommMergeAccessor
::
dim
()
{
return
0
;
}
size_t
CommMergeAccessor
::
dim
()
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.h
浏览文件 @
2089b485
...
@@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
...
@@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor
()
{}
CommMergeAccessor
()
{}
virtual
~
CommMergeAccessor
()
{}
virtual
~
CommMergeAccessor
()
{}
virtual
int
initialize
();
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
// value维度
virtual
size_t
dim
();
virtual
size_t
dim
();
// value各个维度的size
// value各个维度的size
...
...
paddle/fluid/distributed/ps/wrapper/fleet.cc
浏览文件 @
2089b485
...
@@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
...
@@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
pull_result_ptr
.
push_back
(
output_data
+
output_len
);
pull_result_ptr
.
push_back
(
output_data
+
output_len
);
}
}
}
}
auto
status
=
// ps client pull sparse
worker_ptr_
->
pull_sparse
(
pull_result_ptr
.
data
(),
table_id
,
// construct client request context
fea_keys
.
data
(),
fea_keys
.
size
(),
is_training
);
RequestContext
req_context
;
req_context
.
value_type
=
Sparse
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
table_id
;
req_context
.
sparse_values
=
pull_result_ptr
.
data
();
req_context
.
keys
=
fea_keys
.
data
();
req_context
.
num
=
fea_keys
.
size
();
req_context
.
is_training
=
is_training
;
auto
status
=
worker_ptr_
->
Pull
(
req_context
);
// auto status =
// worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
// fea_keys.data(), fea_keys.size(),
// is_training);
status
.
wait
();
status
.
wait
();
auto
ret
=
status
.
get
();
auto
ret
=
status
.
get
();
if
(
ret
!=
0
)
{
if
(
ret
!=
0
)
{
...
@@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync(
...
@@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync(
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
regions
[
i
]
=
std
::
move
(
reg
);
regions
[
i
]
=
std
::
move
(
reg
);
}
}
auto
status
=
worker_ptr_
->
pull_dense
(
regions
.
data
(),
regions
.
size
(),
tid
);
RequestContext
req_context
;
req_context
.
value_type
=
Dense
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
tid
;
req_context
.
dense_values
=
regions
.
data
();
req_context
.
num
=
regions
.
size
();
auto
status
=
worker_ptr_
->
Pull
(
req_context
);
// auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status
->
push_back
(
std
::
move
(
status
));
pull_dense_status
->
push_back
(
std
::
move
(
status
));
}
}
...
@@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync(
...
@@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync(
<<
g
[
tensor
->
numel
()
-
1
];
<<
g
[
tensor
->
numel
()
-
1
];
}
}
auto
push_status
=
RequestContext
req_context
;
worker_ptr_
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
table_id
);
req_context
.
value_type
=
Dense
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
table_id
;
req_context
.
push_context
.
push_dense_values
=
regions
.
data
();
req_context
.
num
=
regions
.
size
();
// auto push_status =
// worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
auto
push_status
=
worker_ptr_
->
Push
(
req_context
);
}
}
void
FleetWrapper
::
PushSparseVarsAsync
(
void
FleetWrapper
::
PushSparseVarsAsync
(
...
@@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync(
...
@@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
}
}
auto
status
=
worker_ptr_
->
push_sparse
(
table_id
,
push_keys
.
data
(),
// ps client push sparse
(
const
float
**
)
push_g_vec
.
data
(),
// construct request context
push_keys
.
size
());
RequestContext
req_context
;
req_context
.
value_type
=
Sparse
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
table_id
;
req_context
.
push_context
.
push_values
=
(
const
float
**
)
push_g_vec
.
data
();
req_context
.
push_context
.
keys
=
push_keys
.
data
();
req_context
.
num
=
push_keys
.
size
();
auto
status
=
worker_ptr_
->
Push
(
req_context
);
// auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
// (const float**)push_g_vec.data(),
// push_keys.size());
}
}
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
)
{
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录