Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
9552cf55
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9552cf55
编写于
9月 18, 2019
作者:
L
linan17
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
ssh://icode.baidu.com:8235/baidu/feed-mlarch/paddle-trainer
Change-Id: If03dc2ea6e1a8b9bcedb9d5c9fa9dbcc44d41396
上级
2a933c78
7654080c
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
335 addition
and
46 deletion
+335
-46
BCLOUD
BCLOUD
+1
-0
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+10
-0
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
...ain/custom_trainer/feed/executor/multi_thread_executor.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+1
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+83
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+5
-1
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
+163
-40
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+51
-4
paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc
...in/custom_trainer/feed/unit_test/test_archive_dataitem.cc
+20
-0
未找到文件。
BCLOUD
浏览文件 @
9552cf55
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/feed-mlarch/hopscotch-map@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
9552cf55
...
@@ -68,6 +68,16 @@ public:
...
@@ -68,6 +68,16 @@ public:
std
::
string
data
;
//样本数据, maybe压缩格式
std
::
string
data
;
//样本数据, maybe压缩格式
};
};
template
<
class
AR
>
paddle
::
framework
::
Archive
<
AR
>&
operator
>>
(
paddle
::
framework
::
Archive
<
AR
>&
ar
,
DataItem
&
x
)
{
return
ar
>>
x
.
id
>>
x
.
data
;
}
template
<
class
AR
>
paddle
::
framework
::
Archive
<
AR
>&
operator
<<
(
paddle
::
framework
::
Archive
<
AR
>&
ar
,
const
DataItem
&
x
)
{
return
ar
<<
x
.
id
<<
x
.
data
;
}
typedef
std
::
shared_ptr
<
Pipeline
<
DataItem
,
SampleInstance
>>
SampleInstancePipe
;
typedef
std
::
shared_ptr
<
Pipeline
<
DataItem
,
SampleInstance
>>
SampleInstancePipe
;
inline
SampleInstancePipe
make_sample_instance_channel
()
{
inline
SampleInstancePipe
make_sample_instance_channel
()
{
return
std
::
make_shared
<
Pipeline
<
DataItem
,
SampleInstance
>>
();
return
std
::
make_shared
<
Pipeline
<
DataItem
,
SampleInstance
>>
();
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
浏览文件 @
9552cf55
...
@@ -243,7 +243,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
...
@@ -243,7 +243,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
for
(
auto
&
monitor
:
_monitors
)
{
for
(
auto
&
monitor
:
_monitors
)
{
if
(
monitor
->
need_compute_result
(
epoch_id
))
{
if
(
monitor
->
need_compute_result
(
epoch_id
))
{
monitor
->
compute_result
();
monitor
->
compute_result
();
ENVLOG_WORKER_MASTER_NOTICE
(
"[Monitor]%s, monitor:%s,
,
result:%s"
,
ENVLOG_WORKER_MASTER_NOTICE
(
"[Monitor]%s, monitor:%s, result:%s"
,
_train_exe_name
.
c_str
(),
monitor
->
get_name
().
c_str
(),
monitor
->
format_result
().
c_str
());
_train_exe_name
.
c_str
(),
monitor
->
get_name
().
c_str
(),
monitor
->
format_result
().
c_str
());
_trainer_context
->
monitor_ssm
<<
_train_exe_name
<<
":"
<<
_trainer_context
->
monitor_ssm
<<
_train_exe_name
<<
":"
<<
monitor
->
get_name
()
<<
":"
<<
monitor
->
format_result
()
<<
","
;
monitor
->
get_name
()
<<
":"
<<
monitor
->
format_result
()
<<
","
;
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
9552cf55
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
//load trainer config
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
cache_dict
.
reset
(
new
SignCacheDict
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
//environment
//environment
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
9552cf55
...
@@ -16,6 +16,8 @@ namespace feed {
...
@@ -16,6 +16,8 @@ namespace feed {
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
Process
::
initialize
(
context_ptr
);
int
ret
=
Process
::
initialize
(
context_ptr
);
auto
&
config
=
_context_ptr
->
trainer_config
;
auto
&
config
=
_context_ptr
->
trainer_config
;
_is_dump_cache_model
=
config
[
"dump_cache_model"
].
as
<
bool
>
(
false
);
_cache_load_converter
=
config
[
"load_cache_converter"
].
as
<
std
::
string
>
(
""
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
if
(
config
[
"executor"
])
{
if
(
config
[
"executor"
])
{
_executors
.
resize
(
config
[
"executor"
].
size
());
_executors
.
resize
(
config
[
"executor"
].
size
());
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return
0
;
return
0
;
}
}
// 更新各节点存储的CacheModel
int
LearnerProcess
::
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
if
(
!
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
return
0
;
}
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
auto
model_dir
=
epoch_accessor
->
model_save_path
(
epoch_id
,
way
);
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/"
,
param
.
table_id
()));
if
(
!
fs
->
exists
(
cache_model_path
))
{
continue
;
}
auto
&
cache_dict
=
*
(
_context_ptr
->
cache_dict
.
get
());
cache_dict
.
clear
();
cache_dict
.
reserve
(
_cache_sign_max_num
);
auto
cache_file_list
=
fs
->
list
(
fs
->
path_join
(
cache_model_path
,
"part*"
));
for
(
auto
&
cache_path
:
cache_file_list
)
{
auto
cache_file
=
fs
->
open_read
(
cache_path
,
_cache_load_converter
);
char
*
buffer
=
nullptr
;
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
cache_file
.
get
()))
!=
-
1
)
{
if
(
line_len
<=
1
)
{
continue
;
}
char
*
data_ptr
=
NULL
;
cache_dict
.
append
(
strtoul
(
buffer
,
&
data_ptr
,
10
));
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
}
}
break
;
}
}
return
0
;
}
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
}
}
timer
.
Pause
();
timer
.
Pause
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
// save cache model, 只有inference需要cache_model
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
(
way
==
ModelSaveWay
::
ModelSaveInferenceBase
||
way
==
ModelSaveWay
::
ModelSaveInferenceDelta
))
{
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
double
cache_threshold
=
0.0
;
auto
status
=
ps_client
->
get_cache_threshold
(
param
.
table_id
(),
cache_threshold
);
CHECK
(
status
.
get
()
==
0
)
<<
"CacheThreshold Get failed!"
;
status
=
ps_client
->
cache_shuffle
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
),
std
::
to_string
(
cache_threshold
));
CHECK
(
status
.
get
()
==
0
)
<<
"Cache Shuffler Failed"
;
status
=
ps_client
->
save_cache
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
));
auto
feature_size
=
status
.
get
();
CHECK
(
feature_size
>=
0
)
<<
"Cache Save Failed"
;
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/sparse_cache.meta"
,
param
.
table_id
()));
auto
cache_meta_file
=
fs
->
open_write
(
cache_model_path
,
""
);
auto
meta
=
string
::
format_string
(
"file_prefix:part
\n
part_num:%d
\n
key_num:%d
\n
"
,
param
.
sparse_table_cache_file_num
(),
feature_size
);
CHECK
(
fwrite
(
meta
.
c_str
(),
meta
.
size
(),
1
,
cache_meta_file
.
get
())
==
1
)
<<
"Cache Meta Failed"
;
if
(
feature_size
>
_cache_sign_max_num
)
{
_cache_sign_max_num
=
feature_size
;
}
}
}
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
return
all_ret
;
return
all_ret
;
}
}
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
{
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
update_cache_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
}
else
{
}
else
{
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
浏览文件 @
9552cf55
...
@@ -22,9 +22,13 @@ protected:
...
@@ -22,9 +22,13 @@ protected:
virtual
int
load_model
(
uint64_t
epoch_id
);
virtual
int
load_model
(
uint64_t
epoch_id
);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
);
private:
private:
bool
_startup_dump_inference_base
;
//启动立即dump base
bool
_is_dump_cache_model
;
// 是否进行cache dump
uint32_t
_cache_sign_max_num
=
0
;
// cache sign最大个数
std
::
string
_cache_load_converter
;
// cache加载的前置转换脚本
bool
_startup_dump_inference_base
;
// 启动立即dump base
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
};
};
...
...
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
浏览文件 @
9552cf55
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
#include <bthread/butex.h>
namespace
paddle
{
namespace
paddle
{
namespace
custom_trainer
{
namespace
custom_trainer
{
...
@@ -40,73 +41,195 @@ public:
...
@@ -40,73 +41,195 @@ public:
Shuffler
::
initialize
(
config
,
context_ptr
);
Shuffler
::
initialize
(
config
,
context_ptr
);
_max_concurrent_num
=
config
[
"max_concurrent_num"
].
as
<
int
>
(
4
);
// 最大并发发送数
_max_concurrent_num
=
config
[
"max_concurrent_num"
].
as
<
int
>
(
4
);
// 最大并发发送数
_max_package_size
=
config
[
"max_package_size"
].
as
<
int
>
(
1024
);
// 最大包个数,一次发送package个数据
_max_package_size
=
config
[
"max_package_size"
].
as
<
int
>
(
1024
);
// 最大包个数,一次发送package个数据
_shuffle_data_msg_type
=
config
[
"shuffle_data_msg_type"
].
as
<
int
>
(
3
);
// c2c msg type
_finish_msg_type
=
config
[
"finish_msg_type"
].
as
<
int
>
(
4
);
// c2c msg type
reset_channel
();
auto
binded
=
std
::
bind
(
&
GlobalShuffler
::
get_client2client_msg
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_trainer_context
->
pslib
->
ps_client
()
->
registe_client2client_msg_handler
(
_shuffle_data_msg_type
,
binded
);
_trainer_context
->
pslib
->
ps_client
()
->
registe_client2client_msg_handler
(
_finish_msg_type
,
binded
);
return
0
;
return
0
;
}
}
// 所有worker必须都调用shuffle,并且shuffler同时只能有一个shuffle任务
virtual
int
shuffle
(
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
{
virtual
int
shuffle
(
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
{
uint32_t
send_count
=
0
;
uint32_t
send_count
=
0
;
uint32_t
package_size
=
_max_package_size
;
uint32_t
package_size
=
_max_package_size
;
uint32_t
concurrent_num
=
_max_concurrent_num
;
uint32_t
concurrent_num
=
_max_concurrent_num
;
uint32_t
current_wait_idx
=
0
;
::
paddle
::
framework
::
Channel
<
DataItem
>
input_channel
=
::
paddle
::
framework
::
MakeChannel
<
DataItem
>
(
data_channel
);
data_channel
.
swap
(
input_channel
);
set_channel
(
data_channel
);
auto
*
environment
=
_trainer_context
->
environment
.
get
();
auto
*
environment
=
_trainer_context
->
environment
.
get
();
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
std
::
vector
<
std
::
vector
<
std
::
future
<
int
>>>
waits
(
concurrent_num
);
std
::
vector
<
std
::
vector
<
std
::
future
<
int
>>>
waits
(
concurrent_num
);
std
::
vector
<
DataItem
>
send_buffer
(
concurrent_num
*
package_size
);
std
::
vector
<
DataItem
>
send_buffer
(
package_size
);
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
request_data_buffer
(
worker_num
);
std
::
vector
<
std
::
vector
<
DataItem
>>
send_buffer_worker
(
worker_num
);
while
(
true
)
{
auto
read_size
=
data_channel
->
Read
(
concurrent_num
*
package_size
,
&
send_buffer
[
0
]);
int
status
=
0
;
// >0: finish; =0: running; <0: fail
while
(
status
==
0
)
{
// update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if
(
_trainer_context
->
is_status
(
TrainerStatus
::
Training
))
{
concurrent_num
=
1
;
package_size
=
_max_concurrent_num
/
2
;
}
else
{
package_size
=
_max_package_size
;
concurrent_num
=
_max_concurrent_num
;
}
for
(
uint32_t
current_wait_idx
=
0
;
status
==
0
&&
current_wait_idx
<
concurrent_num
;
++
current_wait_idx
)
{
auto
read_size
=
input_channel
->
Read
(
package_size
,
send_buffer
.
data
());
if
(
read_size
==
0
)
{
if
(
read_size
==
0
)
{
status
=
1
;
break
;
break
;
}
}
for
(
size_t
idx
=
0
;
idx
<
read_size
;
idx
+=
package_size
)
{
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
// data shard && seriliaze
send_buffer_worker
.
clear
();
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
request_data_buffer
[
i
].
Clear
();
}
}
for
(
size_t
i
=
idx
;
i
<
package_size
&&
i
<
read_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
read_size
;
++
i
)
{
auto
worker_idx
=
_shuffle_key_func
(
send_buffer
[
i
].
id
)
%
worker_num
;
auto
worker_idx
=
_shuffle_key_func
(
send_buffer
[
i
].
id
)
%
worker_num
;
// TODO Serialize To Arcive
send_buffer_worker
[
worker_idx
].
push_back
(
std
::
move
(
send_buffer
[
i
]));
//request_data_buffer[worker_idx] << send_buffer[i];
}
}
std
::
string
data_vec
[
worker_num
];
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
auto
&
buffer
=
request_data_buffer
[
i
];
data_vec
[
i
].
assign
(
buffer
.
Buffer
(),
buffer
.
Length
());
}
// wait async done
for
(
auto
&
wait_s
:
waits
[
current_wait_idx
])
{
for
(
auto
&
wait_s
:
waits
[
current_wait_idx
])
{
if
(
!
wait_s
.
valid
())
{
if
(
wait_s
.
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send shuffle data"
;
status
=
-
1
;
break
;
break
;
}
}
CHECK
(
wait_s
.
get
()
==
0
);
}
if
(
status
!=
0
)
{
break
;
}
waits
[
current_wait_idx
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
!
send_buffer_worker
[
i
].
empty
())
{
waits
[
current_wait_idx
].
push_back
(
send_shuffle_data
(
i
,
send_buffer_worker
[
i
]));
}
}
}
}
for
(
auto
&
waits_s
:
waits
)
{
for
(
auto
&
wait_s
:
waits_s
)
{
if
(
wait_s
.
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send shuffle data"
;
status
=
-
1
;
}
}
}
VLOG
(
5
)
<<
"start send finish, worker_num: "
<<
worker_num
;
waits
[
0
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
waits
[
0
].
push_back
(
send_finish
(
i
));
}
VLOG
(
5
)
<<
"wait all finish"
;
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
waits
[
0
][
i
].
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send finish "
<<
i
;
status
=
-
1
;
}
}
VLOG
(
5
)
<<
"finish shuffler, status: "
<<
status
;
return
status
<
0
?
status
:
0
;
}
}
// send shuffle data
private:
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
/*
waits
[
current_wait_idx
][
i
]
=
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
3
,
i
*
2
,
data_vec
[
i
]);
1. 部分c2c send_shuffle_data先到, 此时channel未设置, 等待wait_channel
2. shuffle中调用set_channel, 先reset_wait_num, 再解锁channel
3. 当接收到所有worker的finish请求后,先reset_channel, 再同时返回
*/
bool
wait_channel
()
{
VLOG
(
5
)
<<
"wait_channel"
;
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_channel_mutex
);
return
_out_channel
!=
nullptr
;
}
void
reset_channel
()
{
VLOG
(
5
)
<<
"reset_channel"
;
_channel_mutex
.
lock
();
if
(
_out_channel
!=
nullptr
)
{
_out_channel
->
Close
();
}
_out_channel
=
nullptr
;
}
void
reset_wait_num
()
{
_wait_num_mutex
.
lock
();
_wait_num
=
_trainer_context
->
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
VLOG
(
5
)
<<
"reset_wait_num: "
<<
_wait_num
;
}
void
set_channel
(
paddle
::
framework
::
Channel
<
DataItem
>&
channel
)
{
VLOG
(
5
)
<<
"set_channel"
;
// 在节点开始写入channel之前,重置wait_num
CHECK
(
_out_channel
==
nullptr
);
_out_channel
=
channel
;
reset_wait_num
();
_channel_mutex
.
unlock
();
}
}
// update status
int32_t
finish_write_channel
()
{
// 如果在训练期,则限速shuffle
int
wait_num
=
--
_wait_num
;
// 如果在wait状态,全速shuffle
VLOG
(
5
)
<<
"finish_write_channel, wait_num: "
<<
wait_num
;
if
(
_trainer_context
->
is_status
(
TrainerStatus
::
Training
))
{
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
concurrent_num
=
1
;
if
(
wait_num
==
0
)
{
package_size
=
_max_concurrent_num
/
2
;
reset_channel
();
_wait_num_mutex
.
unlock
();
}
else
{
}
else
{
package_size
=
_max_package_size
;
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_wait_num_mutex
);
concurrent_num
=
_max_concurrent_num
;
}
}
++
current_wait_idx
;
return
0
;
current_wait_idx
=
current_wait_idx
>=
concurrent_num
?
0
:
current_wait_idx
;
}
}
int32_t
write_to_channel
(
std
::
vector
<
DataItem
>&&
items
)
{
size_t
items_size
=
items
.
size
();
VLOG
(
5
)
<<
"write_to_channel, items_size: "
<<
items_size
;
return
_out_channel
->
Write
(
std
::
move
(
items
))
==
items_size
?
0
:
-
1
;
}
}
return
0
;
int32_t
get_client2client_msg
(
int
msg_type
,
int
from_client
,
const
std
::
string
&
msg
)
{
// wait channel
if
(
!
wait_channel
())
{
LOG
(
FATAL
)
<<
"out_channel is null"
;
return
-
1
;
}
VLOG
(
5
)
<<
"get c2c msg, type: "
<<
msg_type
<<
", from_client: "
<<
from_client
<<
", msg_size: "
<<
msg
.
size
();
if
(
msg_type
==
_shuffle_data_msg_type
)
{
paddle
::
framework
::
BinaryArchive
ar
;
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
msg
.
data
()),
msg
.
size
(),
[](
char
*
){});
std
::
vector
<
DataItem
>
items
;
ar
>>
items
;
return
write_to_channel
(
std
::
move
(
items
));
}
else
if
(
msg_type
==
_finish_msg_type
)
{
return
finish_write_channel
();
}
LOG
(
FATAL
)
<<
"no such msg type: "
<<
msg_type
;
return
-
1
;
}
std
::
future
<
int32_t
>
send_shuffle_data
(
int
to_client_id
,
std
::
vector
<
DataItem
>&
items
)
{
VLOG
(
5
)
<<
"send_shuffle_data, to_client_id: "
<<
to_client_id
<<
", items_size: "
<<
items
.
size
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
items
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_shuffle_data_msg_type
,
to_client_id
,
std
::
string
(
ar
.
Buffer
(),
ar
.
Length
()));
}
std
::
future
<
int32_t
>
send_finish
(
int
to_client_id
)
{
VLOG
(
5
)
<<
"send_finish, to_client_id: "
<<
to_client_id
;
static
const
std
::
string
empty_str
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_finish_msg_type
,
to_client_id
,
empty_str
);
}
}
private:
uint32_t
_max_package_size
=
0
;
uint32_t
_max_package_size
=
0
;
uint32_t
_max_concurrent_num
=
0
;
uint32_t
_max_concurrent_num
=
0
;
int
_shuffle_data_msg_type
=
3
;
int
_finish_msg_type
=
4
;
bthread
::
Mutex
_channel_mutex
;
paddle
::
framework
::
Channel
<
DataItem
>
_out_channel
=
nullptr
;
bthread
::
Mutex
_wait_num_mutex
;
std
::
atomic
<
int
>
_wait_num
;
};
};
REGIST_CLASS
(
Shuffler
,
GlobalShuffler
);
REGIST_CLASS
(
Shuffler
,
GlobalShuffler
);
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
9552cf55
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include <sstream>
#include <sstream>
#include <tsl/bhopscotch_map.h>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
Saving
=
1
// 模型存储状态
Saving
=
1
// 模型存储状态
};
};
const
uint32_t
SignCacheMaxValueNum
=
13
;
struct
SignCacheData
{
SignCacheData
()
{
memset
(
cache_value
,
0
,
sizeof
(
float
)
*
SignCacheMaxValueNum
);
}
uint32_t
idx
;
float
cache_value
[
SignCacheMaxValueNum
];
};
class
SignCacheDict
{
class
SignCacheDict
{
public:
public:
int32_t
sign2index
(
uint64_t
sign
)
{
inline
int32_t
sign2index
(
uint64_t
sign
)
{
auto
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
-
1
;
return
-
1
;
}
}
return
itr
->
second
.
idx
;
}
uint64_t
index2sign
(
int32_t
index
)
{
inline
uint64_t
index2sign
(
int32_t
index
)
{
if
(
index
>=
_sign_list
.
size
())
{
return
0
;
return
0
;
}
}
return
_sign_list
[
index
];
}
inline
void
reserve
(
uint32_t
size
)
{
_sign_list
.
reserve
(
size
);
_sign2data_map
.
reserve
(
size
);
}
inline
void
clear
()
{
_sign_list
.
clear
();
_sign2data_map
.
clear
();
}
inline
void
append
(
uint64_t
sign
)
{
if
(
_sign2data_map
.
find
(
sign
)
!=
_sign2data_map
.
end
())
{
return
;
}
SignCacheData
data
;
data
.
idx
=
_sign_list
.
size
();
_sign_list
.
push_back
(
sign
);
_sign2data_map
.
emplace
(
sign
,
std
::
move
(
data
));
}
inline
SignCacheData
*
data
(
uint64_t
sign
)
{
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>::
iterator
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
nullptr
;
}
return
const_cast
<
SignCacheData
*>
(
&
(
itr
->
second
));
}
private:
std
::
vector
<
uint64_t
>
_sign_list
;
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>
_sign2data_map
;
};
};
class
TrainerContext
{
class
TrainerContext
{
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc
0 → 100644
浏览文件 @
9552cf55
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
TEST
(
Archive
,
DataItem
)
{
paddle
::
custom_trainer
::
feed
::
DataItem
item
;
paddle
::
custom_trainer
::
feed
::
DataItem
item2
;
item
.
id
=
"123"
;
item
.
data
=
"name"
;
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
item
;
ar
>>
item2
;
ASSERT_EQ
(
item
.
id
,
item2
.
id
);
ASSERT_EQ
(
item
.
data
,
item2
.
data
);
item
.
id
+=
"~"
;
item
.
data
+=
"~"
;
ASSERT_NE
(
item
.
id
,
item2
.
id
);
ASSERT_NE
(
item
.
data
,
item2
.
data
);
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录