Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ca8c4f3e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca8c4f3e
编写于
11月 17, 2021
作者:
Z
zhaocaibei123
提交者:
GitHub
11月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update dataset (#37194)
上级
54d2626a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
67 addition
and
29 deletion
+67
-29
paddle/fluid/distributed/fleet.cc
paddle/fluid/distributed/fleet.cc
+9
-2
paddle/fluid/distributed/service/communicator.cc
paddle/fluid/distributed/service/communicator.cc
+0
-13
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+28
-13
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+26
-1
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+4
-0
未找到文件。
paddle/fluid/distributed/fleet.cc
浏览文件 @
ca8c4f3e
...
@@ -710,8 +710,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
...
@@ -710,8 +710,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc
handler
)
{
MsgHandlerFunc
handler
)
{
VLOG
(
1
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
VLOG
(
1
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
auto
*
communicator
=
Communicator
::
GetInstance
();
return
communicator
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
// for unittest which does not call fleet.init_worker() first
handler
);
if
(
communicator
==
nullptr
)
{
VLOG
(
0
)
<<
"FleetWrapper::RegisterClientToClientMsgHandler communicator is "
"null"
;
return
-
1
;
}
else
{
return
communicator
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
}
}
}
std
::
future
<
int32_t
>
FleetWrapper
::
SendClientToClientMsg
(
std
::
future
<
int32_t
>
FleetWrapper
::
SendClientToClientMsg
(
...
...
paddle/fluid/distributed/service/communicator.cc
浏览文件 @
ca8c4f3e
...
@@ -368,20 +368,7 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
...
@@ -368,20 +368,7 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
VLOG
(
1
)
<<
"push dense param to table "
<<
table_id
VLOG
(
1
)
<<
"push dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
<<
" from 0' trainer done"
;
}
}
BarrierWithTable
(
1
);
}
else
{
BarrierWithTable
(
1
);
for
(
auto
&
iter
:
recv_varname_to_ctx
)
{
auto
&
table_id
=
iter
.
first
;
auto
&
varnames
=
iter
.
second
;
RpcRecvDense
(
varnames
,
table_id
,
recv_scope_
);
VLOG
(
1
)
<<
"pull dense param to table "
<<
table_id
<<
" from 0' trainer done"
;
}
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
+
trainer_id_
*
10
));
BarrierWithTable
(
1
);
return
;
return
;
}
}
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
ca8c4f3e
...
@@ -19,6 +19,10 @@
...
@@ -19,6 +19,10 @@
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/platform/timer.h"
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/fleet.h"
#endif
#if defined _WIN32 || defined __APPLE__
#if defined _WIN32 || defined __APPLE__
#else
#else
#define _LINUX
#define _LINUX
...
@@ -208,13 +212,17 @@ void DatasetImpl<T>::CreateChannel() {
...
@@ -208,13 +212,17 @@ void DatasetImpl<T>::CreateChannel() {
// if sent message between workers, should first call this function
// if sent message between workers, should first call this function
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
RegisterClientToClientMsgHandler
()
{
void
DatasetImpl
<
T
>::
RegisterClientToClientMsgHandler
()
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
#ifdef PADDLE_WITH_PSCORE
VLOG
(
3
)
<<
"RegisterClientToClientMsgHandler"
;
auto
fleet_ptr
=
distributed
::
FleetWrapper
::
GetInstance
();
#else
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
#endif
VLOG
(
1
)
<<
"RegisterClientToClientMsgHandler"
;
fleet_ptr
->
RegisterClientToClientMsgHandler
(
fleet_ptr
->
RegisterClientToClientMsgHandler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
});
VLOG
(
3
)
<<
"RegisterClientToClientMsgHandler done"
;
VLOG
(
1
)
<<
"RegisterClientToClientMsgHandler done"
;
}
}
static
void
compute_left_batch_num
(
const
int
ins_num
,
const
int
thread_num
,
static
void
compute_left_batch_num
(
const
int
ins_num
,
const
int
thread_num
,
std
::
vector
<
std
::
pair
<
int
,
int
>>*
offset
,
std
::
vector
<
std
::
pair
<
int
,
int
>>*
offset
,
...
@@ -523,7 +531,7 @@ void DatasetImpl<T>::LocalShuffle() {
...
@@ -523,7 +531,7 @@ void DatasetImpl<T>::LocalShuffle() {
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end, no data to shuffle"
;
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end, no data to shuffle"
;
return
;
return
;
}
}
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
input_channel_
->
Close
();
input_channel_
->
Close
();
std
::
vector
<
T
>
data
;
std
::
vector
<
T
>
data
;
input_channel_
->
ReadAll
(
data
);
input_channel_
->
ReadAll
(
data
);
...
@@ -540,11 +548,14 @@ void DatasetImpl<T>::LocalShuffle() {
...
@@ -540,11 +548,14 @@ void DatasetImpl<T>::LocalShuffle() {
}
}
void
MultiSlotDataset
::
GlobalShuffle
(
int
thread_num
)
{
void
MultiSlotDataset
::
GlobalShuffle
(
int
thread_num
)
{
#ifdef PADDLE_WITH_PSLIB
VLOG
(
3
)
<<
"MultiSlotDataset::GlobalShuffle() begin"
;
VLOG
(
3
)
<<
"MultiSlotDataset::GlobalShuffle() begin"
;
platform
::
Timer
timeline
;
platform
::
Timer
timeline
;
timeline
.
Start
();
timeline
.
Start
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
#ifdef PADDLE_WITH_PSCORE
auto
fleet_ptr
=
distributed
::
FleetWrapper
::
GetInstance
();
#else
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
#endif
if
(
!
input_channel_
||
input_channel_
->
Size
()
==
0
)
{
if
(
!
input_channel_
||
input_channel_
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"MultiSlotDataset::GlobalShuffle() end, no data to shuffle"
;
VLOG
(
3
)
<<
"MultiSlotDataset::GlobalShuffle() end, no data to shuffle"
;
...
@@ -576,7 +587,12 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
...
@@ -576,7 +587,12 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
};
};
auto
global_shuffle_func
=
[
this
,
get_client_id
]()
{
auto
global_shuffle_func
=
[
this
,
get_client_id
]()
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
#ifdef PADDLE_WITH_PSCORE
auto
fleet_ptr
=
distributed
::
FleetWrapper
::
GetInstance
();
#else
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
#endif
// auto fleet_ptr = framework::FleetWrapper::GetInstance();
std
::
vector
<
Record
>
data
;
std
::
vector
<
Record
>
data
;
while
(
this
->
input_channel_
->
Read
(
data
))
{
while
(
this
->
input_channel_
->
Read
(
data
))
{
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
ars
(
this
->
trainer_num_
);
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
ars
(
this
->
trainer_num_
);
...
@@ -633,7 +649,6 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
...
@@ -633,7 +649,6 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
timeline
.
Pause
();
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end, cost time="
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -936,7 +951,7 @@ int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id,
...
@@ -936,7 +951,7 @@ int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id,
}
}
CHECK
(
ar
.
Cursor
()
==
ar
.
Finish
());
CHECK
(
ar
.
Cursor
()
==
ar
.
Finish
());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
// not use random because it doesn't perform well here.
// not use random because it doesn't perform well here.
// to make sure each channel get data equally, we just put data to
// to make sure each channel get data equally, we just put data to
// channel one by one.
// channel one by one.
...
@@ -976,7 +991,7 @@ void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) {
...
@@ -976,7 +991,7 @@ void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) {
void
MultiSlotDataset
::
PostprocessInstance
()
{
void
MultiSlotDataset
::
PostprocessInstance
()
{
// divide pv instance, and merge to input_channel_
// divide pv instance, and merge to input_channel_
if
(
enable_pv_merge_
)
{
if
(
enable_pv_merge_
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
std
::
shuffle
(
input_records_
.
begin
(),
input_records_
.
end
(),
std
::
shuffle
(
input_records_
.
begin
(),
input_records_
.
end
(),
fleet_ptr
->
LocalRandomEngine
());
fleet_ptr
->
LocalRandomEngine
());
input_channel_
->
Open
();
input_channel_
->
Open
();
...
@@ -1014,7 +1029,7 @@ void MultiSlotDataset::PreprocessInstance() {
...
@@ -1014,7 +1029,7 @@ void MultiSlotDataset::PreprocessInstance() {
if
(
!
enable_pv_merge_
)
{
// means to use Record
if
(
!
enable_pv_merge_
)
{
// means to use Record
this
->
LocalShuffle
();
this
->
LocalShuffle
();
}
else
{
// means to use Pv
}
else
{
// means to use Pv
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
input_channel_
->
Close
();
input_channel_
->
Close
();
std
::
vector
<
PvInstance
>
pv_data
;
std
::
vector
<
PvInstance
>
pv_data
;
input_channel_
->
ReadAll
(
input_records_
);
input_channel_
->
ReadAll
(
input_records_
);
...
@@ -1073,7 +1088,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
...
@@ -1073,7 +1088,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
}
}
CHECK
(
multi_output_channel_
.
size
()
!=
0
);
// NOLINT
CHECK
(
multi_output_channel_
.
size
()
!=
0
);
// NOLINT
auto
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr_
=
framework
::
FleetWrapper
::
GetInstance
();
std
::
vector
<
std
::
unordered_map
<
uint64_t
,
std
::
vector
<
float
>>>&
std
::
vector
<
std
::
unordered_map
<
uint64_t
,
std
::
vector
<
float
>>>&
local_map_tables
=
fleet_ptr_
->
GetLocalTable
();
local_map_tables
=
fleet_ptr_
->
GetLocalTable
();
local_map_tables
.
resize
(
shard_num
);
local_map_tables
.
resize
(
shard_num
);
...
@@ -1315,7 +1330,7 @@ void MultiSlotDataset::MergeByInsId() {
...
@@ -1315,7 +1330,7 @@ void MultiSlotDataset::MergeByInsId() {
LOG
(
WARNING
)
<<
"total drop ins num: "
<<
drop_ins_num
;
LOG
(
WARNING
)
<<
"total drop ins num: "
<<
drop_ins_num
;
results
.
shrink_to_fit
();
results
.
shrink_to_fit
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
framework
::
FleetWrapper
::
GetInstance
();
std
::
shuffle
(
results
.
begin
(),
results
.
end
(),
fleet_ptr
->
LocalRandomEngine
());
std
::
shuffle
(
results
.
begin
(),
results
.
end
(),
fleet_ptr
->
LocalRandomEngine
());
channel_data
->
Open
();
channel_data
->
Open
();
channel_data
->
Write
(
std
::
move
(
results
));
channel_data
->
Write
(
std
::
move
(
results
));
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
ca8c4f3e
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include <string>
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#if defined PADDLE_WITH_PSCORE
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/communicator.h"
...
@@ -153,7 +154,20 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
...
@@ -153,7 +154,20 @@ void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if
(
need_dump_field_
||
need_dump_param_
)
{
if
(
need_dump_field_
||
need_dump_param_
)
{
InitDumpEnv
();
InitDumpEnv
();
}
}
VLOG
(
3
)
<<
"init other env done."
;
#ifdef PADDLE_WITH_PSCORE
// pull dense param first
auto
communicator
=
paddle
::
distributed
::
Communicator
::
GetInstance
();
// for unittest which call train_from_dataset but does not call
// fleet.init_worker() first
if
(
communicator
==
nullptr
)
{
VLOG
(
0
)
<<
"MultiTrainer::InitOtherEnv Communicator is null!"
;
}
else
{
auto
&
recv_ctx
=
communicator
->
GetRecvCtxMap
();
communicator
->
PullDense
(
recv_ctx
);
VLOG
(
3
)
<<
"init other env done."
;
}
#endif
}
}
Scope
*
MultiTrainer
::
GetWorkerScope
(
int
thread_id
)
{
Scope
*
MultiTrainer
::
GetWorkerScope
(
int
thread_id
)
{
...
@@ -253,6 +267,17 @@ void MultiTrainer::Finalize() {
...
@@ -253,6 +267,17 @@ void MultiTrainer::Finalize() {
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
MergeDenseParam
();
MergeDenseParam
();
#endif
#endif
#if defined PADDLE_WITH_PSCORE
auto
communicator
=
paddle
::
distributed
::
Communicator
::
GetInstance
();
// for unittest which does not call fleet.init_worker() first
if
(
communicator
==
nullptr
)
{
VLOG
(
0
)
<<
"MultiTrainer::Finalize communicator is null!"
;
}
else
{
communicator
->
_worker_ptr
->
flush
();
VLOG
(
1
)
<<
"MultiTrainer::Finalize ps client flush done"
;
}
#endif
root_scope_
->
DropKids
();
root_scope_
->
DropKids
();
}
}
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
ca8c4f3e
...
@@ -577,8 +577,12 @@ class TheOnePSRuntime(RuntimeBase):
...
@@ -577,8 +577,12 @@ class TheOnePSRuntime(RuntimeBase):
else
:
else
:
init_params
=
dense_map
init_params
=
dense_map
import
paddle.distributed.fleet
as
fleet
if
not
is_test
:
if
not
is_test
:
self
.
_communicator
.
init_params
(
init_params
)
self
.
_communicator
.
init_params
(
init_params
)
fleet
.
util
.
barrier
()
self
.
_communicator
.
pull_dense
(
init_params
)
fleet
.
util
.
barrier
()
if
not
self
.
_communicator
.
is_running
():
if
not
self
.
_communicator
.
is_running
():
self
.
_communicator
.
start
()
self
.
_communicator
.
start
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录