Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c555948c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c555948c
编写于
10月 24, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
AsyncExecutor: C++ side
上级
2256fae4
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
1342 addition
and
1 deletion
+1342
-1
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+5
-0
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+570
-0
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+175
-0
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+162
-0
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+333
-0
paddle/fluid/framework/datafeed_creator.cc
paddle/fluid/framework/datafeed_creator.cc
+26
-0
paddle/fluid/framework/datafeed_creator.h
paddle/fluid/framework/datafeed_creator.h
+22
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
proto/FeedDataParameter.proto
proto/FeedDataParameter.proto
+48
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
c555948c
...
@@ -174,6 +174,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
...
@@ -174,6 +174,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor
)
fast_threaded_ssa_graph_executor
)
endif
()
# NOT WIN32
endif
()
# NOT WIN32
cc_library
(
async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc
DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_test
(
var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
...
...
paddle/fluid/framework/async_executor.cc
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/async_executor.h"
#include <stdio.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <map>
#include <algorithm>
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
namespace
paddle
{
namespace
framework
{
std
::
mutex
ExecutorThreadWorker
::
_s_locker_for_pick_file
;
unsigned
int
ExecutorThreadWorker
::
_s_current_file_idx
=
0
;
size_t
ExecutorThreadWorker
::
_s_current_finished_file_cnt
=
0
;
unsigned
int
ExecutorThreadWorker
::
_s_current_epoch
=
0
;
int
ExecutorThreadWorker
::
_s_current_save_epoch
=
0
;
bool
ExecutorThreadWorker
::
_s_is_first_worker
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
_s_thread_filelist
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
LoDTensor
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
SelectedRows
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
FEED_MINIBATCH
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
FETCH_LIST
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
STEP_SCOPES
)
{
var
->
GetMutable
<
std
::
vector
<
Scope
>>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_RANK_TABLE
)
{
var
->
GetMutable
<
LoDRankTable
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
var
->
GetMutable
<
LoDTensorArray
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
PLACE_LIST
)
{
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
}
else
{
PADDLE_THROW
(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]"
,
var_type
);
}
}
static
void
read_binary_file
(
const
std
::
string
&
filename
,
std
::
string
*
content
)
{
std
::
string
&
contents
=
*
content
;
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
fin
.
good
())
{
LOG
(
ERROR
)
<<
"Cannot open file "
<<
filename
.
c_str
();
}
fin
.
seekg
(
0
,
std
::
ios
::
end
);
contents
.
clear
();
contents
.
resize
(
fin
.
tellg
());
fin
.
seekg
(
0
,
std
::
ios
::
beg
);
fin
.
read
(
&
contents
[
0
],
contents
.
size
());
fin
.
close
();
}
static
void
save_model
(
const
std
::
unique_ptr
<
ProgramDesc
>
&
main_program
,
Scope
*
scope
,
const
std
::
vector
<
std
::
string
>
&
param_names
,
const
std
::
string
&
model_name
,
bool
save_combine
)
{
auto
place
=
platform
::
CPUPlace
();
const
BlockDesc
&
global_block
=
main_program
->
Block
(
0
);
std
::
vector
<
std
::
string
>
paralist
;
for
(
auto
*
var
:
global_block
.
AllVars
())
{
bool
is_model_param
=
false
;
for
(
auto
param_name
:
param_names
)
{
if
(
var
->
Name
()
==
param_name
)
{
is_model_param
=
true
;
break
;
}
}
if
(
!
is_model_param
)
continue
;
if
(
!
save_combine
)
{
LOG
(
ERROR
)
<<
"model var name: "
<<
var
->
Name
().
c_str
();
paddle
::
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"file_path"
,
model_name
+
"/"
+
var
->
Name
()});
auto
save_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"save"
,
{{
"X"
,
{
var
->
Name
()}}},
{},
attrs
);
save_op
->
Run
(
*
scope
,
place
);
}
else
{
paralist
.
push_back
(
var
->
Name
());
}
}
if
(
save_combine
)
{
std
::
sort
(
paralist
.
begin
(),
paralist
.
end
());
paddle
::
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"file_path"
,
model_name
});
auto
save_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"save_combine"
,
{{
"X"
,
paralist
}},
{},
attrs
);
save_op
->
Run
(
*
scope
,
place
);
}
}
// end save_model
void
ExecutorThreadWorker
::
add_train_file
(
const
std
::
string
&
file
)
{
_s_thread_filelist
.
push_back
(
file
);
}
void
ExecutorThreadWorker
::
create_thread_operators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
_op_names
.
clear
();
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
_op_names
.
push_back
(
op_desc
->
Type
());
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
_ops
.
push_back
(
local_op_ptr
);
continue
;
}
}
void
ExecutorThreadWorker
::
create_thread_scope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
_thread_scope
=
&
_root_scope
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
_root_scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("create Persistable var[%s] finished",
// var->Name().c_str());
}
else
{
auto
*
ptr
=
_thread_scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("create unpersistable var[%s] finished",
// var->Name().c_str());
}
}
}
void
ExecutorThreadWorker
::
set_datafeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
_local_reader
=
datafeed
;
}
void
ExecutorThreadWorker
::
binding_datafeed_memory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
_local_reader
->
get_use_slot_alias
();
for
(
auto
name
:
input_feed
)
{
_local_reader
->
add_feed_var
(
_thread_scope
->
Var
(
name
),
name
);
}
}
void
ExecutorThreadWorker
::
set_inspect_var_name
(
const
std
::
string
&
inspect_var_name
)
{
_inspect_var_name
=
inspect_var_name
;
}
void
ExecutorThreadWorker
::
set_model_param_names
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
_model_param_names
=
param_names
;
}
void
ExecutorThreadWorker
::
set_sparse_comm_data
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
)
{
_sparse_comm_data
=
param_names
;
}
void
ExecutorThreadWorker
::
set_device
()
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
};
unsigned
int
i
=
this
->
_thread_id
;
if
(
i
<
sizeof
(
priority
)
/
sizeof
(
unsigned
))
{
unsigned
proc
=
priority
[
i
];
cpu_set_t
mask
;
CPU_ZERO
(
&
mask
);
CPU_SET
(
proc
,
&
mask
);
if
(
-
1
==
sched_setaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
{
LOG
(
ERROR
)
<<
"WARNING: Failed to set thread affinity for thread "
<<
i
;
}
else
{
CPU_ZERO
(
&
mask
);
if
((
0
==
sched_getaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
&&
CPU_ISSET
(
proc
,
&
mask
))
{
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
}
}
}
}
void
ExecutorThreadWorker
::
update_epoch_num
()
{
_s_current_finished_file_cnt
++
;
if
(
_s_current_finished_file_cnt
>=
_s_thread_filelist
.
size
())
{
_s_current_finished_file_cnt
=
0
;
_s_current_epoch
++
;
}
}
const
char
*
ExecutorThreadWorker
::
pick_one_file
()
{
std
::
string
file_to_be_preocessed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
_s_locker_for_pick_file
);
if
(
_s_current_file_idx
>=
_s_thread_filelist
.
size
())
{
std
::
random_shuffle
(
_s_thread_filelist
.
begin
(),
_s_thread_filelist
.
end
());
_s_current_file_idx
=
0
;
// _s_current_epoch++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
_thread_id
<<
": finish traing for epoch "
<<
_s_current_epoch
+
1
;
}
file_to_be_preocessed
=
_s_thread_filelist
[
_s_current_file_idx
];
_s_current_file_idx
++
;
return
file_to_be_preocessed
.
c_str
();
}
void
ExecutorThreadWorker
::
train
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
set_device
();
#ifdef LOCAL_PROF
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
// int total_batch = 0;
for
(
auto
&
op
:
_ops
)
{
op_name
.
push_back
(
op
->
Type
());
}
op_total_time
.
resize
(
_ops
.
size
());
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
op_total_time
[
i
]
=
0.0
;
}
#endif
std
::
string
inspect_key
=
"inspect"
;
if
(
!
_inspect_var_name
.
empty
())
{
inspect_key
=
_inspect_var_name
.
substr
(
0
,
_inspect_var_name
.
find_first_of
(
'_'
));
}
for
(
unsigned
i
=
0
;
i
<
_max_epoch
;
++
i
)
{
LOG
(
ERROR
)
<<
"epoch: "
<<
i
;
#ifdef LOCAL_PROF
Timer
timeline
;
double
total_time
=
0.0
;
double
read_time
=
0.0
;
#endif
float
total_inspect
=
0
;
int
batch_num
=
1
;
while
(
i
==
_s_current_epoch
)
{
const
char
*
filename
=
pick_one_file
();
_local_reader
->
set_file
(
filename
);
while
(
true
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
bool
flag
=
_local_reader
->
read_batch
();
if
(
!
flag
)
{
break
;
}
#ifdef LOCAL_PROF
timeline
.
pause
();
read_time
+=
timeline
.
elapsed_sec
();
total_time
+=
timeline
.
elapsed_sec
();
#endif
if
(
!
flag
)
{
break
;
}
for
(
unsigned
int
i
=
0
;
i
<
_ops
.
size
();
++
i
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
_ops
[
i
]
->
Run
(
*
_thread_scope
,
_place
);
#ifdef LOCAL_PROF
timeline
.
pause
();
op_total_time
[
i
]
+=
timeline
.
elapsed_sec
();
total_time
+=
timeline
.
elapsed_sec
();
#endif
}
batch_num
++
;
float
avg_inspect
=
0.0
;
if
(
!
_inspect_var_name
.
empty
())
{
avg_inspect
=
_thread_scope
->
FindVar
(
_inspect_var_name
)
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
}
total_inspect
+=
avg_inspect
;
_thread_scope
->
DropKids
();
}
update_epoch_num
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
i
+
1
<<
" called: "
<<
memory
::
memory_usage
(
_place
);
#ifdef LOCAL_PROF
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
std
::
cerr
<<
"op_name:["
<<
i
<<
"]["
<<
op_name
[
i
]
<<
"]"
<<
" op_mean_time:["
<<
op_total_time
[
i
]
<<
"s]"
<<
std
::
endl
;
}
std
::
cerr
<<
"read time: "
<<
read_time
<<
"s"
<<
std
::
endl
;
#endif
}
#ifdef LOCAL_PROF
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
<<
", total_time: "
<<
total_time
;
#else
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
;
#endif
if
(
_thread_id
==
0
)
{
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
_model_prefix
.
c_str
(),
i
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG
(
ERROR
)
<<
"Going to save model "
<<
modelfile
;
save_model
(
_main_program
,
_thread_scope
,
_model_param_names
,
model_filename
,
true
);
}
}
}
void
ExecutorThreadWorker
::
set_thread_id
(
int
tid
)
{
_thread_id
=
tid
;
}
void
ExecutorThreadWorker
::
set_place
(
const
platform
::
Place
&
place
)
{
_place
=
place
;
}
void
ExecutorThreadWorker
::
set_main_program
(
const
ProgramDesc
&
main_program_desc
)
{
_main_program
.
reset
(
new
ProgramDesc
(
main_program_desc
));
}
void
ExecutorThreadWorker
::
set_root_scope
(
Scope
*
g_scope
)
{
_root_scope
=
g_scope
;
}
void
ExecutorThreadWorker
::
set_max_training_epoch
(
int
max_epoch
)
{
_max_epoch
=
max_epoch
;
}
MultiExecutor
::
MultiExecutor
(
const
platform
::
Place
&
place
)
:
_place
(
place
)
{}
void
MultiExecutor
::
init_root_scope
(
Scope
*
scope
)
{
_root_scope
=
scope
;
}
void
MultiExecutor
::
set_max_training_epoch
(
int
max_epoch
)
{
_max_epoch
=
max_epoch
;
}
void
MultiExecutor
::
set_datafeed_name
(
const
char
*
feedname
)
{
_feed_name
=
std
::
string
(
feedname
);
}
void
MultiExecutor
::
set_model_prefix
(
const
std
::
string
&
model_prefix
)
{
_model_prefix
=
model_prefix
;
}
void
MultiExecutor
::
run_startup_program
(
const
ProgramDesc
&
program
,
Scope
*
scope
)
{
auto
&
block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("Persistable Var Name:%s", var->Name().c_str());
}
}
std
::
map
<
std
::
string
,
int
>
param_dict
;
std
::
vector
<
OperatorBase
*>
ops
;
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
vector
<
std
::
string
>
param_name_vec
=
op_desc
->
OutputArgumentNames
();
bool
need_to_run
=
false
;
for
(
auto
&
name
:
param_name_vec
)
{
if
(
param_dict
.
find
(
name
)
==
param_dict
.
end
())
{
param_dict
[
name
]
=
1
;
need_to_run
=
true
;
}
}
if
(
need_to_run
)
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
ops
.
push_back
(
local_op_ptr
);
}
}
// LOGERR("There are %d parameters in startup program, %d op needs to run",
// param_dict.size(), ops.size());
for
(
auto
&
op
:
ops
)
{
op
->
Run
(
*
scope
,
_place
);
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for
(
auto
&
op
:
ops
)
{
delete
op
;
}
// LOGERR("run startup program done.");
}
std
::
unique_ptr
<
ProgramDesc
>
MultiExecutor
::
load_desc_from_file
(
const
std
::
string
&
f
)
{
std
::
string
program_desc_str
;
read_binary_file
(
f
,
&
program_desc_str
);
std
::
unique_ptr
<
ProgramDesc
>
program
(
new
ProgramDesc
(
program_desc_str
));
return
program
;
}
void
MultiExecutor
::
set_dense_comm_tensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
)
{
_dense_comm_tensor
.
resize
(
dense_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
dense_comm_tensor
.
size
();
++
i
)
{
_dense_comm_tensor
[
i
]
=
dense_comm_tensor
[
i
];
}
}
void
MultiExecutor
::
set_sparse_comm_tensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
)
{
_sparse_comm_tensor
.
resize
(
sparse_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
sparse_comm_tensor
.
size
();
++
i
)
{
_sparse_comm_tensor
[
i
]
=
sparse_comm_tensor
[
i
];
}
}
void
MultiExecutor
::
set_sparse_comm_data
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
)
{
_sparse_comm_data
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
_sparse_comm_data
.
size
();
}
void
MultiExecutor
::
set_filelist
(
const
char
*
filelist
)
{
_filelist
.
clear
();
std
::
ifstream
fin
(
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
_filelist
.
push_back
(
filename
);
}
fin
.
close
();
}
void
MultiExecutor
::
set_filelist
(
std
::
vector
<
std
::
string
>
tfiles
)
{
_filelist
.
clear
();
_filelist
.
insert
(
_filelist
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
return
;
}
void
MultiExecutor
::
set_inspect_var_name
(
const
std
::
string
&
inspect_var_name
)
{
_inspect_var_name
=
inspect_var_name
;
}
void
MultiExecutor
::
set_param_names
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
_model_param_names
=
param_names
;
}
void
MultiExecutor
::
set_thread_num
(
const
int
thread_num
)
{
_thread_num
=
thread_num
;
}
void
MultiExecutor
::
prepare_threads
(
const
ProgramDesc
&
host_program
)
{
_workers
.
resize
(
_thread_num
);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
_workers
[
i
].
reset
(
new
ExecutorThreadWorker
);
_workers
[
i
]
->
set_thread_id
(
i
);
_workers
[
i
]
->
create_thread_operators
(
host_program
);
_workers
[
i
]
->
set_root_scope
(
_root_scope
);
_workers
[
i
]
->
set_place
(
_place
);
_workers
[
i
]
->
set_max_training_epoch
(
_max_epoch
);
_workers
[
i
]
->
create_thread_scope
(
host_program
);
_workers
[
i
]
->
set_inspect_var_name
(
_inspect_var_name
);
_workers
[
i
]
->
set_model_param_names
(
_model_param_names
);
_workers
[
i
]
->
set_sparse_comm_data
(
_sparse_comm_data
);
_workers
[
i
]
->
set_main_program
(
host_program
);
_workers
[
i
]
->
set_model_prefix
(
_model_prefix
);
}
for
(
unsigned
i
=
0
;
i
<
_filelist
.
size
();
++
i
)
{
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
_workers
[
0
]
->
add_train_file
(
_filelist
[
i
]);
}
// mpi_wrapper::ModelParam model_param(true);
// _workers[0]->register_parallel_training_param(model_param);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
// new a datafeed here
std
::
shared_ptr
<
DataFeed
>
local_feed
=
create_datafeed
(
_feed_name
.
c_str
());
local_feed
->
init
(
_data_feed_param
);
local_feed
->
set_batch_size
(
_batch_size
);
_workers
[
i
]
->
set_datafeed
(
local_feed
);
_workers
[
i
]
->
binding_datafeed_memory
();
_workers
[
i
]
->
set_thread_id
(
i
);
}
}
void
MultiExecutor
::
run_multi_executor
(
const
ProgramDesc
&
host_program
)
{
// thread binding here?
prepare_threads
(
host_program
);
for
(
unsigned
i
=
0
;
i
<
_thread_num
;
++
i
)
{
_threads
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
train
,
_workers
[
i
].
get
()));
}
for
(
auto
&
th
:
_threads
)
{
th
.
join
();
}
}
}
// end namespace framework
}
// end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/framework/async_executor.h
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
#define PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <map>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
class
ExecutorThreadWorker
{
public:
ExecutorThreadWorker
()
{}
virtual
~
ExecutorThreadWorker
()
{}
void
create_thread_scope
(
const
framework
::
ProgramDesc
&
program
);
void
set_datafeed
(
const
DataFeed
&
datafeed
);
void
set_thread_id
(
int
tid
);
void
create_thread_operators
(
const
framework
::
ProgramDesc
&
program
);
void
set_root_scope
(
Scope
*
g_scope
);
void
set_device
();
virtual
void
add_fid_set
();
void
set_comm_batch
(
int
comm_batch
)
{
_comm_batch
=
comm_batch
;
}
void
add_train_file
(
const
std
::
string
&
filename
);
void
set_main_program
(
const
ProgramDesc
&
main_program_desc
);
void
set_place
(
const
paddle
::
platform
::
Place
&
place
);
void
set_max_training_epoch
(
const
int
max_epoch
);
void
binding_datafeed_memory
();
void
set_model_prefix
(
const
std
::
string
&
prefix
)
{
_model_prefix
=
prefix
;
}
void
set_inspect_var_name
(
const
std
::
string
&
inspect_var_name
);
void
set_model_param_names
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
set_sparse_comm_data
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
set_datafeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
virtual
void
mpi_train
();
void
gpu_train
();
void
train
();
virtual
const
char
*
pick_one_file
();
void
update_epoch_num
();
virtual
void
set_dense_comm_tensor
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{}
virtual
void
initialize
()
{}
public:
static
std
::
mutex
_s_locker_for_pick_file
;
static
unsigned
int
_s_current_file_idx
;
static
size_t
_s_current_finished_file_cnt
;
static
unsigned
int
_s_current_epoch
;
static
int
_s_current_save_epoch
;
static
std
::
vector
<
std
::
string
>
_s_thread_filelist
;
// filelist
static
bool
_s_is_first_worker
;
protected:
// thread index
int
_thread_id
;
// current training file
int
_cur_fileidx
;
// max epoch for each thread
unsigned
int
_max_epoch
;
// instances learned currently
int
_comm_batch
;
std
::
string
_model_prefix
;
std
::
vector
<
std
::
string
>
_op_names
;
// local ops for forward and backward
std
::
vector
<
OperatorBase
*>
_ops
;
// main program for training
std
::
unique_ptr
<
framework
::
ProgramDesc
>
_main_program
;
// binary data reader
std
::
shared_ptr
<
DataFeed
>
_local_reader
;
std
::
string
_inspect_var_name
;
std
::
vector
<
std
::
string
>
_model_param_names
;
std
::
map
<
std
::
string
,
int
>
_sparse_comm_data
;
std
::
vector
<
int
>
_ids_buffer
;
// execution place
platform
::
Place
_place
;
// root scope for model parameters
Scope
*
_root_scope
;
// a thread scope, father scope is global score which is shared
Scope
*
_thread_scope
;
};
class
MultiExecutor
{
public:
explicit
MultiExecutor
(
const
platform
::
Place
&
place
);
virtual
~
MultiExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
load_desc_from_file
(
const
std
::
string
&
filename
);
void
init_root_scope
(
Scope
*
scope
);
void
set_inspect_var_name
(
const
std
::
string
&
inspect_var_name
);
void
set_param_names
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
set_max_training_epoch
(
const
int
max_epoch
);
Scope
*
get_root_scope
()
{
return
_root_scope
;
}
void
set_thread_num
(
const
int
thread_num
);
void
set_batch_size
(
const
int
batch_size
)
{
_batch_size
=
batch_size
;
}
void
set_filelist
(
const
char
*
filelist
);
void
set_filelist
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
set_datafeed_name
(
const
char
*
feedname
);
void
set_data_feed_param
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
_data_feed_param
=
feed_param
;
}
void
set_comm_batch
(
int
comm_batch
)
{
_comm_batch
=
comm_batch
;
}
void
set_model_prefix
(
const
std
::
string
&
model_prefix
);
void
set_dense_comm_tensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
void
set_sparse_comm_tensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
);
void
set_sparse_comm_data
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
prepare_threads
(
const
framework
::
ProgramDesc
&
host_program
);
void
run_startup_program
(
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
run_multi_executor
(
const
ProgramDesc
&
host_program
);
public:
unsigned
int
_thread_num
;
datafeed
::
DataFeedParameter
_data_feed_param
;
int
_max_epoch
;
int
_batch_size
;
int
_comm_batch
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
_workers
;
std
::
vector
<
std
::
thread
>
_threads
;
std
::
vector
<
std
::
string
>
_filelist
;
std
::
string
_inspect_var_name
;
std
::
vector
<
std
::
string
>
_model_param_names
;
std
::
vector
<
std
::
string
>
_dense_comm_tensor
;
std
::
vector
<
std
::
string
>
_sparse_comm_tensor
;
std
::
map
<
std
::
string
,
int
>
_sparse_comm_data
;
int
node_num
;
std
::
string
_model_prefix
;
ProgramDesc
_host_program
;
std
::
string
_feed_name
;
Scope
*
_root_scope
;
platform
::
Place
_place
;
};
}
// namespace framework
}
// namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/framework/data_feed.cc
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <algorithm>
#include <utility>
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h"
DEFINE_bool
(
is_text_feed
,
false
,
"is_text_feed"
);
namespace
paddle
{
namespace
framework
{
void
TextClassDataFeed
::
init
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
{
// hard coding for a specific datafeed
_feed_vec
.
resize
(
2
);
// _feed_vec[0].reset(new LoDTensor);
// _feed_vec[1].reset(new LoDTensor);
_all_slot_ids
=
{
0
,
1
};
_use_slot_ids
=
{
0
,
1
};
_use_slot_alias
=
{
"words"
,
"label"
};
_file_content_buffer_host
.
reset
(
new
char
[
200
*
1024
*
1024
],
[](
char
*
p
)
{
delete
[]
p
;});
_file_content_buffer
=
_file_content_buffer_host
.
get
();
_file_content_buffer_ptr
=
_file_content_buffer
;
_batch_id_host
.
reset
(
new
int
[
10240
*
1024
],
[](
int
*
p
)
{
delete
[]
p
;});
// max word num in a batch
_label_host
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
_batch_id_buffer
=
_batch_id_host
.
get
();
_label_ptr
=
_label_host
.
get
();
}
// todo: use elegant implemention for this function
bool
TextClassDataFeed
::
read_batch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
llen
=
0
;
int
inst_idx
=
0
;
offset
.
resize
(
_batch_size
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
_batch_size
)
{
int
ptr_offset
=
0
;
if
(
_file_content_buffer_ptr
-
_file_content_buffer
>=
_file_size
)
{
break
;
}
memcpy
(
reinterpret_cast
<
char
*>
(
&
llen
),
_file_content_buffer_ptr
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
memcpy
(
reinterpret_cast
<
char
*>
(
_batch_id_buffer
+
tlen
),
_file_content_buffer_ptr
+
ptr_offset
,
llen
*
sizeof
(
int
));
tlen
+=
llen
;
offset
[
inst_idx
+
1
]
=
offset
[
inst_idx
]
+
llen
;
ptr_offset
+=
sizeof
(
int
)
*
llen
;
memcpy
(
reinterpret_cast
<
char
*>
(
_label_ptr
+
inst_idx
),
_file_content_buffer_ptr
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
_file_content_buffer_ptr
+=
ptr_offset
;
inst_idx
++
;
}
if
(
inst_idx
!=
_batch_size
)
{
return
false
;
}
LoD
input_lod
{
offset
};
paddle
::
framework
::
Vector
<
size_t
>
label_offset
;
label_offset
.
resize
(
_batch_size
+
1
);
for
(
int
i
=
0
;
i
<=
_batch_size
;
++
i
)
{
label_offset
[
i
]
=
i
;
}
LoD
label_lod
{
label_offset
};
int64_t
*
input_ptr
=
_feed_vec
[
0
]
->
mutable_data
<
int64_t
>
(
{
static_cast
<
int64_t
>
(
offset
.
back
()),
1
},
platform
::
CPUPlace
());
int64_t
*
label_ptr
=
_feed_vec
[
1
]
->
mutable_data
<
int64_t
>
({
_batch_size
,
1
},
platform
::
CPUPlace
());
for
(
unsigned
int
i
=
0
;
i
<
offset
.
back
();
++
i
)
{
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
_batch_id_buffer
[
i
]);
}
for
(
int
i
=
0
;
i
<
_batch_size
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
_label_ptr
[
i
]);
}
_feed_vec
[
0
]
->
set_lod
(
input_lod
);
_feed_vec
[
1
]
->
set_lod
(
label_lod
);
return
true
;
}
void
TextClassDataFeed
::
add_feed_var
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
_use_slot_alias
.
size
();
++
i
)
{
if
(
name
==
_use_slot_alias
[
i
])
{
_feed_vec
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
}
}
}
bool
TextClassDataFeed
::
set_file
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
int
filesize
=
read_whole_file
(
filename
,
_file_content_buffer
);
// todo , remove magic number
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
}
_file_content_buffer_ptr
=
_file_content_buffer
;
_file_size
=
filesize
;
return
true
;
}
int
TextClassDataFeed
::
read_whole_file
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
std
::
ifstream
ifs
(
filename
.
c_str
(),
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
return
-
1
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
int
file_size
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
buffer
,
file_size
);
return
file_size
;
}
}
// namespace framework
}
// namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/framework/data_feed.h
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#include <memory>
#include <set>
#include <map>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <queue>
#include <mutex> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <condition_variable> // NOLINT
#include <fstream>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "proto/FeedDataParameter.pb.h"
namespace
paddle
{
namespace
framework
{
typedef
uint64_t
FeatureKey
;
struct
FeatureItem
{
FeatureItem
()
{}
FeatureItem
(
FeatureKey
sign_
,
uint16_t
slot_
)
{
sign
()
=
sign_
;
slot
()
=
slot_
;
}
FeatureKey
&
sign
()
{
return
*
(
reinterpret_cast
<
FeatureKey
*>
(
sign_buffer
()));
}
const
FeatureKey
&
sign
()
const
{
return
*
(
const
FeatureKey
*
)
sign_buffer
();
}
uint16_t
&
slot
()
{
return
_slot
;
}
const
uint16_t
&
slot
()
const
{
return
_slot
;
}
private:
char
_sign
[
sizeof
(
FeatureKey
)];
uint16_t
_slot
;
char
*
sign_buffer
()
const
{
return
(
char
*
)
_sign
;
}
};
// Record(average:14031B) is smaller than Sample(average:16530B)
struct
Record
{
int
show
,
click
;
std
::
vector
<
FeatureItem
>
feas
;
std
::
string
lineid
;
std
::
string
tags
;
};
struct
Gauc
{
int
show
,
click
;
uint64_t
fea
;
std
::
string
lineid
;
};
struct
Instance
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
feed_vec_buffer
;
std
::
vector
<
std
::
vector
<
int
>>
feed_vec_lod
;
std
::
vector
<
float
>
other_label
;
std
::
vector
<
Gauc
>
gauc_vec
;
};
struct
Sample
{
uint64_t
label
;
std
::
map
<
uint16_t
,
std
::
vector
<
uint64_t
>>
feas
;
bool
from_string
(
const
std
::
string
&
input
,
const
std
::
set
<
uint32_t
>&
slots
)
{
size_t
end
=
input
.
find_first_of
(
' '
);
if
(
end
==
std
::
string
::
npos
)
{
LOG
(
ERROR
)
<<
"[ERROR] Fail in parsing:"
<<
input
;
return
false
;
}
label
=
input
[
end
+
3
]
-
'0'
;
CHECK
(
label
==
0
||
label
==
1
)
<<
"invalid label:"
<<
label
;
std
::
stringstream
ss
(
input
);
std
::
string
token
;
uint16_t
slot_id
=
0
;
uint64_t
feature_id
=
0
;
int
num_nonfeas_token
=
0
;
std
::
ostringstream
os
;
while
(
ss
>>
token
)
{
size_t
end
=
token
.
find_first_of
(
':'
);
if
(
end
==
std
::
string
::
npos
)
{
++
num_nonfeas_token
;
continue
;
}
try
{
slot_id
=
stoi
(
token
.
substr
(
end
+
1
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing slot id:"
<<
token
;
return
false
;
}
try
{
feature_id
=
stoull
(
token
.
substr
(
0
,
end
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing feature id:"
<<
token
;
return
false
;
}
if
(
slot_id
<=
0
)
{
LOG
(
ERROR
)
<<
"invalid slot:"
<<
slot_id
<<
" feasign:"
<<
feature_id
<<
" line:"
<<
input
;
return
false
;
}
if
(
slots
.
find
(
slot_id
)
==
slots
.
end
())
{
continue
;
}
feas
[
slot_id
].
push_back
(
feature_id
);
}
if
(
num_nonfeas_token
!=
4
)
{
LOG
(
ERROR
)
<<
"Format error. Invalid number of non-feasign token:"
<<
num_nonfeas_token
;
return
false
;
}
return
true
;
}
};
struct
TeacherStudentSample
{
uint64_t
label
;
std
::
map
<
uint16_t
,
std
::
vector
<
uint64_t
>>
feas
;
float
q_score
;
void
print
()
{
LOG
(
ERROR
)
<<
"label: "
<<
label
<<
" score: "
<<
q_score
;
for
(
auto
&
slot
:
feas
)
{
for
(
auto
&
fea
:
slot
.
second
)
{
LOG
(
ERROR
)
<<
"slot: "
<<
slot
.
first
<<
" fea: "
<<
fea
;
}
}
}
bool
from_string
(
const
std
::
string
&
input
,
const
std
::
set
<
uint32_t
>&
slots
,
Gauc
&
gauc
)
{
// NOLINT
size_t
end
=
input
.
find_first_of
(
' '
);
if
(
end
==
std
::
string
::
npos
)
{
LOG
(
ERROR
)
<<
"[ERROR] Fail in parsing:"
<<
input
;
return
false
;
}
label
=
input
[
end
+
3
]
-
'0'
;
CHECK
(
label
==
0
||
label
==
1
)
<<
"invalid label:"
<<
label
;
gauc
.
show
=
1
;
gauc
.
click
=
label
;
gauc
.
lineid
=
input
.
substr
(
0
,
end
);
gauc
.
fea
=
0
;
size_t
dnn_start
=
input
.
find
(
"*"
);
if
(
dnn_start
==
std
::
string
::
npos
)
{
q_score
=
-
1.0
;
}
else
{
dnn_start
+=
1
;
size_t
dnn_end
=
input
.
find
(
' '
,
dnn_start
);
q_score
=
static_cast
<
float
>
(
atof
(
input
.
substr
(
dnn_start
,
dnn_end
-
dnn_start
).
c_str
()));
}
size_t
head_pos
=
input
.
find
(
"
\t
"
);
std
::
string
head
=
input
.
substr
(
0
,
head_pos
);
std
::
stringstream
ss
(
head
);
std
::
string
token
;
uint16_t
slot_id
=
0
;
uint64_t
feature_id
=
0
;
int
num_nonfeas_token
=
0
;
std
::
ostringstream
os
;
while
(
ss
>>
token
)
{
size_t
end
=
token
.
find_first_of
(
':'
);
if
(
end
==
std
::
string
::
npos
)
{
++
num_nonfeas_token
;
continue
;
}
try
{
slot_id
=
stoi
(
token
.
substr
(
end
+
1
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing slot id:"
<<
token
;
return
false
;
}
try
{
feature_id
=
stoull
(
token
.
substr
(
0
,
end
));
}
catch
(...)
{
LOG
(
ERROR
)
<<
"Error in parsing feature id:"
<<
token
;
return
false
;
}
if
(
slot_id
<=
0
)
{
LOG
(
ERROR
)
<<
"invalid slot:"
<<
slot_id
<<
" feasign:"
<<
feature_id
<<
" line:"
<<
input
;
return
false
;
}
if
(
slots
.
find
(
slot_id
)
==
slots
.
end
())
{
continue
;
}
if
(
slot_id
==
6048
)
{
gauc
.
fea
=
feature_id
;
}
feas
[
slot_id
].
push_back
(
feature_id
);
}
if
(
num_nonfeas_token
!=
4
)
{
LOG
(
ERROR
)
<<
"Format error. Invalid number of non-feasign token:"
<<
num_nonfeas_token
;
return
false
;
}
return
true
;
}
};
class
DataFeed
{
public:
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
init
(
const
datafeed
::
DataFeedParameter
&
feed_param
)
=
0
;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual
bool
check_file
(
const
char
*
filename
)
=
0
;
virtual
bool
set_file
(
const
char
*
filename
)
=
0
;
virtual
bool
read_batch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
get_all_slot_ids
()
{
return
_all_slot_ids
;
}
virtual
const
std
::
vector
<
uint16_t
>&
get_use_slot_ids
()
{
return
_use_slot_ids
;
}
virtual
const
std
::
vector
<
std
::
string
>&
get_use_slot_alias
()
{
return
_use_slot_alias
;
}
virtual
void
add_feed_var
(
Variable
*
var
,
const
std
::
string
&
name
)
=
0
;
virtual
void
bind_scope
(
Scope
*
scope
)
=
0
;
virtual
void
set_batch_size
(
int
batch
)
{
_default_batch_size
=
batch
;
}
virtual
int
get_batch_size
()
{
return
_batch_size
;
}
virtual
void
set_buffer_size
(
int
buffer_size
)
{}
std
::
vector
<
LoDTensor
*>&
get_feed_vec
()
{
return
_feed_vec
;
}
virtual
std
::
vector
<
LoDTensor
*>&
get_feed_vec
(
const
Instance
&
ins
)
{
LOG
(
ERROR
)
<<
"use defalut get_feed_vec"
;
return
_feed_vec
;
}
protected:
std
::
vector
<
uint16_t
>
_all_slot_ids
;
std
::
vector
<
uint16_t
>
_use_slot_ids
;
std
::
vector
<
std
::
string
>
_use_slot_alias
;
std
::
vector
<
LoDTensor
*>
_feed_vec
;
int
_default_batch_size
;
int
_batch_size
;
};
class
TextClassDataFeed
:
public
DataFeed
{
public:
virtual
~
TextClassDataFeed
()
{}
virtual
void
init
(
const
datafeed
::
DataFeedParameter
&
feed_param
);
virtual
bool
read_batch
();
virtual
void
add_feed_var
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
bind_scope
(
Scope
*
scope
)
{}
virtual
bool
set_file
(
const
char
*
filename
);
virtual
bool
check_file
(
const
char
*
filename
)
{
// TODO(xxx)
return
false
;
}
void
set_batch_size
(
int
batch
)
{
_batch_size
=
batch
;}
private:
int
read_whole_file
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
_file_content_buffer
;
char
*
_file_content_buffer_ptr
;
int
*
_batch_id_buffer
;
int
*
_label_ptr
;
int
_file_size
;
std
::
vector
<
std
::
string
>
_names
;
std
::
shared_ptr
<
char
>
_file_content_buffer_host
;
std
::
shared_ptr
<
int
>
_batch_id_host
;
std
::
shared_ptr
<
int
>
_label_host
;
};
}
// namespace framework
}
// namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/framework/datafeed_creator.cc
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/datafeed_creator.h"
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
create_datafeed
(
const
char
*
datafeed_class
)
{
if
(
strcmp
(
datafeed_class
,
"TextClass"
)
==
0
)
{
return
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
(
new
paddle
::
framework
::
TextClassDataFeed
);
}
return
NULL
;
}
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/framework/datafeed_creator.h
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
#define PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>
create_datafeed
(
const
char
*
datafeed_class
);
#endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
c555948c
set
(
PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder
)
set
(
PYBIND_DEPS pybind python proto_desc memory executor
async_executor
prune feed_fetch_method pass_builder
)
set
(
PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc
)
set
(
PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc
)
if
(
NOT WIN32
)
if
(
NOT WIN32
)
list
(
APPEND PYBIND_DEPS parallel_executor profiler
)
list
(
APPEND PYBIND_DEPS parallel_executor profiler
)
...
...
proto/FeedDataParameter.proto
0 → 100644
浏览文件 @
c555948c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax
=
"proto2"
;
package
datafeed
;
message
DataFeedParameter
{
optional
FeedDataParameter
feed_data_param
=
1
;
optional
JointOneHotParameter
joint_onehot_data_param
=
2
;
optional
ACDXParameter
acdx_data_param
=
3
;
}
message
FeedDataParameter
{
repeated
int32
slot_id
=
1
;
repeated
int32
use_slot_id
=
2
;
repeated
string
use_slot_alias
=
3
;
repeated
uint64
use_slot_mod
=
4
;
repeated
int32
use_slot_type
=
5
;
optional
int32
max_batch_num
=
6
[
default
=
128
];
optional
int32
max_feasign_num
=
7
[
default
=
1000
];
}
message
JointOneHotParameter
{
optional
int32
max_batch_num
=
1
[
default
=
128
];
optional
int32
max_title_num
=
2
[
default
=
400
];
optional
int32
max_term_num
=
3
[
default
=
1024
];
required
float
sampling_rate
=
4
;
repeated
int32
slot_id
=
5
;
repeated
int32
use_slot_id
=
6
;
repeated
string
use_slot_alias
=
7
;
repeated
uint64
use_slot_mod
=
8
;
repeated
int32
use_slot_type
=
9
;
}
message
ACDXParameter
{
optional
int32
max_batch_num
=
1
[
default
=
128
];
optional
int32
max_term_num
=
3
[
default
=
512
];
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录