Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
9552cf55
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9552cf55
编写于
9月 18, 2019
作者:
L
linan17
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
ssh://icode.baidu.com:8235/baidu/feed-mlarch/paddle-trainer
Change-Id: If03dc2ea6e1a8b9bcedb9d5c9fa9dbcc44d41396
上级
2a933c78
7654080c
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
335 addition
and
46 deletion
+335
-46
BCLOUD
BCLOUD
+1
-0
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+10
-0
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
...ain/custom_trainer/feed/executor/multi_thread_executor.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+1
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+83
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+5
-1
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
+163
-40
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+51
-4
paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc
...in/custom_trainer/feed/unit_test/test_archive_dataitem.cc
+20
-0
未找到文件。
BCLOUD
浏览文件 @
9552cf55
...
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/feed-mlarch/hopscotch-map@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
9552cf55
...
...
@@ -68,6 +68,16 @@ public:
std
::
string
data
;
//样本数据, maybe压缩格式
};
template
<
class
AR
>
paddle
::
framework
::
Archive
<
AR
>&
operator
>>
(
paddle
::
framework
::
Archive
<
AR
>&
ar
,
DataItem
&
x
)
{
return
ar
>>
x
.
id
>>
x
.
data
;
}
template
<
class
AR
>
paddle
::
framework
::
Archive
<
AR
>&
operator
<<
(
paddle
::
framework
::
Archive
<
AR
>&
ar
,
const
DataItem
&
x
)
{
return
ar
<<
x
.
id
<<
x
.
data
;
}
typedef
std
::
shared_ptr
<
Pipeline
<
DataItem
,
SampleInstance
>>
SampleInstancePipe
;
inline
SampleInstancePipe
make_sample_instance_channel
()
{
return
std
::
make_shared
<
Pipeline
<
DataItem
,
SampleInstance
>>
();
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
浏览文件 @
9552cf55
...
...
@@ -243,7 +243,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
for
(
auto
&
monitor
:
_monitors
)
{
if
(
monitor
->
need_compute_result
(
epoch_id
))
{
monitor
->
compute_result
();
ENVLOG_WORKER_MASTER_NOTICE
(
"[Monitor]%s, monitor:%s,
,
result:%s"
,
ENVLOG_WORKER_MASTER_NOTICE
(
"[Monitor]%s, monitor:%s, result:%s"
,
_train_exe_name
.
c_str
(),
monitor
->
get_name
().
c_str
(),
monitor
->
format_result
().
c_str
());
_trainer_context
->
monitor_ssm
<<
_train_exe_name
<<
":"
<<
monitor
->
get_name
()
<<
":"
<<
monitor
->
format_result
()
<<
","
;
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
9552cf55
...
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
cache_dict
.
reset
(
new
SignCacheDict
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
//environment
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
9552cf55
...
...
@@ -16,6 +16,8 @@ namespace feed {
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
Process
::
initialize
(
context_ptr
);
auto
&
config
=
_context_ptr
->
trainer_config
;
_is_dump_cache_model
=
config
[
"dump_cache_model"
].
as
<
bool
>
(
false
);
_cache_load_converter
=
config
[
"load_cache_converter"
].
as
<
std
::
string
>
(
""
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
if
(
config
[
"executor"
])
{
_executors
.
resize
(
config
[
"executor"
].
size
());
...
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return
0
;
}
// 更新各节点存储的CacheModel
int
LearnerProcess
::
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
if
(
!
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
return
0
;
}
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
auto
model_dir
=
epoch_accessor
->
model_save_path
(
epoch_id
,
way
);
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/"
,
param
.
table_id
()));
if
(
!
fs
->
exists
(
cache_model_path
))
{
continue
;
}
auto
&
cache_dict
=
*
(
_context_ptr
->
cache_dict
.
get
());
cache_dict
.
clear
();
cache_dict
.
reserve
(
_cache_sign_max_num
);
auto
cache_file_list
=
fs
->
list
(
fs
->
path_join
(
cache_model_path
,
"part*"
));
for
(
auto
&
cache_path
:
cache_file_list
)
{
auto
cache_file
=
fs
->
open_read
(
cache_path
,
_cache_load_converter
);
char
*
buffer
=
nullptr
;
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
cache_file
.
get
()))
!=
-
1
)
{
if
(
line_len
<=
1
)
{
continue
;
}
char
*
data_ptr
=
NULL
;
cache_dict
.
append
(
strtoul
(
buffer
,
&
data_ptr
,
10
));
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
}
}
break
;
}
}
return
0
;
}
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
...
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
}
timer
.
Pause
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
// save cache model, 只有inference需要cache_model
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
(
way
==
ModelSaveWay
::
ModelSaveInferenceBase
||
way
==
ModelSaveWay
::
ModelSaveInferenceDelta
))
{
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
double
cache_threshold
=
0.0
;
auto
status
=
ps_client
->
get_cache_threshold
(
param
.
table_id
(),
cache_threshold
);
CHECK
(
status
.
get
()
==
0
)
<<
"CacheThreshold Get failed!"
;
status
=
ps_client
->
cache_shuffle
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
),
std
::
to_string
(
cache_threshold
));
CHECK
(
status
.
get
()
==
0
)
<<
"Cache Shuffler Failed"
;
status
=
ps_client
->
save_cache
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
));
auto
feature_size
=
status
.
get
();
CHECK
(
feature_size
>=
0
)
<<
"Cache Save Failed"
;
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/sparse_cache.meta"
,
param
.
table_id
()));
auto
cache_meta_file
=
fs
->
open_write
(
cache_model_path
,
""
);
auto
meta
=
string
::
format_string
(
"file_prefix:part
\n
part_num:%d
\n
key_num:%d
\n
"
,
param
.
sparse_table_cache_file_num
(),
feature_size
);
CHECK
(
fwrite
(
meta
.
c_str
(),
meta
.
size
(),
1
,
cache_meta_file
.
get
())
==
1
)
<<
"Cache Meta Failed"
;
if
(
feature_size
>
_cache_sign_max_num
)
{
_cache_sign_max_num
=
feature_size
;
}
}
}
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
return
all_ret
;
}
...
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
update_cache_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
}
else
{
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
浏览文件 @
9552cf55
...
...
@@ -22,9 +22,13 @@ protected:
virtual
int
load_model
(
uint64_t
epoch_id
);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
);
private:
bool
_startup_dump_inference_base
;
//启动立即dump base
bool
_is_dump_cache_model
;
// 是否进行cache dump
uint32_t
_cache_sign_max_num
=
0
;
// cache sign最大个数
std
::
string
_cache_load_converter
;
// cache加载的前置转换脚本
bool
_startup_dump_inference_base
;
// 启动立即dump base
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
};
...
...
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
浏览文件 @
9552cf55
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
#include <bthread/butex.h>
namespace
paddle
{
namespace
custom_trainer
{
...
...
@@ -40,73 +41,195 @@ public:
Shuffler
::
initialize
(
config
,
context_ptr
);
_max_concurrent_num
=
config
[
"max_concurrent_num"
].
as
<
int
>
(
4
);
// 最大并发发送数
_max_package_size
=
config
[
"max_package_size"
].
as
<
int
>
(
1024
);
// 最大包个数,一次发送package个数据
_shuffle_data_msg_type
=
config
[
"shuffle_data_msg_type"
].
as
<
int
>
(
3
);
// c2c msg type
_finish_msg_type
=
config
[
"finish_msg_type"
].
as
<
int
>
(
4
);
// c2c msg type
reset_channel
();
auto
binded
=
std
::
bind
(
&
GlobalShuffler
::
get_client2client_msg
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_trainer_context
->
pslib
->
ps_client
()
->
registe_client2client_msg_handler
(
_shuffle_data_msg_type
,
binded
);
_trainer_context
->
pslib
->
ps_client
()
->
registe_client2client_msg_handler
(
_finish_msg_type
,
binded
);
return
0
;
}
// 所有worker必须都调用shuffle,并且shuffler同时只能有一个shuffle任务
virtual
int
shuffle
(
::
paddle
::
framework
::
Channel
<
DataItem
>&
data_channel
)
{
uint32_t
send_count
=
0
;
uint32_t
package_size
=
_max_package_size
;
uint32_t
concurrent_num
=
_max_concurrent_num
;
uint32_t
current_wait_idx
=
0
;
::
paddle
::
framework
::
Channel
<
DataItem
>
input_channel
=
::
paddle
::
framework
::
MakeChannel
<
DataItem
>
(
data_channel
);
data_channel
.
swap
(
input_channel
);
set_channel
(
data_channel
);
auto
*
environment
=
_trainer_context
->
environment
.
get
();
auto
worker_num
=
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
std
::
vector
<
std
::
vector
<
std
::
future
<
int
>>>
waits
(
concurrent_num
);
std
::
vector
<
DataItem
>
send_buffer
(
concurrent_num
*
package_size
);
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
request_data_buffer
(
worker_num
);
while
(
true
)
{
auto
read_size
=
data_channel
->
Read
(
concurrent_num
*
package_size
,
&
send_buffer
[
0
]);
std
::
vector
<
DataItem
>
send_buffer
(
package_size
);
std
::
vector
<
std
::
vector
<
DataItem
>>
send_buffer_worker
(
worker_num
);
int
status
=
0
;
// >0: finish; =0: running; <0: fail
while
(
status
==
0
)
{
// update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if
(
_trainer_context
->
is_status
(
TrainerStatus
::
Training
))
{
concurrent_num
=
1
;
package_size
=
_max_concurrent_num
/
2
;
}
else
{
package_size
=
_max_package_size
;
concurrent_num
=
_max_concurrent_num
;
}
for
(
uint32_t
current_wait_idx
=
0
;
status
==
0
&&
current_wait_idx
<
concurrent_num
;
++
current_wait_idx
)
{
auto
read_size
=
input_channel
->
Read
(
package_size
,
send_buffer
.
data
());
if
(
read_size
==
0
)
{
status
=
1
;
break
;
}
for
(
size_t
idx
=
0
;
idx
<
read_size
;
idx
+=
package_size
)
{
// data shard && seriliaze
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
request_data_buffer
[
i
].
Clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
send_buffer_worker
.
clear
();
}
for
(
size_t
i
=
idx
;
i
<
package_size
&&
i
<
read_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
read_size
;
++
i
)
{
auto
worker_idx
=
_shuffle_key_func
(
send_buffer
[
i
].
id
)
%
worker_num
;
// TODO Serialize To Arcive
//request_data_buffer[worker_idx] << send_buffer[i];
send_buffer_worker
[
worker_idx
].
push_back
(
std
::
move
(
send_buffer
[
i
]));
}
std
::
string
data_vec
[
worker_num
];
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
auto
&
buffer
=
request_data_buffer
[
i
];
data_vec
[
i
].
assign
(
buffer
.
Buffer
(),
buffer
.
Length
());
}
// wait async done
for
(
auto
&
wait_s
:
waits
[
current_wait_idx
])
{
if
(
!
wait_s
.
valid
())
{
if
(
wait_s
.
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send shuffle data"
;
status
=
-
1
;
break
;
}
CHECK
(
wait_s
.
get
()
==
0
);
}
if
(
status
!=
0
)
{
break
;
}
waits
[
current_wait_idx
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
!
send_buffer_worker
[
i
].
empty
())
{
waits
[
current_wait_idx
].
push_back
(
send_shuffle_data
(
i
,
send_buffer_worker
[
i
]));
}
}
}
}
for
(
auto
&
waits_s
:
waits
)
{
for
(
auto
&
wait_s
:
waits_s
)
{
if
(
wait_s
.
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send shuffle data"
;
status
=
-
1
;
}
}
}
VLOG
(
5
)
<<
"start send finish, worker_num: "
<<
worker_num
;
waits
[
0
].
clear
();
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
waits
[
0
].
push_back
(
send_finish
(
i
));
}
VLOG
(
5
)
<<
"wait all finish"
;
for
(
int
i
=
0
;
i
<
worker_num
;
++
i
)
{
if
(
waits
[
0
][
i
].
get
()
!=
0
)
{
LOG
(
WARNING
)
<<
"fail to send finish "
<<
i
;
status
=
-
1
;
}
}
VLOG
(
5
)
<<
"finish shuffler, status: "
<<
status
;
return
status
<
0
?
status
:
0
;
}
// send shuffle data
for
(
size_t
i
=
0
;
i
<
worker_num
;
++
i
)
{
waits
[
current_wait_idx
][
i
]
=
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
3
,
i
*
2
,
data_vec
[
i
]);
private:
/*
1. 部分c2c send_shuffle_data先到, 此时channel未设置, 等待wait_channel
2. shuffle中调用set_channel, 先reset_wait_num, 再解锁channel
3. 当接收到所有worker的finish请求后,先reset_channel, 再同时返回
*/
bool
wait_channel
()
{
VLOG
(
5
)
<<
"wait_channel"
;
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_channel_mutex
);
return
_out_channel
!=
nullptr
;
}
void
reset_channel
()
{
VLOG
(
5
)
<<
"reset_channel"
;
_channel_mutex
.
lock
();
if
(
_out_channel
!=
nullptr
)
{
_out_channel
->
Close
();
}
_out_channel
=
nullptr
;
}
void
reset_wait_num
()
{
_wait_num_mutex
.
lock
();
_wait_num
=
_trainer_context
->
environment
->
node_num
(
EnvironmentRole
::
WORKER
);
VLOG
(
5
)
<<
"reset_wait_num: "
<<
_wait_num
;
}
void
set_channel
(
paddle
::
framework
::
Channel
<
DataItem
>&
channel
)
{
VLOG
(
5
)
<<
"set_channel"
;
// 在节点开始写入channel之前,重置wait_num
CHECK
(
_out_channel
==
nullptr
);
_out_channel
=
channel
;
reset_wait_num
();
_channel_mutex
.
unlock
();
}
// update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if
(
_trainer_context
->
is_status
(
TrainerStatus
::
Training
))
{
concurrent_num
=
1
;
package_size
=
_max_concurrent_num
/
2
;
int32_t
finish_write_channel
()
{
int
wait_num
=
--
_wait_num
;
VLOG
(
5
)
<<
"finish_write_channel, wait_num: "
<<
wait_num
;
// 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
if
(
wait_num
==
0
)
{
reset_channel
();
_wait_num_mutex
.
unlock
();
}
else
{
package_size
=
_max_package_size
;
concurrent_num
=
_max_concurrent_num
;
std
::
lock_guard
<
bthread
::
Mutex
>
lock
(
_wait_num_mutex
);
}
++
current_wait_idx
;
current_wait_idx
=
current_wait_idx
>=
concurrent_num
?
0
:
current_wait_idx
;
return
0
;
}
int32_t
write_to_channel
(
std
::
vector
<
DataItem
>&&
items
)
{
size_t
items_size
=
items
.
size
();
VLOG
(
5
)
<<
"write_to_channel, items_size: "
<<
items_size
;
return
_out_channel
->
Write
(
std
::
move
(
items
))
==
items_size
?
0
:
-
1
;
}
return
0
;
int32_t
get_client2client_msg
(
int
msg_type
,
int
from_client
,
const
std
::
string
&
msg
)
{
// wait channel
if
(
!
wait_channel
())
{
LOG
(
FATAL
)
<<
"out_channel is null"
;
return
-
1
;
}
VLOG
(
5
)
<<
"get c2c msg, type: "
<<
msg_type
<<
", from_client: "
<<
from_client
<<
", msg_size: "
<<
msg
.
size
();
if
(
msg_type
==
_shuffle_data_msg_type
)
{
paddle
::
framework
::
BinaryArchive
ar
;
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
msg
.
data
()),
msg
.
size
(),
[](
char
*
){});
std
::
vector
<
DataItem
>
items
;
ar
>>
items
;
return
write_to_channel
(
std
::
move
(
items
));
}
else
if
(
msg_type
==
_finish_msg_type
)
{
return
finish_write_channel
();
}
LOG
(
FATAL
)
<<
"no such msg type: "
<<
msg_type
;
return
-
1
;
}
std
::
future
<
int32_t
>
send_shuffle_data
(
int
to_client_id
,
std
::
vector
<
DataItem
>&
items
)
{
VLOG
(
5
)
<<
"send_shuffle_data, to_client_id: "
<<
to_client_id
<<
", items_size: "
<<
items
.
size
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
items
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_shuffle_data_msg_type
,
to_client_id
,
std
::
string
(
ar
.
Buffer
(),
ar
.
Length
()));
}
std
::
future
<
int32_t
>
send_finish
(
int
to_client_id
)
{
VLOG
(
5
)
<<
"send_finish, to_client_id: "
<<
to_client_id
;
static
const
std
::
string
empty_str
;
return
_trainer_context
->
pslib
->
ps_client
()
->
send_client2client_msg
(
_finish_msg_type
,
to_client_id
,
empty_str
);
}
private:
uint32_t
_max_package_size
=
0
;
uint32_t
_max_concurrent_num
=
0
;
int
_shuffle_data_msg_type
=
3
;
int
_finish_msg_type
=
4
;
bthread
::
Mutex
_channel_mutex
;
paddle
::
framework
::
Channel
<
DataItem
>
_out_channel
=
nullptr
;
bthread
::
Mutex
_wait_num_mutex
;
std
::
atomic
<
int
>
_wait_num
;
};
REGIST_CLASS
(
Shuffler
,
GlobalShuffler
);
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
9552cf55
...
...
@@ -3,6 +3,7 @@
#include <memory>
#include <vector>
#include <sstream>
#include <tsl/bhopscotch_map.h>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
...
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
Saving
=
1
// 模型存储状态
};
const
uint32_t
SignCacheMaxValueNum
=
13
;
struct
SignCacheData
{
SignCacheData
()
{
memset
(
cache_value
,
0
,
sizeof
(
float
)
*
SignCacheMaxValueNum
);
}
uint32_t
idx
;
float
cache_value
[
SignCacheMaxValueNum
];
};
class
SignCacheDict
{
public:
int32_t
sign2index
(
uint64_t
sign
)
{
inline
int32_t
sign2index
(
uint64_t
sign
)
{
auto
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
-
1
;
}
return
itr
->
second
.
idx
;
}
uint64_t
index2sign
(
int32_t
index
)
{
inline
uint64_t
index2sign
(
int32_t
index
)
{
if
(
index
>=
_sign_list
.
size
())
{
return
0
;
}
return
_sign_list
[
index
];
}
inline
void
reserve
(
uint32_t
size
)
{
_sign_list
.
reserve
(
size
);
_sign2data_map
.
reserve
(
size
);
}
inline
void
clear
()
{
_sign_list
.
clear
();
_sign2data_map
.
clear
();
}
inline
void
append
(
uint64_t
sign
)
{
if
(
_sign2data_map
.
find
(
sign
)
!=
_sign2data_map
.
end
())
{
return
;
}
SignCacheData
data
;
data
.
idx
=
_sign_list
.
size
();
_sign_list
.
push_back
(
sign
);
_sign2data_map
.
emplace
(
sign
,
std
::
move
(
data
));
}
inline
SignCacheData
*
data
(
uint64_t
sign
)
{
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>::
iterator
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
nullptr
;
}
return
const_cast
<
SignCacheData
*>
(
&
(
itr
->
second
));
}
private:
std
::
vector
<
uint64_t
>
_sign_list
;
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>
_sign2data_map
;
};
class
TrainerContext
{
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_archive_dataitem.cc
0 → 100644
浏览文件 @
9552cf55
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
TEST
(
Archive
,
DataItem
)
{
paddle
::
custom_trainer
::
feed
::
DataItem
item
;
paddle
::
custom_trainer
::
feed
::
DataItem
item2
;
item
.
id
=
"123"
;
item
.
data
=
"name"
;
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
item
;
ar
>>
item2
;
ASSERT_EQ
(
item
.
id
,
item2
.
id
);
ASSERT_EQ
(
item
.
data
,
item2
.
data
);
item
.
id
+=
"~"
;
item
.
data
+=
"~"
;
ASSERT_NE
(
item
.
id
,
item2
.
id
);
ASSERT_NE
(
item
.
data
,
item2
.
data
);
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录