Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
92a98ca7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92a98ca7
编写于
11月 19, 2018
作者:
B
barrierye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add MultiSlotDataFeed
上级
78c3380b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
491 addition
and
261 deletion
+491
-261
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+197
-177
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+283
-84
paddle/fluid/framework/data_feed.proto
paddle/fluid/framework/data_feed.proto
+11
-0
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
92a98ca7
...
...
@@ -34,221 +34,241 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h"
DEFINE_bool
(
is_text_feed
,
false
,
"is_text_feed"
);
namespace
paddle
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClassDataFeed
::
s_filelist_
;
std
::
mutex
TextClassDataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClassDataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClassDataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClassDataFeed
::
s_current_epoch_
=
0
;
int
TextClassDataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClassDataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClassDataFeed
::
s_condition_epoch_start_
;
bool
TextClassDataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClassDataFeed
::
Init
()
{
// hard coding for a specific datafeed
feed_vec_
.
resize
(
2
);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
all_slot_ids_
=
{
0
,
1
};
use_slot_ids_
=
{
0
,
1
};
use_slot_alias_
=
{
"words"
,
"label"
};
file_content_buffer_host_
.
reset
(
new
char
[
200
*
1024
*
1024
],
[](
char
*
p
)
{
delete
[]
p
;});
file_content_buffer_
=
file_content_buffer_host_
.
get
();
file_content_buffer_ptr_
=
file_content_buffer_
;
batch_id_host_
.
reset
(
new
int
[
10240
*
1024
],
[](
int
*
p
)
{
delete
[]
p
;});
// max word num in a batch
batch_id_buffer_
=
batch_id_host_
.
get
();
label_host_
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
label_ptr_
=
label_host_
.
get
();
field_names_
.
clear
();
}
TextClassDataFeed
::
TextClassDataFeed
()
{
Init
();
}
// todo: use elegant implemention for this function
bool
TextClassDataFeed
::
ReadBatch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
llen
=
0
;
int
inst_idx
=
0
;
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
batch_size_
)
{
int
ptr_offset
=
0
;
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
break
;
std
::
vector
<
std
::
string
>
DataFeed
::
filelist_
;
size_t
DataFeed
::
file_idx_
;
std
::
mutex
DataFeed
::
mutex_for_pick_file_
;
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
if
(
CheckInit
()
==
false
)
{
return
;}
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
if
(
name
==
use_slots_
[
i
])
{
if
(
use_slots_is_dense_
[
i
])
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
Tensor
>
());
}
else
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
memcpy
(
reinterpret_cast
<
char
*>
(
&
llen
),
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
memcpy
(
reinterpret_cast
<
char
*>
(
batch_id_buffer_
+
tlen
),
file_content_buffer_ptr_
+
ptr_offset
,
llen
*
sizeof
(
int
));
tlen
+=
llen
;
offset
[
inst_idx
+
1
]
=
offset
[
inst_idx
]
+
llen
;
ptr_offset
+=
sizeof
(
int
)
*
llen
;
memcpy
(
reinterpret_cast
<
char
*>
(
label_ptr_
+
inst_idx
),
file_content_buffer_ptr_
+
ptr_offset
,
sizeof
(
int
));
ptr_offset
+=
sizeof
(
int
);
file_content_buffer_ptr_
+=
ptr_offset
;
inst_idx
++
;
}
}
if
(
inst_idx
!=
batch_size_
)
{
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
if
(
CheckInit
()
==
false
)
{
return
false
;}
if
(
files
.
size
()
==
0
)
{
LOG
(
ERROR
)
<<
"error: you have set an empty filelist"
;
return
false
;
}
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
file_idx_
=
0
;
LoD
input_lod
{
offset
};
paddle
::
framework
::
Vector
<
size_t
>
label_offset
;
label_offset
.
resize
(
batch_size_
+
1
);
for
(
int
i
=
0
;
i
<=
batch_size_
;
++
i
)
{
label_offset
[
i
]
=
i
;
}
finish_set_filelist_
=
true
;
return
true
;
}
LoD
label_lod
{
label_offset
};
int64_t
*
input_ptr
=
feed_vec_
[
0
]
->
mutable_data
<
int64_t
>
(
{
static_cast
<
int64_t
>
(
offset
.
back
()),
1
},
platform
::
CPUPlace
());
int64_t
*
label_ptr
=
feed_vec_
[
1
]
->
mutable_data
<
int64_t
>
({
batch_size_
,
1
},
platform
::
CPUPlace
());
for
(
unsigned
int
i
=
0
;
i
<
offset
.
back
();
++
i
)
{
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
batch_id_buffer_
[
i
]);
}
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
label_ptr_
[
i
]);
bool
DataFeed
::
PickOneFile
(
std
::
string
&
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
if
(
file_idx_
==
filelist_
.
size
())
{
return
false
;
}
feed_vec_
[
0
]
->
set_lod
(
input_lod
);
feed_vec_
[
1
]
->
set_lod
(
label_lod
);
filename
=
filelist_
[
file_idx_
++
];
return
true
;
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
)
{
Init
();
SetBatchSize
(
data_feed
.
batch_size_
)
;
SetFieldNames
(
data_feed
.
field_names_
)
;
bool
DataFeed
::
CheckInit
(
)
{
if
(
finish_init_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: initialization did not succeed"
;
return
false
;
}
void
TextClassDataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
}
bool
DataFeed
::
CheckSetFileList
()
{
if
(
finish_set_filelist_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: set filelist did not succeed"
;
return
false
;
}
bool
DataFeed
::
CheckStart
()
{
if
(
finish_start_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: Datafeed has not started running yet"
;
return
false
;
}
template
<
typename
T
>
void
PrivateQueueDataFeed
<
T
>::
SetQueueSize
(
int
queue_size
)
{
if
(
!
CheckInit
())
{
return
;}
if
(
queue_size
<=
0
)
{
LOG
(
ERROR
)
<<
"error: illegal queue size: "
<<
queue_size
;
return
;
}
queue_size_
=
queue_size
;
queue_
.
ReCap
(
queue_size_
);
}
void
TextClassDataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
"Opening file %s fail"
,
filelist
);
template
<
typename
T
>
bool
PrivateQueueDataFeed
<
T
>::
Start
()
{
if
(
!
(
CheckSetFileList
()))
{
return
false
;}
read_thread_
=
std
::
thread
(
&
PrivateQueueDataFeed
::
ReadThread
,
this
);
read_thread_
.
detach
();
finish_start_
=
true
;
return
true
;
}
template
<
typename
T
>
void
PrivateQueueDataFeed
<
T
>::
ReadThread
(){
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
s_filelist_
.
push_back
(
filename
);
while
(
PickOneFile
(
filename
))
{
file_
.
open
(
filename
.
c_str
());
// is_text_feed
if
(
!
file_
.
is_open
())
{
LOG
(
ERROR
)
<<
"error: open file<"
<<
filename
<<
"> fail"
;
}
T
instance
;
while
(
ParseOneInstance
(
instance
))
{
queue_
.
Send
(
instance
);
}
file_
.
close
();
}
fin
.
c
lose
();
queue_
.
C
lose
();
}
void
TextClassDataFeed
::
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
template
<
typename
T
>
bool
PrivateQueueDataFeed
<
T
>::
Next
(){
if
(
!
CheckStart
())
{
return
false
;}
int
index
=
0
;
T
instance
;
T
ins_vec
(
use_slots_
.
size
());
while
(
index
<
default_batch_size_
)
{
if
(
!
queue_
.
Receive
(
&
instance
))
{
break
;
}
AddInstanceToInsVec
(
ins_vec
,
instance
,
index
++
);
}
batch_size_
=
index
;
PutToFeedVec
(
ins_vec
);
return
batch_size_
!=
0
;
}
bool
TextClassDataFeed
::
SetFile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
return
false
;
void
MultiSlotDataFeed
::
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
{
finish_init_
=
false
;
finish_set_filelist_
=
false
;
finish_start_
=
false
;
if
(
!
data_feed_desc
.
has_multi_slot_desc
()){
LOG
(
ERROR
)
<<
"error: multi_slot_desc has not been set"
;
return
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
int
filesize
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
file_content_buffer_
,
filesize
);
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
paddle
::
MultiSlotDesc
multi_slot_desc
=
data_feed_desc
.
multi_slot_desc
();
size_t
all_slot_num
=
multi_slot_desc
.
slots_size
();
all_slots_
.
resize
(
all_slot_num
);
all_slots_type_
.
resize
(
all_slot_num
);
use_slots_index_
.
resize
(
all_slot_num
);
use_slots_
.
clear
();
use_slots_is_dense_
.
clear
();
for
(
size_t
i
=
0
;
i
<
all_slot_num
;
++
i
)
{
auto
&
slot
=
multi_slot_desc
.
slots
(
i
);
all_slots_
[
i
]
=
slot
.
name
();
all_slots_type_
[
i
]
=
slot
.
type
();
use_slots_index_
[
i
]
=
slot
.
use
()
?
use_slots_
.
size
()
:
-
1
;
if
(
slot
.
use
())
{
use_slots_
.
push_back
(
all_slots_
[
i
]);
use_slots_is_dense_
.
push_back
(
slot
.
dense
());
}
}
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
// todo , remove magic number
feed_vec_
.
resize
(
use_slots_
.
size
());
return
true
;
finish_init_
=
true
;
}
void
TextClassDataFeed
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
#if 1
LOG
(
WARNING
)
<<
"UpdateEpochNum: epoch = "
<<
s_current_epoch_
;
#endif
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
false
;
bool
MultiSlotDataFeed
::
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>&
instance
)
{
std
::
string
line
;
if
(
getline
(
file_
,
line
))
{
int
use_slots_num
=
use_slots_
.
size
();
instance
.
resize
(
use_slots_num
);
//parse line
const
char
*
str
=
line
.
c_str
();
char
*
endptr
=
(
char
*
)
str
;
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
int
idx
=
use_slots_index_
[
i
];
int
num
=
(
int
)
strtol
(
&
str
[
pos
],
&
endptr
,
10
);
if
(
num
==
0
)
{
LOG
(
ERROR
)
<<
"error: the number of ids can not be zero, you need padding it"
;
exit
(
-
1
);
}
if
(
idx
!=
-
1
)
{
instance
[
idx
].
SetType
(
all_slots_type_
[
i
]);
if
(
instance
[
idx
].
GetType
()[
0
]
==
'f'
)
{
// float
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
float
feasign
=
(
float
)
strtof
(
endptr
,
&
endptr
);
instance
[
idx
].
AddValue
(
feasign
);
}
}
else
if
(
instance
[
idx
].
GetType
()[
0
]
==
'u'
){
// uint64
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
uint64_t
feasign
=
(
uint64_t
)
strtoull
(
endptr
,
&
endptr
,
10
);
instance
[
idx
].
AddValue
(
feasign
);
}
}
pos
=
endptr
-
str
;
}
else
{
for
(
int
j
=
0
;
j
<=
num
;
++
j
)
{
pos
=
line
.
find_first_of
(
' '
,
pos
+
1
);
}
}
}
}
else
{
return
false
;
}
return
true
;
}
void
TextClassDataFeed
::
StartOneEpoch
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
s_current_file_idx_
=
0
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
void
MultiSlotDataFeed
::
AddInstanceToInsVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
,
std
::
vector
<
MultiSlotType
>&
instance
,
int
index
)
{
if
(
index
==
0
)
{
for
(
size_t
i
=
0
;
i
<
instance
.
size
();
++
i
)
{
ins_vec
[
i
].
SetType
(
instance
[
i
].
GetType
())
;
}
}
for
(
size_t
i
=
0
;
i
<
instance
.
size
();
++
i
){
ins_vec
[
i
].
AddIns
(
instance
[
i
])
;
}
s_condition_epoch_start_
.
notify_all
();
}
void
TextClassDataFeed
::
WaitNextEpoch
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
}
const
char
*
TextClassDataFeed
::
PickOneFile
()
{
std
::
string
file_to_be_processed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
// One epoch has run over
// Wait for next epoch
if
(
s_current_file_idx_
>=
s_filelist_
.
size
())
{
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
return
NULL
;
void
MultiSlotDataFeed
::
PutToFeedVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
)
{
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
auto
&
type
=
ins_vec
[
i
].
GetType
();
auto
&
offset
=
ins_vec
[
i
].
GetOffset
();
int
total_instance
=
static_cast
<
int
>
(
offset
.
back
());
if
(
type
[
0
]
==
'f'
)
{
// float
auto
&
feasign
=
ins_vec
[
i
].
GetFloatData
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
float
>
({
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
}
else
{
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
float
>
({
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
}
else
if
(
type
[
0
]
==
'u'
)
{
// uint64
// no uint64_t type
auto
&
feasign
=
ins_vec
[
i
].
GetUint64Data
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
int64_t
>
({
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
int64_t
));
}
else
{
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
int64_t
>
({
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
uint64_t
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
}
}
file_to_be_processed
=
s_filelist_
[
s_current_file_idx_
];
s_current_file_idx_
++
;
return
file_to_be_processed
.
c_str
();
}
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
92a98ca7
...
...
@@ -27,136 +27,335 @@ limitations under the License. */
#include <unordered_set>
#include <condition_variable> // NOLINT
#include <fstream>
#include <deque>
#include <atomic>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_feed.pb.h"
namespace
paddle
{
namespace
framework
{
struct
Gauc
{
int
show
,
click
;
uint64_t
fea
;
std
::
string
lineid
;
};
struct
Instance
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
feed_vec_buffer
;
std
::
vector
<
std
::
vector
<
int
>>
feed_vec_lod
;
std
::
vector
<
float
>
other_label
;
std
::
vector
<
Gauc
>
gauc_vec
;
class
MixTensor
{
public:
MixTensor
(){}
MixTensor
(
LoDTensor
*
lodtensor
)
{
is_dense_
=
false
;
lodtensor_
=
lodtensor
;
}
MixTensor
(
Tensor
*
tensor
)
{
is_dense_
=
true
;
tensor_
=
tensor
;
}
bool
IsDense
()
{
return
is_dense_
;}
LoDTensor
*
GetLoDTensor
(){
if
(
is_dense_
)
{
LOG
(
ERROR
)
<<
"error: let a dense var return a LoDTensor ptr"
;
return
NULL
;
}
return
lodtensor_
;
}
Tensor
*
GetTensor
(){
if
(
!
is_dense_
)
{
LOG
(
ERROR
)
<<
"error: let a sparse var return a Tensor ptr"
;
return
NULL
;
}
return
tensor_
;
}
private:
bool
is_dense_
;
LoDTensor
*
lodtensor_
;
Tensor
*
tensor_
;
};
class
DataFeed
{
template
<
typename
T
>
class
BlockingQueue
{
public:
DataFeed
()
:
default_batch_size_
(
1
),
batch_size_
(
0
),
thread_id_
(
0
)
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
()
=
0
;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual
bool
CheckFile
(
const
char
*
filename
)
=
0
;
virtual
bool
SetFile
(
const
char
*
filename
)
=
0
;
virtual
bool
ReadBatch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
GetAllSlotIds
()
{
return
all_slot_ids_
;
explicit
BlockingQueue
(
size_t
capacity
=
32
)
:
capacity_
(
capacity
),
closed_
(
false
)
{
size_
.
store
(
0
);
}
v
irtual
const
std
::
vector
<
uint16_t
>&
GetUseSlotIds
(
)
{
return
use_slot_ids_
;
v
oid
ReCap
(
size_t
capacity
)
{
capacity_
=
capacity
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlotAlias
()
{
return
use_slot_alias_
;
}
bool
Send
(
const
T
&
elem
)
{
int
c
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
send_mutex_
);
send_cv_
.
wait
(
lock
,
[
&
]
{
return
size_
.
load
()
<
capacity_
||
closed_
;});
if
(
closed_
)
{
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
return
false
;
}
queue_
.
push_back
(
elem
);
c
=
size_
.
load
();
size_
.
fetch_add
(
1
);
}
if
(
c
+
1
<
capacity_
)
{
send_cv_
.
notify_one
();
}
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
=
0
;
virtual
void
BindScope
(
Scope
*
scope
)
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
void
SetBufferSize
(
int
buffer_size
)
{}
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
StartOneEpoch
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
if
(
c
==
0
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
receive_mutex_
);
receive_cv_
.
notify_one
();
}
return
true
;
}
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
return
feed_vec_
;
bool
Receive
(
T
*
elem
)
{
int
c
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
receive_mutex_
);
receive_cv_
.
wait
(
lock
,
[
&
]
{
return
size_
.
load
()
!=
0
||
closed_
;});
if
(
size_
.
load
()
!=
0
)
{
*
elem
=
queue_
.
front
();
queue_
.
pop_front
();
c
=
size_
.
load
();
size_
.
fetch_sub
(
1
);
}
else
{
return
false
;
}
}
if
(
c
>
1
)
{
receive_cv_
.
notify_one
();
}
if
(
c
==
capacity_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
send_mutex_
);
send_cv_
.
notify_one
();
}
return
true
;
}
virtual
std
::
vector
<
LoDTensor
*>&
GetFeedVec
(
const
Instance
&
ins
)
{
LOG
(
ERROR
)
<<
"use defalut get_feed_vec"
;
return
feed_vec_
;
void
Close
()
{
std
::
lock_guard
<
std
::
mutex
>
lock1
(
send_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock2
(
receive_mutex_
);
closed_
=
true
;
send_cv_
.
notify_all
();
receive_cv_
.
notify_all
();
}
int
GetThreadId
()
{
return
thread_id_
;}
void
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;}
private:
size_t
capacity_
;
std
::
atomic_size_t
size_
;
bool
closed_
;
std
::
deque
<
T
>
queue_
;
mutable
std
::
mutex
send_mutex_
;
mutable
std
::
mutex
receive_mutex_
;
mutable
std
::
condition_variable
send_cv_
;
mutable
std
::
condition_variable
receive_cv_
;
};
class
DataFeed
{
public:
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
// for some datafeeds may not be able to implement this interface
virtual
bool
CheckFile
(
const
char
*
filename
)
{
LOG
(
ERROR
)
<<
"error: The function CheckFile is not implemented"
;
return
false
;
}
virtual
bool
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
);
virtual
bool
Start
()
=
0
;
virtual
bool
Next
()
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
// for subclass with queue
virtual
void
SetQueueSize
(
int
queue_size
)
{
LOG
(
ERROR
)
<<
"error: The function SetQueueSize is not implemented"
;
}
// for subclass with buffer
virtual
void
SetBufferSize
(
int
buffer_size
)
{
LOG
(
ERROR
)
<<
"error: The function SetBufferSize is not implemented"
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetAllSlots
()
{
return
all_slots_
;}
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlots
()
{
return
use_slots_
;}
std
::
vector
<
MixTensor
>&
GetFeedVec
()
{
return
feed_vec_
;}
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
protected:
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
std
::
vector
<
std
::
string
>
use_slot_alias_
;
std
::
vector
<
LoDTensor
*>
feed_vec_
;
// Check if it is executed in this order:
// Init -> SetFileList/BindingMemory -> Start -> Next
virtual
bool
CheckInit
();
virtual
bool
CheckSetFileList
();
virtual
bool
CheckStart
();
virtual
bool
PickOneFile
(
std
::
string
&
filename
);
static
std
::
vector
<
std
::
string
>
filelist_
;
static
size_t
file_idx_
;
static
std
::
mutex
mutex_for_pick_file_
;
std
::
vector
<
std
::
string
>
use_slots_
;
std
::
vector
<
bool
>
use_slots_is_dense_
;
std
::
vector
<
std
::
string
>
all_slots_
;
std
::
vector
<
std
::
string
>
all_slots_type_
;
std
::
vector
<
int
>
use_slots_index_
;
// -1: not used; >=0: the index of use_slots_
std
::
vector
<
MixTensor
>
feed_vec_
;
int
default_batch_size_
;
int
batch_size_
;
int
thread_id_
;
bool
finish_init_
;
bool
finish_set_filelist_
;
bool
finish_binding_memory_
;
bool
finish_start_
;
};
class
TextClassDataFeed
:
public
DataFeed
{
template
<
typename
T
>
class
PrivateQueueDataFeed
:
public
DataFeed
{
public:
TextClassDataFeed
();
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
);
PrivateQueueDataFeed
()
{}
virtual
~
PrivateQueueDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
Start
();
virtual
bool
Next
();
// no buffer
virtual
void
SetQueueSize
(
int
queue_size
);
protected:
virtual
void
ReadThread
();
virtual
bool
ParseOneInstance
(
T
&
instance
)
=
0
;
virtual
void
AddInstanceToInsVec
(
T
&
vec_ins
,
T
&
instance
,
int
index
)
=
0
;
virtual
void
PutToFeedVec
(
T
&
ins_vec
)
=
0
;
std
::
thread
read_thread_
;
// the thread for read files
/* using ifstream one line and one line parse is faster
* than using fread one buffer and one buffer parse.
* for 601M JingPai data:
* ifstream one line and one line parse: 6034 ms
* fread one buffer and one buffer parse: 7097 ms */
std
::
ifstream
file_
;
size_t
queue_size_
;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue
<
T
>
queue_
;
};
class
MultiSlotType
{
public:
MultiSlotType
()
{
float_feasign_
.
clear
();
uint64_feasign_
.
clear
();
offset_
.
resize
(
1
);
offset_
[
0
]
=
0
;
}
~
MultiSlotType
()
{}
void
SetType
(
std
::
string
&
type
)
{
if
(
!
CheckType
(
type
))
{
return
;}
type_
=
type
;
}
std
::
vector
<
size_t
>&
GetOffset
()
{
return
offset_
;
}
void
AddValue
(
float
v
)
{
if
(
!
CheckFloat
())
{
return
;}
float_feasign_
.
push_back
(
v
);
}
void
AddValue
(
uint64_t
v
)
{
if
(
!
CheckUint64
())
{
return
;}
uint64_feasign_
.
push_back
(
v
);
}
void
AddIns
(
MultiSlotType
&
ins
)
{
if
(
ins
.
GetType
()[
0
]
==
'f'
)
{
//float
if
(
!
CheckFloat
())
{
return
;}
auto
&
vec
=
ins
.
GetFloatData
();
offset_
.
push_back
(
offset_
.
back
()
+
vec
.
size
());
float_feasign_
.
insert
(
float_feasign_
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
else
if
(
ins
.
GetType
()[
0
]
==
'u'
)
{
//uint64
if
(
!
CheckUint64
())
{
return
;}
auto
&
vec
=
ins
.
GetUint64Data
();
offset_
.
push_back
(
offset_
.
back
()
+
vec
.
size
());
uint64_feasign_
.
insert
(
uint64_feasign_
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
}
std
::
vector
<
float
>&
GetFloatData
()
{
return
float_feasign_
;
}
std
::
vector
<
uint64_t
>&
GetUint64Data
()
{
return
uint64_feasign_
;
}
std
::
string
&
GetType
()
{
return
type_
;
}
private:
bool
CheckType
(
std
::
string
&
type
)
{
if
(
type
!=
"uint64"
&&
type
!=
"float"
)
{
// check in here
LOG
(
ERROR
)
<<
"error: here is no this type"
;
return
false
;
}
return
true
;
}
bool
CheckFloat
()
{
if
(
type_
[
0
]
!=
'f'
)
{
//float
LOG
(
ERROR
)
<<
"error: add "
<<
type_
<<
" value to float slot"
;
return
false
;
}
return
true
;
}
bool
CheckUint64
()
{
if
(
type_
[
0
]
!=
'u'
)
{
//uint64
LOG
(
ERROR
)
<<
"error: add "
<<
type_
<<
" value to uint64 slot"
;
return
false
;
}
return
true
;
}
std
::
vector
<
float
>
float_feasign_
;
std
::
vector
<
uint64_t
>
uint64_feasign_
;
std
::
string
type_
;
std
::
vector
<
size_t
>
offset_
;
};
class
MultiSlotDataFeed
:
public
PrivateQueueDataFeed
<
std
::
vector
<
MultiSlotType
>>
{
public:
MultiSlotDataFeed
()
{}
virtual
~
MultiSlotDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
);
//TODO: virtual bool CheckFile();
protected:
virtual
void
AddInstanceToInsVec
(
std
::
vector
<
MultiSlotType
>&
vec_ins
,
std
::
vector
<
MultiSlotType
>&
instance
,
int
index
);
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>&
instance
);
virtual
void
PutToFeedVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
);
};
//TODO: to be deleted
class
TextClassDataFeed
:
public
DataFeed
{
public:
virtual
~
TextClassDataFeed
()
{}
virtual
void
Init
();
virtual
bool
ReadBatch
();
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
{}
virtual
bool
Start
()
{
return
false
;};
//TODO
virtual
bool
Next
()
{
return
false
;};
//TODO
virtual
bool
ReadBatch
()
{
return
false
;}
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{}
virtual
void
BindScope
(
Scope
*
scope
)
{}
virtual
bool
SetFile
(
const
char
*
filename
);
virtual
bool
SetFile
(
const
char
*
filename
)
{
return
false
;}
virtual
bool
CheckFile
(
const
char
*
filename
)
{
// TODO(xxx)
return
false
;
}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
void
UpdateEpochNum
();
void
StartOneEpoch
();
void
WaitNextEpoch
();
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
public:
static
void
SetFileList
(
const
char
*
filelist
);
private:
const
char
*
PickOneFile
();
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
private:
int
ReadWholeFile
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
return
-
1
;}
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
field_
names_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
static
std
::
vector
<
std
::
string
>
s_filelist_
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
mutex
s_locker_epoch_start_
;
static
std
::
condition_variable
s_condition_epoch_start_
;
static
bool
s_epoch_start_flag_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.proto
浏览文件 @
92a98ca7
...
...
@@ -17,5 +17,16 @@ package paddle;
message
DataFeedDesc
{
optional
string
name
=
1
;
optional
int32
batch
=
2
[
default
=
32
];
optional
MultiSlotDesc
multi_slot_desc
=
3
;
}
message
MultiSlotDesc
{
repeated
Slot
slots
=
1
;
}
message
Slot
{
required
string
name
=
1
;
required
string
type
=
2
;
optional
bool
dense
=
3
[
default
=
false
];
optional
bool
use
=
4
[
default
=
true
];
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录