Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
84b0ec97
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2292
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
84b0ec97
编写于
11月 15, 2021
作者:
Z
zhaocaibei123
提交者:
GitHub
11月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Accessor 20211112 2 (#37181)
上级
12339fa0
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
461 addition
and
49 deletion
+461
-49
paddle/fluid/distributed/fleet.cc
paddle/fluid/distributed/fleet.cc
+164
-18
paddle/fluid/distributed/fleet.h
paddle/fluid/distributed/fleet.h
+7
-2
paddle/fluid/distributed/service/communicator.cc
paddle/fluid/distributed/service/communicator.cc
+31
-4
paddle/fluid/distributed/service/communicator.h
paddle/fluid/distributed/service/communicator.h
+21
-5
paddle/fluid/distributed/table/common_dense_table.cc
paddle/fluid/distributed/table/common_dense_table.cc
+222
-0
paddle/fluid/distributed/table/common_dense_table.h
paddle/fluid/distributed/table/common_dense_table.h
+16
-20
未找到文件。
paddle/fluid/distributed/fleet.cc
浏览文件 @
84b0ec97
...
...
@@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
std
::
vector
<
uint64_t
>
FleetWrapper
::
GetClientsInfo
()
{
VLOG
(
3
)
<<
"Going to get client info"
;
return
pserver_ptr_
->
get_client_info
();
return
std
::
vector
<
uint64_t
>
();
auto
*
communicator
=
Communicator
::
GetInstance
();
std
::
vector
<
uint64_t
>
res
=
communicator
->
GetClientInfo
();
return
res
;
}
void
FleetWrapper
::
CreateClient2ClientConnection
()
{
VLOG
(
3
)
<<
"Going to create client2client connection"
;
pserver_ptr_
->
create_client2client_connection
(
VLOG
(
1
)
<<
"Going to create client2client connection"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
communicator
->
_worker_ptr
->
create_client2client_connection
(
client2client_request_timeout_ms_
,
client2client_connect_timeout_ms_
,
client2client_max_retry_
);
}
...
...
@@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync(
const
std
::
vector
<
std
::
string
>&
var_names
,
std
::
vector
<
std
::
future
<
int32_t
>>*
push_sparse_status
,
float
scale_datanorm
,
int
batch_size
)
{
auto
*
communicator
=
Communicator
::
GetInstance
();
PADDLE_ENFORCE_EQ
(
communicator
->
Check
(
table_id
),
true
,
platform
::
errors
::
InvalidArgument
(
"can not find table: %s, please check your config"
,
table_id
));
communicator
->
Send
(
var_names
,
scope
);
auto
place
=
platform
::
CPUPlace
();
std
::
vector
<
paddle
::
distributed
::
Region
>
regions
;
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
float
*
g
=
tensor
->
mutable_data
<
float
>
(
place
);
paddle
::
distributed
::
Region
reg
(
g
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
3
)
<<
"FleetWrapper::PushDenseVarsAsync Var "
<<
t
<<
" talbe_id "
<<
table_id
<<
" Temp_data[0] "
<<
g
[
0
]
<<
" Temp_data[-1] "
<<
g
[
tensor
->
numel
()
-
1
];
}
auto
*
communicator
=
dynamic_cast
<
AsyncCommunicator
*>
(
Communicator
::
GetInstance
());
auto
push_status
=
communicator
->
_worker_ptr
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
table_id
);
communicator
->
PushDensePostProcessing
();
}
void
FleetWrapper
::
PushSparseVarsAsync
(
...
...
@@ -417,10 +433,140 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return
;
}
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
void
FleetWrapper
::
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
const
LoDTensor
*
shows
,
const
LoDTensor
*
clks
,
std
::
vector
<
LoDTensor
*>*
outputs
)
{
int
batch_size
=
-
1
;
for
(
auto
*
input
:
*
inputs
)
{
int
cur_batch_size
=
input
->
lod
().
size
()
?
input
->
lod
()[
0
].
size
()
-
1
:
input
->
dims
()[
0
];
if
(
batch_size
==
-
1
)
{
batch_size
=
cur_batch_size
;
}
else
{
CHECK
(
batch_size
==
cur_batch_size
);
// NOLINT
}
}
CHECK
(
batch_size
>
0
);
// NOLINT
int
show_size
=
shows
->
lod
().
size
()
?
shows
->
lod
()[
0
].
size
()
-
1
:
shows
->
dims
()[
0
];
CHECK
(
show_size
==
batch_size
||
show_size
==
1
);
int
clk_size
=
clks
->
lod
().
size
()
?
clks
->
lod
()[
0
].
size
()
-
1
:
clks
->
dims
()[
0
];
CHECK
(
clk_size
==
batch_size
||
clk_size
==
1
);
std
::
vector
<
float
>
g
;
for
(
framework
::
LoDTensor
*
g_tensor
:
*
outputs
)
{
float
*
g_ori
=
g_tensor
->
data
<
float
>
();
// no cvm
if
(
true
)
{
// TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen
::
Map
<
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
g_mat
(
g_ori
,
g_tensor
->
numel
()
/
fea_dim
,
fea_dim
);
g_mat
.
rightCols
(
fea_dim
)
*=
batch_size
;
}
size_t
origin
=
g
.
size
();
size_t
add
=
g_tensor
->
numel
();
g
.
resize
(
origin
+
add
);
memcpy
(
g
.
data
()
+
origin
,
g_tensor
->
data
<
float
>
(),
add
*
sizeof
(
float
));
}
std
::
vector
<
uint64_t
>
push_keys
;
push_keys
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
std
::
vector
<
std
::
vector
<
float
>>
push_values
;
push_values
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
size_t
output_len
=
0
;
size_t
input_idx
=
0
;
VLOG
(
2
)
<<
"fleet.cc::emb_dim: "
<<
fea_dim
;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
const
int64_t
*
show_tensor
=
shows
->
data
<
int64_t
>
();
const
int64_t
*
clk_tensor
=
clks
->
data
<
int64_t
>
();
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
const
framework
::
LoDTensor
*
tensor
=
inputs
->
at
(
index
);
const
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
size_t
len
=
tensor
->
numel
();
if
(
tensor
->
lod
().
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
tensor
->
lod
()[
0
].
size
()
-
1
;
++
i
)
{
for
(
int
j
=
tensor
->
lod
()[
0
][
i
];
j
<
tensor
->
lod
()[
0
][
i
+
1
];
++
j
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
j
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
3
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
static_cast
<
float
>
(
show_tensor
[
i
]));
push_values
.
back
()[
2
]
=
(
i
>=
clk_size
?
0
:
static_cast
<
float
>
(
clk_tensor
[
i
]));
float
*
data
=
push_values
.
back
().
data
()
+
3
;
memcpy
(
data
,
g
.
data
()
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
i
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
3
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
static_cast
<
float
>
(
show_tensor
[
i
]));
push_values
.
back
()[
2
]
=
(
i
>=
clk_size
?
0
:
static_cast
<
float
>
(
clk_tensor
[
i
]));
float
*
data
=
push_values
.
back
().
data
()
+
3
;
memcpy
(
data
,
g
.
data
()
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
}
VLOG
(
1
)
<<
"output_len: "
<<
output_len
<<
" g.size(): "
<<
g
.
size
();
CHECK
(
output_len
==
g
.
size
());
std
::
vector
<
float
*>
push_g_vec
(
input_idx
,
nullptr
);
for
(
auto
i
=
0u
;
i
<
push_keys
.
size
();
++
i
)
{
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
}
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
ret
=
communicator
->
_worker_ptr
->
load
(
path
,
mode
);
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
PADDLE_ENFORCE_EQ
(
communicator
->
Check
(
table_id
),
true
,
platform
::
errors
::
InvalidArgument
(
"can not find table: %s, please check your config"
,
table_id
));
auto
status
=
communicator
->
_worker_ptr
->
push_sparse
(
table_id
,
push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
push_keys
.
size
());
}
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
)
{
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
ret
=
communicator
->
_worker_ptr
->
load
(
path
,
std
::
to_string
(
mode
));
ret
.
wait
();
if
(
ret
.
get
()
!=
0
)
{
LOG
(
ERROR
)
<<
"load model from path:"
<<
path
<<
" failed"
;
...
...
@@ -562,16 +708,16 @@ void FleetWrapper::ClientFlush() {
int
FleetWrapper
::
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
VLOG
(
3
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
VLOG
(
3
)
<<
"pserver_ptr_="
<<
pserver_ptr_
;
VLOG
(
3
)
<<
"_worker_ptr="
<<
pserver_ptr_
->
_worker_ptr
;
return
pserver_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
VLOG
(
1
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
return
communicator
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
}
std
::
future
<
int32_t
>
FleetWrapper
::
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
return
pserver_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
auto
*
communicator
=
Communicator
::
GetInstance
();
return
communicator
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
}
...
...
paddle/fluid/distributed/fleet.h
浏览文件 @
84b0ec97
...
...
@@ -157,7 +157,12 @@ class FleetWrapper {
const
std
::
vector
<
std
::
string
>&
input_names
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
// NOLINT
std
::
vector
<
const
LoDTensor
*>*
outputs
);
// NOLINT
void
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
const
LoDTensor
*
shows
,
const
LoDTensor
*
clicks
,
std
::
vector
<
LoDTensor
*>*
outputs
);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
...
...
@@ -200,7 +205,7 @@ class FleetWrapper {
void
PrintTableStat
(
const
uint64_t
table_id
);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void
LoadModel
(
const
std
::
string
&
path
,
const
std
::
string
&
mode
);
void
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void
LoadModelOneTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
,
...
...
paddle/fluid/distributed/service/communicator.cc
浏览文件 @
84b0ec97
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
...
...
@@ -87,7 +88,7 @@ void Communicator::InitBrpcClient(
servers_
=
host_sign_list
.
size
();
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
set_ps_servers
(
&
host_sign_list
,
servers_
);
_worker_ptr
=
std
::
shared
_ptr
<
paddle
::
distributed
::
PSClient
>
(
_worker_ptr
=
std
::
unique
_ptr
<
paddle
::
distributed
::
PSClient
>
(
paddle
::
distributed
::
PSClientFactory
::
create
(
_ps_param
));
_worker_ptr
->
configure
(
_ps_param
,
_dense_pull_regions
,
_ps_env
,
trainer_id_
);
...
...
@@ -95,6 +96,19 @@ void Communicator::InitBrpcClient(
return
;
}
std
::
vector
<
uint64_t
>
Communicator
::
GetClientInfo
()
{
std
::
vector
<
uint64_t
>
res
=
_ps_env
.
get_client_info
();
for
(
auto
rr
:
res
)
{
VLOG
(
2
)
<<
"Communicator::GetClientInfo "
<<
rr
;
}
return
res
;
}
int
Communicator
::
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
)
{
int
node
=
host_sign_list
.
size
();
return
_ps_env
.
set_ps_clients
(
host_sign_list
.
data
(),
node
);
}
void
Communicator
::
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcRecvDense"
);
...
...
@@ -130,6 +144,11 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
VLOG
(
1
)
<<
"AsyncCommunicator::RecvNoBarrier Var "
<<
t
<<
" On gpu? "
<<
platform
::
is_gpu_place
(
tensor
->
place
());
float
*
temp_recv_data
=
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
VLOG
(
1
)
<<
"AsyncCommunicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_recv_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_recv_data
[
tensor
->
numel
()
-
1
];
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
LoDTensor
*
temp_tensor
=
...
...
@@ -519,6 +538,7 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars
<
float
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
}
if
(
ctx
.
is_tensor_table
)
{
SendGlobalStep
(
ctx
,
merged_var_num
,
send_scope_
.
get
());
}
else
if
(
ctx
.
is_sparse
)
{
...
...
@@ -547,6 +567,13 @@ void AsyncCommunicator::SendByCommunicator() {
return
;
}
void
AsyncCommunicator
::
PushDensePostProcessing
()
{
if
(
independent_recv_
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
return
;
}
void
AsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"AsyncCommunicator MainThread start and wait"
;
...
...
@@ -627,13 +654,13 @@ void AsyncCommunicator::Start() {
}
void
AsyncCommunicator
::
Stop
()
{
VLOG
(
1
)
<<
"Communicator stop"
;
_worker_ptr
->
finalize_worker
();
VLOG
(
0
)
<<
"Communicator finalize_worker done"
;
VLOG
(
1
)
<<
"Communicator stop begin"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
_worker_ptr
->
finalize_worker
();
VLOG
(
1
)
<<
"client finalize_worker done"
;
if
(
recv_thread_
)
{
VLOG
(
1
)
<<
"stop recv thread"
;
recv_thread_
->
join
();
...
...
paddle/fluid/distributed/service/communicator.h
浏览文件 @
84b0ec97
...
...
@@ -245,6 +245,11 @@ class Communicator {
virtual
void
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
);
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
();
virtual
int
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
);
// NOLINT
// 1. recv dense param
virtual
void
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
);
...
...
@@ -271,6 +276,7 @@ class Communicator {
virtual
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
// note: only for pull dense param first before training
virtual
void
PullDense
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
virtual
void
Start
()
=
0
;
...
...
@@ -296,6 +302,13 @@ class Communicator {
rets
.
wait
();
}
virtual
void
CreateC2CConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
_worker_ptr
->
create_client2client_connection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
);
}
virtual
void
BarrierTriggerDecrement
()
{}
virtual
void
BarrierTriggerReset
(
int
init_counter
)
{}
...
...
@@ -342,13 +355,13 @@ class Communicator {
PSClient
*
GetPsClient
()
{
return
_worker_ptr
.
get
();
}
std
::
shared
_ptr
<
paddle
::
distributed
::
PSClient
>
GetPsClientPtr
()
{
return
_worker_ptr
;
std
::
unique
_ptr
<
paddle
::
distributed
::
PSClient
>
GetPsClientPtr
()
{
return
std
::
move
(
_worker_ptr
)
;
}
RecvCtxMap
&
GetRecvCtxMap
()
{
return
recv_varname_to_ctx_
;
}
std
::
shared
_ptr
<
PSClient
>
_worker_ptr
;
// pointer to worker
std
::
unique
_ptr
<
PSClient
>
_worker_ptr
;
// pointer to worker
protected:
bool
running_
=
false
;
...
...
@@ -434,6 +447,8 @@ class AsyncCommunicator : public Communicator {
virtual
void
BarrierWeakUp
()
{}
void
PushDensePostProcessing
();
protected:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
...
...
@@ -542,14 +557,15 @@ class GeoCommunicator : public AsyncCommunicator {
Scope
*
recv_scope
)
override
;
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
override
;
void
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
);
void
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
);
// NOLINT
void
InitSparse
(
const
std
::
string
&
var_name
,
int
table_id
);
void
SendDense
(
const
CommContext
&
send_ctx
);
void
RecvDense
(
const
CommContext
&
send_ctx
);
std
::
vector
<
int64_t
>
MergeSparseIds
(
const
std
::
string
&
varname
);
void
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
void
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
// NOLINT
int
table_id
,
int
ep_idx
);
void
RecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
int
ep_idx
);
...
...
paddle/fluid/distributed/table/common_dense_table.cc
浏览文件 @
84b0ec97
...
...
@@ -19,6 +19,8 @@
namespace
paddle
{
namespace
distributed
{
int
FLAGS_pslib_table_save_max_retry_dense
=
3
;
void
CommonDenseTable
::
create_initializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
)
{
auto
slices
=
string
::
split_string
<
std
::
string
>
(
attr
,
"&"
);
...
...
@@ -56,6 +58,7 @@ int32_t CommonDenseTable::initialize_value() {
auto
common
=
_config
.
common
();
int
size
=
static_cast
<
int
>
(
common
.
params
().
size
());
values_
.
resize
(
size
);
total_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
varname
=
common
.
params
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
...
...
@@ -63,7 +66,9 @@ int32_t CommonDenseTable::initialize_value() {
param_dim_
=
dim
;
param_idx_
=
x
;
}
auto
&
initializer
=
common
.
initializers
()[
x
];
total_dim_
+=
dim
;
create_initializer
(
initializer
,
varname
);
values_
[
x
].
resize
(
dim
);
...
...
@@ -74,6 +79,22 @@ int32_t CommonDenseTable::initialize_value() {
}
}
fixed_len_params_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
dim
=
common
.
dims
()[
x
];
if
(
dim
!=
param_dim_
)
{
fixed_len_params_dim_
+=
dim
;
}
else
{
param_col_ids_
.
push_back
(
x
);
}
}
if
(
_config
.
common
().
name
()
==
"adam_d2sum"
)
{
param_col_ids_
.
insert
(
param_col_ids_
.
begin
()
+
1
,
-
1
);
}
VLOG
(
1
)
<<
"CommonDenseTable::initialize_value total dim: "
<<
total_dim_
<<
" fixed_len_params_dim: "
<<
fixed_len_params_dim_
;
pull_reservoir_
=
ReservoirValue
<
float
>
(
param_dim_
);
return
0
;
}
...
...
@@ -89,6 +110,9 @@ int32_t CommonDenseTable::initialize_optimizer() {
}
else
if
(
name
==
"adam"
)
{
optimizer_
=
std
::
make_shared
<
DAdam
>
(
common
,
&
values_
);
optimizer_
->
set_global_lr
(
_global_lr
);
}
else
if
(
name
==
"adam_d2sum"
)
{
optimizer_
=
std
::
make_shared
<
DAdamD2Sum
>
(
common
,
&
values_
);
// optimizer_->set_global_lr(_global_lr); //no use
}
else
if
(
name
==
"sum"
)
{
optimizer_
=
std
::
make_shared
<
DSUM
>
(
common
,
&
values_
);
}
else
{
...
...
@@ -162,8 +186,206 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
VLOG
(
2
)
<<
"debug CommonDenseTable::_push_dense done"
;
return
0
;
}
int32_t
CommonDenseTable
::
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
if
(
param_dim_
<=
0
)
{
return
0
;
}
std
::
string
table_path
=
table_dir
(
path
);
auto
file_list
=
_afs_client
.
list
(
table_path
);
std
::
sort
(
file_list
.
begin
(),
file_list
.
end
());
for
(
auto
ff
:
file_list
)
{
VLOG
(
1
)
<<
"load dense table file list: "
<<
ff
;
}
size_t
dim_num_per_file
=
_config
.
accessor
().
fea_dim
()
/
file_list
.
size
()
+
1
;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t
dim_num_per_shard
=
_value_accesor
->
fea_dim
()
/
_shard_num
+
1
;
size_t
start_dim_idx
=
dim_num_per_shard
*
_shard_idx
;
size_t
start_file_idx
=
start_dim_idx
/
dim_num_per_file
;
size_t
end_file_idx
=
(
start_dim_idx
+
param_dim_
)
/
dim_num_per_file
;
end_file_idx
=
end_file_idx
<
file_list
.
size
()
?
end_file_idx
:
file_list
.
size
()
-
1
;
VLOG
(
2
)
<<
"load dense table start_file_idx: "
<<
start_file_idx
<<
" end_file_idx: "
<<
end_file_idx
;
int
load_param
=
atoi
(
param
.
c_str
());
FsChannelConfig
channel_config
;
channel_config
.
converter
=
_value_accesor
->
converter
(
load_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
converter
(
load_param
).
deconverter
;
bool
is_read_failed
=
false
;
int
err_no
=
0
;
int
retry_num
=
0
;
do
{
is_read_failed
=
false
;
try
{
size_t
dim_idx
=
0
;
float
data_buffer
[
5
];
float
*
data_buff_ptr
=
data_buffer
;
std
::
string
line_data
;
int
size
=
static_cast
<
int
>
(
values_
.
size
());
auto
common
=
_config
.
common
();
for
(
int
i
=
start_file_idx
;
i
<
end_file_idx
+
1
;
++
i
)
{
channel_config
.
path
=
file_list
[
i
];
err_no
=
0
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
size_t
file_start_idx
=
start_dim_idx
-
i
*
dim_num_per_file
;
// not all file contains param and the length of last file containing
// param may not equal to others
size_t
file_dim_idx
=
0
;
for
(;
file_dim_idx
<
dim_num_per_file
;
++
file_dim_idx
)
{
if
(
read_channel
->
read_line
(
line_data
)
!=
0
)
{
break
;
}
if
(
dim_idx
>=
param_dim_
)
{
break
;
}
if
(
file_dim_idx
<
file_start_idx
)
{
continue
;
}
auto
str_len
=
paddle
::
string
::
str_to_float
(
line_data
.
data
(),
data_buff_ptr
);
CHECK
(
str_len
==
param_col_ids_
.
size
())
<<
"expect "
<<
param_col_ids_
.
size
()
<<
" float, but got "
<<
str_len
;
for
(
size_t
col_idx
=
0
;
col_idx
<
str_len
;
++
col_idx
)
{
if
(
param_col_ids_
[
col_idx
]
<
0
)
{
continue
;
}
values_
[
param_col_ids_
[
col_idx
]][
dim_idx
]
=
data_buffer
[
col_idx
];
VLOG
(
2
)
<<
"CommonDenseTable::load param x: "
<<
param_col_ids_
[
col_idx
]
<<
" y: "
<<
dim_idx
<<
" value: "
<<
values_
[
param_col_ids_
[
col_idx
]][
dim_idx
]
<<
" line "
<<
file_dim_idx
;
}
++
dim_idx
;
}
read_channel
->
close
();
VLOG
(
1
)
<<
"DownpourDenseTable load done "
<<
channel_config
.
path
<<
" file_start_idx: "
<<
file_start_idx
<<
" dim_idx: "
<<
dim_idx
;
if
(
err_no
==
-
1
)
{
if
(
retry_num
>
FLAGS_pslib_table_save_max_retry_dense
)
{
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed reach max limit!"
;
exit
(
-
1
);
}
++
retry_num
;
--
i
;
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed after read , retry it! path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
continue
;
}
retry_num
=
0
;
start_dim_idx
+=
file_dim_idx
-
file_start_idx
;
LOG
(
INFO
)
<<
"DownpourDenseTable load success, path:"
<<
channel_config
.
path
;
}
}
catch
(...)
{
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed, retry it! path:"
<<
channel_config
.
path
;
}
}
while
(
is_read_failed
);
return
0
;
}
int32_t
CommonDenseTable
::
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
int
save_param
=
atoi
(
param
.
c_str
());
uint32_t
feasign_size
;
VLOG
(
0
)
<<
"CommonDenseTable::save path "
<<
path
;
FsChannelConfig
channel_config
;
if
(
_config
.
compress_in_save
())
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d.gz"
,
table_dir
(
path
).
c_str
(),
_shard_idx
);
}
else
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d"
,
table_dir
(
path
).
c_str
(),
_shard_idx
);
}
_afs_client
.
remove
(
channel_config
.
path
);
channel_config
.
converter
=
_value_accesor
->
converter
(
save_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
converter
(
save_param
).
deconverter
;
bool
is_write_failed
=
false
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
result_buffer_param
(
param_dim_
,
std
::
vector
<
std
::
string
>
());
std
::
vector
<
std
::
string
>
result_buffer_fixed_len
;
result_buffer_fixed_len
.
reserve
(
fixed_len_params_dim_
);
auto
common
=
_config
.
common
();
int
size
=
static_cast
<
int
>
(
common
.
params
().
size
());
std
::
ostringstream
os
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
varname
=
common
.
params
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
VLOG
(
0
)
<<
"CommonDenseTable::save dim "
<<
x
<<
" size: "
<<
dim
;
for
(
int
y
=
0
;
y
<
dim
;
++
y
)
{
os
.
clear
();
os
.
str
(
""
);
os
<<
values_
[
x
][
y
];
if
(
dim
==
param_dim_
)
{
result_buffer_param
[
y
].
emplace_back
(
std
::
move
(
os
.
str
()));
}
else
{
result_buffer_fixed_len
.
emplace_back
(
std
::
move
(
os
.
str
()));
}
}
}
int
retry_num
=
0
;
int
err_no
=
0
;
do
{
err_no
=
0
;
is_write_failed
=
false
;
feasign_size
=
0
;
// 40M
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
auto
&
t
:
result_buffer_param
)
{
if
(
_config
.
common
().
name
()
==
"adam_d2sum"
)
{
t
.
insert
(
t
.
begin
()
+
1
,
"0"
);
// avg_w
}
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
join_strings
(
t
,
' '
)))
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed, retry it! "
"path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
break
;
}
}
++
feasign_size
;
write_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed after write, retry it! "
<<
"path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
}
if
(
retry_num
>
paddle
::
distributed
::
FLAGS_pslib_table_save_max_retry_dense
)
{
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed reach max limit!"
;
exit
(
-
1
);
}
}
while
(
is_write_failed
);
LOG
(
INFO
)
<<
"DownpourDenseTable save success, path:"
<<
channel_config
.
path
;
return
feasign_size
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/table/common_dense_table.h
浏览文件 @
84b0ec97
...
...
@@ -32,33 +32,26 @@ class DenseOptimizer;
class
CommonDenseTable
:
public
DenseTable
{
public:
explicit
CommonDenseTable
()
{}
CommonDenseTable
()
{}
virtual
~
CommonDenseTable
()
{}
virtual
int32_t
initialize
()
override
;
virtual
int32_t
initialize_shard
()
override
{
return
0
;
}
int32_t
initialize
()
override
;
int32_t
initialize_shard
()
override
{
return
0
;
}
virtual
void
create_initializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
);
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_optimizer
();
virtual
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
virtual
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
virtual
int32_t
pour
()
override
;
virtual
int32_t
set_global_lr
(
float
*
lr
)
override
;
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
pour
()
override
;
int32_t
set_global_lr
(
float
*
lr
)
override
;
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
{
VLOG
(
0
)
<<
"WARNING: dense variables will load on No.0 trainer"
;
return
0
;
}
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
{
VLOG
(
0
)
<<
"WARNING: dense variables will save on No.0 trainer"
;
return
0
;
}
virtual
int32_t
flush
()
override
{
return
0
;
}
virtual
int32_t
shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
virtual
void
clear
()
override
{
return
;
}
int32_t
flush
()
override
{
return
0
;
}
int32_t
shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
void
clear
()
override
{
return
;
}
protected:
int32_t
_push_dense
(
const
float
*
values
,
size_t
num
);
...
...
@@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable {
ReservoirValue
<
float
>
pull_reservoir_
;
std
::
unordered_map
<
std
::
string
,
Initializer
*>
initializers_
;
std
::
unordered_map
<
std
::
string
,
int
>
names_index_
;
int
total_dim_
=
0
;
int
fixed_len_params_dim_
=
0
;
// used for save/load
std
::
vector
<
int
>
param_col_ids_
;
// used for save/load
};
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录