Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
84b0ec97
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
84b0ec97
编写于
11月 15, 2021
作者:
Z
zhaocaibei123
提交者:
GitHub
11月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Accessor 20211112 2 (#37181)
上级
12339fa0
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
461 addition
and
49 deletion
+461
-49
paddle/fluid/distributed/fleet.cc
paddle/fluid/distributed/fleet.cc
+164
-18
paddle/fluid/distributed/fleet.h
paddle/fluid/distributed/fleet.h
+7
-2
paddle/fluid/distributed/service/communicator.cc
paddle/fluid/distributed/service/communicator.cc
+31
-4
paddle/fluid/distributed/service/communicator.h
paddle/fluid/distributed/service/communicator.h
+21
-5
paddle/fluid/distributed/table/common_dense_table.cc
paddle/fluid/distributed/table/common_dense_table.cc
+222
-0
paddle/fluid/distributed/table/common_dense_table.h
paddle/fluid/distributed/table/common_dense_table.h
+16
-20
未找到文件。
paddle/fluid/distributed/fleet.cc
浏览文件 @
84b0ec97
...
...
@@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
std
::
vector
<
uint64_t
>
FleetWrapper
::
GetClientsInfo
()
{
VLOG
(
3
)
<<
"Going to get client info"
;
return
pserver_ptr_
->
get_client_info
();
return
std
::
vector
<
uint64_t
>
();
auto
*
communicator
=
Communicator
::
GetInstance
();
std
::
vector
<
uint64_t
>
res
=
communicator
->
GetClientInfo
();
return
res
;
}
void
FleetWrapper
::
CreateClient2ClientConnection
()
{
VLOG
(
3
)
<<
"Going to create client2client connection"
;
pserver_ptr_
->
create_client2client_connection
(
VLOG
(
1
)
<<
"Going to create client2client connection"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
communicator
->
_worker_ptr
->
create_client2client_connection
(
client2client_request_timeout_ms_
,
client2client_connect_timeout_ms_
,
client2client_max_retry_
);
}
...
...
@@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync(
const
std
::
vector
<
std
::
string
>&
var_names
,
std
::
vector
<
std
::
future
<
int32_t
>>*
push_sparse_status
,
float
scale_datanorm
,
int
batch_size
)
{
auto
*
communicator
=
Communicator
::
GetInstance
();
PADDLE_ENFORCE_EQ
(
communicator
->
Check
(
table_id
),
true
,
platform
::
errors
::
InvalidArgument
(
"can not find table: %s, please check your config"
,
table_id
));
communicator
->
Send
(
var_names
,
scope
);
auto
place
=
platform
::
CPUPlace
();
std
::
vector
<
paddle
::
distributed
::
Region
>
regions
;
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
float
*
g
=
tensor
->
mutable_data
<
float
>
(
place
);
paddle
::
distributed
::
Region
reg
(
g
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
VLOG
(
3
)
<<
"FleetWrapper::PushDenseVarsAsync Var "
<<
t
<<
" talbe_id "
<<
table_id
<<
" Temp_data[0] "
<<
g
[
0
]
<<
" Temp_data[-1] "
<<
g
[
tensor
->
numel
()
-
1
];
}
auto
*
communicator
=
dynamic_cast
<
AsyncCommunicator
*>
(
Communicator
::
GetInstance
());
auto
push_status
=
communicator
->
_worker_ptr
->
push_dense
(
regions
.
data
(),
regions
.
size
(),
table_id
);
communicator
->
PushDensePostProcessing
();
}
void
FleetWrapper
::
PushSparseVarsAsync
(
...
...
@@ -417,10 +433,140 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return
;
}
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
void
FleetWrapper
::
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
const
LoDTensor
*
shows
,
const
LoDTensor
*
clks
,
std
::
vector
<
LoDTensor
*>*
outputs
)
{
int
batch_size
=
-
1
;
for
(
auto
*
input
:
*
inputs
)
{
int
cur_batch_size
=
input
->
lod
().
size
()
?
input
->
lod
()[
0
].
size
()
-
1
:
input
->
dims
()[
0
];
if
(
batch_size
==
-
1
)
{
batch_size
=
cur_batch_size
;
}
else
{
CHECK
(
batch_size
==
cur_batch_size
);
// NOLINT
}
}
CHECK
(
batch_size
>
0
);
// NOLINT
int
show_size
=
shows
->
lod
().
size
()
?
shows
->
lod
()[
0
].
size
()
-
1
:
shows
->
dims
()[
0
];
CHECK
(
show_size
==
batch_size
||
show_size
==
1
);
int
clk_size
=
clks
->
lod
().
size
()
?
clks
->
lod
()[
0
].
size
()
-
1
:
clks
->
dims
()[
0
];
CHECK
(
clk_size
==
batch_size
||
clk_size
==
1
);
std
::
vector
<
float
>
g
;
for
(
framework
::
LoDTensor
*
g_tensor
:
*
outputs
)
{
float
*
g_ori
=
g_tensor
->
data
<
float
>
();
// no cvm
if
(
true
)
{
// TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen
::
Map
<
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
g_mat
(
g_ori
,
g_tensor
->
numel
()
/
fea_dim
,
fea_dim
);
g_mat
.
rightCols
(
fea_dim
)
*=
batch_size
;
}
size_t
origin
=
g
.
size
();
size_t
add
=
g_tensor
->
numel
();
g
.
resize
(
origin
+
add
);
memcpy
(
g
.
data
()
+
origin
,
g_tensor
->
data
<
float
>
(),
add
*
sizeof
(
float
));
}
std
::
vector
<
uint64_t
>
push_keys
;
push_keys
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
std
::
vector
<
std
::
vector
<
float
>>
push_values
;
push_values
.
reserve
(
MAX_FEASIGN_NUM
/
100
);
size_t
output_len
=
0
;
size_t
input_idx
=
0
;
VLOG
(
2
)
<<
"fleet.cc::emb_dim: "
<<
fea_dim
;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
const
int64_t
*
show_tensor
=
shows
->
data
<
int64_t
>
();
const
int64_t
*
clk_tensor
=
clks
->
data
<
int64_t
>
();
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
const
framework
::
LoDTensor
*
tensor
=
inputs
->
at
(
index
);
const
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
size_t
len
=
tensor
->
numel
();
if
(
tensor
->
lod
().
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
tensor
->
lod
()[
0
].
size
()
-
1
;
++
i
)
{
for
(
int
j
=
tensor
->
lod
()[
0
][
i
];
j
<
tensor
->
lod
()[
0
][
i
+
1
];
++
j
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
j
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
3
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
static_cast
<
float
>
(
show_tensor
[
i
]));
push_values
.
back
()[
2
]
=
(
i
>=
clk_size
?
0
:
static_cast
<
float
>
(
clk_tensor
[
i
]));
float
*
data
=
push_values
.
back
().
data
()
+
3
;
memcpy
(
data
,
g
.
data
()
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
len
;
++
i
,
output_len
+=
fea_dim
)
{
uint64_t
real_id
=
static_cast
<
uint64_t
>
(
ids
[
i
]);
if
(
real_id
==
padding_id
)
{
continue
;
}
push_keys
.
emplace_back
(
real_id
);
push_values
.
emplace_back
(
fea_dim
+
3
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
static_cast
<
float
>
(
show_tensor
[
i
]));
push_values
.
back
()[
2
]
=
(
i
>=
clk_size
?
0
:
static_cast
<
float
>
(
clk_tensor
[
i
]));
float
*
data
=
push_values
.
back
().
data
()
+
3
;
memcpy
(
data
,
g
.
data
()
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
++
input_idx
;
}
}
}
VLOG
(
1
)
<<
"output_len: "
<<
output_len
<<
" g.size(): "
<<
g
.
size
();
CHECK
(
output_len
==
g
.
size
());
std
::
vector
<
float
*>
push_g_vec
(
input_idx
,
nullptr
);
for
(
auto
i
=
0u
;
i
<
push_keys
.
size
();
++
i
)
{
push_g_vec
[
i
]
=
push_values
.
at
(
i
).
data
();
}
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
ret
=
communicator
->
_worker_ptr
->
load
(
path
,
mode
);
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
PADDLE_ENFORCE_EQ
(
communicator
->
Check
(
table_id
),
true
,
platform
::
errors
::
InvalidArgument
(
"can not find table: %s, please check your config"
,
table_id
));
auto
status
=
communicator
->
_worker_ptr
->
push_sparse
(
table_id
,
push_keys
.
data
(),
(
const
float
**
)
push_g_vec
.
data
(),
push_keys
.
size
());
}
void
FleetWrapper
::
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
)
{
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
ret
=
communicator
->
_worker_ptr
->
load
(
path
,
std
::
to_string
(
mode
));
ret
.
wait
();
if
(
ret
.
get
()
!=
0
)
{
LOG
(
ERROR
)
<<
"load model from path:"
<<
path
<<
" failed"
;
...
...
@@ -562,16 +708,16 @@ void FleetWrapper::ClientFlush() {
int
FleetWrapper
::
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
VLOG
(
3
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
VLOG
(
3
)
<<
"pserver_ptr_="
<<
pserver_ptr_
;
VLOG
(
3
)
<<
"_worker_ptr="
<<
pserver_ptr_
->
_worker_ptr
;
return
pserver_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
VLOG
(
1
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
return
communicator
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
}
std
::
future
<
int32_t
>
FleetWrapper
::
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
return
pserver_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
auto
*
communicator
=
Communicator
::
GetInstance
();
return
communicator
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
}
...
...
paddle/fluid/distributed/fleet.h
浏览文件 @
84b0ec97
...
...
@@ -157,7 +157,12 @@ class FleetWrapper {
const
std
::
vector
<
std
::
string
>&
input_names
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
// NOLINT
std
::
vector
<
const
LoDTensor
*>*
outputs
);
// NOLINT
void
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
const
LoDTensor
*
shows
,
const
LoDTensor
*
clicks
,
std
::
vector
<
LoDTensor
*>*
outputs
);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
...
...
@@ -200,7 +205,7 @@ class FleetWrapper {
void
PrintTableStat
(
const
uint64_t
table_id
);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void
LoadModel
(
const
std
::
string
&
path
,
const
std
::
string
&
mode
);
void
LoadModel
(
const
std
::
string
&
path
,
const
int
mode
);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void
LoadModelOneTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
,
...
...
paddle/fluid/distributed/service/communicator.cc
浏览文件 @
84b0ec97
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
...
...
@@ -87,7 +88,7 @@ void Communicator::InitBrpcClient(
servers_
=
host_sign_list
.
size
();
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
set_ps_servers
(
&
host_sign_list
,
servers_
);
_worker_ptr
=
std
::
shared
_ptr
<
paddle
::
distributed
::
PSClient
>
(
_worker_ptr
=
std
::
unique
_ptr
<
paddle
::
distributed
::
PSClient
>
(
paddle
::
distributed
::
PSClientFactory
::
create
(
_ps_param
));
_worker_ptr
->
configure
(
_ps_param
,
_dense_pull_regions
,
_ps_env
,
trainer_id_
);
...
...
@@ -95,6 +96,19 @@ void Communicator::InitBrpcClient(
return
;
}
std
::
vector
<
uint64_t
>
Communicator
::
GetClientInfo
()
{
std
::
vector
<
uint64_t
>
res
=
_ps_env
.
get_client_info
();
for
(
auto
rr
:
res
)
{
VLOG
(
2
)
<<
"Communicator::GetClientInfo "
<<
rr
;
}
return
res
;
}
int
Communicator
::
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
)
{
int
node
=
host_sign_list
.
size
();
return
_ps_env
.
set_ps_clients
(
host_sign_list
.
data
(),
node
);
}
void
Communicator
::
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
)
{
platform
::
RecordEvent
record_event
(
"Communicator->RpcRecvDense"
);
...
...
@@ -130,6 +144,11 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
VLOG
(
1
)
<<
"AsyncCommunicator::RecvNoBarrier Var "
<<
t
<<
" On gpu? "
<<
platform
::
is_gpu_place
(
tensor
->
place
());
float
*
temp_recv_data
=
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
VLOG
(
1
)
<<
"AsyncCommunicator::RpcRecvDense Var "
<<
t
<<
" table_id "
<<
table_id
<<
" Temp_data[0] "
<<
temp_recv_data
[
0
]
<<
" Temp_data[-1] "
<<
temp_recv_data
[
tensor
->
numel
()
-
1
];
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
LoDTensor
*
temp_tensor
=
...
...
@@ -519,6 +538,7 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars
<
float
>
(
var_name
,
vars
[
i
],
send_scope_
.
get
(),
1
);
}
}
if
(
ctx
.
is_tensor_table
)
{
SendGlobalStep
(
ctx
,
merged_var_num
,
send_scope_
.
get
());
}
else
if
(
ctx
.
is_sparse
)
{
...
...
@@ -547,6 +567,13 @@ void AsyncCommunicator::SendByCommunicator() {
return
;
}
void
AsyncCommunicator
::
PushDensePostProcessing
()
{
if
(
independent_recv_
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
return
;
}
void
AsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"AsyncCommunicator MainThread start and wait"
;
...
...
@@ -627,13 +654,13 @@ void AsyncCommunicator::Start() {
}
void
AsyncCommunicator
::
Stop
()
{
VLOG
(
1
)
<<
"Communicator stop"
;
_worker_ptr
->
finalize_worker
();
VLOG
(
0
)
<<
"Communicator finalize_worker done"
;
VLOG
(
1
)
<<
"Communicator stop begin"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
_worker_ptr
->
finalize_worker
();
VLOG
(
1
)
<<
"client finalize_worker done"
;
if
(
recv_thread_
)
{
VLOG
(
1
)
<<
"stop recv thread"
;
recv_thread_
->
join
();
...
...
paddle/fluid/distributed/service/communicator.h
浏览文件 @
84b0ec97
...
...
@@ -245,6 +245,11 @@ class Communicator {
virtual
void
InitBrpcClient
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>
&
host_sign_list
);
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
();
virtual
int
SetClients
(
std
::
vector
<
uint64_t
>
&
host_sign_list
);
// NOLINT
// 1. recv dense param
virtual
void
RpcRecvDense
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
,
Scope
*
scope
);
...
...
@@ -271,6 +276,7 @@ class Communicator {
virtual
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
// note: only for pull dense param first before training
virtual
void
PullDense
(
const
RecvCtxMap
&
recv_varname_to_ctx
);
virtual
void
Start
()
=
0
;
...
...
@@ -296,6 +302,13 @@ class Communicator {
rets
.
wait
();
}
virtual
void
CreateC2CConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
_worker_ptr
->
create_client2client_connection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
);
}
virtual
void
BarrierTriggerDecrement
()
{}
virtual
void
BarrierTriggerReset
(
int
init_counter
)
{}
...
...
@@ -342,13 +355,13 @@ class Communicator {
PSClient
*
GetPsClient
()
{
return
_worker_ptr
.
get
();
}
std
::
shared
_ptr
<
paddle
::
distributed
::
PSClient
>
GetPsClientPtr
()
{
return
_worker_ptr
;
std
::
unique
_ptr
<
paddle
::
distributed
::
PSClient
>
GetPsClientPtr
()
{
return
std
::
move
(
_worker_ptr
)
;
}
RecvCtxMap
&
GetRecvCtxMap
()
{
return
recv_varname_to_ctx_
;
}
std
::
shared
_ptr
<
PSClient
>
_worker_ptr
;
// pointer to worker
std
::
unique
_ptr
<
PSClient
>
_worker_ptr
;
// pointer to worker
protected:
bool
running_
=
false
;
...
...
@@ -434,6 +447,8 @@ class AsyncCommunicator : public Communicator {
virtual
void
BarrierWeakUp
()
{}
void
PushDensePostProcessing
();
protected:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
...
...
@@ -542,14 +557,15 @@ class GeoCommunicator : public AsyncCommunicator {
Scope
*
recv_scope
)
override
;
void
InitParams
(
const
RecvCtxMap
&
recv_varname_to_ctx
)
override
;
void
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
);
void
InitDense
(
std
::
vector
<
std
::
string
>
&
varnames
,
int
table_id
);
// NOLINT
void
InitSparse
(
const
std
::
string
&
var_name
,
int
table_id
);
void
SendDense
(
const
CommContext
&
send_ctx
);
void
RecvDense
(
const
CommContext
&
send_ctx
);
std
::
vector
<
int64_t
>
MergeSparseIds
(
const
std
::
string
&
varname
);
void
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
void
SendSparse
(
const
std
::
string
&
varname
,
std
::
vector
<
int64_t
>
&
sparse_ids
,
// NOLINT
int
table_id
,
int
ep_idx
);
void
RecvSparse
(
const
std
::
string
&
varname
,
int
table_id
,
int
ep_idx
);
...
...
paddle/fluid/distributed/table/common_dense_table.cc
浏览文件 @
84b0ec97
...
...
@@ -19,6 +19,8 @@
namespace
paddle
{
namespace
distributed
{
int
FLAGS_pslib_table_save_max_retry_dense
=
3
;
void
CommonDenseTable
::
create_initializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
)
{
auto
slices
=
string
::
split_string
<
std
::
string
>
(
attr
,
"&"
);
...
...
@@ -56,6 +58,7 @@ int32_t CommonDenseTable::initialize_value() {
auto
common
=
_config
.
common
();
int
size
=
static_cast
<
int
>
(
common
.
params
().
size
());
values_
.
resize
(
size
);
total_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
varname
=
common
.
params
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
...
...
@@ -63,7 +66,9 @@ int32_t CommonDenseTable::initialize_value() {
param_dim_
=
dim
;
param_idx_
=
x
;
}
auto
&
initializer
=
common
.
initializers
()[
x
];
total_dim_
+=
dim
;
create_initializer
(
initializer
,
varname
);
values_
[
x
].
resize
(
dim
);
...
...
@@ -74,6 +79,22 @@ int32_t CommonDenseTable::initialize_value() {
}
}
fixed_len_params_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
dim
=
common
.
dims
()[
x
];
if
(
dim
!=
param_dim_
)
{
fixed_len_params_dim_
+=
dim
;
}
else
{
param_col_ids_
.
push_back
(
x
);
}
}
if
(
_config
.
common
().
name
()
==
"adam_d2sum"
)
{
param_col_ids_
.
insert
(
param_col_ids_
.
begin
()
+
1
,
-
1
);
}
VLOG
(
1
)
<<
"CommonDenseTable::initialize_value total dim: "
<<
total_dim_
<<
" fixed_len_params_dim: "
<<
fixed_len_params_dim_
;
pull_reservoir_
=
ReservoirValue
<
float
>
(
param_dim_
);
return
0
;
}
...
...
@@ -89,6 +110,9 @@ int32_t CommonDenseTable::initialize_optimizer() {
}
else
if
(
name
==
"adam"
)
{
optimizer_
=
std
::
make_shared
<
DAdam
>
(
common
,
&
values_
);
optimizer_
->
set_global_lr
(
_global_lr
);
}
else
if
(
name
==
"adam_d2sum"
)
{
optimizer_
=
std
::
make_shared
<
DAdamD2Sum
>
(
common
,
&
values_
);
// optimizer_->set_global_lr(_global_lr); //no use
}
else
if
(
name
==
"sum"
)
{
optimizer_
=
std
::
make_shared
<
DSUM
>
(
common
,
&
values_
);
}
else
{
...
...
@@ -162,8 +186,206 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
for
(
size_t
shard_id
=
0
;
shard_id
<
tasks
.
size
();
++
shard_id
)
{
tasks
[
shard_id
].
wait
();
}
VLOG
(
2
)
<<
"debug CommonDenseTable::_push_dense done"
;
return
0
;
}
int32_t
CommonDenseTable
::
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
if
(
param_dim_
<=
0
)
{
return
0
;
}
std
::
string
table_path
=
table_dir
(
path
);
auto
file_list
=
_afs_client
.
list
(
table_path
);
std
::
sort
(
file_list
.
begin
(),
file_list
.
end
());
for
(
auto
ff
:
file_list
)
{
VLOG
(
1
)
<<
"load dense table file list: "
<<
ff
;
}
size_t
dim_num_per_file
=
_config
.
accessor
().
fea_dim
()
/
file_list
.
size
()
+
1
;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t
dim_num_per_shard
=
_value_accesor
->
fea_dim
()
/
_shard_num
+
1
;
size_t
start_dim_idx
=
dim_num_per_shard
*
_shard_idx
;
size_t
start_file_idx
=
start_dim_idx
/
dim_num_per_file
;
size_t
end_file_idx
=
(
start_dim_idx
+
param_dim_
)
/
dim_num_per_file
;
end_file_idx
=
end_file_idx
<
file_list
.
size
()
?
end_file_idx
:
file_list
.
size
()
-
1
;
VLOG
(
2
)
<<
"load dense table start_file_idx: "
<<
start_file_idx
<<
" end_file_idx: "
<<
end_file_idx
;
int
load_param
=
atoi
(
param
.
c_str
());
FsChannelConfig
channel_config
;
channel_config
.
converter
=
_value_accesor
->
converter
(
load_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
converter
(
load_param
).
deconverter
;
bool
is_read_failed
=
false
;
int
err_no
=
0
;
int
retry_num
=
0
;
do
{
is_read_failed
=
false
;
try
{
size_t
dim_idx
=
0
;
float
data_buffer
[
5
];
float
*
data_buff_ptr
=
data_buffer
;
std
::
string
line_data
;
int
size
=
static_cast
<
int
>
(
values_
.
size
());
auto
common
=
_config
.
common
();
for
(
int
i
=
start_file_idx
;
i
<
end_file_idx
+
1
;
++
i
)
{
channel_config
.
path
=
file_list
[
i
];
err_no
=
0
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
size_t
file_start_idx
=
start_dim_idx
-
i
*
dim_num_per_file
;
// not all file contains param and the length of last file containing
// param may not equal to others
size_t
file_dim_idx
=
0
;
for
(;
file_dim_idx
<
dim_num_per_file
;
++
file_dim_idx
)
{
if
(
read_channel
->
read_line
(
line_data
)
!=
0
)
{
break
;
}
if
(
dim_idx
>=
param_dim_
)
{
break
;
}
if
(
file_dim_idx
<
file_start_idx
)
{
continue
;
}
auto
str_len
=
paddle
::
string
::
str_to_float
(
line_data
.
data
(),
data_buff_ptr
);
CHECK
(
str_len
==
param_col_ids_
.
size
())
<<
"expect "
<<
param_col_ids_
.
size
()
<<
" float, but got "
<<
str_len
;
for
(
size_t
col_idx
=
0
;
col_idx
<
str_len
;
++
col_idx
)
{
if
(
param_col_ids_
[
col_idx
]
<
0
)
{
continue
;
}
values_
[
param_col_ids_
[
col_idx
]][
dim_idx
]
=
data_buffer
[
col_idx
];
VLOG
(
2
)
<<
"CommonDenseTable::load param x: "
<<
param_col_ids_
[
col_idx
]
<<
" y: "
<<
dim_idx
<<
" value: "
<<
values_
[
param_col_ids_
[
col_idx
]][
dim_idx
]
<<
" line "
<<
file_dim_idx
;
}
++
dim_idx
;
}
read_channel
->
close
();
VLOG
(
1
)
<<
"DownpourDenseTable load done "
<<
channel_config
.
path
<<
" file_start_idx: "
<<
file_start_idx
<<
" dim_idx: "
<<
dim_idx
;
if
(
err_no
==
-
1
)
{
if
(
retry_num
>
FLAGS_pslib_table_save_max_retry_dense
)
{
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed reach max limit!"
;
exit
(
-
1
);
}
++
retry_num
;
--
i
;
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed after read , retry it! path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
continue
;
}
retry_num
=
0
;
start_dim_idx
+=
file_dim_idx
-
file_start_idx
;
LOG
(
INFO
)
<<
"DownpourDenseTable load success, path:"
<<
channel_config
.
path
;
}
}
catch
(...)
{
is_read_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable load failed, retry it! path:"
<<
channel_config
.
path
;
}
}
while
(
is_read_failed
);
return
0
;
}
int32_t
CommonDenseTable
::
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
int
save_param
=
atoi
(
param
.
c_str
());
uint32_t
feasign_size
;
VLOG
(
0
)
<<
"CommonDenseTable::save path "
<<
path
;
FsChannelConfig
channel_config
;
if
(
_config
.
compress_in_save
())
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d.gz"
,
table_dir
(
path
).
c_str
(),
_shard_idx
);
}
else
{
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d"
,
table_dir
(
path
).
c_str
(),
_shard_idx
);
}
_afs_client
.
remove
(
channel_config
.
path
);
channel_config
.
converter
=
_value_accesor
->
converter
(
save_param
).
converter
;
channel_config
.
deconverter
=
_value_accesor
->
converter
(
save_param
).
deconverter
;
bool
is_write_failed
=
false
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
result_buffer_param
(
param_dim_
,
std
::
vector
<
std
::
string
>
());
std
::
vector
<
std
::
string
>
result_buffer_fixed_len
;
result_buffer_fixed_len
.
reserve
(
fixed_len_params_dim_
);
auto
common
=
_config
.
common
();
int
size
=
static_cast
<
int
>
(
common
.
params
().
size
());
std
::
ostringstream
os
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
varname
=
common
.
params
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
VLOG
(
0
)
<<
"CommonDenseTable::save dim "
<<
x
<<
" size: "
<<
dim
;
for
(
int
y
=
0
;
y
<
dim
;
++
y
)
{
os
.
clear
();
os
.
str
(
""
);
os
<<
values_
[
x
][
y
];
if
(
dim
==
param_dim_
)
{
result_buffer_param
[
y
].
emplace_back
(
std
::
move
(
os
.
str
()));
}
else
{
result_buffer_fixed_len
.
emplace_back
(
std
::
move
(
os
.
str
()));
}
}
}
int
retry_num
=
0
;
int
err_no
=
0
;
do
{
err_no
=
0
;
is_write_failed
=
false
;
feasign_size
=
0
;
// 40M
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
auto
&
t
:
result_buffer_param
)
{
if
(
_config
.
common
().
name
()
==
"adam_d2sum"
)
{
t
.
insert
(
t
.
begin
()
+
1
,
"0"
);
// avg_w
}
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
join_strings
(
t
,
' '
)))
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed, retry it! "
"path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
break
;
}
}
++
feasign_size
;
write_channel
->
close
();
if
(
err_no
==
-
1
)
{
++
retry_num
;
is_write_failed
=
true
;
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed after write, retry it! "
<<
"path:"
<<
channel_config
.
path
<<
", retry_num="
<<
retry_num
;
}
if
(
is_write_failed
)
{
_afs_client
.
remove
(
channel_config
.
path
);
}
if
(
retry_num
>
paddle
::
distributed
::
FLAGS_pslib_table_save_max_retry_dense
)
{
LOG
(
ERROR
)
<<
"DownpourDenseTable save failed reach max limit!"
;
exit
(
-
1
);
}
}
while
(
is_write_failed
);
LOG
(
INFO
)
<<
"DownpourDenseTable save success, path:"
<<
channel_config
.
path
;
return
feasign_size
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/table/common_dense_table.h
浏览文件 @
84b0ec97
...
...
@@ -32,33 +32,26 @@ class DenseOptimizer;
class
CommonDenseTable
:
public
DenseTable
{
public:
explicit
CommonDenseTable
()
{}
CommonDenseTable
()
{}
virtual
~
CommonDenseTable
()
{}
virtual
int32_t
initialize
()
override
;
virtual
int32_t
initialize_shard
()
override
{
return
0
;
}
int32_t
initialize
()
override
;
int32_t
initialize_shard
()
override
{
return
0
;
}
virtual
void
create_initializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
);
virtual
int32_t
initialize_value
();
virtual
int32_t
initialize_optimizer
();
virtual
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
virtual
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
virtual
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
virtual
int32_t
pour
()
override
;
virtual
int32_t
set_global_lr
(
float
*
lr
)
override
;
int32_t
pull_dense
(
float
*
pull_values
,
size_t
num
)
override
;
int32_t
push_dense_param
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
push_dense
(
const
float
*
values
,
size_t
num
)
override
;
int32_t
pour
()
override
;
int32_t
set_global_lr
(
float
*
lr
)
override
;
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
{
VLOG
(
0
)
<<
"WARNING: dense variables will load on No.0 trainer"
;
return
0
;
}
int32_t
load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
;
int32_t
save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
override
{
VLOG
(
0
)
<<
"WARNING: dense variables will save on No.0 trainer"
;
return
0
;
}
virtual
int32_t
flush
()
override
{
return
0
;
}
virtual
int32_t
shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
virtual
void
clear
()
override
{
return
;
}
int32_t
flush
()
override
{
return
0
;
}
int32_t
shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
void
clear
()
override
{
return
;
}
protected:
int32_t
_push_dense
(
const
float
*
values
,
size_t
num
);
...
...
@@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable {
ReservoirValue
<
float
>
pull_reservoir_
;
std
::
unordered_map
<
std
::
string
,
Initializer
*>
initializers_
;
std
::
unordered_map
<
std
::
string
,
int
>
names_index_
;
int
total_dim_
=
0
;
int
fixed_len_params_dim_
=
0
;
// used for save/load
std
::
vector
<
int
>
param_col_ids_
;
// used for save/load
};
}
// namespace distributed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录