Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
92a98ca7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92a98ca7
编写于
11月 19, 2018
作者:
B
barrierye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add MultiSlotDataFeed
上级
78c3380b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
491 addition
and
261 deletion
+491
-261
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+197
-177
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+283
-84
paddle/fluid/framework/data_feed.proto
paddle/fluid/framework/data_feed.proto
+11
-0
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
92a98ca7
...
...
@@ -34,221 +34,241 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h"
DEFINE_bool
(
is_text_feed
,
false
,
"is_text_feed"
);
namespace
paddle
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClassDataFeed
::
s_filelist_
;
std
::
mutex
TextClassDataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClassDataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClassDataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClassDataFeed
::
s_current_epoch_
=
0
;
int
TextClassDataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClassDataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClassDataFeed
::
s_condition_epoch_start_
;
bool
TextClassDataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClassDataFeed
::
Init
()
{
// hard coding for a specific datafeed
feed_vec_
.
resize
(
2
);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
all_slot_ids_
=
{
0
,
1
};
use_slot_ids_
=
{
0
,
1
};
use_slot_alias_
=
{
"words"
,
"label"
};
file_content_buffer_host_
.
reset
(
new
char
[
200
*
1024
*
1024
],
[](
char
*
p
)
{
delete
[]
p
;});
file_content_buffer_
=
file_content_buffer_host_
.
get
();
file_content_buffer_ptr_
=
file_content_buffer_
;
batch_id_host_
.
reset
(
new
int
[
10240
*
1024
],
[](
int
*
p
)
{
delete
[]
p
;});
// max word num in a batch
batch_id_buffer_
=
batch_id_host_
.
get
();
label_host_
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
label_ptr_
=
label_host_
.
get
();
field_names_
.
clear
();
}
TextClassDataFeed
::
TextClassDataFeed
()
{
Init
();
}
// todo: use elegant implemention for this function
bool
TextClassDataFeed
::
ReadBatch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
llen
=
0
;
int
inst_idx
=
0
;
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
batch_size_
)
{
int
ptr_offset
=
0
;
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
break
;
std
::
vector
<
std
::
string
>
DataFeed
::
filelist_
;
size_t
DataFeed
::
file_idx_
;
std
::
mutex
DataFeed
::
mutex_for_pick_file_
;
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
if
(
CheckInit
()
==
false
)
{
return
;}
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
if
(
name
==
use_slots_
[
i
])
{
if
(
use_slots_is_dense_
[
i
])
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
Tensor
>
());
}
else
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
memcpy
(
reinterpret_cast
<
char
*>
(
&
llen
),
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
memcpy
(
reinterpret_cast
<
char
*>
(
batch_id_buffer_
+
tlen
),
file_content_buffer_ptr_
+
ptr_offset
,
llen
*
sizeof
(
int
));
tlen
+=
llen
;
offset
[
inst_idx
+
1
]
=
offset
[
inst_idx
]
+
llen
;
ptr_offset
+=
sizeof
(
int
)
*
llen
;
memcpy
(
reinterpret_cast
<
char
*>
(
label_ptr_
+
inst_idx
),
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
file_content_buffer_ptr_
+=
ptr_offset
;
inst_idx
++
;
}
}
if
(
inst_idx
!=
batch_size_
)
{
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
if
(
CheckInit
()
==
false
)
{
return
false
;}
if
(
files
.
size
()
==
0
)
{
LOG
(
ERROR
)
<<
"error: you have set an empty filelist"
;
return
false
;
}
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
file_idx_
=
0
;
LoD
input_lod
{
offset
};
paddle
::
framework
::
Vector
<
size_t
>
label_offset
;
label_offset
.
resize
(
batch_size_
+
1
);
for
(
int
i
=
0
;
i
<=
batch_size_
;
++
i
)
{
label_offset
[
i
]
=
i
;
}
finish_set_filelist_
=
true
;
return
true
;
}
LoD
label_lod
{
label_offset
};
int64_t
*
input_ptr
=
feed_vec_
[
0
]
->
mutable_data
<
int64_t
>
(
{
static_cast
<
int64_t
>
(
offset
.
back
()),
1
},
platform
::
CPUPlace
());
int64_t
*
label_ptr
=
feed_vec_
[
1
]
->
mutable_data
<
int64_t
>
({
batch_size_
,
1
},
platform
::
CPUPlace
());
for
(
unsigned
int
i
=
0
;
i
<
offset
.
back
();
++
i
)
{
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
batch_id_buffer_
[
i
]);
}
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
label_ptr_
[
i
]);
bool
DataFeed
::
PickOneFile
(
std
::
string
&
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
if
(
file_idx_
==
filelist_
.
size
())
{
return
false
;
}
feed_vec_
[
0
]
->
set_lod
(
input_lod
);
feed_vec_
[
1
]
->
set_lod
(
label_lod
);
filename
=
filelist_
[
file_idx_
++
];
return
true
;
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
)
{
Init
();
SetBatchSize
(
data_feed
.
batch_size_
)
;
SetFieldNames
(
data_feed
.
field_names_
)
;
bool
DataFeed
::
CheckInit
(
)
{
if
(
finish_init_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: initialization did not succeed"
;
return
false
;
}
void
TextClassDataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
}
bool
DataFeed
::
CheckSetFileList
()
{
if
(
finish_set_filelist_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: set filelist did not succeed"
;
return
false
;
}
bool
DataFeed
::
CheckStart
()
{
if
(
finish_start_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: Datafeed has not started running yet"
;
return
false
;
}
template
<
typename
T
>
void
PrivateQueueDataFeed
<
T
>::
SetQueueSize
(
int
queue_size
)
{
if
(
!
CheckInit
())
{
return
;}
if
(
queue_size
<=
0
)
{
LOG
(
ERROR
)
<<
"error: illegal queue size: "
<<
queue_size
;
return
;
}
queue_size_
=
queue_size
;
queue_
.
ReCap
(
queue_size_
);
}
void
TextClassDataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
"Opening file %s fail"
,
filelist
);
template
<
typename
T
>
bool
PrivateQueueDataFeed
<
T
>::
Start
()
{
if
(
!
(
CheckSetFileList
()))
{
return
false
;}
read_thread_
=
std
::
thread
(
&
PrivateQueueDataFeed
::
ReadThread
,
this
);
read_thread_
.
detach
();
finish_start_
=
true
;
return
true
;
}
template
<
typename
T
>
void
PrivateQueueDataFeed
<
T
>::
ReadThread
(){
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
s_filelist_
.
push_back
(
filename
);
while
(
PickOneFile
(
filename
))
{
file_
.
open
(
filename
.
c_str
());
// is_text_feed
if
(
!
file_
.
is_open
())
{
LOG
(
ERROR
)
<<
"error: open file<"
<<
filename
<<
"> fail"
;
}
T
instance
;
while
(
ParseOneInstance
(
instance
))
{
queue_
.
Send
(
instance
);
}
file_
.
close
();
}
fin
.
c
lose
();
queue_
.
C
lose
();
}
void
TextClassDataFeed
::
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
template
<
typename
T
>
bool
PrivateQueueDataFeed
<
T
>::
Next
(){
if
(
!
CheckStart
())
{
return
false
;}
int
index
=
0
;
T
instance
;
T
ins_vec
(
use_slots_
.
size
());
while
(
index
<
default_batch_size_
)
{
if
(
!
queue_
.
Receive
(
&
instance
))
{
break
;
}
AddInstanceToInsVec
(
ins_vec
,
instance
,
index
++
);
}
batch_size_
=
index
;
PutToFeedVec
(
ins_vec
);
return
batch_size_
!=
0
;
}
bool
TextClassDataFeed
::
SetFile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
return
false
;
void
MultiSlotDataFeed
::
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
{
finish_init_
=
false
;
finish_set_filelist_
=
false
;
finish_start_
=
false
;
if
(
!
data_feed_desc
.
has_multi_slot_desc
()){
LOG
(
ERROR
)
<<
"error: multi_slot_desc has not been set"
;
return
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
int
filesize
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
file_content_buffer_
,
filesize
);
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
paddle
::
MultiSlotDesc
multi_slot_desc
=
data_feed_desc
.
multi_slot_desc
();
size_t
all_slot_num
=
multi_slot_desc
.
slots_size
();
all_slots_
.
resize
(
all_slot_num
);
all_slots_type_
.
resize
(
all_slot_num
);
use_slots_index_
.
resize
(
all_slot_num
);
use_slots_
.
clear
();
use_slots_is_dense_
.
clear
();
for
(
size_t
i
=
0
;
i
<
all_slot_num
;
++
i
)
{
auto
&
slot
=
multi_slot_desc
.
slots
(
i
);
all_slots_
[
i
]
=
slot
.
name
();
all_slots_type_
[
i
]
=
slot
.
type
();
use_slots_index_
[
i
]
=
slot
.
use
()
?
use_slots_
.
size
()
:
-
1
;
if
(
slot
.
use
())
{
use_slots_
.
push_back
(
all_slots_
[
i
]);
use_slots_is_dense_
.
push_back
(
slot
.
dense
());
}
}
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
// todo , remove magic number
feed_vec_
.
resize
(
use_slots_
.
size
());
return
true
;
finish_init_
=
true
;
}
void
TextClassDataFeed
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
#if 1
LOG
(
WARNING
)
<<
"UpdateEpochNum: epoch = "
<<
s_current_epoch_
;
#endif
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
false
;
bool
MultiSlotDataFeed
::
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>&
instance
)
{
std
::
string
line
;
if
(
getline
(
file_
,
line
))
{
int
use_slots_num
=
use_slots_
.
size
();
instance
.
resize
(
use_slots_num
);
//parse line
const
char
*
str
=
line
.
c_str
();
char
*
endptr
=
(
char
*
)
str
;
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
int
idx
=
use_slots_index_
[
i
];
int
num
=
(
int
)
strtol
(
&
str
[
pos
],
&
endptr
,
10
);
if
(
num
==
0
)
{
LOG
(
ERROR
)
<<
"error: the number of ids can not be zero, you need padding it"
;
exit
(
-
1
);
}
if
(
idx
!=
-
1
)
{
instance
[
idx
].
SetType
(
all_slots_type_
[
i
]);
if
(
instance
[
idx
].
GetType
()[
0
]
==
'f'
)
{
// float
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
float
feasign
=
(
float
)
strtof
(
endptr
,
&
endptr
);
instance
[
idx
].
AddValue
(
feasign
);
}
}
else
if
(
instance
[
idx
].
GetType
()[
0
]
==
'u'
){
// uint64
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
uint64_t
feasign
=
(
uint64_t
)
strtoull
(
endptr
,
&
endptr
,
10
);
instance
[
idx
].
AddValue
(
feasign
);
}
}
pos
=
endptr
-
str
;
}
else
{
for
(
int
j
=
0
;
j
<=
num
;
++
j
)
{
pos
=
line
.
find_first_of
(
' '
,
pos
+
1
);
}
}
}
}
else
{
return
false
;
}
return
true
;
}
void
TextClassDataFeed
::
StartOneEpoch
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
s_current_file_idx_
=
0
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
void
MultiSlotDataFeed
::
AddInstanceToInsVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
,
std
::
vector
<
MultiSlotType
>&
instance
,
int
index
)
{
if
(
index
==
0
)
{
for
(
size_t
i
=
0
;
i
<
instance
.
size
();
++
i
)
{
ins_vec
[
i
].
SetType
(
instance
[
i
].
GetType
())
;
}
}
for
(
size_t
i
=
0
;
i
<
instance
.
size
();
++
i
){
ins_vec
[
i
].
AddIns
(
instance
[
i
])
;
}
s_condition_epoch_start_
.
notify_all
();
}
void
TextClassDataFeed
::
WaitNextEpoch
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
}
const
char
*
TextClassDataFeed
::
PickOneFile
()
{
std
::
string
file_to_be_processed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
// One epoch has run over
// Wait for next epoch
if
(
s_current_file_idx_
>=
s_filelist_
.
size
())
{
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
return
NULL
;
void
MultiSlotDataFeed
::
PutToFeedVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
)
{
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
auto
&
type
=
ins_vec
[
i
].
GetType
();
auto
&
offset
=
ins_vec
[
i
].
GetOffset
();
int
total_instance
=
static_cast
<
int
>
(
offset
.
back
());
if
(
type
[
0
]
==
'f'
)
{
// float
auto
&
feasign
=
ins_vec
[
i
].
GetFloatData
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
float
>
({
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
}
else
{
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
float
>
({
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
}
else
if
(
type
[
0
]
==
'u'
)
{
// uint64
// no uint64_t type
auto
&
feasign
=
ins_vec
[
i
].
GetUint64Data
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
int64_t
>
({
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
int64_t
));
}
else
{
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
int64_t
>
({
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
uint64_t
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
}
}
file_to_be_processed
=
s_filelist_
[
s_current_file_idx_
];
s_current_file_idx_
++
;
return
file_to_be_processed
.
c_str
();
}
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
92a98ca7
...
...
@@ -27,136 +27,335 @@ limitations under the License. */
#include <unordered_set>
#include <condition_variable> // NOLINT
#include <fstream>
#include <deque>
#include <atomic>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_feed.pb.h"
namespace
paddle
{
namespace
framework
{
struct
Gauc
{
int
show
,
click
;
uint64_t
fea
;
std
::
string
lineid
;
};
struct
Instance
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
feed_vec_buffer
;
std
::
vector
<
std
::
vector
<
int
>>
feed_vec_lod
;
std
::
vector
<
float
>
other_label
;
std
::
vector
<
Gauc
>
gauc_vec
;
class
MixTensor
{
public:
MixTensor
(){}
MixTensor
(
LoDTensor
*
lodtensor
)
{
is_dense_
=
false
;
lodtensor_
=
lodtensor
;
}
MixTensor
(
Tensor
*
tensor
)
{
is_dense_
=
true
;
tensor_
=
tensor
;
}
bool
IsDense
()
{
return
is_dense_
;}
LoDTensor
*
GetLoDTensor
(){
if
(
is_dense_
)
{
LOG
(
ERROR
)
<<
"error: let a dense var return a LoDTensor ptr"
;
return
NULL
;
}
return
lodtensor_
;
}
Tensor
*
GetTensor
(){
if
(
!
is_dense_
)
{
LOG
(
ERROR
)
<<
"error: let a sparse var return a Tensor ptr"
;
return
NULL
;
}
return
tensor_
;
}
private:
bool
is_dense_
;
LoDTensor
*
lodtensor_
;
Tensor
*
tensor_
;
};
class
DataFeed
{
template
<
typename
T
>
class
BlockingQueue
{
public:
DataFeed
()
:
default_batch_size_
(
1
),
batch_size_
(
0
),
thread_id_
(
0
)
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
()
=
0
;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual
bool
CheckFile
(
const
char
*
filename
)
=
0
;
virtual
bool
SetFile
(
const
char
*
filename
)
=
0
;
virtual
bool
ReadBatch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
GetAllSlotIds
()
{
return
all_slot_ids_
;
explicit
BlockingQueue
(
size_t
capacity
=
32
)
:
capacity_
(
capacity
),
closed_
(
false
)
{
size_
.
store
(
0
);
}
v
irtual
const
std
::
vector
<
uint16_t
>&
GetUseSlotIds
(
)
{
return
use_slot_ids_
;
v
oid
ReCap
(
size_t
capacity
)
{
capacity_
=
capacity
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlotAlias
()
{
return
use_slot_alias_
;
}
bool
Send
(
const
T
&
elem
)
{
int
c
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
send_mutex_
);
send_cv_
.
wait
(
lock
,
[
&
]
{
return
size_
.
load
()
<
capacity_
||
closed_
;});
if
(
closed_
)
{
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
return
false
;
}
queue_
.
push_back
(
elem
);
c
=
size_
.
load
();
size_
.
fetch_add
(
1
);
}
if
(
c
+
1
<
capacity_
)
{
send_cv_
.
notify_one
();
}
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
=
0
;
virtual
void
BindScope
(
Scope
*
scope
)
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
void
SetBufferSize
(
int
buffer_size
)
{}
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
StartOneEpoch
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
if
(
c
==
0
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
receive_mutex_
);
receive_cv_
.
notify_one
();
}
return
true
;
}
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
return
feed_vec_
;
bool
Receive
(
T
*
elem
)
{
int
c
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
receive_mutex_
);
receive_cv_
.
wait
(
lock
,
[
&
]
{
return
size_
.
load
()
!=
0
||
closed_
;});
if
(
size_
.
load
()
!=
0
)
{
*
elem
=
queue_
.
front
();
queue_
.
pop_front
();
c
=
size_
.
load
();
size_
.
fetch_sub
(
1
);
}
else
{
return
false
;
}
}
if
(
c
>
1
)
{
receive_cv_
.
notify_one
();
}
if
(
c
==
capacity_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
send_mutex_
);
send_cv_
.
notify_one
();
}
return
true
;
}
virtual
std
::
vector
<
LoDTensor
*>&
GetFeedVec
(
const
Instance
&
ins
)
{
LOG
(
ERROR
)
<<
"use defalut get_feed_vec"
;
return
feed_vec_
;
void
Close
()
{
std
::
lock_guard
<
std
::
mutex
>
lock1
(
send_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock2
(
receive_mutex_
);
closed_
=
true
;
send_cv_
.
notify_all
();
receive_cv_
.
notify_all
();
}
int
GetThreadId
()
{
return
thread_id_
;}
void
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;}
private:
size_t
capacity_
;
std
::
atomic_size_t
size_
;
bool
closed_
;
std
::
deque
<
T
>
queue_
;
mutable
std
::
mutex
send_mutex_
;
mutable
std
::
mutex
receive_mutex_
;
mutable
std
::
condition_variable
send_cv_
;
mutable
std
::
condition_variable
receive_cv_
;
};
class
DataFeed
{
public:
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
// for some datafeeds may not be able to implement this interface
virtual
bool
CheckFile
(
const
char
*
filename
)
{
LOG
(
ERROR
)
<<
"error: The function CheckFile is not implemented"
;
return
false
;
}
virtual
bool
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
);
virtual
bool
Start
()
=
0
;
virtual
bool
Next
()
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
// for subclass with queue
virtual
void
SetQueueSize
(
int
queue_size
)
{
LOG
(
ERROR
)
<<
"error: The function SetQueueSize is not implemented"
;
}
// for subclass with buffer
virtual
void
SetBufferSize
(
int
buffer_size
)
{
LOG
(
ERROR
)
<<
"error: The function SetBufferSize is not implemented"
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetAllSlots
()
{
return
all_slots_
;}
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlots
()
{
return
use_slots_
;}
std
::
vector
<
MixTensor
>&
GetFeedVec
()
{
return
feed_vec_
;}
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
protected:
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
std
::
vector
<
std
::
string
>
use_slot_alias_
;
std
::
vector
<
LoDTensor
*>
feed_vec_
;
// Check if it is executed in this order:
// Init -> SetFileList/BindingMemory -> Start -> Next
virtual
bool
CheckInit
();
virtual
bool
CheckSetFileList
();
virtual
bool
CheckStart
();
virtual
bool
PickOneFile
(
std
::
string
&
filename
);
static
std
::
vector
<
std
::
string
>
filelist_
;
static
size_t
file_idx_
;
static
std
::
mutex
mutex_for_pick_file_
;
std
::
vector
<
std
::
string
>
use_slots_
;
std
::
vector
<
bool
>
use_slots_is_dense_
;
std
::
vector
<
std
::
string
>
all_slots_
;
std
::
vector
<
std
::
string
>
all_slots_type_
;
std
::
vector
<
int
>
use_slots_index_
;
// -1: not used; >=0: the index of use_slots_
std
::
vector
<
MixTensor
>
feed_vec_
;
int
default_batch_size_
;
int
batch_size_
;
int
thread_id_
;
bool
finish_init_
;
bool
finish_set_filelist_
;
bool
finish_binding_memory_
;
bool
finish_start_
;
};
class
TextClassDataFeed
:
public
DataFeed
{
template
<
typename
T
>
class
PrivateQueueDataFeed
:
public
DataFeed
{
public:
TextClassDataFeed
();
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
);
PrivateQueueDataFeed
()
{}
virtual
~
PrivateQueueDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
Start
();
virtual
bool
Next
();
// no buffer
virtual
void
SetQueueSize
(
int
queue_size
);
protected:
virtual
void
ReadThread
();
virtual
bool
ParseOneInstance
(
T
&
instance
)
=
0
;
virtual
void
AddInstanceToInsVec
(
T
&
vec_ins
,
T
&
instance
,
int
index
)
=
0
;
virtual
void
PutToFeedVec
(
T
&
ins_vec
)
=
0
;
std
::
thread
read_thread_
;
// the thread for read files
/* using ifstream one line and one line parse is faster
* than using fread one buffer and one buffer parse.
* for 601M JingPai data:
* ifstream one line and one line parse: 6034 ms
* fread one buffer and one buffer parse: 7097 ms */
std
::
ifstream
file_
;
size_t
queue_size_
;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue
<
T
>
queue_
;
};
class
MultiSlotType
{
public:
MultiSlotType
()
{
float_feasign_
.
clear
();
uint64_feasign_
.
clear
();
offset_
.
resize
(
1
);
offset_
[
0
]
=
0
;
}
~
MultiSlotType
()
{}
void
SetType
(
std
::
string
&
type
)
{
if
(
!
CheckType
(
type
))
{
return
;}
type_
=
type
;
}
std
::
vector
<
size_t
>&
GetOffset
()
{
return
offset_
;
}
void
AddValue
(
float
v
)
{
if
(
!
CheckFloat
())
{
return
;}
float_feasign_
.
push_back
(
v
);
}
void
AddValue
(
uint64_t
v
)
{
if
(
!
CheckUint64
())
{
return
;}
uint64_feasign_
.
push_back
(
v
);
}
void
AddIns
(
MultiSlotType
&
ins
)
{
if
(
ins
.
GetType
()[
0
]
==
'f'
)
{
//float
if
(
!
CheckFloat
())
{
return
;}
auto
&
vec
=
ins
.
GetFloatData
();
offset_
.
push_back
(
offset_
.
back
()
+
vec
.
size
());
float_feasign_
.
insert
(
float_feasign_
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
else
if
(
ins
.
GetType
()[
0
]
==
'u'
)
{
//uint64
if
(
!
CheckUint64
())
{
return
;}
auto
&
vec
=
ins
.
GetUint64Data
();
offset_
.
push_back
(
offset_
.
back
()
+
vec
.
size
());
uint64_feasign_
.
insert
(
uint64_feasign_
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
}
std
::
vector
<
float
>&
GetFloatData
()
{
return
float_feasign_
;
}
std
::
vector
<
uint64_t
>&
GetUint64Data
()
{
return
uint64_feasign_
;
}
std
::
string
&
GetType
()
{
return
type_
;
}
private:
bool
CheckType
(
std
::
string
&
type
)
{
if
(
type
!=
"uint64"
&&
type
!=
"float"
)
{
// check in here
LOG
(
ERROR
)
<<
"error: here is no this type"
;
return
false
;
}
return
true
;
}
bool
CheckFloat
()
{
if
(
type_
[
0
]
!=
'f'
)
{
//float
LOG
(
ERROR
)
<<
"error: add "
<<
type_
<<
" value to float slot"
;
return
false
;
}
return
true
;
}
bool
CheckUint64
()
{
if
(
type_
[
0
]
!=
'u'
)
{
//uint64
LOG
(
ERROR
)
<<
"error: add "
<<
type_
<<
" value to uint64 slot"
;
return
false
;
}
return
true
;
}
std
::
vector
<
float
>
float_feasign_
;
std
::
vector
<
uint64_t
>
uint64_feasign_
;
std
::
string
type_
;
std
::
vector
<
size_t
>
offset_
;
};
class
MultiSlotDataFeed
:
public
PrivateQueueDataFeed
<
std
::
vector
<
MultiSlotType
>>
{
public:
MultiSlotDataFeed
()
{}
virtual
~
MultiSlotDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
);
//TODO: virtual bool CheckFile();
protected:
virtual
void
AddInstanceToInsVec
(
std
::
vector
<
MultiSlotType
>&
vec_ins
,
std
::
vector
<
MultiSlotType
>&
instance
,
int
index
);
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>&
instance
);
virtual
void
PutToFeedVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
);
};
//TODO: to be deleted
class
TextClassDataFeed
:
public
DataFeed
{
public:
virtual
~
TextClassDataFeed
()
{}
virtual
void
Init
();
virtual
bool
ReadBatch
();
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
{}
virtual
bool
Start
()
{
return
false
;};
//TODO
virtual
bool
Next
()
{
return
false
;};
//TODO
virtual
bool
ReadBatch
()
{
return
false
;}
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{}
virtual
void
BindScope
(
Scope
*
scope
)
{}
virtual
bool
SetFile
(
const
char
*
filename
);
virtual
bool
SetFile
(
const
char
*
filename
)
{
return
false
;}
virtual
bool
CheckFile
(
const
char
*
filename
)
{
// TODO(xxx)
return
false
;
}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
void
UpdateEpochNum
();
void
StartOneEpoch
();
void
WaitNextEpoch
();
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
public:
static
void
SetFileList
(
const
char
*
filelist
);
private:
const
char
*
PickOneFile
();
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
private:
int
ReadWholeFile
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
return
-
1
;}
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
field_
names_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
static
std
::
vector
<
std
::
string
>
s_filelist_
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
mutex
s_locker_epoch_start_
;
static
std
::
condition_variable
s_condition_epoch_start_
;
static
bool
s_epoch_start_flag_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.proto
浏览文件 @
92a98ca7
...
...
@@ -17,5 +17,16 @@ package paddle;
message
DataFeedDesc
{
optional
string
name
=
1
;
optional
int32
batch
=
2
[
default
=
32
];
optional
MultiSlotDesc
multi_slot_desc
=
3
;
}
message
MultiSlotDesc
{
repeated
Slot
slots
=
1
;
}
message
Slot
{
required
string
name
=
1
;
required
string
type
=
2
;
optional
bool
dense
=
3
[
default
=
false
];
optional
bool
use
=
4
[
default
=
true
];
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录