Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
b23282f3
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b23282f3
编写于
9月 18, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix shuffler bug
上级
7654080c
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
43 addition
and
12 deletion
+43
-12
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+4
-0
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
+17
-7
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+22
-5
未找到文件。
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
b23282f3
...
...
@@ -77,6 +77,7 @@ int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
return
0
;
}
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
ContextStatusGurad
status_guard
(
_context_ptr
,
TrainerStatus
::
Saving
);
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
...
...
@@ -154,6 +155,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
if
(
!
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
return
0
;
}
VLOG
(
2
)
<<
"Start Load Model"
;
auto
*
fs
=
_context_ptr
->
file_system
.
get
();
std
::
set
<
uint32_t
>
loaded_table_set
;
auto
model_dir
=
_context_ptr
->
epoch_accessor
->
checkpoint_path
();
...
...
@@ -177,6 +179,7 @@ int LearnerProcess::load_model(uint64_t epoch_id) {
loaded_table_set
.
insert
(
itr
.
first
);
}
}
VLOG
(
2
)
<<
"Finish Load Model"
;
return
0
;
}
...
...
@@ -223,6 +226,7 @@ int LearnerProcess::run() {
//Step2. 运行训练网络
{
ContextStatusGurad
status_guard
(
_context_ptr
,
TrainerStatus
::
Training
);
std
::
map
<
std
::
string
,
paddle
::
framework
::
Channel
<
DataItem
>>
backup_input_map
;
for
(
auto
&
executor
:
_executors
)
{
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
...
...
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
浏览文件 @
b23282f3
...
...
@@ -39,8 +39,8 @@ public:
virtual
int
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
Shuffler
::
initialize
(
config
,
context_ptr
);
_max_concurrent_num
=
config
[
"max_concurrent_num"
].
as
<
int
>
(
4
);
// 最大并发发送数
_max_package_size
=
config
[
"max_package_size"
].
as
<
int
>
(
1024
);
// 最大包个数,一次发送package个数据
_max_concurrent_num
=
config
[
"max_concurrent_num"
].
as
<
int
>
(
6
);
// 最大并发发送数
_max_package_size
=
config
[
"max_package_size"
].
as
<
int
>
(
256
);
// 最大包个数,一次发送package个数据
_shuffle_data_msg_type
=
config
[
"shuffle_data_msg_type"
].
as
<
int
>
(
3
);
// c2c msg type
_finish_msg_type
=
config
[
"finish_msg_type"
].
as
<
int
>
(
4
);
// c2c msg type
...
...
@@ -62,6 +62,8 @@ public:
data_channel
.
swap
(
input_channel
);
set_channel
(
data_channel
);
_item_send_count
=
0
;
_item_receive_count
=
0
;
auto
*
environment
=
_trainer_context
->
environment
.
get
();
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
std
::
vector
<
std
::
vector
<
std
::
future
<
int
>>>
waits
(
concurrent_num
);
...
...
@@ -86,8 +88,9 @@ public:
status
=
1
;
break
;
}
_item_send_count
+=
read_size
;
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
send_buffer_worker
.
clear
();
send_buffer_worker
[
i
]
.
clear
();
}
for
(
int
i
=
0
;
i
<
read_size
;
++
i
)
{
auto
worker_idx
=
_shuffle_key_func
(
send_buffer
[
i
].
id
)
%
worker_num
;
...
...
@@ -119,19 +122,19 @@ public:
}
}
}
VLOG
(
5
)
<<
"start send finish, worker_num: "
<<
worker_num
;
VLOG
(
2
)
<<
"start send finish, worker_num: "
<<
worker_num
;
waits
[
0
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
waits
[
0
].
push_back
(
send_finish
(
i
));
}
VLOG
(
5
)
<<
"wait all finish"
;
VLOG
(
2
)
<<
"wait all finish"
;
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
waits
[
0
][
i
].
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send finish "
<<
i
;
status
=
-
1
;
}
}
VLOG
(
5
)
<<
"finish shuffler, status: "
<<
status
;
VLOG
(
2
)
<<
"finish shuffler_send_channel, total_send:"
<<
_item_send_count
;
return
status
<
0
?
status
:
0
;
}
...
...
@@ -174,6 +177,7 @@ private:
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
if
(
wait_num
==
0
)
{
reset_channel
();
VLOG
(
2
)
<<
"finish shuffle_receive_channel, receive_count: "
<<
_item_receive_count
;
_wait_num_mutex
.
unlock
();
}
else
{
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_wait_num_mutex
);
...
...
@@ -182,7 +186,7 @@ private:
}
int32_t
write_to_channel
(
std
::
vector
<
DataItem
>&&
items
)
{
size_t
items_size
=
items
.
size
();
VLOG
(
5
)
<<
"write_to_channel, items_size: "
<<
items_size
;
_item_receive_count
+=
items_size
;
return
_out_channel
->
Write
(
std
::
move
(
items
))
==
items_size
?
0
:
-
1
;
}
...
...
@@ -207,6 +211,8 @@ private:
}
std
::
future
<
int32_t
>
send_shuffle_data
(
int
to_client_id
,
std
::
vector
<
DataItem
>&
items
)
{
// server端也开启了client, worker节点为偶数编号
to_client_id
=
2
*
to_client_id
;
VLOG
(
5
)
<<
"send_shuffle_data, to_client_id: "
<<
to_client_id
<<
", items_size: "
<<
items
.
size
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
items
;
...
...
@@ -215,6 +221,8 @@ private:
}
std
::
future
<
int32_t
>
send_finish
(
int
to_client_id
)
{
// server端也开启了client, worker节点为偶数编号
to_client_id
=
2
*
to_client_id
;
VLOG
(
5
)
<<
"send_finish, to_client_id: "
<<
to_client_id
;
static
const
std
::
string
empty_str
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_finish_msg_type
,
to_client_id
,
empty_str
);
...
...
@@ -230,6 +238,8 @@ private:
bthread
::
Mutex
_wait_num_mutex
;
std
::
atomic
<
int
>
_wait_num
;
std
::
atomic
<
uint32_t
>
_item_send_count
;
std
::
atomic
<
uint32_t
>
_item_receive_count
;
};
REGIST_CLASS
(
Shuffler
,
GlobalShuffler
);
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
b23282f3
...
...
@@ -95,20 +95,23 @@ private:
class
TrainerContext
{
public:
TrainerContext
()
{
trainer_status
.
resize
(
2
,
0
);
}
inline
paddle
::
ps
::
PSClient
*
ps_client
()
{
return
pslib
->
ps_client
();
}
inline
bool
is_status
(
TrainerStatus
status
)
{
auto
bit
_idx
=
static_cast
<
uint32_t
>
(
status
);
return
((
trainer_status
>>
bit_idx
)
&
1
)
>
0
;
auto
status
_idx
=
static_cast
<
uint32_t
>
(
status
);
return
trainer_status
[
status_idx
]
>
0
;
}
// 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性
inline
void
set_status
(
TrainerStatus
status
,
bool
on
)
{
auto
bit
_idx
=
static_cast
<
uint32_t
>
(
status
);
trainer_status
=
trainer_status
&
(
1L
<<
bit_idx
)
;
auto
status
_idx
=
static_cast
<
uint32_t
>
(
status
);
trainer_status
[
status_idx
]
=
on
?
1
:
0
;
}
uint32_t
trainer_status
;
// trainer当前,由于可同时处于多种状态,这里分bit存储状态
std
::
vector
<
uint32_t
>
trainer_status
;
YAML
::
Node
trainer_config
;
paddle
::
platform
::
CPUPlace
cpu_place
;
...
...
@@ -122,6 +125,20 @@ public:
std
::
shared_ptr
<
SignCacheDict
>
cache_dict
;
//大模型cache词典
};
class
ContextStatusGurad
{
public:
ContextStatusGurad
(
TrainerContext
*
context
,
TrainerStatus
status
)
:
_context
(
context
),
_status
(
status
)
{
_context
->
set_status
(
_status
,
true
);
}
virtual
~
ContextStatusGurad
()
{
_context
->
set_status
(
_status
,
false
);
}
private:
TrainerStatus
_status
;
TrainerContext
*
_context
=
nullptr
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录