Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
162d5fc7
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
162d5fc7
编写于
9月 17, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
global_shuffle
Change-Id: I431825fe28853a81151febcd0a45fd77584a6e2a
上级
229964e4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
193 addition
and
40 deletion
+193
-40
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+10
-0
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
+163
-40
paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc
...in/custom_trainer/feed/unit_test/test_archive_dataitem.cc
+20
-0
未找到文件。
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
162d5fc7
...
...
@@ -68,6 +68,16 @@ public:
std
::
string
data
;
//样本数据, maybe压缩格式
};
template
<
class
AR
>
paddle
::
framework
::
Archive
<
AR
>&
operator
>>
(
paddle
::
framework
::
Archive
<
AR
>&
ar
,
DataItem
&
x
)
{
return
ar
>>
x
.
id
>>
x
.
data
;
}
template
<
class
AR
>
paddle
::
framework
::
Archive
<
AR
>&
operator
<<
(
paddle
::
framework
::
Archive
<
AR
>&
ar
,
const
DataItem
&
x
)
{
return
ar
<<
x
.
id
<<
x
.
data
;
}
typedef
std
::
shared_ptr
<
Pipeline
<
DataItem
,
SampleInstance
>>
SampleInstancePipe
;
inline
SampleInstancePipe
make_sample_instance_channel
()
{
return
std
::
make_shared
<
Pipeline
<
DataItem
,
SampleInstance
>>
();
...
...
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
浏览文件 @
162d5fc7
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
#include <bthread/butex.h>
namespace
paddle
{
namespace
custom_trainer
{
...
...
@@ -40,73 +41,195 @@ public:
Shuffler
::
initialize
(
config
,
context_ptr
);
_max_concurrent_num
=
config
[
"max_concurrent_num"
].
as
<
int
>
(
4
);
// 最大并发发送数
_max_package_size
=
config
[
"max_package_size"
].
as
<
int
>
(
1024
);
// 最大包个数,一次发送package个数据
_shuffle_data_msg_type
=
config
[
"shuffle_data_msg_type"
].
as
<
int
>
(
3
);
// c2c msg type
_finish_msg_type
=
config
[
"finish_msg_type"
].
as
<
int
>
(
4
);
// c2c msg type
reset_channel
();
auto
binded
=
std
::
bind
(
&
GlobalShuffler
::
get_client2client_msg
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_trainer_context
->
pslib
->
ps_client
()
->
registe_client2client_msg_handler
(
_shuffle_data_msg_type
,
binded
);
_trainer_context
->
pslib
->
ps_client
()
->
registe_client2client_msg_handler
(
_finish_msg_type
,
binded
);
return
0
;
}
// 所有worker必须都调用shuffle,并且shuffler同时只能有一个shuffle任务
virtual
int
shuffle
(
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
{
uint32_t
send_count
=
0
;
uint32_t
package_size
=
_max_package_size
;
uint32_t
concurrent_num
=
_max_concurrent_num
;
uint32_t
current_wait_idx
=
0
;
::
paddle
::
framework
::
Channel
<
DataItem
>
input_channel
=
::
paddle
::
framework
::
MakeChannel
<
DataItem
>
(
data_channel
);
data_channel
.
swap
(
input_channel
);
set_channel
(
data_channel
);
auto
*
environment
=
_trainer_context
->
environment
.
get
();
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
std
::
vector
<
std
::
vector
<
std
::
future
<
int
>>>
waits
(
concurrent_num
);
std
::
vector
<
DataItem
>
send_buffer
(
concurrent_num
*
package_size
);
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
request_data_buffer
(
worker_num
);
while
(
true
)
{
auto
read_size
=
data_channel
->
Read
(
concurrent_num
*
package_size
,
&
send_buffer
[
0
]);
if
(
read_size
==
0
)
{
break
;
std
::
vector
<
DataItem
>
send_buffer
(
package_size
);
std
::
vector
<
std
::
vector
<
DataItem
>>
send_buffer_worker
(
worker_num
);
int
status
=
0
;
// >0: finish; =0: running; <0: fail
while
(
status
==
0
)
{
// update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if
(
_trainer_context
->
is_status
(
TrainerStatus
::
Training
))
{
concurrent_num
=
1
;
package_size
=
_max_concurrent_num
/
2
;
}
else
{
package_size
=
_max_package_size
;
concurrent_num
=
_max_concurrent_num
;
}
for
(
size_t
idx
=
0
;
idx
<
read_size
;
idx
+=
package_size
)
{
// data shard && seriliaze
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
request_data_buffer
[
i
].
Clear
();
for
(
uint32_t
current_wait_idx
=
0
;
status
==
0
&&
current_wait_idx
<
concurrent_num
;
++
current_wait_idx
)
{
auto
read_size
=
input_channel
->
Read
(
package_size
,
send_buffer
.
data
());
if
(
read_size
==
0
)
{
status
=
1
;
break
;
}
for
(
size_t
i
=
idx
;
i
<
package_size
&&
i
<
read_size
;
++
i
)
{
auto
worker_idx
=
_shuffle_key_func
(
send_buffer
[
i
].
id
)
%
worker_num
;
// TODO Serialize To Arcive
//request_data_buffer[worker_idx] << send_buffer[i];
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
send_buffer_worker
.
clear
();
}
std
::
string
data_vec
[
worker_num
];
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
auto
&
buffer
=
request_data_buffer
[
i
];
data_vec
[
i
].
assign
(
buffer
.
Buffer
(),
buffer
.
Length
());
for
(
int
i
=
0
;
i
<
read_size
;
++
i
)
{
auto
worker_idx
=
_shuffle_key_func
(
send_buffer
[
i
].
id
)
%
worker_num
;
send_buffer_worker
[
worker_idx
].
push_back
(
std
::
move
(
send_buffer
[
i
]));
}
// wait async done
for
(
auto
&
wait_s
:
waits
[
current_wait_idx
])
{
if
(
!
wait_s
.
valid
())
{
if
(
wait_s
.
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send shuffle data"
;
status
=
-
1
;
break
;
}
CHECK
(
wait_s
.
get
()
==
0
);
}
// send shuffle data
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
waits
[
current_wait_idx
][
i
]
=
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
3
,
i
*
2
,
data_vec
[
i
]);
if
(
status
!=
0
)
{
break
;
}
waits
[
current_wait_idx
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
!
send_buffer_worker
[
i
].
empty
())
{
waits
[
current_wait_idx
].
push_back
(
send_shuffle_data
(
i
,
send_buffer_worker
[
i
]));
}
}
// update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if
(
_trainer_context
->
is_status
(
TrainerStatus
::
Training
))
{
concurrent_num
=
1
;
package_size
=
_max_concurrent_num
/
2
;
}
else
{
package_size
=
_max_package_size
;
concurrent_num
=
_max_concurrent_num
;
}
}
for
(
auto
&
waits_s
:
waits
)
{
for
(
auto
&
wait_s
:
waits_s
)
{
if
(
wait_s
.
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send shuffle data"
;
status
=
-
1
;
}
++
current_wait_idx
;
current_wait_idx
=
current_wait_idx
>=
concurrent_num
?
0
:
current_wait_idx
;
}
}
return
0
;
VLOG
(
5
)
<<
"start send finish, worker_num: "
<<
worker_num
;
waits
[
0
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
waits
[
0
].
push_back
(
send_finish
(
i
));
}
VLOG
(
5
)
<<
"wait all finish"
;
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
waits
[
0
][
i
].
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send finish "
<<
i
;
status
=
-
1
;
}
}
VLOG
(
5
)
<<
"finish shuffler, status: "
<<
status
;
return
status
<
0
?
status
:
0
;
}
private:
/*
1. 部分c2c send_shuffle_data先到, 此时channel未设置, 等待wait_channel
2. shuffle中调用set_channel, 先reset_wait_num, 再解锁channel
3. 当接收到所有worker的finish请求后,先reset_channel, 再同时返回
*/
bool
wait_channel
()
{
VLOG
(
5
)
<<
"wait_channel"
;
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_channel_mutex
);
return
_out_channel
!=
nullptr
;
}
void
reset_channel
()
{
VLOG
(
5
)
<<
"reset_channel"
;
_channel_mutex
.
lock
();
if
(
_out_channel
!=
nullptr
)
{
_out_channel
->
Close
();
}
_out_channel
=
nullptr
;
}
void
reset_wait_num
()
{
_wait_num_mutex
.
lock
();
_wait_num
=
_trainer_context
->
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
VLOG
(
5
)
<<
"reset_wait_num: "
<<
_wait_num
;
}
void
set_channel
(
paddle
::
framework
::
Channel
<
DataItem
>&
channel
)
{
VLOG
(
5
)
<<
"set_channel"
;
// 在节点开始写入channel之前,重置wait_num
CHECK
(
_out_channel
==
nullptr
);
_out_channel
=
channel
;
reset_wait_num
();
_channel_mutex
.
unlock
();
}
int32_t
finish_write_channel
()
{
int
wait_num
=
--
_wait_num
;
VLOG
(
5
)
<<
"finish_write_channel, wait_num: "
<<
wait_num
;
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
if
(
wait_num
==
0
)
{
reset_channel
();
_wait_num_mutex
.
unlock
();
}
else
{
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_wait_num_mutex
);
}
return
0
;
}
int32_t
write_to_channel
(
std
::
vector
<
DataItem
>&&
items
)
{
size_t
items_size
=
items
.
size
();
VLOG
(
5
)
<<
"write_to_channel, items_size: "
<<
items_size
;
return
_out_channel
->
Write
(
std
::
move
(
items
))
==
items_size
?
0
:
-
1
;
}
int32_t
get_client2client_msg
(
int
msg_type
,
int
from_client
,
const
std
::
string
&
msg
)
{
// wait channel
if
(
!
wait_channel
())
{
LOG
(
FATAL
)
<<
"out_channel is null"
;
return
-
1
;
}
VLOG
(
5
)
<<
"get c2c msg, type: "
<<
msg_type
<<
", from_client: "
<<
from_client
<<
", msg_size: "
<<
msg
.
size
();
if
(
msg_type
==
_shuffle_data_msg_type
)
{
paddle
::
framework
::
BinaryArchive
ar
;
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
msg
.
data
()),
msg
.
size
(),
[](
char
*
){});
std
::
vector
<
DataItem
>
items
;
ar
>>
items
;
return
write_to_channel
(
std
::
move
(
items
));
}
else
if
(
msg_type
==
_finish_msg_type
)
{
return
finish_write_channel
();
}
LOG
(
FATAL
)
<<
"no such msg type: "
<<
msg_type
;
return
-
1
;
}
std
::
future
<
int32_t
>
send_shuffle_data
(
int
to_client_id
,
std
::
vector
<
DataItem
>&
items
)
{
VLOG
(
5
)
<<
"send_shuffle_data, to_client_id: "
<<
to_client_id
<<
", items_size: "
<<
items
.
size
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
items
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_shuffle_data_msg_type
,
to_client_id
,
std
::
string
(
ar
.
Buffer
(),
ar
.
Length
()));
}
std
::
future
<
int32_t
>
send_finish
(
int
to_client_id
)
{
VLOG
(
5
)
<<
"send_finish, to_client_id: "
<<
to_client_id
;
static
const
std
::
string
empty_str
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_finish_msg_type
,
to_client_id
,
empty_str
);
}
uint32_t
_max_package_size
=
0
;
uint32_t
_max_concurrent_num
=
0
;
int
_shuffle_data_msg_type
=
3
;
int
_finish_msg_type
=
4
;
bthread
::
Mutex
_channel_mutex
;
paddle
::
framework
::
Channel
<
DataItem
>
_out_channel
=
nullptr
;
bthread
::
Mutex
_wait_num_mutex
;
std
::
atomic
<
int
>
_wait_num
;
};
REGIST_CLASS
(
Shuffler
,
GlobalShuffler
);
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc
0 → 100644
浏览文件 @
162d5fc7
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
TEST
(
Archive
,
DataItem
)
{
paddle
::
custom_trainer
::
feed
::
DataItem
item
;
paddle
::
custom_trainer
::
feed
::
DataItem
item2
;
item
.
id
=
"123"
;
item
.
data
=
"name"
;
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
item
;
ar
>>
item2
;
ASSERT_EQ
(
item
.
id
,
item2
.
id
);
ASSERT_EQ
(
item
.
data
,
item2
.
data
);
item
.
id
+=
"~"
;
item
.
data
+=
"~"
;
ASSERT_NE
(
item
.
id
,
item2
.
id
);
ASSERT_NE
(
item
.
data
,
item2
.
data
);
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录