Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
929a9e80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
929a9e80
编写于
10月 25, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Google naming conventions
上级
c555948c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
320 addition
and
524 deletion
+320
-524
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+145
-145
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+85
-93
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+46
-43
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+42
-241
paddle/fluid/framework/datafeed_creator.cc
paddle/fluid/framework/datafeed_creator.cc
+1
-1
paddle/fluid/framework/datafeed_creator.h
paddle/fluid/framework/datafeed_creator.h
+1
-1
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
929a9e80
...
...
@@ -37,13 +37,13 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
std
::
mutex
ExecutorThreadWorker
::
_s_locker_for_pick_file
;
unsigned
int
ExecutorThreadWorker
::
_s_current_file_idx
=
0
;
size_t
ExecutorThreadWorker
::
_s_current_finished_file_cnt
=
0
;
unsigned
int
ExecutorThreadWorker
::
_s_current_epoch
=
0
;
int
ExecutorThreadWorker
::
_s_current_save_epoch
=
0
;
bool
ExecutorThreadWorker
::
_s_is_first_worker
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
_s_thread_filelist
;
std
::
mutex
ExecutorThreadWorker
::
s_locker_for_pick_file_
;
unsigned
int
ExecutorThreadWorker
::
s_current_file_idx_
=
0
;
size_t
ExecutorThreadWorker
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
ExecutorThreadWorker
::
s_current_epoch_
=
0
;
int
ExecutorThreadWorker
::
s_current_save_epoch_
=
0
;
bool
ExecutorThreadWorker
::
s_is_first_worker_
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
s_thread_filelist_
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
...
...
@@ -142,33 +142,33 @@ static void save_model(
}
// end save_model
void
ExecutorThreadWorker
::
add_train_f
ile
(
const
std
::
string
&
file
)
{
_s_thread_filelist
.
push_back
(
file
);
void
ExecutorThreadWorker
::
AddTrainF
ile
(
const
std
::
string
&
file
)
{
s_thread_filelist_
.
push_back
(
file
);
}
void
ExecutorThreadWorker
::
create_thread_o
perators
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadO
perators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
_op_names
.
clear
();
op_names_
.
clear
();
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
_op_names
.
push_back
(
op_desc
->
Type
());
op_names_
.
push_back
(
op_desc
->
Type
());
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
_ops
.
push_back
(
local_op_ptr
);
ops_
.
push_back
(
local_op_ptr
);
continue
;
}
}
void
ExecutorThreadWorker
::
create_thread_s
cope
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadS
cope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
_thread_scope
=
&
_root_scope
->
NewScope
();
thread_scope_
=
&
root_scope_
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
_root_scope
->
Var
(
var
->
Name
());
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("create Persistable var[%s] finished",
// var->Name().c_str());
}
else
{
auto
*
ptr
=
_thread_scope
->
Var
(
var
->
Name
());
auto
*
ptr
=
thread_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("create unpersistable var[%s] finished",
// var->Name().c_str());
...
...
@@ -176,33 +176,33 @@ void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) {
}
}
void
ExecutorThreadWorker
::
set_dataf
eed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
_local_reader
=
datafeed
;
void
ExecutorThreadWorker
::
SetDataF
eed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
local_reader_
=
datafeed
;
}
void
ExecutorThreadWorker
::
binding_datafeed_m
emory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
_local_reader
->
get_use_slot_a
lias
();
void
ExecutorThreadWorker
::
BindingDataFeedM
emory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
local_reader_
->
GetUseSlotA
lias
();
for
(
auto
name
:
input_feed
)
{
_local_reader
->
add_feed_var
(
_thread_scope
->
Var
(
name
),
name
);
local_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
void
ExecutorThreadWorker
::
set_inspect_var_n
ame
(
void
ExecutorThreadWorker
::
SetInspectVarN
ame
(
const
std
::
string
&
inspect_var_name
)
{
_inspect_var_name
=
inspect_var_name
;
inspect_var_name_
=
inspect_var_name
;
}
void
ExecutorThreadWorker
::
set_model_param_n
ames
(
void
ExecutorThreadWorker
::
SetModelParamN
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
_model_param_names
=
param_names
;
model_param_names_
=
param_names
;
}
void
ExecutorThreadWorker
::
set_sparse_comm_d
ata
(
void
ExecutorThreadWorker
::
SetSparseCommD
ata
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
)
{
_sparse_comm_data
=
param_names
;
sparse_comm_data_
=
param_names
;
}
void
ExecutorThreadWorker
::
set_d
evice
()
{
void
ExecutorThreadWorker
::
SetD
evice
()
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
...
...
@@ -214,7 +214,7 @@ void ExecutorThreadWorker::set_device() {
42
,
43
,
44
,
45
,
46
,
47
};
unsigned
int
i
=
this
->
_thread_id
;
unsigned
int
i
=
this
->
thread_id_
;
if
(
i
<
sizeof
(
priority
)
/
sizeof
(
unsigned
))
{
unsigned
proc
=
priority
[
i
];
...
...
@@ -235,55 +235,55 @@ void ExecutorThreadWorker::set_device() {
}
}
void
ExecutorThreadWorker
::
update_epoch_n
um
()
{
_s_current_finished_file_cnt
++
;
void
ExecutorThreadWorker
::
UpdateEpochN
um
()
{
s_current_finished_file_cnt_
++
;
if
(
_s_current_finished_file_cnt
>=
_s_thread_filelist
.
size
())
{
_s_current_finished_file_cnt
=
0
;
_s_current_epoch
++
;
if
(
s_current_finished_file_cnt_
>=
s_thread_filelist_
.
size
())
{
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
}
}
const
char
*
ExecutorThreadWorker
::
pick_one_f
ile
()
{
const
char
*
ExecutorThreadWorker
::
PickOneF
ile
()
{
std
::
string
file_to_be_preocessed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
_s_locker_for_pick_file
);
if
(
_s_current_file_idx
>=
_s_thread_filelist
.
size
())
{
std
::
random_shuffle
(
_s_thread_filelist
.
begin
(),
_s_thread_filelist
.
end
());
_s_current_file_idx
=
0
;
//
_s_current_epoch
++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
_thread_id
<<
": finish traing for epoch "
<<
_s_current_epoch
+
1
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
if
(
s_current_file_idx_
>=
s_thread_filelist_
.
size
())
{
std
::
random_shuffle
(
s_thread_filelist_
.
begin
(),
s_thread_filelist_
.
end
());
s_current_file_idx_
=
0
;
//
s_current_epoch_
++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
}
file_to_be_preocessed
=
_s_thread_filelist
[
_s_current_file_idx
];
file_to_be_preocessed
=
s_thread_filelist_
[
s_current_file_idx_
];
_s_current_file_idx
++
;
s_current_file_idx_
++
;
return
file_to_be_preocessed
.
c_str
();
}
void
ExecutorThreadWorker
::
t
rain
()
{
void
ExecutorThreadWorker
::
T
rain
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
set_d
evice
();
SetD
evice
();
#ifdef LOCAL_PROF
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
// int total_batch = 0;
for
(
auto
&
op
:
_ops
)
{
for
(
auto
&
op
:
ops_
)
{
op_name
.
push_back
(
op
->
Type
());
}
op_total_time
.
resize
(
_ops
.
size
());
op_total_time
.
resize
(
ops_
.
size
());
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
op_total_time
[
i
]
=
0.0
;
}
#endif
std
::
string
inspect_key
=
"inspect"
;
if
(
!
_inspect_var_name
.
empty
())
{
inspect_key
=
_inspect_var_name
.
substr
(
0
,
_inspect_var_name
.
find_first_of
(
'_'
));
if
(
!
inspect_var_name_
.
empty
())
{
inspect_key
=
inspect_var_name_
.
substr
(
0
,
inspect_var_name_
.
find_first_of
(
'_'
));
}
for
(
unsigned
i
=
0
;
i
<
_max_epoch
;
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
max_epoch_
;
++
i
)
{
LOG
(
ERROR
)
<<
"epoch: "
<<
i
;
#ifdef LOCAL_PROF
Timer
timeline
;
...
...
@@ -292,14 +292,14 @@ void ExecutorThreadWorker::train() {
#endif
float
total_inspect
=
0
;
int
batch_num
=
1
;
while
(
i
==
_s_current_epoch
)
{
const
char
*
filename
=
pick_one_f
ile
();
_local_reader
->
set_f
ile
(
filename
);
while
(
i
==
s_current_epoch_
)
{
const
char
*
filename
=
PickOneF
ile
();
local_reader_
->
SetF
ile
(
filename
);
while
(
true
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
bool
flag
=
_local_reader
->
read_b
atch
();
bool
flag
=
local_reader_
->
ReadB
atch
();
if
(
!
flag
)
{
break
;
}
...
...
@@ -312,11 +312,11 @@ void ExecutorThreadWorker::train() {
break
;
}
for
(
unsigned
int
i
=
0
;
i
<
_ops
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
_ops
[
i
]
->
Run
(
*
_thread_scope
,
_place
);
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
#ifdef LOCAL_PROF
timeline
.
pause
();
op_total_time
[
i
]
+=
timeline
.
elapsed_sec
();
...
...
@@ -325,17 +325,17 @@ void ExecutorThreadWorker::train() {
}
batch_num
++
;
float
avg_inspect
=
0.0
;
if
(
!
_inspect_var_name
.
empty
())
{
avg_inspect
=
_thread_scope
->
FindVar
(
_inspect_var_name
)
if
(
!
inspect_var_name_
.
empty
())
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_name_
)
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
}
total_inspect
+=
avg_inspect
;
_thread_scope
->
DropKids
();
thread_scope_
->
DropKids
();
}
update_epoch_n
um
();
UpdateEpochN
um
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
i
+
1
<<
" called: "
<<
memory
::
memory_usage
(
_place
);
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
#ifdef LOCAL_PROF
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
...
...
@@ -354,12 +354,12 @@ void ExecutorThreadWorker::train() {
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
;
#endif
if
(
_thread_id
==
0
)
{
if
(
thread_id_
==
0
)
{
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
_model_prefix
.
c_str
(),
model_prefix_
.
c_str
(),
i
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
...
...
@@ -367,55 +367,55 @@ void ExecutorThreadWorker::train() {
//
// currently comment it
LOG
(
ERROR
)
<<
"Going to save model "
<<
modelfile
;
save_model
(
_main_program
,
_thread_scope
,
_model_param_names
,
save_model
(
main_program_
,
thread_scope_
,
model_param_names_
,
model_filename
,
true
);
}
}
}
void
ExecutorThreadWorker
::
set_thread_i
d
(
int
tid
)
{
_thread_id
=
tid
;
void
ExecutorThreadWorker
::
SetThreadI
d
(
int
tid
)
{
thread_id_
=
tid
;
}
void
ExecutorThreadWorker
::
set_p
lace
(
const
platform
::
Place
&
place
)
{
_place
=
place
;
void
ExecutorThreadWorker
::
SetP
lace
(
const
platform
::
Place
&
place
)
{
place_
=
place
;
}
void
ExecutorThreadWorker
::
set_main_p
rogram
(
void
ExecutorThreadWorker
::
SetMainP
rogram
(
const
ProgramDesc
&
main_program_desc
)
{
_main_program
.
reset
(
new
ProgramDesc
(
main_program_desc
));
main_program_
.
reset
(
new
ProgramDesc
(
main_program_desc
));
}
void
ExecutorThreadWorker
::
set_root_s
cope
(
Scope
*
g_scope
)
{
_root_scope
=
g_scope
;
void
ExecutorThreadWorker
::
SetRootS
cope
(
Scope
*
g_scope
)
{
root_scope_
=
g_scope
;
}
void
ExecutorThreadWorker
::
set_max_training_e
poch
(
int
max_epoch
)
{
_max_epoch
=
max_epoch
;
void
ExecutorThreadWorker
::
SetMaxTrainingE
poch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
MultiExecutor
::
MultiExecutor
(
const
platform
::
Place
&
place
)
:
_place
(
place
)
{}
MultiExecutor
::
MultiExecutor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
void
MultiExecutor
::
init_root_s
cope
(
Scope
*
scope
)
{
_root_scope
=
scope
;
void
MultiExecutor
::
InitRootS
cope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
MultiExecutor
::
set_max_training_e
poch
(
int
max_epoch
)
{
_max_epoch
=
max_epoch
;
void
MultiExecutor
::
SetMaxTrainingE
poch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
void
MultiExecutor
::
set_datafeed_n
ame
(
const
char
*
feedname
)
{
_feed_name
=
std
::
string
(
feedname
);
void
MultiExecutor
::
SetDataFeedN
ame
(
const
char
*
feedname
)
{
feed_name_
=
std
::
string
(
feedname
);
}
void
MultiExecutor
::
set_model_p
refix
(
const
std
::
string
&
model_prefix
)
{
_model_prefix
=
model_prefix
;
void
MultiExecutor
::
SetModelP
refix
(
const
std
::
string
&
model_prefix
)
{
model_prefix_
=
model_prefix
;
}
void
MultiExecutor
::
run_startup_p
rogram
(
const
ProgramDesc
&
program
,
void
MultiExecutor
::
RunStartupP
rogram
(
const
ProgramDesc
&
program
,
Scope
*
scope
)
{
auto
&
block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
block
.
AllVars
())
{
...
...
@@ -447,7 +447,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// param_dict.size(), ops.size());
for
(
auto
&
op
:
ops
)
{
op
->
Run
(
*
scope
,
_place
);
op
->
Run
(
*
scope
,
place_
);
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for
(
auto
&
op
:
ops
)
{
...
...
@@ -456,7 +456,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// LOGERR("run startup program done.");
}
std
::
unique_ptr
<
ProgramDesc
>
MultiExecutor
::
load_desc_from_f
ile
(
std
::
unique_ptr
<
ProgramDesc
>
MultiExecutor
::
LoadDescFromF
ile
(
const
std
::
string
&
f
)
{
std
::
string
program_desc_str
;
read_binary_file
(
f
,
&
program_desc_str
);
...
...
@@ -464,102 +464,102 @@ std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file(
return
program
;
}
void
MultiExecutor
::
set_dense_comm_t
ensor
(
void
MultiExecutor
::
SetDenseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
)
{
_dense_comm_tensor
.
resize
(
dense_comm_tensor
.
size
());
dense_comm_tensor_
.
resize
(
dense_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
dense_comm_tensor
.
size
();
++
i
)
{
_dense_comm_tensor
[
i
]
=
dense_comm_tensor
[
i
];
dense_comm_tensor_
[
i
]
=
dense_comm_tensor
[
i
];
}
}
void
MultiExecutor
::
set_sparse_comm_t
ensor
(
void
MultiExecutor
::
SetSparseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
)
{
_sparse_comm_tensor
.
resize
(
sparse_comm_tensor
.
size
());
sparse_comm_tensor_
.
resize
(
sparse_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
sparse_comm_tensor
.
size
();
++
i
)
{
_sparse_comm_tensor
[
i
]
=
sparse_comm_tensor
[
i
];
sparse_comm_tensor_
[
i
]
=
sparse_comm_tensor
[
i
];
}
}
void
MultiExecutor
::
set_sparse_comm_d
ata
(
void
MultiExecutor
::
SetSparseCommD
ata
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
)
{
_sparse_comm_data
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
_sparse_comm_data
.
size
();
sparse_comm_data_
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
sparse_comm_data_
.
size
();
}
void
MultiExecutor
::
set_filel
ist
(
const
char
*
filelist
)
{
_filelist
.
clear
();
void
MultiExecutor
::
SetFileL
ist
(
const
char
*
filelist
)
{
filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
_filelist
.
push_back
(
filename
);
filelist_
.
push_back
(
filename
);
}
fin
.
close
();
}
void
MultiExecutor
::
set_filel
ist
(
std
::
vector
<
std
::
string
>
tfiles
)
{
_filelist
.
clear
();
_filelist
.
insert
(
_filelist
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
void
MultiExecutor
::
SetFileL
ist
(
std
::
vector
<
std
::
string
>
tfiles
)
{
filelist_
.
clear
();
filelist_
.
insert
(
filelist_
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
return
;
}
void
MultiExecutor
::
set_inspect_var_n
ame
(
const
std
::
string
&
inspect_var_name
)
{
_inspect_var_name
=
inspect_var_name
;
void
MultiExecutor
::
SetInspectVarN
ame
(
const
std
::
string
&
inspect_var_name
)
{
inspect_var_name_
=
inspect_var_name
;
}
void
MultiExecutor
::
set_param_n
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
_model_param_names
=
param_names
;
void
MultiExecutor
::
SetParamN
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
model_param_names_
=
param_names
;
}
void
MultiExecutor
::
set_thread_n
um
(
const
int
thread_num
)
{
_thread_num
=
thread_num
;
void
MultiExecutor
::
SetThreadN
um
(
const
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
MultiExecutor
::
prepare_t
hreads
(
const
ProgramDesc
&
host_program
)
{
_workers
.
resize
(
_thread_num
);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
_workers
[
i
].
reset
(
new
ExecutorThreadWorker
);
_workers
[
i
]
->
set_thread_i
d
(
i
);
_workers
[
i
]
->
create_thread_o
perators
(
host_program
);
_workers
[
i
]
->
set_root_scope
(
_root_scope
);
_workers
[
i
]
->
set_place
(
_place
);
_workers
[
i
]
->
set_max_training_epoch
(
_max_epoch
);
_workers
[
i
]
->
create_thread_s
cope
(
host_program
);
_workers
[
i
]
->
set_inspect_var_name
(
_inspect_var_name
);
_workers
[
i
]
->
set_model_param_names
(
_model_param_names
);
_workers
[
i
]
->
set_sparse_comm_data
(
_sparse_comm_data
);
_workers
[
i
]
->
set_main_p
rogram
(
host_program
);
_workers
[
i
]
->
set_model_prefix
(
_model_prefix
);
void
MultiExecutor
::
PrepareT
hreads
(
const
ProgramDesc
&
host_program
)
{
workers_
.
resize
(
thread_num_
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
]
->
SetThreadI
d
(
i
);
workers_
[
i
]
->
CreateThreadO
perators
(
host_program
);
workers_
[
i
]
->
SetRootScope
(
root_scope_
);
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
CreateThreadS
cope
(
host_program
);
workers_
[
i
]
->
SetInspectVarName
(
inspect_var_name_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetSparseCommData
(
sparse_comm_data_
);
workers_
[
i
]
->
SetMainP
rogram
(
host_program
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
}
for
(
unsigned
i
=
0
;
i
<
_filelist
.
size
();
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
filelist_
.
size
();
++
i
)
{
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
_workers
[
0
]
->
add_train_file
(
_filelist
[
i
]);
workers_
[
0
]
->
AddTrainFile
(
filelist_
[
i
]);
}
// mpi_wrapper::ModelParam model_param(true);
//
_workers
[0]->register_parallel_training_param(model_param);
//
workers_
[0]->register_parallel_training_param(model_param);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
// new a datafeed here
std
::
shared_ptr
<
DataFeed
>
local_feed
=
create_datafeed
(
_feed_name
.
c_str
());
local_feed
->
init
(
_data_feed_param
);
local_feed
->
set_batch_size
(
_batch_size
);
_workers
[
i
]
->
set_dataf
eed
(
local_feed
);
_workers
[
i
]
->
binding_datafeed_m
emory
();
_workers
[
i
]
->
set_thread_i
d
(
i
);
std
::
shared_ptr
<
DataFeed
>
local_feed
=
CreateDataFeed
(
feed_name_
.
c_str
());
local_feed
->
Init
(
data_feed_param_
);
local_feed
->
SetBatchSize
(
batch_size_
);
workers_
[
i
]
->
SetDataF
eed
(
local_feed
);
workers_
[
i
]
->
BindingDataFeedM
emory
();
workers_
[
i
]
->
SetThreadI
d
(
i
);
}
}
void
MultiExecutor
::
run_multi_e
xecutor
(
const
ProgramDesc
&
host_program
)
{
void
MultiExecutor
::
RunMultiE
xecutor
(
const
ProgramDesc
&
host_program
)
{
// thread binding here?
prepare_t
hreads
(
host_program
);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
_threads
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
t
rain
,
_workers
[
i
].
get
()));
PrepareT
hreads
(
host_program
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
T
rain
,
workers_
[
i
].
get
()));
}
for
(
auto
&
th
:
_threads
)
{
for
(
auto
&
th
:
threads_
)
{
th
.
join
();
}
}
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
929a9e80
...
...
@@ -36,137 +36,129 @@ class ExecutorThreadWorker {
public:
ExecutorThreadWorker
()
{}
virtual
~
ExecutorThreadWorker
()
{}
void
create_thread_scope
(
const
framework
::
ProgramDesc
&
program
);
void
set_datafeed
(
const
DataFeed
&
datafeed
);
void
set_thread_id
(
int
tid
);
void
create_thread_operators
(
const
framework
::
ProgramDesc
&
program
);
void
set_root_scope
(
Scope
*
g_scope
);
void
set_device
();
virtual
void
add_fid_set
();
void
set_comm_batch
(
int
comm_batch
)
{
_comm_batch
=
comm_batch
;
}
void
add_train_file
(
const
std
::
string
&
filename
);
void
set_main_program
(
const
ProgramDesc
&
main_program_desc
);
void
set_place
(
const
paddle
::
platform
::
Place
&
place
);
void
set_max_training_epoch
(
const
int
max_epoch
);
void
binding_datafeed_memory
();
void
set_model_prefix
(
const
std
::
string
&
prefix
)
{
_model_prefix
=
prefix
;
}
void
set_inspect_var_name
(
const
std
::
string
&
inspect_var_name
);
void
set_model_param_names
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
set_sparse_comm_data
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
set_datafeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
virtual
void
mpi_train
();
void
gpu_train
();
void
train
();
virtual
const
char
*
pick_one_file
();
void
update_epoch_num
();
virtual
void
set_dense_comm_tensor
(
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
void
SetDataFeed
(
const
DataFeed
&
datafeed
);
void
SetThreadId
(
int
tid
);
void
CreateThreadOperators
(
const
framework
::
ProgramDesc
&
program
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetDevice
();
virtual
void
AddFidSet
();
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
AddTrainFile
(
const
std
::
string
&
filename
);
void
SetMainProgram
(
const
ProgramDesc
&
main_program_desc
);
void
SetPlace
(
const
paddle
::
platform
::
Place
&
place
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
void
BindingDataFeedMemory
();
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
Train
();
virtual
const
char
*
PickOneFile
();
void
UpdateEpochNum
();
virtual
void
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{}
virtual
void
i
nitialize
()
{}
virtual
void
I
nitialize
()
{}
public:
static
std
::
mutex
_s_locker_for_pick_file
;
static
unsigned
int
_s_current_file_idx
;
static
size_t
_s_current_finished_file_cnt
;
static
unsigned
int
_s_current_epoch
;
static
int
_s_current_save_epoch
;
static
std
::
vector
<
std
::
string
>
_s_thread_filelist
;
// filelist
static
bool
_s_is_first_worker
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
vector
<
std
::
string
>
s_thread_filelist_
;
// filelist
static
bool
s_is_first_worker_
;
protected:
// thread index
int
_thread_id
;
// current training file
int
_cur_fileidx
;
int
thread_id_
;
// max epoch for each thread
unsigned
int
_max_epoch
;
unsigned
int
max_epoch_
;
// instances learned currently
int
_comm_batch
;
std
::
string
_model_prefix
;
std
::
vector
<
std
::
string
>
_op_names
;
int
comm_batch_
;
std
::
string
model_prefix_
;
std
::
vector
<
std
::
string
>
op_names_
;
// local ops for forward and backward
std
::
vector
<
OperatorBase
*>
_ops
;
std
::
vector
<
OperatorBase
*>
ops_
;
// main program for training
std
::
unique_ptr
<
framework
::
ProgramDesc
>
_main_program
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
main_program_
;
// binary data reader
std
::
shared_ptr
<
DataFeed
>
_local_reader
;
std
::
shared_ptr
<
DataFeed
>
local_reader_
;
std
::
string
_inspect_var_name
;
std
::
vector
<
std
::
string
>
_model_param_names
;
std
::
map
<
std
::
string
,
int
>
_sparse_comm_data
;
std
::
vector
<
int
>
_ids_buffer
;
std
::
string
inspect_var_name_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
// execution place
platform
::
Place
_place
;
platform
::
Place
place_
;
// root scope for model parameters
Scope
*
_root_scope
;
Scope
*
root_scope_
;
// a thread scope, father scope is global score which is shared
Scope
*
_thread_scope
;
Scope
*
thread_scope_
;
};
class
MultiExecutor
{
public:
explicit
MultiExecutor
(
const
platform
::
Place
&
place
);
virtual
~
MultiExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
load_desc_from_f
ile
(
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromF
ile
(
const
std
::
string
&
filename
);
void
init_root_s
cope
(
Scope
*
scope
);
void
set_inspect_var_n
ame
(
const
std
::
string
&
inspect_var_name
);
void
set_param_n
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
set_max_training_e
poch
(
const
int
max_epoch
);
Scope
*
get_root_scope
()
{
return
_root_scope
;
}
void
set_thread_n
um
(
const
int
thread_num
);
void
set_batch_size
(
const
int
batch_size
)
{
_batch_size
=
batch_size
;
}
void
set_filel
ist
(
const
char
*
filelist
);
void
set_filel
ist
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
set_datafeed_n
ame
(
const
char
*
feedname
);
void
set_data_feed_p
aram
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
_data_feed_param
=
feed_param
;
void
InitRootS
cope
(
Scope
*
scope
);
void
SetInspectVarN
ame
(
const
std
::
string
&
inspect_var_name
);
void
SetParamN
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetMaxTrainingE
poch
(
const
int
max_epoch
);
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
void
SetThreadN
um
(
const
int
thread_num
);
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetFileL
ist
(
const
char
*
filelist
);
void
SetFileL
ist
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
SetDataFeedN
ame
(
const
char
*
feedname
);
void
SetDataFeedP
aram
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
data_feed_param_
=
feed_param
;
}
void
set_comm_b
atch
(
int
comm_batch
)
{
_comm_batch
=
comm_batch
;
void
SetCommB
atch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
set_model_p
refix
(
const
std
::
string
&
model_prefix
);
void
set_dense_comm_t
ensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
void
set_sparse_comm_t
ensor
(
void
SetModelP
refix
(
const
std
::
string
&
model_prefix
);
void
SetDenseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
void
SetSparseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
);
void
set_sparse_comm_d
ata
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
prepare_t
hreads
(
const
framework
::
ProgramDesc
&
host_program
);
void
run_startup_p
rogram
(
const
framework
::
ProgramDesc
&
program
,
void
SetSparseCommD
ata
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
PrepareT
hreads
(
const
framework
::
ProgramDesc
&
host_program
);
void
RunStartupP
rogram
(
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
run_multi_e
xecutor
(
const
ProgramDesc
&
host_program
);
void
RunMultiE
xecutor
(
const
ProgramDesc
&
host_program
);
public:
unsigned
int
_thread_num
;
datafeed
::
DataFeedParameter
_data_feed_param
;
int
_max_epoch
;
int
_batch_size
;
int
_comm_batch
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
_workers
;
std
::
vector
<
std
::
thread
>
_threads
;
std
::
vector
<
std
::
string
>
_filelist
;
std
::
string
_inspect_var_name
;
std
::
vector
<
std
::
string
>
_model_param_names
;
std
::
vector
<
std
::
string
>
_dense_comm_tensor
;
std
::
vector
<
std
::
string
>
_sparse_comm_tensor
;
std
::
map
<
std
::
string
,
int
>
_sparse_comm_data
;
int
node_num
;
std
::
string
_model_prefix
;
ProgramDesc
_host_program
;
std
::
string
_feed_name
;
Scope
*
_root_scope
;
platform
::
Place
_place
;
unsigned
int
thread_num_
;
datafeed
::
DataFeedParameter
data_feed_param_
;
int
max_epoch_
;
int
batch_size_
;
int
comm_batch_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
string
>
filelist_
;
std
::
string
inspect_var_name_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
vector
<
std
::
string
>
dense_comm_tensor_
;
std
::
vector
<
std
::
string
>
sparse_comm_tensor_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
std
::
string
model_prefix_
;
std
::
string
feed_name_
;
Scope
*
root_scope_
;
platform
::
Place
place_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
929a9e80
...
...
@@ -38,111 +38,114 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace
paddle
{
namespace
framework
{
void
TextClassDataFeed
::
i
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
void
TextClassDataFeed
::
I
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
// hard coding for a specific datafeed
_feed_vec
.
resize
(
2
);
// _feed_vec[0].reset(new LoDTensor);
// _feed_vec[1].reset(new LoDTensor);
_all_slot_ids
=
{
0
,
1
};
_use_slot_ids
=
{
0
,
1
};
_use_slot_alias
=
{
"words"
,
"label"
};
_file_content_buffer_host
.
reset
(
new
char
[
200
*
1024
*
1024
],
feed_vec_
.
resize
(
2
);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
all_slot_ids_
=
{
0
,
1
};
use_slot_ids_
=
{
0
,
1
};
use_slot_alias_
=
{
"words"
,
"label"
};
file_content_buffer_host_
.
reset
(
new
char
[
200
*
1024
*
1024
],
[](
char
*
p
)
{
delete
[]
p
;});
_file_content_buffer
=
_file_content_buffer_host
.
get
();
_file_content_buffer_ptr
=
_file_content_buffer
;
_batch_id_host
.
reset
(
new
int
[
10240
*
1024
],
file_content_buffer_
=
file_content_buffer_host_
.
get
();
file_content_buffer_ptr_
=
file_content_buffer_
;
batch_id_host_
.
reset
(
new
int
[
10240
*
1024
],
[](
int
*
p
)
{
delete
[]
p
;});
// max word num in a batch
_label_host
.
reset
(
new
int
[
10240
],
batch_id_buffer_
=
batch_id_host_
.
get
();
label_host_
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
_batch_id_buffer
=
_batch_id_host
.
get
();
_label_ptr
=
_label_host
.
get
();
label_ptr_
=
label_host_
.
get
();
}
// todo: use elegant implemention for this function
bool
TextClassDataFeed
::
read_b
atch
()
{
bool
TextClassDataFeed
::
ReadB
atch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
llen
=
0
;
int
inst_idx
=
0
;
offset
.
resize
(
_batch_size
+
1
);
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
_batch_size
)
{
while
(
inst_idx
<
batch_size_
)
{
int
ptr_offset
=
0
;
if
(
_file_content_buffer_ptr
-
_file_content_buffer
>=
_file_size
)
{
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
break
;
}
memcpy
(
reinterpret_cast
<
char
*>
(
&
llen
),
_file_content_buffer_ptr
+
ptr_offset
,
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
memcpy
(
reinterpret_cast
<
char
*>
(
_batch_id_buffer
+
tlen
),
_file_content_buffer_ptr
+
ptr_offset
,
memcpy
(
reinterpret_cast
<
char
*>
(
batch_id_buffer_
+
tlen
),
file_content_buffer_ptr_
+
ptr_offset
,
llen
*
sizeof
(
int
));
tlen
+=
llen
;
offset
[
inst_idx
+
1
]
=
offset
[
inst_idx
]
+
llen
;
ptr_offset
+=
sizeof
(
int
)
*
llen
;
memcpy
(
reinterpret_cast
<
char
*>
(
_label_ptr
+
inst_idx
),
_file_content_buffer_ptr
+
ptr_offset
,
memcpy
(
reinterpret_cast
<
char
*>
(
label_ptr_
+
inst_idx
),
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
_file_content_buffer_ptr
+=
ptr_offset
;
file_content_buffer_ptr_
+=
ptr_offset
;
inst_idx
++
;
}
if
(
inst_idx
!=
_batch_size
)
{
if
(
inst_idx
!=
batch_size_
)
{
return
false
;
}
LoD
input_lod
{
offset
};
paddle
::
framework
::
Vector
<
size_t
>
label_offset
;
label_offset
.
resize
(
_batch_size
+
1
);
for
(
int
i
=
0
;
i
<=
_batch_size
;
++
i
)
{
label_offset
.
resize
(
batch_size_
+
1
);
for
(
int
i
=
0
;
i
<=
batch_size_
;
++
i
)
{
label_offset
[
i
]
=
i
;
}
LoD
label_lod
{
label_offset
};
int64_t
*
input_ptr
=
_feed_vec
[
0
]
->
mutable_data
<
int64_t
>
(
int64_t
*
input_ptr
=
feed_vec_
[
0
]
->
mutable_data
<
int64_t
>
(
{
static_cast
<
int64_t
>
(
offset
.
back
()),
1
},
platform
::
CPUPlace
());
int64_t
*
label_ptr
=
_feed_vec
[
1
]
->
mutable_data
<
int64_t
>
({
_batch_size
,
1
},
int64_t
*
label_ptr
=
feed_vec_
[
1
]
->
mutable_data
<
int64_t
>
({
batch_size_
,
1
},
platform
::
CPUPlace
());
for
(
unsigned
int
i
=
0
;
i
<
offset
.
back
();
++
i
)
{
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
_batch_id_buffer
[
i
]);
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
batch_id_buffer_
[
i
]);
}
for
(
int
i
=
0
;
i
<
_batch_size
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
_label_ptr
[
i
]);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
label_ptr_
[
i
]);
}
_feed_vec
[
0
]
->
set_lod
(
input_lod
);
_feed_vec
[
1
]
->
set_lod
(
label_lod
);
feed_vec_
[
0
]
->
set_lod
(
input_lod
);
feed_vec_
[
1
]
->
set_lod
(
label_lod
);
return
true
;
}
void
TextClassDataFeed
::
add_feed_v
ar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
_use_slot_alias
.
size
();
++
i
)
{
if
(
name
==
_use_slot_alias
[
i
])
{
_feed_vec
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
void
TextClassDataFeed
::
AddFeedV
ar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
}
}
}
bool
TextClassDataFeed
::
set_f
ile
(
const
char
*
filename
)
{
bool
TextClassDataFeed
::
SetF
ile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
int
filesize
=
read_whole_file
(
filename
,
_file_content_buffer
);
int
filesize
=
ReadWholeFile
(
filename
,
file_content_buffer_
);
// todo , remove magic number
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
}
_file_content_buffer_ptr
=
_file_content_buffer
;
_file_size
=
filesize
;
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
return
true
;
}
int
TextClassDataFeed
::
read_whole_f
ile
(
const
std
::
string
&
filename
,
int
TextClassDataFeed
::
ReadWholeF
ile
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
std
::
ifstream
ifs
(
filename
.
c_str
(),
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
929a9e80
...
...
@@ -35,47 +35,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
uint64_t
FeatureKey
;
struct
FeatureItem
{
FeatureItem
()
{}
FeatureItem
(
FeatureKey
sign_
,
uint16_t
slot_
)
{
sign
()
=
sign_
;
slot
()
=
slot_
;
}
FeatureKey
&
sign
()
{
return
*
(
reinterpret_cast
<
FeatureKey
*>
(
sign_buffer
()));
}
const
FeatureKey
&
sign
()
const
{
return
*
(
const
FeatureKey
*
)
sign_buffer
();
}
uint16_t
&
slot
()
{
return
_slot
;
}
const
uint16_t
&
slot
()
const
{
return
_slot
;
}
private:
char
_sign
[
sizeof
(
FeatureKey
)];
uint16_t
_slot
;
char
*
sign_buffer
()
const
{
return
(
char
*
)
_sign
;
}
};
// Record(average:14031B) is smaller than Sample(average:16530B)
struct
Record
{
int
show
,
click
;
std
::
vector
<
FeatureItem
>
feas
;
std
::
string
lineid
;
std
::
string
tags
;
};
struct
Gauc
{
int
show
,
click
;
uint64_t
fea
;
...
...
@@ -89,241 +48,83 @@ struct Instance {
std
::
vector
<
Gauc
>
gauc_vec
;
};
struct
Sample
{
uint64_t
label
;
std
::
map
<
uint16_t
,
std
::
vector
<
uint64_t
>>
feas
;
bool
from_string
(
const
std
::
string
&
input
,
const
std
::
set
<
uint32_t
>&
slots
)
{
size_t
end
=
input
.
find_first_of
(
' '
);
if
(
end
==
std
::
string
::
npos
)
{
LOG
(
ERROR
)
<<
"[ERROR] Fail in parsing:"
<<
input
;
return
false
;
}
label
=
input
[
end
+
3
]
-
'0'
;
CHECK
(
label
==
0
||
label
==
1
)
<<
"invalid label:"
<<
label
;
std
::
stringstream
ss
(
input
);
std
::
string
token
;
uint16_t
slot_id
=
0
;
uint64_t
feature_id
=
0
;
int
num_nonfeas_token
=
0
;
std
::
ostringstream
os
;
while
(
ss
>>
token
)
{
size_t
end
=
token
.
find_first_of
(
':'
);
if
(
end
==
std
::
string
::
npos
)
{
++
num_nonfeas_token
;
continue
;
}
try
{
slot_id
=
stoi
(
token
.
substr
(
end
+
1
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing slot id:"
<<
token
;
return
false
;
}
try
{
feature_id
=
stoull
(
token
.
substr
(
0
,
end
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing feature id:"
<<
token
;
return
false
;
}
if
(
slot_id
<=
0
)
{
LOG
(
ERROR
)
<<
"invalid slot:"
<<
slot_id
<<
" feasign:"
<<
feature_id
<<
" line:"
<<
input
;
return
false
;
}
if
(
slots
.
find
(
slot_id
)
==
slots
.
end
())
{
continue
;
}
feas
[
slot_id
].
push_back
(
feature_id
);
}
if
(
num_nonfeas_token
!=
4
)
{
LOG
(
ERROR
)
<<
"Format error. Invalid number of non-feasign token:"
<<
num_nonfeas_token
;
return
false
;
}
return
true
;
}
};
struct
TeacherStudentSample
{
uint64_t
label
;
std
::
map
<
uint16_t
,
std
::
vector
<
uint64_t
>>
feas
;
float
q_score
;
void
print
()
{
LOG
(
ERROR
)
<<
"label: "
<<
label
<<
" score: "
<<
q_score
;
for
(
auto
&
slot
:
feas
)
{
for
(
auto
&
fea
:
slot
.
second
)
{
LOG
(
ERROR
)
<<
"slot: "
<<
slot
.
first
<<
" fea: "
<<
fea
;
}
}
}
bool
from_string
(
const
std
::
string
&
input
,
const
std
::
set
<
uint32_t
>&
slots
,
Gauc
&
gauc
)
{
// NOLINT
size_t
end
=
input
.
find_first_of
(
' '
);
if
(
end
==
std
::
string
::
npos
)
{
LOG
(
ERROR
)
<<
"[ERROR] Fail in parsing:"
<<
input
;
return
false
;
}
label
=
input
[
end
+
3
]
-
'0'
;
CHECK
(
label
==
0
||
label
==
1
)
<<
"invalid label:"
<<
label
;
gauc
.
show
=
1
;
gauc
.
click
=
label
;
gauc
.
lineid
=
input
.
substr
(
0
,
end
);
gauc
.
fea
=
0
;
size_t
dnn_start
=
input
.
find
(
"*"
);
if
(
dnn_start
==
std
::
string
::
npos
)
{
q_score
=
-
1.0
;
}
else
{
dnn_start
+=
1
;
size_t
dnn_end
=
input
.
find
(
' '
,
dnn_start
);
q_score
=
static_cast
<
float
>
(
atof
(
input
.
substr
(
dnn_start
,
dnn_end
-
dnn_start
).
c_str
()));
}
size_t
head_pos
=
input
.
find
(
"
\t
"
);
std
::
string
head
=
input
.
substr
(
0
,
head_pos
);
std
::
stringstream
ss
(
head
);
std
::
string
token
;
uint16_t
slot_id
=
0
;
uint64_t
feature_id
=
0
;
int
num_nonfeas_token
=
0
;
std
::
ostringstream
os
;
while
(
ss
>>
token
)
{
size_t
end
=
token
.
find_first_of
(
':'
);
if
(
end
==
std
::
string
::
npos
)
{
++
num_nonfeas_token
;
continue
;
}
try
{
slot_id
=
stoi
(
token
.
substr
(
end
+
1
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing slot id:"
<<
token
;
return
false
;
}
try
{
feature_id
=
stoull
(
token
.
substr
(
0
,
end
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing feature id:"
<<
token
;
return
false
;
}
if
(
slot_id
<=
0
)
{
LOG
(
ERROR
)
<<
"invalid slot:"
<<
slot_id
<<
" feasign:"
<<
feature_id
<<
" line:"
<<
input
;
return
false
;
}
if
(
slots
.
find
(
slot_id
)
==
slots
.
end
())
{
continue
;
}
if
(
slot_id
==
6048
)
{
gauc
.
fea
=
feature_id
;
}
feas
[
slot_id
].
push_back
(
feature_id
);
}
if
(
num_nonfeas_token
!=
4
)
{
LOG
(
ERROR
)
<<
"Format error. Invalid number of non-feasign token:"
<<
num_nonfeas_token
;
return
false
;
}
return
true
;
}
};
class
DataFeed
{
public:
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
i
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
=
0
;
virtual
void
I
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
=
0
;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual
bool
check_f
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
set_f
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
read_b
atch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
get_all_slot_i
ds
()
{
return
_all_slot_ids
;
virtual
bool
CheckF
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
SetF
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
ReadB
atch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
GetAllSlotI
ds
()
{
return
all_slot_ids_
;
}
virtual
const
std
::
vector
<
uint16_t
>&
get_use_slot_i
ds
()
{
return
_use_slot_ids
;
virtual
const
std
::
vector
<
uint16_t
>&
GetUseSlotI
ds
()
{
return
use_slot_ids_
;
}
virtual
const
std
::
vector
<
std
::
string
>&
get_use_slot_a
lias
()
{
return
_use_slot_alias
;
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlotA
lias
()
{
return
use_slot_alias_
;
}
virtual
void
add_feed_v
ar
(
Variable
*
var
,
virtual
void
AddFeedV
ar
(
Variable
*
var
,
const
std
::
string
&
name
)
=
0
;
virtual
void
bind_s
cope
(
Scope
*
scope
)
=
0
;
virtual
void
set_batch_size
(
int
batch
)
{
_default_batch_size
=
batch
;
}
virtual
int
get_batch_size
()
{
return
_batch_size
;
}
virtual
void
set_buffer_s
ize
(
int
buffer_size
)
{}
virtual
void
BindS
cope
(
Scope
*
scope
)
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
void
SetBufferS
ize
(
int
buffer_size
)
{}
std
::
vector
<
LoDTensor
*>&
get_feed_v
ec
()
{
return
_feed_vec
;
std
::
vector
<
LoDTensor
*>&
GetFeedV
ec
()
{
return
feed_vec_
;
}
virtual
std
::
vector
<
LoDTensor
*>&
get_feed_v
ec
(
const
Instance
&
ins
)
{
virtual
std
::
vector
<
LoDTensor
*>&
GetFeedV
ec
(
const
Instance
&
ins
)
{
LOG
(
ERROR
)
<<
"use defalut get_feed_vec"
;
return
_feed_vec
;
return
feed_vec_
;
}
protected:
std
::
vector
<
uint16_t
>
_all_slot_ids
;
std
::
vector
<
uint16_t
>
_use_slot_ids
;
std
::
vector
<
std
::
string
>
_use_slot_alias
;
std
::
vector
<
LoDTensor
*>
_feed_vec
;
int
_default_batch_size
;
int
_batch_size
;
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
std
::
vector
<
std
::
string
>
use_slot_alias_
;
std
::
vector
<
LoDTensor
*>
feed_vec_
;
int
default_batch_size_
;
int
batch_size_
;
};
class
TextClassDataFeed
:
public
DataFeed
{
public:
virtual
~
TextClassDataFeed
()
{}
virtual
void
i
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
);
virtual
bool
read_b
atch
();
virtual
void
add_feed_v
ar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
bind_s
cope
(
Scope
*
scope
)
{}
virtual
bool
set_f
ile
(
const
char
*
filename
);
virtual
void
I
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
);
virtual
bool
ReadB
atch
();
virtual
void
AddFeedV
ar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
BindS
cope
(
Scope
*
scope
)
{}
virtual
bool
SetF
ile
(
const
char
*
filename
);
virtual
bool
check_f
ile
(
const
char
*
filename
)
{
virtual
bool
CheckF
ile
(
const
char
*
filename
)
{
// TODO(xxx)
return
false
;
}
void
set_batch_size
(
int
batch
)
{
_batch_size
=
batch
;}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
private:
int
read_whole_f
ile
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
_file_content_buffer
;
char
*
_file_content_buffer_ptr
;
int
*
_batch_id_buffer
;
int
*
_label_ptr
;
int
_file_size
;
std
::
vector
<
std
::
string
>
_names
;
std
::
shared_ptr
<
char
>
_file_content_buffer_host
;
std
::
shared_ptr
<
int
>
_batch_id_host
;
std
::
shared_ptr
<
int
>
_label_host
;
int
ReadWholeF
ile
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/datafeed_creator.cc
浏览文件 @
929a9e80
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/datafeed_creator.h"
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
create_dataf
eed
(
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
CreateDataF
eed
(
const
char
*
datafeed_class
)
{
if
(
strcmp
(
datafeed_class
,
"TextClass"
)
==
0
)
{
return
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
(
...
...
paddle/fluid/framework/datafeed_creator.h
浏览文件 @
929a9e80
...
...
@@ -17,6 +17,6 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
create_dataf
eed
(
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
CreateDataF
eed
(
const
char
*
datafeed_class
);
#endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录