Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2089b485
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2089b485
编写于
3月 30, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
3月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change to new api in ssync mode (#41022)
* change to new api in ssync mode * fix * fix * fix * fix
上级
60c4c9cd
变更
22
显示空白变更内容
内联
并排
Showing
22 changed file
with
296 addition
and
88 deletion
+296
-88
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+45
-37
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
+30
-7
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+3
-3
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+5
-3
paddle/fluid/distributed/ps/table/accessor.h
paddle/fluid/distributed/ps/table/accessor.h
+13
-1
paddle/fluid/distributed/ps/table/common_dense_table.cc
paddle/fluid/distributed/ps/table/common_dense_table.cc
+2
-2
paddle/fluid/distributed/ps/table/common_sparse_table.cc
paddle/fluid/distributed/ps/table/common_sparse_table.cc
+1
-1
paddle/fluid/distributed/ps/table/ctr_accessor.cc
paddle/fluid/distributed/ps/table/ctr_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/ctr_accessor.h
paddle/fluid/distributed/ps/table/ctr_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
+1
-0
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
+17
-13
paddle/fluid/distributed/ps/table/sparse_accessor.cc
paddle/fluid/distributed/ps/table/sparse_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/sparse_accessor.h
paddle/fluid/distributed/ps/table/sparse_accessor.h
+2
-1
paddle/fluid/distributed/ps/table/table.cc
paddle/fluid/distributed/ps/table/table.cc
+1
-0
paddle/fluid/distributed/ps/table/table.h
paddle/fluid/distributed/ps/table/table.h
+3
-2
paddle/fluid/distributed/ps/table/tensor_accessor.cc
paddle/fluid/distributed/ps/table/tensor_accessor.cc
+24
-1
paddle/fluid/distributed/ps/table/tensor_accessor.h
paddle/fluid/distributed/ps/table/tensor_accessor.h
+2
-1
paddle/fluid/distributed/ps/wrapper/fleet.cc
paddle/fluid/distributed/ps/wrapper/fleet.cc
+45
-9
未找到文件。
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
浏览文件 @
2089b485
...
...
@@ -532,18 +532,17 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if
(
pull_context
.
value_type
==
Dense
)
{
// pull dense
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
return
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
float
**
select_values
=
reinterpret_cast
<
float
**>
(
pull_context
.
sparse_values
);
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
bool
is_training
=
pull_context
.
is_training
;
if
(
pull_context
.
training_mode
==
Geo
)
{
// for geo
pull_sparse_param
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
return
pull_sparse_param
(
pull_context
.
sparse_values
,
table_id
,
pull_context
.
keys
,
num
,
is_training
);
}
else
if
(
pull_context
.
training_mode
==
Async
)
{
// for async
pull_sparse
(
select_values
,
table_id
,
keys
,
num
,
is_training
);
return
pull_sparse
(
pull_context
.
sparse_values
,
table_id
,
pull_context
.
keys
,
num
,
is_training
);
}
}
}
...
...
@@ -551,7 +550,7 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
std
::
future
<
int32_t
>
BrpcPsClient
::
Push
(
RequestContext
&
push_context
)
{
if
(
push_context
.
value_type
==
Dense
)
{
// push dense
const
Region
*
dense_region
=
push_context
.
push_context
.
push_dense_values
;
push_dense
(
dense_region
,
push_context
.
num
,
push_context
.
table
);
return
push_dense
(
dense_region
,
push_context
.
num
,
push_context
.
table
);
}
else
{
// push sparse
size_t
table_id
=
push_context
.
table
;
size_t
num
=
push_context
.
num
;
...
...
@@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
}
else
if
(
push_context
.
training_mode
==
Async
)
{
// for async
const
uint64_t
*
keys
=
push_context
.
push_context
.
keys
;
const
float
**
update_values
=
push_context
.
push_context
.
push_values
;
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
return
push_sparse
(
table_id
,
keys
,
update_values
,
num
);
}
}
}
...
...
@@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
&
shard_nums
),
sizeof
(
uint32_t
));
keys
->
resize
(
shard_nums
);
values
->
resize
(
shard_nums
*
accessor
->
update_dim
(
));
values
->
resize
(
shard_nums
*
accessor
->
GetTableInfo
(
UPDATE_DIM
));
io_buffer_itr
.
copy_and_forward
((
void
*
)(
keys
->
data
()),
// NOLINT
sizeof
(
uint64_t
)
*
shard_nums
);
io_buffer_itr
.
copy_and_forward
((
void
*
)(
values
->
data
()),
// NOLINT
shard_nums
*
accessor
->
update_size
());
io_buffer_itr
.
copy_and_forward
(
(
void
*
)(
values
->
data
()),
// NOLINT
shard_nums
*
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
...
...
@@ -630,7 +630,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
auto
kvs
=
ids
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
update_size
(
);
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_PARAM
);
...
...
@@ -638,13 +638,14 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
update_size
()));
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
GetTableInfo
(
UPDATE_SIZE
)));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
int
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
update_size
(
));
push_data_ptr
+=
accessor
->
update_size
(
);
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
push_data_ptr
+=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
}
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
...
...
@@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t
table_id
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_dense"
);
auto
*
accessor
=
table_accessor
(
table_id
);
auto
fea_dim
=
accessor
->
GetTableInfo
(
FEA_DIM
);
auto
select_size
=
accessor
->
GetTableInfo
(
SELECT_SIZE
);
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
num_per_shard
,
regions
,
region_num
,
...
...
@@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t
region_idx
=
0
;
// 当前填充的region偏移
size_t
region_data_idx
=
0
;
// 当前填充的region内data偏移
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
select_size
();
size_t
shard_data_size
=
num_per_shard
*
accessor
->
GetTableInfo
(
SELECT_SIZE
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
...
...
@@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std
::
vector
<
std
::
vector
<
Region
>>
regions_partition
(
request_call_num
);
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
update_size
(
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
size_t
current_region_idx
=
0
;
size_t
current_region_data_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
...
...
@@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
update_size
(
);
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
...
...
@@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
update_size
()));
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
accessor
->
GetTableInfo
(
UPDATE_SIZE
)));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
int
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
update_size
(
));
push_data_ptr
+=
accessor
->
update_size
(
);
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
push_data_ptr
+=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
}
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
...
...
@@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
std
::
future
<
int
>
fut
=
promise
->
get_future
();
auto
*
accessor
=
table_accessor
(
table_id
);
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
...
...
@@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
}
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
select_size
();
size_t
value_size
=
accessor
->
GetTableInfo
(
SELECT_SIZE
);
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
...
...
@@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
}
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
select_size
(
);
size_t
value_size
=
accessor
->
GetTableInfo
(
SELECT_SIZE
);
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
...
...
@@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
{
auto
*
accessor
=
table_accessor
(
table_id
);
size_t
value_size
=
accessor
->
update_size
(
);
size_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
...
...
@@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
shard_kv_data
.
kv_num
=
0
;
continue
;
}
uint32_t
value_size
=
accessor
->
update_size
();
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
shard_kv_data
.
key_list
[
kv_idx
]
=
sorted_kv_list
[
kv_idx
].
first
;
shard_kv_data
.
value_list
[
kv_idx
].
assign
(
...
...
@@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() {
void
sparse_local_merge
(
ValueAccessor
*
accessor
,
float
*
merge_data
,
const
float
*
another_data
)
{
size_t
col_num
=
accessor
->
update_size
(
)
/
sizeof
(
float
);
size_t
col_num
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
)
/
sizeof
(
float
);
float
*
merge_data_shell
[
col_num
];
const
float
*
another_data_shell
[
col_num
];
for
(
int
i
=
0
;
i
<
col_num
;
++
i
)
{
...
...
@@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge(
ValueAccessor
*
accessor
)
{
size_t
merged_kv_count
=
0
;
uint64_t
min_key
=
UINT64_MAX
;
uint32_t
value_size
=
accessor
->
update_size
(
);
uint32_t
value_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
thread_local
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>
sorted_kv_list
;
sorted_kv_list
.
clear
();
...
...
@@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push(
push_request
->
add_params
(
reinterpret_cast
<
char
*>
(
&
merged_kv_count
),
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
int
update_size
=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
push_data
->
resize
(
merged_kv_count
*
(
sizeof
(
uint64_t
)
+
accessor
->
update_size
(
)));
(
sizeof
(
uint64_t
)
+
accessor
->
GetTableInfo
(
UPDATE_SIZE
)));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
merged_key_list
.
data
(),
merged_kv_count
*
sizeof
(
uint64_t
));
...
...
@@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push(
const
char
*
task_data_ptr
=
merged_value_list
[
i
].
data
();
memcpy
(
push_data_ptr
,
(
float
*
)(
task_data_ptr
),
// NOLINT
accessor
->
update_size
(
));
push_data_ptr
+=
accessor
->
update_size
(
);
accessor
->
GetTableInfo
(
UPDATE_SIZE
));
push_data_ptr
+=
accessor
->
GetTableInfo
(
UPDATE_SIZE
);
}
PsService_Stub
rpc_stub
(
get_sparse_channel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
...
...
@@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
table_accessor
(
table_id
);
int
fea_dim
=
accessor
->
GetTableInfo
(
FEA_DIM
);
int
update_dim
=
accessor
->
GetTableInfo
(
UPDATE_DIM
);
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense"
);
auto
parse_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_parse"
);
...
...
@@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
// 将region数据拷贝到转置矩阵中
async_task
->
data
()
->
resize
(
num_per_shard
*
request_call_num
*
accessor
->
update_dim
(
));
accessor
->
GetTableInfo
(
UPDATE_DIM
));
float
*
data
=
async_task
->
data
()
->
data
();
size_t
data_size
=
async_task
->
data
()
->
size
();
uint32_t
pos
=
0
;
...
...
@@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient(
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_rpc"
);
closure
->
add_timer
(
timer
);
uint32_t
num_per_shard
=
dense_dim_per_shard
(
accessor
->
fea_dim
(
),
request_call_num
);
dense_dim_per_shard
(
accessor
->
GetTableInfo
(
FEA_DIM
),
request_call_num
);
auto
send_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_send"
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
...
...
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
浏览文件 @
2089b485
...
...
@@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
table
->
value_accesor
()
->
select_size
()
/
sizeof
(
float
));
table
->
pull_dense
(
res_data
->
data
(),
num
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table_context
.
num
=
num
;
table
->
Pull
(
table_context
);
// table->pull_dense(res_data->data(), num);
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
...
...
@@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
|--4B---|----------------|
*/
uint32_t
num
=
*
(
const
uint32_t
*
)(
request
.
data
().
data
());
const
float
*
values
=
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
(
const
float
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
if
(
table
->
push_dense
(
values
,
num
)
!=
0
)
{
table_context
.
num
=
num
;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->push_dense(values, num) != 0) {
set_response_code
(
response
,
-
1
,
"push_dense failed"
);
}
...
...
@@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table,
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
dim
);
table
->
pull_sparse
(
res_data
->
data
(),
value
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
pull_value
=
value
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table
->
Pull
(
table_context
);
// table->pull_sparse(res_data->data(), value);
cntl
->
response_attachment
().
append
((
char
*
)(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
...
...
@@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table,
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const
uint64_t
*
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
const
float
*
values
=
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
table_context
.
push_context
.
values
=
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
if
(
table
->
push_sparse
(
keys
,
values
,
num
)
!=
0
)
{
table_context
.
num
=
num
;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->push_sparse(keys, values, num) != 0) {
set_response_code
(
response
,
-
1
,
"push_sparse error"
);
}
return
0
;
...
...
paddle/fluid/distributed/ps/service/ps_client.h
浏览文件 @
2089b485
...
...
@@ -86,8 +86,8 @@ struct RequestContext {
TrainingMode
training_mode
;
// 1 for async, 2 for geo, 3 for sync
TrainingPhase
training_phase
;
// 1 for init, 2 for train
ValueType
value_type
;
// 1 for sparse, 2 for dense
void
*
keys
;
void
**
sparse_values
;
// for sparse values
uint64_t
*
keys
;
float
**
sparse_values
;
// for sparse values
Region
*
dense_values
;
// for dense values
PushContext
push_context
;
size_t
num
;
...
...
paddle/fluid/distributed/ps/service/ps_local_client.cc
浏览文件 @
2089b485
...
...
@@ -126,11 +126,13 @@ std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
Region
*
dense_region
=
reinterpret_cast
<
Region
*>
(
pull_context
.
dense_values
);
pull_dense
(
dense_region
,
pull_context
.
num
,
pull_context
.
table
);
}
else
{
// pull sparse
uint64_t
*
keys
=
reinterpret_cast
<
uint64_t
*>
(
pull_context
.
keys
);
char
**
select_values
=
reinterpret_cast
<
char
**>
(
pull_context
.
sparse_values
);
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t
table_id
=
pull_context
.
table
;
size_t
num
=
pull_context
.
num
;
pull_sparse_ptr
(
select_values
,
table_id
,
keys
,
num
);
pull_sparse_ptr
(
reinterpret_cast
<
char
**>
(
pull_context
.
sparse_values
),
table_id
,
pull_context
.
keys
,
num
);
}
}
...
...
paddle/fluid/distributed/ps/table/accessor.h
浏览文件 @
2089b485
...
...
@@ -56,6 +56,17 @@ struct AccessorInfo {
size_t
fea_dim
;
};
enum
InfoKey
{
DIM
=
0
,
SIZE
=
1
,
SELECT_SIZE
=
2
,
SELECT_DIM
=
3
,
UPDATE_SIZE
=
4
,
UPDATE_DIM
=
5
,
MF_SIZE
=
6
,
FEA_DIM
=
7
};
class
ValueAccessor
{
public:
ValueAccessor
()
{}
...
...
@@ -79,7 +90,8 @@ class ValueAccessor {
}
virtual
int
initialize
()
=
0
;
virtual
void
GetTableInfo
(
AccessorInfo
&
info
)
=
0
;
virtual
void
SetTableInfo
(
AccessorInfo
&
info
)
=
0
;
virtual
size_t
GetTableInfo
(
InfoKey
key
)
=
0
;
// value维度
virtual
size_t
dim
()
=
0
;
...
...
paddle/fluid/distributed/ps/table/common_dense_table.cc
浏览文件 @
2089b485
...
...
@@ -138,7 +138,7 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
int32_t
CommonDenseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
if
(
context
.
pu
ll
_context
.
values
!=
nullptr
)
{
if
(
context
.
pu
sh
_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
return
push_dense
(
values
,
context
.
num
);
}
...
...
@@ -220,7 +220,7 @@ int32_t CommonDenseTable::load(const std::string& path,
}
size_t
dim_num_per_file
=
_config
.
accessor
().
fea_dim
()
/
file_list
.
size
()
+
1
;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t
dim_num_per_shard
=
_
value_accesor
->
fea_dim
()
/
_shard_num
+
1
;
size_t
dim_num_per_shard
=
_
table_info
.
fea_dim
/
_shard_num
+
1
;
size_t
start_dim_idx
=
dim_num_per_shard
*
_shard_idx
;
size_t
start_file_idx
=
start_dim_idx
/
dim_num_per_file
;
size_t
end_file_idx
=
(
start_dim_idx
+
param_dim_
)
/
dim_num_per_file
;
...
...
paddle/fluid/distributed/ps/table/common_sparse_table.cc
浏览文件 @
2089b485
...
...
@@ -370,7 +370,7 @@ int32_t CommonSparseTable::Pull(TableContext& context) {
int32_t
CommonSparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
pu
ll
_context
.
values
!=
nullptr
)
{
if
(
context
.
pu
sh
_context
.
values
!=
nullptr
)
{
const
float
*
values
=
context
.
push_context
.
values
;
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
values
,
context
.
num
);
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.cc
浏览文件 @
2089b485
...
...
@@ -38,16 +38,39 @@ int CtrCommonAccessor::initialize() {
return
0
;
}
void
CtrCommonAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
CtrCommonAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
CtrCommonAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
CtrCommonAccessor
::
dim
()
{
return
common_feature_value
.
dim
();
}
size_t
CtrCommonAccessor
::
dim_size
(
size_t
dim
)
{
...
...
paddle/fluid/distributed/ps/table/ctr_accessor.h
浏览文件 @
2089b485
...
...
@@ -137,7 +137,8 @@ class CtrCommonAccessor : public ValueAccessor {
virtual
int
initialize
();
virtual
~
CtrCommonAccessor
()
{}
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
浏览文件 @
2089b485
...
...
@@ -37,16 +37,39 @@ int DownpourCtrDoubleAccessor::initialize() {
return
0
;
}
void
DownpourCtrDoubleAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
DownpourCtrDoubleAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
DownpourCtrDoubleAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
DownpourCtrDoubleAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrDoubleFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.h
浏览文件 @
2089b485
...
...
@@ -168,7 +168,8 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor
()
{}
virtual
~
DownpourCtrDoubleAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/depends/sparse_utils.h
浏览文件 @
2089b485
...
...
@@ -24,6 +24,7 @@ namespace paddle {
namespace
distributed
{
struct
PullSparseValue
{
PullSparseValue
()
{}
explicit
PullSparseValue
(
int
numel
,
int
dim
)
:
numel_
(
numel
),
dim_
(
dim
),
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.cc
浏览文件 @
2089b485
...
...
@@ -37,16 +37,39 @@ int DownpourCtrAccessor::initialize() {
return
0
;
}
void
DownpourCtrAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
DownpourCtrAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
DownpourCtrAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
DownpourCtrAccessor
::
dim
()
{
auto
embedx_dim
=
_config
.
embedx_dim
();
return
DownpourCtrFeatureValue
::
dim
(
embedx_dim
);
...
...
paddle/fluid/distributed/ps/table/downpour_ctr_accessor.h
浏览文件 @
2089b485
...
...
@@ -160,7 +160,8 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual
~
DownpourCtrAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
浏览文件 @
2089b485
...
...
@@ -88,7 +88,8 @@ int32_t MemorySparseTable::load(const std::string& path,
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
feature_value_size
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
feature_value_size
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
omp_set_num_threads
(
thread_num
);
...
...
@@ -173,7 +174,8 @@ int32_t MemorySparseTable::load_local_fs(const std::string& path,
size_t
file_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
size_t
feature_value_size
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
feature_value_size
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
int
thread_num
=
_real_local_shard_num
<
15
?
_real_local_shard_num
:
15
;
omp_set_num_threads
(
thread_num
);
...
...
@@ -407,7 +409,7 @@ int32_t MemorySparseTable::Push(TableContext& context) {
CHECK
(
context
.
value_type
==
Sparse
);
const
uint64_t
*
keys
=
context
.
push_context
.
keys
;
return
push_sparse
(
keys
,
context
.
push_context
.
ptr_
values
,
context
.
num
);
return
push_sparse
(
keys
,
context
.
push_context
.
values
,
context
.
num
);
}
int32_t
MemorySparseTable
::
pull_sparse
(
float
*
pull_values
,
...
...
@@ -415,9 +417,10 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
CostTimer
timer
(
"pserver_sparse_select_all"
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
const
size_t
value_size
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
mf_size
()
/
sizeof
(
float
);
size_t
select_value_size
=
_value_accesor
->
select_size
()
/
sizeof
(
float
);
const
size_t
value_size
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
size_t
mf_value_size
=
_value_accesor
->
GetTableInfo
(
MF_SIZE
)
/
sizeof
(
float
);
size_t
select_value_size
=
_value_accesor
->
GetTableInfo
(
SELECT_SIZE
)
/
sizeof
(
float
);
// std::atomic<uint32_t> missed_keys{0};
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
...
...
@@ -475,7 +478,6 @@ int32_t MemorySparseTable::pull_sparse(float* pull_values,
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
return
0
;
}
...
...
@@ -541,9 +543,10 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
const
size_t
value_col
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
mf_size
()
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
update_size
()
/
sizeof
(
float
);
const
size_t
value_col
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetTableInfo
(
MF_SIZE
)
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetTableInfo
(
UPDATE_SIZE
)
/
sizeof
(
float
);
for
(
size_t
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
...
...
@@ -618,9 +621,10 @@ int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
task_keys
[
shard_id
].
push_back
({
keys
[
i
],
i
});
}
size_t
value_col
=
_value_accesor
->
size
()
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
mf_size
()
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
update_size
()
/
sizeof
(
float
);
size_t
value_col
=
_value_accesor
->
GetTableInfo
(
SIZE
)
/
sizeof
(
float
);
size_t
mf_value_col
=
_value_accesor
->
GetTableInfo
(
MF_SIZE
)
/
sizeof
(
float
);
size_t
update_value_col
=
_value_accesor
->
GetTableInfo
(
UPDATE_SIZE
)
/
sizeof
(
float
);
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_task_pool_size
]
->
enqueue
(
...
...
paddle/fluid/distributed/ps/table/sparse_accessor.cc
浏览文件 @
2089b485
...
...
@@ -38,16 +38,39 @@ int SparseAccessor::initialize() {
return
0
;
}
void
SparseAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
SparseAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
SparseAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
size_t
SparseAccessor
::
dim
()
{
return
sparse_feature_value
.
dim
();
}
size_t
SparseAccessor
::
dim_size
(
size_t
dim
)
{
...
...
paddle/fluid/distributed/ps/table/sparse_accessor.h
浏览文件 @
2089b485
...
...
@@ -123,7 +123,8 @@ class SparseAccessor : public ValueAccessor {
};
SparseAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
virtual
~
SparseAccessor
()
{}
// value维度
...
...
paddle/fluid/distributed/ps/table/table.cc
浏览文件 @
2089b485
...
...
@@ -103,6 +103,7 @@ int32_t Table::initialize_accessor() {
return
-
1
;
}
_value_accesor
.
reset
(
accessor
);
// _value_accesor->SetTableInfo(_table_info);
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/table.h
浏览文件 @
2089b485
...
...
@@ -37,7 +37,7 @@ enum ValueType { Sparse = 0, Dense = 1 };
struct
PullContext
{
const
uint64_t
*
keys
;
const
PullSparseValue
pull_value
;
PullSparseValue
pull_value
;
float
*
values
;
char
**
ptr_values
;
};
...
...
@@ -53,7 +53,7 @@ struct TableContext {
PullContext
pull_context
;
TablePushContext
push_context
;
size_t
num
;
bool
use_ptr
;
bool
use_ptr
=
false
;
};
class
Table
{
...
...
@@ -164,6 +164,7 @@ class Table {
TableParameter
_config
;
float
*
_global_lr
=
nullptr
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
AccessorInfo
_table_info
;
AfsClient
_afs_client
;
};
REGISTER_PSCORE_REGISTERER
(
Table
);
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.cc
浏览文件 @
2089b485
...
...
@@ -20,16 +20,39 @@ namespace distributed {
int
CommMergeAccessor
::
initialize
()
{
return
0
;
}
void
CommMergeAccessor
::
G
etTableInfo
(
AccessorInfo
&
info
)
{
void
CommMergeAccessor
::
S
etTableInfo
(
AccessorInfo
&
info
)
{
info
.
dim
=
dim
();
info
.
size
=
size
();
info
.
select_dim
=
select_dim
();
info
.
select_size
=
select_size
();
info
.
update_dim
=
update_dim
();
info
.
update_size
=
update_size
();
info
.
mf_size
=
mf_size
();
info
.
fea_dim
=
fea_dim
();
}
size_t
CommMergeAccessor
::
GetTableInfo
(
InfoKey
key
)
{
switch
(
key
)
{
case
DIM
:
return
dim
();
case
SIZE
:
return
size
();
case
SELECT_DIM
:
return
select_dim
();
case
SELECT_SIZE
:
return
select_size
();
case
UPDATE_DIM
:
return
update_dim
();
case
UPDATE_SIZE
:
return
update_size
();
case
MF_SIZE
:
return
mf_size
();
case
FEA_DIM
:
return
fea_dim
();
}
return
0
;
}
// value 维度
size_t
CommMergeAccessor
::
dim
()
{
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/tensor_accessor.h
浏览文件 @
2089b485
...
...
@@ -30,7 +30,8 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor
()
{}
virtual
~
CommMergeAccessor
()
{}
virtual
int
initialize
();
virtual
void
GetTableInfo
(
AccessorInfo
&
info
);
virtual
void
SetTableInfo
(
AccessorInfo
&
info
);
virtual
size_t
GetTableInfo
(
InfoKey
key
);
// value维度
virtual
size_t
dim
();
// value各个维度的size
...
...
paddle/fluid/distributed/ps/wrapper/fleet.cc
浏览文件 @
2089b485
...
...
@@ -337,9 +337,21 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
pull_result_ptr
.
push_back
(
output_data
+
output_len
);
}
}
auto
status
=
worker_ptr_
->
pull_sparse
(
pull_result_ptr
.
data
(),
table_id
,
fea_keys
.
data
(),
fea_keys
.
size
(),
is_training
);
// ps client pull sparse
// construct client request context
RequestContext
req_context
;
req_context
.
value_type
=
Sparse
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
table_id
;
req_context
.
sparse_values
=
pull_result_ptr
.
data
();
req_context
.
keys
=
fea_keys
.
data
();
req_context
.
num
=
fea_keys
.
size
();
req_context
.
is_training
=
is_training
;
auto
status
=
worker_ptr_
->
Pull
(
req_context
);
// auto status =
// worker_ptr_->pull_sparse(pull_result_ptr.data(), table_id,
// fea_keys.data(), fea_keys.size(),
// is_training);
status
.
wait
();
auto
ret
=
status
.
get
();
if
(
ret
!=
0
)
{
...
...
@@ -366,7 +378,14 @@ void FleetWrapper::PullDenseVarsAsync(
paddle
::
distributed
::
Region
reg
(
w
,
tensor
->
numel
());
regions
[
i
]
=
std
::
move
(
reg
);
}
auto
status
=
worker_ptr_
->
pull_dense
(
regions
.
data
(),
regions
.
size
(),
tid
);
RequestContext
req_context
;
req_context
.
value_type
=
Dense
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
tid
;
req_context
.
dense_values
=
regions
.
data
();
req_context
.
num
=
regions
.
size
();
auto
status
=
worker_ptr_
->
Pull
(
req_context
);
// auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status
->
push_back
(
std
::
move
(
status
));
}
...
...
@@ -451,8 +470,15 @@ void FleetWrapper::PushDenseVarsAsync(
<<
g
[
tensor
->
numel
()
-
1
];
}
auto
push_status
=
worker_ptr_
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
table_id
);
RequestContext
req_context
;
req_context
.
value_type
=
Dense
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
table_id
;
req_context
.
push_context
.
push_dense_values
=
regions
.
data
();
req_context
.
num
=
regions
.
size
();
// auto push_status =
// worker_ptr_->push_dense(regions.data(), regions.size(), table_id);
auto
push_status
=
worker_ptr_
->
Push
(
req_context
);
}
void
FleetWrapper
::
PushSparseVarsAsync
(
...
...
@@ -624,9 +650,19 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
}
auto
status
=
worker_ptr_
->
push_sparse
(
table_id
,
push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
push_keys
.
size
());
// ps client push sparse
// construct request context
RequestContext
req_context
;
req_context
.
value_type
=
Sparse
;
req_context
.
training_mode
=
Async
;
req_context
.
table
=
table_id
;
req_context
.
push_context
.
push_values
=
(
const
float
**
)
push_g_vec
.
data
();
req_context
.
push_context
.
keys
=
push_keys
.
data
();
req_context
.
num
=
push_keys
.
size
();
auto
status
=
worker_ptr_
->
Push
(
req_context
);
// auto status = worker_ptr_->push_sparse(table_id, push_keys.data(),
// (const float**)push_g_vec.data(),
// push_keys.size());
}
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录