Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
929a9e80
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
929a9e80
编写于
10月 25, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Google naming conventions
上级
c555948c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
320 addition
and
524 deletion
+320
-524
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+145
-145
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+85
-93
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+46
-43
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+42
-241
paddle/fluid/framework/datafeed_creator.cc
paddle/fluid/framework/datafeed_creator.cc
+1
-1
paddle/fluid/framework/datafeed_creator.h
paddle/fluid/framework/datafeed_creator.h
+1
-1
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
929a9e80
...
...
@@ -37,13 +37,13 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
std
::
mutex
ExecutorThreadWorker
::
_s_locker_for_pick_file
;
unsigned
int
ExecutorThreadWorker
::
_s_current_file_idx
=
0
;
size_t
ExecutorThreadWorker
::
_s_current_finished_file_cnt
=
0
;
unsigned
int
ExecutorThreadWorker
::
_s_current_epoch
=
0
;
int
ExecutorThreadWorker
::
_s_current_save_epoch
=
0
;
bool
ExecutorThreadWorker
::
_s_is_first_worker
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
_s_thread_filelist
;
std
::
mutex
ExecutorThreadWorker
::
s_locker_for_pick_file_
;
unsigned
int
ExecutorThreadWorker
::
s_current_file_idx_
=
0
;
size_t
ExecutorThreadWorker
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
ExecutorThreadWorker
::
s_current_epoch_
=
0
;
int
ExecutorThreadWorker
::
s_current_save_epoch_
=
0
;
bool
ExecutorThreadWorker
::
s_is_first_worker_
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
s_thread_filelist_
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
...
...
@@ -142,33 +142,33 @@ static void save_model(
}
// end save_model
void
ExecutorThreadWorker
::
add_train_f
ile
(
const
std
::
string
&
file
)
{
_s_thread_filelist
.
push_back
(
file
);
void
ExecutorThreadWorker
::
AddTrainF
ile
(
const
std
::
string
&
file
)
{
s_thread_filelist_
.
push_back
(
file
);
}
void
ExecutorThreadWorker
::
create_thread_o
perators
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadO
perators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
_op_names
.
clear
();
op_names_
.
clear
();
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
_op_names
.
push_back
(
op_desc
->
Type
());
op_names_
.
push_back
(
op_desc
->
Type
());
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
_ops
.
push_back
(
local_op_ptr
);
ops_
.
push_back
(
local_op_ptr
);
continue
;
}
}
void
ExecutorThreadWorker
::
create_thread_s
cope
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadS
cope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
_thread_scope
=
&
_root_scope
->
NewScope
();
thread_scope_
=
&
root_scope_
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
_root_scope
->
Var
(
var
->
Name
());
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("create Persistable var[%s] finished",
// var->Name().c_str());
}
else
{
auto
*
ptr
=
_thread_scope
->
Var
(
var
->
Name
());
auto
*
ptr
=
thread_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("create unpersistable var[%s] finished",
// var->Name().c_str());
...
...
@@ -176,33 +176,33 @@ void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) {
}
}
void
ExecutorThreadWorker
::
set_dataf
eed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
_local_reader
=
datafeed
;
void
ExecutorThreadWorker
::
SetDataF
eed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
local_reader_
=
datafeed
;
}
void
ExecutorThreadWorker
::
binding_datafeed_m
emory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
_local_reader
->
get_use_slot_a
lias
();
void
ExecutorThreadWorker
::
BindingDataFeedM
emory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
local_reader_
->
GetUseSlotA
lias
();
for
(
auto
name
:
input_feed
)
{
_local_reader
->
add_feed_var
(
_thread_scope
->
Var
(
name
),
name
);
local_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
void
ExecutorThreadWorker
::
set_inspect_var_n
ame
(
void
ExecutorThreadWorker
::
SetInspectVarN
ame
(
const
std
::
string
&
inspect_var_name
)
{
_inspect_var_name
=
inspect_var_name
;
inspect_var_name_
=
inspect_var_name
;
}
void
ExecutorThreadWorker
::
set_model_param_n
ames
(
void
ExecutorThreadWorker
::
SetModelParamN
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
_model_param_names
=
param_names
;
model_param_names_
=
param_names
;
}
void
ExecutorThreadWorker
::
set_sparse_comm_d
ata
(
void
ExecutorThreadWorker
::
SetSparseCommD
ata
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
)
{
_sparse_comm_data
=
param_names
;
sparse_comm_data_
=
param_names
;
}
void
ExecutorThreadWorker
::
set_d
evice
()
{
void
ExecutorThreadWorker
::
SetD
evice
()
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
...
...
@@ -214,7 +214,7 @@ void ExecutorThreadWorker::set_device() {
42
,
43
,
44
,
45
,
46
,
47
};
unsigned
int
i
=
this
->
_thread_id
;
unsigned
int
i
=
this
->
thread_id_
;
if
(
i
<
sizeof
(
priority
)
/
sizeof
(
unsigned
))
{
unsigned
proc
=
priority
[
i
];
...
...
@@ -235,55 +235,55 @@ void ExecutorThreadWorker::set_device() {
}
}
void
ExecutorThreadWorker
::
update_epoch_n
um
()
{
_s_current_finished_file_cnt
++
;
void
ExecutorThreadWorker
::
UpdateEpochN
um
()
{
s_current_finished_file_cnt_
++
;
if
(
_s_current_finished_file_cnt
>=
_s_thread_filelist
.
size
())
{
_s_current_finished_file_cnt
=
0
;
_s_current_epoch
++
;
if
(
s_current_finished_file_cnt_
>=
s_thread_filelist_
.
size
())
{
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
}
}
const
char
*
ExecutorThreadWorker
::
pick_one_f
ile
()
{
const
char
*
ExecutorThreadWorker
::
PickOneF
ile
()
{
std
::
string
file_to_be_preocessed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
_s_locker_for_pick_file
);
if
(
_s_current_file_idx
>=
_s_thread_filelist
.
size
())
{
std
::
random_shuffle
(
_s_thread_filelist
.
begin
(),
_s_thread_filelist
.
end
());
_s_current_file_idx
=
0
;
//
_s_current_epoch
++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
_thread_id
<<
": finish traing for epoch "
<<
_s_current_epoch
+
1
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
if
(
s_current_file_idx_
>=
s_thread_filelist_
.
size
())
{
std
::
random_shuffle
(
s_thread_filelist_
.
begin
(),
s_thread_filelist_
.
end
());
s_current_file_idx_
=
0
;
//
s_current_epoch_
++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
}
file_to_be_preocessed
=
_s_thread_filelist
[
_s_current_file_idx
];
file_to_be_preocessed
=
s_thread_filelist_
[
s_current_file_idx_
];
_s_current_file_idx
++
;
s_current_file_idx_
++
;
return
file_to_be_preocessed
.
c_str
();
}
void
ExecutorThreadWorker
::
t
rain
()
{
void
ExecutorThreadWorker
::
T
rain
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
set_d
evice
();
SetD
evice
();
#ifdef LOCAL_PROF
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
// int total_batch = 0;
for
(
auto
&
op
:
_ops
)
{
for
(
auto
&
op
:
ops_
)
{
op_name
.
push_back
(
op
->
Type
());
}
op_total_time
.
resize
(
_ops
.
size
());
op_total_time
.
resize
(
ops_
.
size
());
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
op_total_time
[
i
]
=
0.0
;
}
#endif
std
::
string
inspect_key
=
"inspect"
;
if
(
!
_inspect_var_name
.
empty
())
{
inspect_key
=
_inspect_var_name
.
substr
(
0
,
_inspect_var_name
.
find_first_of
(
'_'
));
if
(
!
inspect_var_name_
.
empty
())
{
inspect_key
=
inspect_var_name_
.
substr
(
0
,
inspect_var_name_
.
find_first_of
(
'_'
));
}
for
(
unsigned
i
=
0
;
i
<
_max_epoch
;
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
max_epoch_
;
++
i
)
{
LOG
(
ERROR
)
<<
"epoch: "
<<
i
;
#ifdef LOCAL_PROF
Timer
timeline
;
...
...
@@ -292,14 +292,14 @@ void ExecutorThreadWorker::train() {
#endif
float
total_inspect
=
0
;
int
batch_num
=
1
;
while
(
i
==
_s_current_epoch
)
{
const
char
*
filename
=
pick_one_f
ile
();
_local_reader
->
set_f
ile
(
filename
);
while
(
i
==
s_current_epoch_
)
{
const
char
*
filename
=
PickOneF
ile
();
local_reader_
->
SetF
ile
(
filename
);
while
(
true
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
bool
flag
=
_local_reader
->
read_b
atch
();
bool
flag
=
local_reader_
->
ReadB
atch
();
if
(
!
flag
)
{
break
;
}
...
...
@@ -312,11 +312,11 @@ void ExecutorThreadWorker::train() {
break
;
}
for
(
unsigned
int
i
=
0
;
i
<
_ops
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
_ops
[
i
]
->
Run
(
*
_thread_scope
,
_place
);
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
#ifdef LOCAL_PROF
timeline
.
pause
();
op_total_time
[
i
]
+=
timeline
.
elapsed_sec
();
...
...
@@ -325,17 +325,17 @@ void ExecutorThreadWorker::train() {
}
batch_num
++
;
float
avg_inspect
=
0.0
;
if
(
!
_inspect_var_name
.
empty
())
{
avg_inspect
=
_thread_scope
->
FindVar
(
_inspect_var_name
)
if
(
!
inspect_var_name_
.
empty
())
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_name_
)
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
}
total_inspect
+=
avg_inspect
;
_thread_scope
->
DropKids
();
thread_scope_
->
DropKids
();
}
update_epoch_n
um
();
UpdateEpochN
um
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
i
+
1
<<
" called: "
<<
memory
::
memory_usage
(
_place
);
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
#ifdef LOCAL_PROF
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
...
...
@@ -354,12 +354,12 @@ void ExecutorThreadWorker::train() {
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
;
#endif
if
(
_thread_id
==
0
)
{
if
(
thread_id_
==
0
)
{
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
_model_prefix
.
c_str
(),
model_prefix_
.
c_str
(),
i
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
...
...
@@ -367,55 +367,55 @@ void ExecutorThreadWorker::train() {
//
// currently comment it
LOG
(
ERROR
)
<<
"Going to save model "
<<
modelfile
;
save_model
(
_main_program
,
_thread_scope
,
_model_param_names
,
save_model
(
main_program_
,
thread_scope_
,
model_param_names_
,
model_filename
,
true
);
}
}
}
void
ExecutorThreadWorker
::
set_thread_i
d
(
int
tid
)
{
_thread_id
=
tid
;
void
ExecutorThreadWorker
::
SetThreadI
d
(
int
tid
)
{
thread_id_
=
tid
;
}
void
ExecutorThreadWorker
::
set_p
lace
(
const
platform
::
Place
&
place
)
{
_place
=
place
;
void
ExecutorThreadWorker
::
SetP
lace
(
const
platform
::
Place
&
place
)
{
place_
=
place
;
}
void
ExecutorThreadWorker
::
set_main_p
rogram
(
void
ExecutorThreadWorker
::
SetMainP
rogram
(
const
ProgramDesc
&
main_program_desc
)
{
_main_program
.
reset
(
new
ProgramDesc
(
main_program_desc
));
main_program_
.
reset
(
new
ProgramDesc
(
main_program_desc
));
}
void
ExecutorThreadWorker
::
set_root_s
cope
(
Scope
*
g_scope
)
{
_root_scope
=
g_scope
;
void
ExecutorThreadWorker
::
SetRootS
cope
(
Scope
*
g_scope
)
{
root_scope_
=
g_scope
;
}
void
ExecutorThreadWorker
::
set_max_training_e
poch
(
int
max_epoch
)
{
_max_epoch
=
max_epoch
;
void
ExecutorThreadWorker
::
SetMaxTrainingE
poch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
MultiExecutor
::
MultiExecutor
(
const
platform
::
Place
&
place
)
:
_place
(
place
)
{}
MultiExecutor
::
MultiExecutor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
void
MultiExecutor
::
init_root_s
cope
(
Scope
*
scope
)
{
_root_scope
=
scope
;
void
MultiExecutor
::
InitRootS
cope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
MultiExecutor
::
set_max_training_e
poch
(
int
max_epoch
)
{
_max_epoch
=
max_epoch
;
void
MultiExecutor
::
SetMaxTrainingE
poch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
void
MultiExecutor
::
set_datafeed_n
ame
(
const
char
*
feedname
)
{
_feed_name
=
std
::
string
(
feedname
);
void
MultiExecutor
::
SetDataFeedN
ame
(
const
char
*
feedname
)
{
feed_name_
=
std
::
string
(
feedname
);
}
void
MultiExecutor
::
set_model_p
refix
(
const
std
::
string
&
model_prefix
)
{
_model_prefix
=
model_prefix
;
void
MultiExecutor
::
SetModelP
refix
(
const
std
::
string
&
model_prefix
)
{
model_prefix_
=
model_prefix
;
}
void
MultiExecutor
::
run_startup_p
rogram
(
const
ProgramDesc
&
program
,
void
MultiExecutor
::
RunStartupP
rogram
(
const
ProgramDesc
&
program
,
Scope
*
scope
)
{
auto
&
block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
block
.
AllVars
())
{
...
...
@@ -447,7 +447,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// param_dict.size(), ops.size());
for
(
auto
&
op
:
ops
)
{
op
->
Run
(
*
scope
,
_place
);
op
->
Run
(
*
scope
,
place_
);
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for
(
auto
&
op
:
ops
)
{
...
...
@@ -456,7 +456,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// LOGERR("run startup program done.");
}
std
::
unique_ptr
<
ProgramDesc
>
MultiExecutor
::
load_desc_from_f
ile
(
std
::
unique_ptr
<
ProgramDesc
>
MultiExecutor
::
LoadDescFromF
ile
(
const
std
::
string
&
f
)
{
std
::
string
program_desc_str
;
read_binary_file
(
f
,
&
program_desc_str
);
...
...
@@ -464,102 +464,102 @@ std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file(
return
program
;
}
void
MultiExecutor
::
set_dense_comm_t
ensor
(
void
MultiExecutor
::
SetDenseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
)
{
_dense_comm_tensor
.
resize
(
dense_comm_tensor
.
size
());
dense_comm_tensor_
.
resize
(
dense_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
dense_comm_tensor
.
size
();
++
i
)
{
_dense_comm_tensor
[
i
]
=
dense_comm_tensor
[
i
];
dense_comm_tensor_
[
i
]
=
dense_comm_tensor
[
i
];
}
}
void
MultiExecutor
::
set_sparse_comm_t
ensor
(
void
MultiExecutor
::
SetSparseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
)
{
_sparse_comm_tensor
.
resize
(
sparse_comm_tensor
.
size
());
sparse_comm_tensor_
.
resize
(
sparse_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
sparse_comm_tensor
.
size
();
++
i
)
{
_sparse_comm_tensor
[
i
]
=
sparse_comm_tensor
[
i
];
sparse_comm_tensor_
[
i
]
=
sparse_comm_tensor
[
i
];
}
}
void
MultiExecutor
::
set_sparse_comm_d
ata
(
void
MultiExecutor
::
SetSparseCommD
ata
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
)
{
_sparse_comm_data
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
_sparse_comm_data
.
size
();
sparse_comm_data_
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
sparse_comm_data_
.
size
();
}
void
MultiExecutor
::
set_filel
ist
(
const
char
*
filelist
)
{
_filelist
.
clear
();
void
MultiExecutor
::
SetFileL
ist
(
const
char
*
filelist
)
{
filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
_filelist
.
push_back
(
filename
);
filelist_
.
push_back
(
filename
);
}
fin
.
close
();
}
void
MultiExecutor
::
set_filel
ist
(
std
::
vector
<
std
::
string
>
tfiles
)
{
_filelist
.
clear
();
_filelist
.
insert
(
_filelist
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
void
MultiExecutor
::
SetFileL
ist
(
std
::
vector
<
std
::
string
>
tfiles
)
{
filelist_
.
clear
();
filelist_
.
insert
(
filelist_
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
return
;
}
void
MultiExecutor
::
set_inspect_var_n
ame
(
const
std
::
string
&
inspect_var_name
)
{
_inspect_var_name
=
inspect_var_name
;
void
MultiExecutor
::
SetInspectVarN
ame
(
const
std
::
string
&
inspect_var_name
)
{
inspect_var_name_
=
inspect_var_name
;
}
void
MultiExecutor
::
set_param_n
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
_model_param_names
=
param_names
;
void
MultiExecutor
::
SetParamN
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
model_param_names_
=
param_names
;
}
void
MultiExecutor
::
set_thread_n
um
(
const
int
thread_num
)
{
_thread_num
=
thread_num
;
void
MultiExecutor
::
SetThreadN
um
(
const
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
MultiExecutor
::
prepare_t
hreads
(
const
ProgramDesc
&
host_program
)
{
_workers
.
resize
(
_thread_num
);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
_workers
[
i
].
reset
(
new
ExecutorThreadWorker
);
_workers
[
i
]
->
set_thread_i
d
(
i
);
_workers
[
i
]
->
create_thread_o
perators
(
host_program
);
_workers
[
i
]
->
set_root_scope
(
_root_scope
);
_workers
[
i
]
->
set_place
(
_place
);
_workers
[
i
]
->
set_max_training_epoch
(
_max_epoch
);
_workers
[
i
]
->
create_thread_s
cope
(
host_program
);
_workers
[
i
]
->
set_inspect_var_name
(
_inspect_var_name
);
_workers
[
i
]
->
set_model_param_names
(
_model_param_names
);
_workers
[
i
]
->
set_sparse_comm_data
(
_sparse_comm_data
);
_workers
[
i
]
->
set_main_p
rogram
(
host_program
);
_workers
[
i
]
->
set_model_prefix
(
_model_prefix
);
void
MultiExecutor
::
PrepareT
hreads
(
const
ProgramDesc
&
host_program
)
{
workers_
.
resize
(
thread_num_
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
]
->
SetThreadI
d
(
i
);
workers_
[
i
]
->
CreateThreadO
perators
(
host_program
);
workers_
[
i
]
->
SetRootScope
(
root_scope_
);
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
CreateThreadS
cope
(
host_program
);
workers_
[
i
]
->
SetInspectVarName
(
inspect_var_name_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetSparseCommData
(
sparse_comm_data_
);
workers_
[
i
]
->
SetMainP
rogram
(
host_program
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
}
for
(
unsigned
i
=
0
;
i
<
_filelist
.
size
();
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
filelist_
.
size
();
++
i
)
{
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
_workers
[
0
]
->
add_train_file
(
_filelist
[
i
]);
workers_
[
0
]
->
AddTrainFile
(
filelist_
[
i
]);
}
// mpi_wrapper::ModelParam model_param(true);
//
_workers
[0]->register_parallel_training_param(model_param);
//
workers_
[0]->register_parallel_training_param(model_param);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
// new a datafeed here
std
::
shared_ptr
<
DataFeed
>
local_feed
=
create_datafeed
(
_feed_name
.
c_str
());
local_feed
->
init
(
_data_feed_param
);
local_feed
->
set_batch_size
(
_batch_size
);
_workers
[
i
]
->
set_dataf
eed
(
local_feed
);
_workers
[
i
]
->
binding_datafeed_m
emory
();
_workers
[
i
]
->
set_thread_i
d
(
i
);
std
::
shared_ptr
<
DataFeed
>
local_feed
=
CreateDataFeed
(
feed_name_
.
c_str
());
local_feed
->
Init
(
data_feed_param_
);
local_feed
->
SetBatchSize
(
batch_size_
);
workers_
[
i
]
->
SetDataF
eed
(
local_feed
);
workers_
[
i
]
->
BindingDataFeedM
emory
();
workers_
[
i
]
->
SetThreadI
d
(
i
);
}
}
void
MultiExecutor
::
run_multi_e
xecutor
(
const
ProgramDesc
&
host_program
)
{
void
MultiExecutor
::
RunMultiE
xecutor
(
const
ProgramDesc
&
host_program
)
{
// thread binding here?
prepare_t
hreads
(
host_program
);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
_threads
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
t
rain
,
_workers
[
i
].
get
()));
PrepareT
hreads
(
host_program
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
T
rain
,
workers_
[
i
].
get
()));
}
for
(
auto
&
th
:
_threads
)
{
for
(
auto
&
th
:
threads_
)
{
th
.
join
();
}
}
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
929a9e80
...
...
@@ -36,137 +36,129 @@ class ExecutorThreadWorker {
public:
ExecutorThreadWorker
()
{}
virtual
~
ExecutorThreadWorker
()
{}
void
create_thread_scope
(
const
framework
::
ProgramDesc
&
program
);
void
set_datafeed
(
const
DataFeed
&
datafeed
);
void
set_thread_id
(
int
tid
);
void
create_thread_operators
(
const
framework
::
ProgramDesc
&
program
);
void
set_root_scope
(
Scope
*
g_scope
);
void
set_device
();
virtual
void
add_fid_set
();
void
set_comm_batch
(
int
comm_batch
)
{
_comm_batch
=
comm_batch
;
}
void
add_train_file
(
const
std
::
string
&
filename
);
void
set_main_program
(
const
ProgramDesc
&
main_program_desc
);
void
set_place
(
const
paddle
::
platform
::
Place
&
place
);
void
set_max_training_epoch
(
const
int
max_epoch
);
void
binding_datafeed_memory
();
void
set_model_prefix
(
const
std
::
string
&
prefix
)
{
_model_prefix
=
prefix
;
}
void
set_inspect_var_name
(
const
std
::
string
&
inspect_var_name
);
void
set_model_param_names
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
set_sparse_comm_data
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
set_datafeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
virtual
void
mpi_train
();
void
gpu_train
();
void
train
();
virtual
const
char
*
pick_one_file
();
void
update_epoch_num
();
virtual
void
set_dense_comm_tensor
(
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
void
SetDataFeed
(
const
DataFeed
&
datafeed
);
void
SetThreadId
(
int
tid
);
void
CreateThreadOperators
(
const
framework
::
ProgramDesc
&
program
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetDevice
();
virtual
void
AddFidSet
();
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
AddTrainFile
(
const
std
::
string
&
filename
);
void
SetMainProgram
(
const
ProgramDesc
&
main_program_desc
);
void
SetPlace
(
const
paddle
::
platform
::
Place
&
place
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
void
BindingDataFeedMemory
();
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
Train
();
virtual
const
char
*
PickOneFile
();
void
UpdateEpochNum
();
virtual
void
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{}
virtual
void
i
nitialize
()
{}
virtual
void
I
nitialize
()
{}
public:
static
std
::
mutex
_s_locker_for_pick_file
;
static
unsigned
int
_s_current_file_idx
;
static
size_t
_s_current_finished_file_cnt
;
static
unsigned
int
_s_current_epoch
;
static
int
_s_current_save_epoch
;
static
std
::
vector
<
std
::
string
>
_s_thread_filelist
;
// filelist
static
bool
_s_is_first_worker
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
vector
<
std
::
string
>
s_thread_filelist_
;
// filelist
static
bool
s_is_first_worker_
;
protected:
// thread index
int
_thread_id
;
// current training file
int
_cur_fileidx
;
int
thread_id_
;
// max epoch for each thread
unsigned
int
_max_epoch
;
unsigned
int
max_epoch_
;
// instances learned currently
int
_comm_batch
;
std
::
string
_model_prefix
;
std
::
vector
<
std
::
string
>
_op_names
;
int
comm_batch_
;
std
::
string
model_prefix_
;
std
::
vector
<
std
::
string
>
op_names_
;
// local ops for forward and backward
std
::
vector
<
OperatorBase
*>
_ops
;
std
::
vector
<
OperatorBase
*>
ops_
;
// main program for training
std
::
unique_ptr
<
framework
::
ProgramDesc
>
_main_program
;
std
::
unique_ptr
<
framework
::
ProgramDesc
>
main_program_
;
// binary data reader
std
::
shared_ptr
<
DataFeed
>
_local_reader
;
std
::
shared_ptr
<
DataFeed
>
local_reader_
;
std
::
string
_inspect_var_name
;
std
::
vector
<
std
::
string
>
_model_param_names
;
std
::
map
<
std
::
string
,
int
>
_sparse_comm_data
;
std
::
vector
<
int
>
_ids_buffer
;
std
::
string
inspect_var_name_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
// execution place
platform
::
Place
_place
;
platform
::
Place
place_
;
// root scope for model parameters
Scope
*
_root_scope
;
Scope
*
root_scope_
;
// a thread scope, father scope is global score which is shared
Scope
*
_thread_scope
;
Scope
*
thread_scope_
;
};
class
MultiExecutor
{
public:
explicit
MultiExecutor
(
const
platform
::
Place
&
place
);
virtual
~
MultiExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
load_desc_from_f
ile
(
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromF
ile
(
const
std
::
string
&
filename
);
void
init_root_s
cope
(
Scope
*
scope
);
void
set_inspect_var_n
ame
(
const
std
::
string
&
inspect_var_name
);
void
set_param_n
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
set_max_training_e
poch
(
const
int
max_epoch
);
Scope
*
get_root_scope
()
{
return
_root_scope
;
}
void
set_thread_n
um
(
const
int
thread_num
);
void
set_batch_size
(
const
int
batch_size
)
{
_batch_size
=
batch_size
;
}
void
set_filel
ist
(
const
char
*
filelist
);
void
set_filel
ist
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
set_datafeed_n
ame
(
const
char
*
feedname
);
void
set_data_feed_p
aram
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
_data_feed_param
=
feed_param
;
void
InitRootS
cope
(
Scope
*
scope
);
void
SetInspectVarN
ame
(
const
std
::
string
&
inspect_var_name
);
void
SetParamN
ames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetMaxTrainingE
poch
(
const
int
max_epoch
);
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
void
SetThreadN
um
(
const
int
thread_num
);
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetFileL
ist
(
const
char
*
filelist
);
void
SetFileL
ist
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
SetDataFeedN
ame
(
const
char
*
feedname
);
void
SetDataFeedP
aram
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
data_feed_param_
=
feed_param
;
}
void
set_comm_b
atch
(
int
comm_batch
)
{
_comm_batch
=
comm_batch
;
void
SetCommB
atch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
set_model_p
refix
(
const
std
::
string
&
model_prefix
);
void
set_dense_comm_t
ensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
void
set_sparse_comm_t
ensor
(
void
SetModelP
refix
(
const
std
::
string
&
model_prefix
);
void
SetDenseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
void
SetSparseCommT
ensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
);
void
set_sparse_comm_d
ata
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
prepare_t
hreads
(
const
framework
::
ProgramDesc
&
host_program
);
void
run_startup_p
rogram
(
const
framework
::
ProgramDesc
&
program
,
void
SetSparseCommD
ata
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
PrepareT
hreads
(
const
framework
::
ProgramDesc
&
host_program
);
void
RunStartupP
rogram
(
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
run_multi_e
xecutor
(
const
ProgramDesc
&
host_program
);
void
RunMultiE
xecutor
(
const
ProgramDesc
&
host_program
);
public:
unsigned
int
_thread_num
;
datafeed
::
DataFeedParameter
_data_feed_param
;
int
_max_epoch
;
int
_batch_size
;
int
_comm_batch
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
_workers
;
std
::
vector
<
std
::
thread
>
_threads
;
std
::
vector
<
std
::
string
>
_filelist
;
std
::
string
_inspect_var_name
;
std
::
vector
<
std
::
string
>
_model_param_names
;
std
::
vector
<
std
::
string
>
_dense_comm_tensor
;
std
::
vector
<
std
::
string
>
_sparse_comm_tensor
;
std
::
map
<
std
::
string
,
int
>
_sparse_comm_data
;
int
node_num
;
std
::
string
_model_prefix
;
ProgramDesc
_host_program
;
std
::
string
_feed_name
;
Scope
*
_root_scope
;
platform
::
Place
_place
;
unsigned
int
thread_num_
;
datafeed
::
DataFeedParameter
data_feed_param_
;
int
max_epoch_
;
int
batch_size_
;
int
comm_batch_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
string
>
filelist_
;
std
::
string
inspect_var_name_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
vector
<
std
::
string
>
dense_comm_tensor_
;
std
::
vector
<
std
::
string
>
sparse_comm_tensor_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
std
::
string
model_prefix_
;
std
::
string
feed_name_
;
Scope
*
root_scope_
;
platform
::
Place
place_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
929a9e80
...
...
@@ -38,111 +38,114 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace
paddle
{
namespace
framework
{
void
TextClassDataFeed
::
i
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
void
TextClassDataFeed
::
I
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
// hard coding for a specific datafeed
_feed_vec
.
resize
(
2
);
// _feed_vec[0].reset(new LoDTensor);
// _feed_vec[1].reset(new LoDTensor);
_all_slot_ids
=
{
0
,
1
};
_use_slot_ids
=
{
0
,
1
};
_use_slot_alias
=
{
"words"
,
"label"
};
_file_content_buffer_host
.
reset
(
new
char
[
200
*
1024
*
1024
],
feed_vec_
.
resize
(
2
);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
all_slot_ids_
=
{
0
,
1
};
use_slot_ids_
=
{
0
,
1
};
use_slot_alias_
=
{
"words"
,
"label"
};
file_content_buffer_host_
.
reset
(
new
char
[
200
*
1024
*
1024
],
[](
char
*
p
)
{
delete
[]
p
;});
_file_content_buffer
=
_file_content_buffer_host
.
get
();
_file_content_buffer_ptr
=
_file_content_buffer
;
_batch_id_host
.
reset
(
new
int
[
10240
*
1024
],
file_content_buffer_
=
file_content_buffer_host_
.
get
();
file_content_buffer_ptr_
=
file_content_buffer_
;
batch_id_host_
.
reset
(
new
int
[
10240
*
1024
],
[](
int
*
p
)
{
delete
[]
p
;});
// max word num in a batch
_label_host
.
reset
(
new
int
[
10240
],
batch_id_buffer_
=
batch_id_host_
.
get
();
label_host_
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
_batch_id_buffer
=
_batch_id_host
.
get
();
_label_ptr
=
_label_host
.
get
();
label_ptr_
=
label_host_
.
get
();
}
// todo: use elegant implemention for this function
bool
TextClassDataFeed
::
read_b
atch
()
{
bool
TextClassDataFeed
::
ReadB
atch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
llen
=
0
;
int
inst_idx
=
0
;
offset
.
resize
(
_batch_size
+
1
);
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
_batch_size
)
{
while
(
inst_idx
<
batch_size_
)
{
int
ptr_offset
=
0
;
if
(
_file_content_buffer_ptr
-
_file_content_buffer
>=
_file_size
)
{
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
break
;
}
memcpy
(
reinterpret_cast
<
char
*>
(
&
llen
),
_file_content_buffer_ptr
+
ptr_offset
,
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
memcpy
(
reinterpret_cast
<
char
*>
(
_batch_id_buffer
+
tlen
),
_file_content_buffer_ptr
+
ptr_offset
,
memcpy
(
reinterpret_cast
<
char
*>
(
batch_id_buffer_
+
tlen
),
file_content_buffer_ptr_
+
ptr_offset
,
llen
*
sizeof
(
int
));
tlen
+=
llen
;
offset
[
inst_idx
+
1
]
=
offset
[
inst_idx
]
+
llen
;
ptr_offset
+=
sizeof
(
int
)
*
llen
;
memcpy
(
reinterpret_cast
<
char
*>
(
_label_ptr
+
inst_idx
),
_file_content_buffer_ptr
+
ptr_offset
,
memcpy
(
reinterpret_cast
<
char
*>
(
label_ptr_
+
inst_idx
),
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
_file_content_buffer_ptr
+=
ptr_offset
;
file_content_buffer_ptr_
+=
ptr_offset
;
inst_idx
++
;
}
if
(
inst_idx
!=
_batch_size
)
{
if
(
inst_idx
!=
batch_size_
)
{
return
false
;
}
LoD
input_lod
{
offset
};
paddle
::
framework
::
Vector
<
size_t
>
label_offset
;
label_offset
.
resize
(
_batch_size
+
1
);
for
(
int
i
=
0
;
i
<=
_batch_size
;
++
i
)
{
label_offset
.
resize
(
batch_size_
+
1
);
for
(
int
i
=
0
;
i
<=
batch_size_
;
++
i
)
{
label_offset
[
i
]
=
i
;
}
LoD
label_lod
{
label_offset
};
int64_t
*
input_ptr
=
_feed_vec
[
0
]
->
mutable_data
<
int64_t
>
(
int64_t
*
input_ptr
=
feed_vec_
[
0
]
->
mutable_data
<
int64_t
>
(
{
static_cast
<
int64_t
>
(
offset
.
back
()),
1
},
platform
::
CPUPlace
());
int64_t
*
label_ptr
=
_feed_vec
[
1
]
->
mutable_data
<
int64_t
>
({
_batch_size
,
1
},
int64_t
*
label_ptr
=
feed_vec_
[
1
]
->
mutable_data
<
int64_t
>
({
batch_size_
,
1
},
platform
::
CPUPlace
());
for
(
unsigned
int
i
=
0
;
i
<
offset
.
back
();
++
i
)
{
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
_batch_id_buffer
[
i
]);
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
batch_id_buffer_
[
i
]);
}
for
(
int
i
=
0
;
i
<
_batch_size
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
_label_ptr
[
i
]);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
label_ptr_
[
i
]);
}
_feed_vec
[
0
]
->
set_lod
(
input_lod
);
_feed_vec
[
1
]
->
set_lod
(
label_lod
);
feed_vec_
[
0
]
->
set_lod
(
input_lod
);
feed_vec_
[
1
]
->
set_lod
(
label_lod
);
return
true
;
}
void
TextClassDataFeed
::
add_feed_v
ar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
_use_slot_alias
.
size
();
++
i
)
{
if
(
name
==
_use_slot_alias
[
i
])
{
_feed_vec
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
void
TextClassDataFeed
::
AddFeedV
ar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
}
}
}
bool
TextClassDataFeed
::
set_f
ile
(
const
char
*
filename
)
{
bool
TextClassDataFeed
::
SetF
ile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
int
filesize
=
read_whole_file
(
filename
,
_file_content_buffer
);
int
filesize
=
ReadWholeFile
(
filename
,
file_content_buffer_
);
// todo , remove magic number
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
}
_file_content_buffer_ptr
=
_file_content_buffer
;
_file_size
=
filesize
;
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
return
true
;
}
int
TextClassDataFeed
::
read_whole_f
ile
(
const
std
::
string
&
filename
,
int
TextClassDataFeed
::
ReadWholeF
ile
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
std
::
ifstream
ifs
(
filename
.
c_str
(),
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
929a9e80
...
...
@@ -35,47 +35,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
uint64_t
FeatureKey
;
struct
FeatureItem
{
FeatureItem
()
{}
FeatureItem
(
FeatureKey
sign_
,
uint16_t
slot_
)
{
sign
()
=
sign_
;
slot
()
=
slot_
;
}
FeatureKey
&
sign
()
{
return
*
(
reinterpret_cast
<
FeatureKey
*>
(
sign_buffer
()));
}
const
FeatureKey
&
sign
()
const
{
return
*
(
const
FeatureKey
*
)
sign_buffer
();
}
uint16_t
&
slot
()
{
return
_slot
;
}
const
uint16_t
&
slot
()
const
{
return
_slot
;
}
private:
char
_sign
[
sizeof
(
FeatureKey
)];
uint16_t
_slot
;
char
*
sign_buffer
()
const
{
return
(
char
*
)
_sign
;
}
};
// Record(average:14031B) is smaller than Sample(average:16530B)
struct
Record
{
int
show
,
click
;
std
::
vector
<
FeatureItem
>
feas
;
std
::
string
lineid
;
std
::
string
tags
;
};
struct
Gauc
{
int
show
,
click
;
uint64_t
fea
;
...
...
@@ -89,241 +48,83 @@ struct Instance {
std
::
vector
<
Gauc
>
gauc_vec
;
};
struct
Sample
{
uint64_t
label
;
std
::
map
<
uint16_t
,
std
::
vector
<
uint64_t
>>
feas
;
bool
from_string
(
const
std
::
string
&
input
,
const
std
::
set
<
uint32_t
>&
slots
)
{
size_t
end
=
input
.
find_first_of
(
' '
);
if
(
end
==
std
::
string
::
npos
)
{
LOG
(
ERROR
)
<<
"[ERROR] Fail in parsing:"
<<
input
;
return
false
;
}
label
=
input
[
end
+
3
]
-
'0'
;
CHECK
(
label
==
0
||
label
==
1
)
<<
"invalid label:"
<<
label
;
std
::
stringstream
ss
(
input
);
std
::
string
token
;
uint16_t
slot_id
=
0
;
uint64_t
feature_id
=
0
;
int
num_nonfeas_token
=
0
;
std
::
ostringstream
os
;
while
(
ss
>>
token
)
{
size_t
end
=
token
.
find_first_of
(
':'
);
if
(
end
==
std
::
string
::
npos
)
{
++
num_nonfeas_token
;
continue
;
}
try
{
slot_id
=
stoi
(
token
.
substr
(
end
+
1
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing slot id:"
<<
token
;
return
false
;
}
try
{
feature_id
=
stoull
(
token
.
substr
(
0
,
end
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing feature id:"
<<
token
;
return
false
;
}
if
(
slot_id
<=
0
)
{
LOG
(
ERROR
)
<<
"invalid slot:"
<<
slot_id
<<
" feasign:"
<<
feature_id
<<
" line:"
<<
input
;
return
false
;
}
if
(
slots
.
find
(
slot_id
)
==
slots
.
end
())
{
continue
;
}
feas
[
slot_id
].
push_back
(
feature_id
);
}
if
(
num_nonfeas_token
!=
4
)
{
LOG
(
ERROR
)
<<
"Format error. Invalid number of non-feasign token:"
<<
num_nonfeas_token
;
return
false
;
}
return
true
;
}
};
struct
TeacherStudentSample
{
uint64_t
label
;
std
::
map
<
uint16_t
,
std
::
vector
<
uint64_t
>>
feas
;
float
q_score
;
void
print
()
{
LOG
(
ERROR
)
<<
"label: "
<<
label
<<
" score: "
<<
q_score
;
for
(
auto
&
slot
:
feas
)
{
for
(
auto
&
fea
:
slot
.
second
)
{
LOG
(
ERROR
)
<<
"slot: "
<<
slot
.
first
<<
" fea: "
<<
fea
;
}
}
}
bool
from_string
(
const
std
::
string
&
input
,
const
std
::
set
<
uint32_t
>&
slots
,
Gauc
&
gauc
)
{
// NOLINT
size_t
end
=
input
.
find_first_of
(
' '
);
if
(
end
==
std
::
string
::
npos
)
{
LOG
(
ERROR
)
<<
"[ERROR] Fail in parsing:"
<<
input
;
return
false
;
}
label
=
input
[
end
+
3
]
-
'0'
;
CHECK
(
label
==
0
||
label
==
1
)
<<
"invalid label:"
<<
label
;
gauc
.
show
=
1
;
gauc
.
click
=
label
;
gauc
.
lineid
=
input
.
substr
(
0
,
end
);
gauc
.
fea
=
0
;
size_t
dnn_start
=
input
.
find
(
"*"
);
if
(
dnn_start
==
std
::
string
::
npos
)
{
q_score
=
-
1.0
;
}
else
{
dnn_start
+=
1
;
size_t
dnn_end
=
input
.
find
(
' '
,
dnn_start
);
q_score
=
static_cast
<
float
>
(
atof
(
input
.
substr
(
dnn_start
,
dnn_end
-
dnn_start
).
c_str
()));
}
size_t
head_pos
=
input
.
find
(
"
\t
"
);
std
::
string
head
=
input
.
substr
(
0
,
head_pos
);
std
::
stringstream
ss
(
head
);
std
::
string
token
;
uint16_t
slot_id
=
0
;
uint64_t
feature_id
=
0
;
int
num_nonfeas_token
=
0
;
std
::
ostringstream
os
;
while
(
ss
>>
token
)
{
size_t
end
=
token
.
find_first_of
(
':'
);
if
(
end
==
std
::
string
::
npos
)
{
++
num_nonfeas_token
;
continue
;
}
try
{
slot_id
=
stoi
(
token
.
substr
(
end
+
1
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing slot id:"
<<
token
;
return
false
;
}
try
{
feature_id
=
stoull
(
token
.
substr
(
0
,
end
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing feature id:"
<<
token
;
return
false
;
}
if
(
slot_id
<=
0
)
{
LOG
(
ERROR
)
<<
"invalid slot:"
<<
slot_id
<<
" feasign:"
<<
feature_id
<<
" line:"
<<
input
;
return
false
;
}
if
(
slots
.
find
(
slot_id
)
==
slots
.
end
())
{
continue
;
}
if
(
slot_id
==
6048
)
{
gauc
.
fea
=
feature_id
;
}
feas
[
slot_id
].
push_back
(
feature_id
);
}
if
(
num_nonfeas_token
!=
4
)
{
LOG
(
ERROR
)
<<
"Format error. Invalid number of non-feasign token:"
<<
num_nonfeas_token
;
return
false
;
}
return
true
;
}
};
class
DataFeed
{
public:
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
i
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
=
0
;
virtual
void
I
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
=
0
;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual
bool
check_f
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
set_f
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
read_b
atch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
get_all_slot_i
ds
()
{
return
_all_slot_ids
;
virtual
bool
CheckF
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
SetF
ile
(
const
char
*
filename
)
=
0
;
virtual
bool
ReadB
atch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
GetAllSlotI
ds
()
{
return
all_slot_ids_
;
}
virtual
const
std
::
vector
<
uint16_t
>&
get_use_slot_i
ds
()
{
return
_use_slot_ids
;
virtual
const
std
::
vector
<
uint16_t
>&
GetUseSlotI
ds
()
{
return
use_slot_ids_
;
}
virtual
const
std
::
vector
<
std
::
string
>&
get_use_slot_a
lias
()
{
return
_use_slot_alias
;
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlotA
lias
()
{
return
use_slot_alias_
;
}
virtual
void
add_feed_v
ar
(
Variable
*
var
,
virtual
void
AddFeedV
ar
(
Variable
*
var
,
const
std
::
string
&
name
)
=
0
;
virtual
void
bind_s
cope
(
Scope
*
scope
)
=
0
;
virtual
void
set_batch_size
(
int
batch
)
{
_default_batch_size
=
batch
;
}
virtual
int
get_batch_size
()
{
return
_batch_size
;
}
virtual
void
set_buffer_s
ize
(
int
buffer_size
)
{}
virtual
void
BindS
cope
(
Scope
*
scope
)
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
void
SetBufferS
ize
(
int
buffer_size
)
{}
std
::
vector
<
LoDTensor
*>&
get_feed_v
ec
()
{
return
_feed_vec
;
std
::
vector
<
LoDTensor
*>&
GetFeedV
ec
()
{
return
feed_vec_
;
}
virtual
std
::
vector
<
LoDTensor
*>&
get_feed_v
ec
(
const
Instance
&
ins
)
{
virtual
std
::
vector
<
LoDTensor
*>&
GetFeedV
ec
(
const
Instance
&
ins
)
{
LOG
(
ERROR
)
<<
"use defalut get_feed_vec"
;
return
_feed_vec
;
return
feed_vec_
;
}
protected:
std
::
vector
<
uint16_t
>
_all_slot_ids
;
std
::
vector
<
uint16_t
>
_use_slot_ids
;
std
::
vector
<
std
::
string
>
_use_slot_alias
;
std
::
vector
<
LoDTensor
*>
_feed_vec
;
int
_default_batch_size
;
int
_batch_size
;
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
std
::
vector
<
std
::
string
>
use_slot_alias_
;
std
::
vector
<
LoDTensor
*>
feed_vec_
;
int
default_batch_size_
;
int
batch_size_
;
};
class
TextClassDataFeed
:
public
DataFeed
{
public:
virtual
~
TextClassDataFeed
()
{}
virtual
void
i
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
);
virtual
bool
read_b
atch
();
virtual
void
add_feed_v
ar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
bind_s
cope
(
Scope
*
scope
)
{}
virtual
bool
set_f
ile
(
const
char
*
filename
);
virtual
void
I
nit
(
const
datafeed
::
DataFeedParameter
&
feed_param
);
virtual
bool
ReadB
atch
();
virtual
void
AddFeedV
ar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
BindS
cope
(
Scope
*
scope
)
{}
virtual
bool
SetF
ile
(
const
char
*
filename
);
virtual
bool
check_f
ile
(
const
char
*
filename
)
{
virtual
bool
CheckF
ile
(
const
char
*
filename
)
{
// TODO(xxx)
return
false
;
}
void
set_batch_size
(
int
batch
)
{
_batch_size
=
batch
;}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
private:
int
read_whole_f
ile
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
_file_content_buffer
;
char
*
_file_content_buffer_ptr
;
int
*
_batch_id_buffer
;
int
*
_label_ptr
;
int
_file_size
;
std
::
vector
<
std
::
string
>
_names
;
std
::
shared_ptr
<
char
>
_file_content_buffer_host
;
std
::
shared_ptr
<
int
>
_batch_id_host
;
std
::
shared_ptr
<
int
>
_label_host
;
int
ReadWholeF
ile
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/datafeed_creator.cc
浏览文件 @
929a9e80
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/datafeed_creator.h"
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
create_dataf
eed
(
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
CreateDataF
eed
(
const
char
*
datafeed_class
)
{
if
(
strcmp
(
datafeed_class
,
"TextClass"
)
==
0
)
{
return
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
(
...
...
paddle/fluid/framework/datafeed_creator.h
浏览文件 @
929a9e80
...
...
@@ -17,6 +17,6 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
create_dataf
eed
(
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
CreateDataF
eed
(
const
char
*
datafeed_class
);
#endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录