Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
92a98ca7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92a98ca7
编写于
11月 19, 2018
作者:
B
barrierye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add MultiSlotDataFeed
上级
78c3380b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
491 addition
and
261 deletion
+491
-261
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+197
-177
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+283
-84
paddle/fluid/framework/data_feed.proto
paddle/fluid/framework/data_feed.proto
+11
-0
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
92a98ca7
...
@@ -34,221 +34,241 @@ limitations under the License. */
...
@@ -34,221 +34,241 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.h"
DEFINE_bool
(
is_text_feed
,
false
,
"is_text_feed"
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClassDataFeed
::
s_filelist_
;
std
::
mutex
TextClassDataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClassDataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClassDataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClassDataFeed
::
s_current_epoch_
=
0
;
int
TextClassDataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClassDataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClassDataFeed
::
s_condition_epoch_start_
;
bool
TextClassDataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClassDataFeed
::
Init
()
{
std
::
vector
<
std
::
string
>
DataFeed
::
filelist_
;
// hard coding for a specific datafeed
size_t
DataFeed
::
file_idx_
;
feed_vec_
.
resize
(
2
);
std
::
mutex
DataFeed
::
mutex_for_pick_file_
;
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
all_slot_ids_
=
{
0
,
1
};
if
(
CheckInit
()
==
false
)
{
return
;}
use_slot_ids_
=
{
0
,
1
};
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
use_slot_alias_
=
{
"words"
,
"label"
};
if
(
name
==
use_slots_
[
i
])
{
if
(
use_slots_is_dense_
[
i
])
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
Tensor
>
());
}
else
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
}
}
file_content_buffer_host_
.
reset
(
new
char
[
200
*
1024
*
1024
],
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
[](
char
*
p
)
{
delete
[]
p
;});
if
(
CheckInit
()
==
false
)
{
return
false
;}
file_content_buffer_
=
file_content_buffer_host_
.
get
();
if
(
files
.
size
()
==
0
)
{
file_content_buffer_ptr_
=
file_content_buffer_
;
LOG
(
ERROR
)
<<
"error: you have set an empty filelist"
;
return
false
;
}
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
file_idx_
=
0
;
batch_id_host_
.
reset
(
new
int
[
10240
*
1024
],
finish_set_filelist_
=
true
;
[](
int
*
p
)
{
delete
[]
p
;});
// max word num in a batch
return
true
;
batch_id_buffer_
=
batch_id_host_
.
get
();
}
label_host_
.
reset
(
new
int
[
10240
],
bool
DataFeed
::
PickOneFile
(
std
::
string
&
filename
)
{
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
label_ptr_
=
label_host_
.
get
();
if
(
file_idx_
==
filelist_
.
size
())
{
return
false
;
}
filename
=
filelist_
[
file_idx_
++
];
return
true
;
}
field_names_
.
clear
();
bool
DataFeed
::
CheckInit
()
{
if
(
finish_init_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: initialization did not succeed"
;
return
false
;
}
}
TextClassDataFeed
::
TextClassDataFeed
()
{
bool
DataFeed
::
CheckSetFileList
()
{
Init
();
if
(
finish_set_filelist_
)
{
return
true
;}
LOG
(
ERROR
)
<<
"error: set filelist did not succeed"
;
return
false
;
}
}
// todo: use elegant implemention for this function
bool
DataFeed
::
CheckStart
()
{
bool
TextClassDataFeed
::
ReadBatch
()
{
if
(
finish_start_
)
{
return
true
;}
paddle
::
framework
::
Vector
<
size_t
>
offset
;
LOG
(
ERROR
)
<<
"error: Datafeed has not started running yet"
;
int
tlen
=
0
;
return
false
;
int
llen
=
0
;
}
int
inst_idx
=
0
;
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
batch_size_
)
{
template
<
typename
T
>
int
ptr_offset
=
0
;
void
PrivateQueueDataFeed
<
T
>::
SetQueueSize
(
int
queue_size
)
{
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
if
(
!
CheckInit
())
{
return
;}
break
;
if
(
queue_size
<=
0
)
{
LOG
(
ERROR
)
<<
"error: illegal queue size: "
<<
queue_size
;
return
;
}
}
queue_size_
=
queue_size
;
queue_
.
ReCap
(
queue_size_
);
}
memcpy
(
reinterpret_cast
<
char
*>
(
&
llen
),
template
<
typename
T
>
file_content_buffer_ptr_
+
ptr_offset
,
bool
PrivateQueueDataFeed
<
T
>::
Start
()
{
sizeof
(
int
));
if
(
!
(
CheckSetFileList
()))
{
return
false
;}
ptr_offset
+=
sizeof
(
int
);
read_thread_
=
std
::
thread
(
&
PrivateQueueDataFeed
::
ReadThread
,
this
);
read_thread_
.
detach
();
memcpy
(
reinterpret_cast
<
char
*>
(
batch_id_buffer_
+
tlen
),
file_content_buffer_ptr_
+
ptr_offset
,
llen
*
sizeof
(
int
));
tlen
+=
llen
;
offset
[
inst_idx
+
1
]
=
offset
[
inst_idx
]
+
llen
;
ptr_offset
+=
sizeof
(
int
)
*
llen
;
memcpy
(
reinterpret_cast
<
char
*>
(
label_ptr_
+
inst_idx
),
finish_start_
=
true
;
file_content_buffer_ptr_
+
ptr_offset
,
return
true
;
sizeof
(
int
));
}
ptr_offset
+=
sizeof
(
int
);
file_content_buffer_ptr_
+=
ptr_offset
;
template
<
typename
T
>
inst_idx
++
;
void
PrivateQueueDataFeed
<
T
>::
ReadThread
(){
std
::
string
filename
;
while
(
PickOneFile
(
filename
))
{
file_
.
open
(
filename
.
c_str
());
// is_text_feed
if
(
!
file_
.
is_open
())
{
LOG
(
ERROR
)
<<
"error: open file<"
<<
filename
<<
"> fail"
;
}
}
T
instance
;
if
(
inst_idx
!=
batch_size_
)
{
while
(
ParseOneInstance
(
instance
)
)
{
return
false
;
queue_
.
Send
(
instance
)
;
}
}
file_
.
close
();
LoD
input_lod
{
offset
};
paddle
::
framework
::
Vector
<
size_t
>
label_offset
;
label_offset
.
resize
(
batch_size_
+
1
);
for
(
int
i
=
0
;
i
<=
batch_size_
;
++
i
)
{
label_offset
[
i
]
=
i
;
}
}
queue_
.
Close
();
}
LoD
label_lod
{
label_offset
};
template
<
typename
T
>
int64_t
*
input_ptr
=
feed_vec_
[
0
]
->
mutable_data
<
int64_t
>
(
bool
PrivateQueueDataFeed
<
T
>::
Next
(){
{
static_cast
<
int64_t
>
(
offset
.
back
()),
1
},
if
(
!
CheckStart
())
{
return
false
;}
platform
::
CPUPlace
());
int
index
=
0
;
int64_t
*
label_ptr
=
feed_vec_
[
1
]
->
mutable_data
<
int64_t
>
({
batch_size_
,
1
},
T
instance
;
platform
::
CPUPlace
());
T
ins_vec
(
use_slots_
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
offset
.
back
();
++
i
)
{
while
(
index
<
default_batch_size_
)
{
input_ptr
[
i
]
=
static_cast
<
int64_t
>
(
batch_id_buffer_
[
i
]);
if
(
!
queue_
.
Receive
(
&
instance
))
{
break
;
}
}
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
AddInstanceToInsVec
(
ins_vec
,
instance
,
index
++
);
label_ptr
[
i
]
=
static_cast
<
int64_t
>
(
label_ptr_
[
i
]);
}
}
feed_vec_
[
0
]
->
set_lod
(
input_lod
)
;
batch_size_
=
index
;
feed_vec_
[
1
]
->
set_lod
(
label_lod
);
PutToFeedVec
(
ins_vec
);
return
true
;
return
batch_size_
!=
0
;
}
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
)
{
void
MultiSlotDataFeed
::
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
{
Init
();
finish_init_
=
false
;
SetBatchSize
(
data_feed
.
batch_size_
);
finish_set_filelist_
=
false
;
SetFieldNames
(
data_feed
.
field_names_
);
finish_start_
=
false
;
}
if
(
!
data_feed_desc
.
has_multi_slot_desc
()){
LOG
(
ERROR
)
<<
"error: multi_slot_desc has not been set"
;
void
TextClassDataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
return
;
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
}
}
paddle
::
MultiSlotDesc
multi_slot_desc
=
data_feed_desc
.
multi_slot_desc
();
size_t
all_slot_num
=
multi_slot_desc
.
slots_size
();
all_slots_
.
resize
(
all_slot_num
);
all_slots_type_
.
resize
(
all_slot_num
);
use_slots_index_
.
resize
(
all_slot_num
);
use_slots_
.
clear
();
use_slots_is_dense_
.
clear
();
for
(
size_t
i
=
0
;
i
<
all_slot_num
;
++
i
)
{
auto
&
slot
=
multi_slot_desc
.
slots
(
i
);
all_slots_
[
i
]
=
slot
.
name
();
all_slots_type_
[
i
]
=
slot
.
type
();
use_slots_index_
[
i
]
=
slot
.
use
()
?
use_slots_
.
size
()
:
-
1
;
if
(
slot
.
use
())
{
use_slots_
.
push_back
(
all_slots_
[
i
]);
use_slots_is_dense_
.
push_back
(
slot
.
dense
());
}
}
}
void
TextClassDataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
"Opening file %s fail"
,
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
s_filelist_
.
push_back
(
filename
);
}
}
fin
.
close
();
feed_vec_
.
resize
(
use_slots_
.
size
());
}
void
TextClassDataFeed
::
SetFieldNames
(
finish_init_
=
true
;
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
}
}
bool
TextClassDataFeed
::
SetFile
(
const
char
*
filename
)
{
bool
MultiSlotDataFeed
::
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>&
instance
)
{
// termnum termid termid ... termid label
std
::
string
line
;
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
if
(
getline
(
file_
,
line
))
{
if
(
ifs
.
fail
())
{
int
use_slots_num
=
use_slots_
.
size
();
return
false
;
instance
.
resize
(
use_slots_num
);
//parse line
const
char
*
str
=
line
.
c_str
();
char
*
endptr
=
(
char
*
)
str
;
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
int
idx
=
use_slots_index_
[
i
];
int
num
=
(
int
)
strtol
(
&
str
[
pos
],
&
endptr
,
10
);
if
(
num
==
0
)
{
LOG
(
ERROR
)
<<
"error: the number of ids can not be zero, you need padding it"
;
exit
(
-
1
);
}
}
if
(
idx
!=
-
1
)
{
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
instance
[
idx
].
SetType
(
all_slots_type_
[
i
]);
int
filesize
=
ifs
.
tellg
();
if
(
instance
[
idx
].
GetType
()[
0
]
==
'f'
)
{
// float
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
ifs
.
read
(
file_content_buffer_
,
filesize
);
float
feasign
=
(
float
)
strtof
(
endptr
,
&
endptr
);
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
instance
[
idx
].
AddValue
(
feasign
);
}
}
else
if
(
instance
[
idx
].
GetType
()[
0
]
==
'u'
){
// uint64
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
uint64_t
feasign
=
(
uint64_t
)
strtoull
(
endptr
,
&
endptr
,
10
);
instance
[
idx
].
AddValue
(
feasign
);
}
}
pos
=
endptr
-
str
;
}
else
{
for
(
int
j
=
0
;
j
<=
num
;
++
j
)
{
pos
=
line
.
find_first_of
(
' '
,
pos
+
1
);
}
}
}
}
else
{
return
false
;
return
false
;
}
}
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
// todo , remove magic number
return
true
;
return
true
;
}
}
void
TextClassDataFeed
::
UpdateEpochNum
()
{
void
MultiSlotDataFeed
::
AddInstanceToInsVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
,
s_current_finished_file_cnt_
++
;
std
::
vector
<
MultiSlotType
>&
instance
,
int
index
)
{
if
(
index
==
0
)
{
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
for
(
size_t
i
=
0
;
i
<
instance
.
size
();
++
i
)
{
s_current_finished_file_cnt_
=
0
;
ins_vec
[
i
].
SetType
(
instance
[
i
].
GetType
());
s_current_epoch_
++
;
#if 1
LOG
(
WARNING
)
<<
"UpdateEpochNum: epoch = "
<<
s_current_epoch_
;
#endif
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
false
;
}
}
}
}
}
for
(
size_t
i
=
0
;
i
<
instance
.
size
();
++
i
){
ins_vec
[
i
].
AddIns
(
instance
[
i
]);
void
TextClassDataFeed
::
StartOneEpoch
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
s_current_file_idx_
=
0
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
}
}
s_condition_epoch_start_
.
notify_all
();
}
}
void
MultiSlotDataFeed
::
PutToFeedVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
)
{
void
TextClassDataFeed
::
WaitNextEpoch
()
{
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
auto
&
type
=
ins_vec
[
i
].
GetType
();
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
auto
&
offset
=
ins_vec
[
i
].
GetOffset
();
}
int
total_instance
=
static_cast
<
int
>
(
offset
.
back
());
if
(
type
[
0
]
==
'f'
)
{
// float
const
char
*
TextClassDataFeed
::
PickOneFile
()
{
auto
&
feasign
=
ins_vec
[
i
].
GetFloatData
();
std
::
string
file_to_be_processed
;
if
(
feed_vec_
[
i
].
IsDense
())
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
int
size_in_each_batch
=
total_instance
/
batch_size_
;
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
// One epoch has run over
mutable_data
<
float
>
({
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
// Wait for next epoch
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
if
(
s_current_file_idx_
>=
s_filelist_
.
size
())
{
}
else
{
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
mutable_data
<
float
>
({
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
return
NULL
;
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
}
else
if
(
type
[
0
]
==
'u'
)
{
// uint64
// no uint64_t type
auto
&
feasign
=
ins_vec
[
i
].
GetUint64Data
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
int64_t
>
({
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
int64_t
));
}
else
{
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
int64_t
>
({
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
uint64_t
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
}
}
}
file_to_be_processed
=
s_filelist_
[
s_current_file_idx_
];
s_current_file_idx_
++
;
return
file_to_be_processed
.
c_str
();
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
92a98ca7
...
@@ -27,136 +27,335 @@ limitations under the License. */
...
@@ -27,136 +27,335 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include <condition_variable> // NOLINT
#include <condition_variable> // NOLINT
#include <fstream>
#include <fstream>
#include <deque>
#include <atomic>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_feed.pb.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
struct
Gauc
{
int
show
,
click
;
uint64_t
fea
;
std
::
string
lineid
;
};
struct
Instance
{
class
MixTensor
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
feed_vec_buffer
;
public:
std
::
vector
<
std
::
vector
<
int
>>
feed_vec_lod
;
MixTensor
(){}
std
::
vector
<
float
>
other_label
;
MixTensor
(
LoDTensor
*
lodtensor
)
{
std
::
vector
<
Gauc
>
gauc_vec
;
is_dense_
=
false
;
lodtensor_
=
lodtensor
;
}
MixTensor
(
Tensor
*
tensor
)
{
is_dense_
=
true
;
tensor_
=
tensor
;
}
bool
IsDense
()
{
return
is_dense_
;}
LoDTensor
*
GetLoDTensor
(){
if
(
is_dense_
)
{
LOG
(
ERROR
)
<<
"error: let a dense var return a LoDTensor ptr"
;
return
NULL
;
}
return
lodtensor_
;
}
Tensor
*
GetTensor
(){
if
(
!
is_dense_
)
{
LOG
(
ERROR
)
<<
"error: let a sparse var return a Tensor ptr"
;
return
NULL
;
}
return
tensor_
;
}
private:
bool
is_dense_
;
LoDTensor
*
lodtensor_
;
Tensor
*
tensor_
;
};
};
class
DataFeed
{
template
<
typename
T
>
class
BlockingQueue
{
public:
public:
DataFeed
()
:
default_batch_size_
(
1
),
batch_size_
(
0
),
thread_id_
(
0
)
{}
explicit
BlockingQueue
(
size_t
capacity
=
32
)
virtual
~
DataFeed
()
{}
:
capacity_
(
capacity
),
closed_
(
false
)
{
virtual
void
Init
()
=
0
;
size_
.
store
(
0
);
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual
bool
CheckFile
(
const
char
*
filename
)
=
0
;
virtual
bool
SetFile
(
const
char
*
filename
)
=
0
;
virtual
bool
ReadBatch
()
=
0
;
virtual
const
std
::
vector
<
uint16_t
>&
GetAllSlotIds
()
{
return
all_slot_ids_
;
}
}
v
irtual
const
std
::
vector
<
uint16_t
>&
GetUseSlotIds
(
)
{
v
oid
ReCap
(
size_t
capacity
)
{
return
use_slot_ids_
;
capacity_
=
capacity
;
}
}
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlotAlias
()
{
bool
Send
(
const
T
&
elem
)
{
return
use_slot_alias_
;
int
c
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
send_mutex_
);
send_cv_
.
wait
(
lock
,
[
&
]
{
return
size_
.
load
()
<
capacity_
||
closed_
;});
if
(
closed_
)
{
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
return
false
;
}
queue_
.
push_back
(
elem
);
c
=
size_
.
load
();
size_
.
fetch_add
(
1
);
}
if
(
c
+
1
<
capacity_
)
{
send_cv_
.
notify_one
();
}
}
virtual
void
AddFeedVar
(
Variable
*
var
,
if
(
c
==
0
)
{
const
std
::
string
&
name
)
=
0
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
receive_mutex_
);
virtual
void
BindScope
(
Scope
*
scope
)
=
0
;
receive_cv_
.
notify_one
();
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
return
true
;
virtual
void
SetBufferSize
(
int
buffer_size
)
{}
}
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
StartOneEpoch
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
bool
Receive
(
T
*
elem
)
{
return
feed_vec_
;
int
c
=
-
1
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
receive_mutex_
);
receive_cv_
.
wait
(
lock
,
[
&
]
{
return
size_
.
load
()
!=
0
||
closed_
;});
if
(
size_
.
load
()
!=
0
)
{
*
elem
=
queue_
.
front
();
queue_
.
pop_front
();
c
=
size_
.
load
();
size_
.
fetch_sub
(
1
);
}
else
{
return
false
;
}
}
if
(
c
>
1
)
{
receive_cv_
.
notify_one
();
}
if
(
c
==
capacity_
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
send_mutex_
);
send_cv_
.
notify_one
();
}
return
true
;
}
}
virtual
std
::
vector
<
LoDTensor
*>&
GetFeedVec
(
const
Instance
&
ins
)
{
void
Close
()
{
LOG
(
ERROR
)
<<
"use defalut get_feed_vec"
;
std
::
lock_guard
<
std
::
mutex
>
lock1
(
send_mutex_
);
return
feed_vec_
;
std
::
lock_guard
<
std
::
mutex
>
lock2
(
receive_mutex_
);
closed_
=
true
;
send_cv_
.
notify_all
();
receive_cv_
.
notify_all
();
}
}
int
GetThreadId
()
{
return
thread_id_
;}
private:
void
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;}
size_t
capacity_
;
std
::
atomic_size_t
size_
;
bool
closed_
;
std
::
deque
<
T
>
queue_
;
mutable
std
::
mutex
send_mutex_
;
mutable
std
::
mutex
receive_mutex_
;
mutable
std
::
condition_variable
send_cv_
;
mutable
std
::
condition_variable
receive_cv_
;
};
class
DataFeed
{
public:
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
// for some datafeeds may not be able to implement this interface
virtual
bool
CheckFile
(
const
char
*
filename
)
{
LOG
(
ERROR
)
<<
"error: The function CheckFile is not implemented"
;
return
false
;
}
virtual
bool
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
);
virtual
bool
Start
()
=
0
;
virtual
bool
Next
()
=
0
;
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
// for subclass with queue
virtual
void
SetQueueSize
(
int
queue_size
)
{
LOG
(
ERROR
)
<<
"error: The function SetQueueSize is not implemented"
;
}
// for subclass with buffer
virtual
void
SetBufferSize
(
int
buffer_size
)
{
LOG
(
ERROR
)
<<
"error: The function SetBufferSize is not implemented"
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetAllSlots
()
{
return
all_slots_
;}
virtual
const
std
::
vector
<
std
::
string
>&
GetUseSlots
()
{
return
use_slots_
;}
std
::
vector
<
MixTensor
>&
GetFeedVec
()
{
return
feed_vec_
;}
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
protected:
protected:
std
::
vector
<
uint16_t
>
all_slot_ids_
;
// Check if it is executed in this order:
std
::
vector
<
uint16_t
>
use_slot_ids_
;
// Init -> SetFileList/BindingMemory -> Start -> Next
std
::
vector
<
std
::
string
>
use_slot_alias_
;
virtual
bool
CheckInit
();
std
::
vector
<
LoDTensor
*>
feed_vec_
;
virtual
bool
CheckSetFileList
();
virtual
bool
CheckStart
();
virtual
bool
PickOneFile
(
std
::
string
&
filename
);
static
std
::
vector
<
std
::
string
>
filelist_
;
static
size_t
file_idx_
;
static
std
::
mutex
mutex_for_pick_file_
;
std
::
vector
<
std
::
string
>
use_slots_
;
std
::
vector
<
bool
>
use_slots_is_dense_
;
std
::
vector
<
std
::
string
>
all_slots_
;
std
::
vector
<
std
::
string
>
all_slots_type_
;
std
::
vector
<
int
>
use_slots_index_
;
// -1: not used; >=0: the index of use_slots_
std
::
vector
<
MixTensor
>
feed_vec_
;
int
default_batch_size_
;
int
default_batch_size_
;
int
batch_size_
;
int
batch_size_
;
int
thread_id_
;
bool
finish_init_
;
bool
finish_set_filelist_
;
bool
finish_binding_memory_
;
bool
finish_start_
;
};
};
class
TextClassDataFeed
:
public
DataFeed
{
template
<
typename
T
>
class
PrivateQueueDataFeed
:
public
DataFeed
{
public:
public:
TextClassDataFeed
();
PrivateQueueDataFeed
()
{}
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
);
virtual
~
PrivateQueueDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
Start
();
virtual
bool
Next
();
// no buffer
virtual
void
SetQueueSize
(
int
queue_size
);
protected:
virtual
void
ReadThread
();
virtual
bool
ParseOneInstance
(
T
&
instance
)
=
0
;
virtual
void
AddInstanceToInsVec
(
T
&
vec_ins
,
T
&
instance
,
int
index
)
=
0
;
virtual
void
PutToFeedVec
(
T
&
ins_vec
)
=
0
;
std
::
thread
read_thread_
;
// the thread for read files
/* using ifstream one line and one line parse is faster
* than using fread one buffer and one buffer parse.
* for 601M JingPai data:
* ifstream one line and one line parse: 6034 ms
* fread one buffer and one buffer parse: 7097 ms */
std
::
ifstream
file_
;
size_t
queue_size_
;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue
<
T
>
queue_
;
};
class
MultiSlotType
{
public:
public:
virtual
~
TextClassDataFeed
()
{}
MultiSlotType
()
{
virtual
void
Init
();
float_feasign_
.
clear
();
virtual
bool
ReadBatch
();
uint64_feasign_
.
clear
();
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
offset_
.
resize
(
1
);
virtual
void
BindScope
(
Scope
*
scope
)
{}
offset_
[
0
]
=
0
;
virtual
bool
SetFile
(
const
char
*
filename
);
}
virtual
bool
CheckFile
(
const
char
*
filename
)
{
~
MultiSlotType
()
{}
// TODO(xxx)
void
SetType
(
std
::
string
&
type
)
{
if
(
!
CheckType
(
type
))
{
return
;}
type_
=
type
;
}
std
::
vector
<
size_t
>&
GetOffset
()
{
return
offset_
;
}
void
AddValue
(
float
v
)
{
if
(
!
CheckFloat
())
{
return
;}
float_feasign_
.
push_back
(
v
);
}
void
AddValue
(
uint64_t
v
)
{
if
(
!
CheckUint64
())
{
return
;}
uint64_feasign_
.
push_back
(
v
);
}
void
AddIns
(
MultiSlotType
&
ins
)
{
if
(
ins
.
GetType
()[
0
]
==
'f'
)
{
//float
if
(
!
CheckFloat
())
{
return
;}
auto
&
vec
=
ins
.
GetFloatData
();
offset_
.
push_back
(
offset_
.
back
()
+
vec
.
size
());
float_feasign_
.
insert
(
float_feasign_
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
else
if
(
ins
.
GetType
()[
0
]
==
'u'
)
{
//uint64
if
(
!
CheckUint64
())
{
return
;}
auto
&
vec
=
ins
.
GetUint64Data
();
offset_
.
push_back
(
offset_
.
back
()
+
vec
.
size
());
uint64_feasign_
.
insert
(
uint64_feasign_
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
}
std
::
vector
<
float
>&
GetFloatData
()
{
return
float_feasign_
;
}
std
::
vector
<
uint64_t
>&
GetUint64Data
()
{
return
uint64_feasign_
;
}
std
::
string
&
GetType
()
{
return
type_
;
}
private:
bool
CheckType
(
std
::
string
&
type
)
{
if
(
type
!=
"uint64"
&&
type
!=
"float"
)
{
// check in here
LOG
(
ERROR
)
<<
"error: here is no this type"
;
return
false
;
return
false
;
}
}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
return
true
;
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
}
void
UpdateEpochNum
();
bool
CheckFloat
()
{
void
StartOneEpoch
();
if
(
type_
[
0
]
!=
'f'
)
{
//float
void
WaitNextEpoch
();
LOG
(
ERROR
)
<<
"error: add "
<<
type_
<<
" value to float slot"
;
return
false
;
}
return
true
;
}
bool
CheckUint64
()
{
if
(
type_
[
0
]
!=
'u'
)
{
//uint64
LOG
(
ERROR
)
<<
"error: add "
<<
type_
<<
" value to uint64 slot"
;
return
false
;
}
return
true
;
}
std
::
vector
<
float
>
float_feasign_
;
std
::
vector
<
uint64_t
>
uint64_feasign_
;
std
::
string
type_
;
std
::
vector
<
size_t
>
offset_
;
};
class
MultiSlotDataFeed
:
public
PrivateQueueDataFeed
<
std
::
vector
<
MultiSlotType
>>
{
public:
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
MultiSlotDataFeed
()
{}
virtual
~
MultiSlotDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
);
//TODO: virtual bool CheckFile();
protected:
virtual
void
AddInstanceToInsVec
(
std
::
vector
<
MultiSlotType
>&
vec_ins
,
std
::
vector
<
MultiSlotType
>&
instance
,
int
index
);
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>&
instance
);
virtual
void
PutToFeedVec
(
std
::
vector
<
MultiSlotType
>&
ins_vec
);
};
//TODO: to be deleted
class
TextClassDataFeed
:
public
DataFeed
{
public:
public:
static
void
SetFileList
(
const
char
*
filelist
);
virtual
~
TextClassDataFeed
()
{}
virtual
void
Init
(
paddle
::
DataFeedDesc
&
data_feed_desc
)
{}
virtual
bool
Start
()
{
return
false
;};
//TODO
virtual
bool
Next
()
{
return
false
;};
//TODO
virtual
bool
ReadBatch
()
{
return
false
;}
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{}
virtual
void
BindScope
(
Scope
*
scope
)
{}
virtual
bool
SetFile
(
const
char
*
filename
)
{
return
false
;}
private:
virtual
bool
CheckFile
(
const
char
*
filename
)
{
const
char
*
PickOneFile
();
// TODO(xxx)
return
false
;
}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
private:
private:
int
ReadWholeFile
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
return
-
1
;}
char
*
file_content_buffer_
;
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
*
label_ptr_
;
int
file_size_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
field_
names_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
static
std
::
vector
<
std
::
string
>
s_filelist_
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
mutex
s_locker_epoch_start_
;
static
std
::
condition_variable
s_condition_epoch_start_
;
static
bool
s_epoch_start_flag_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.proto
浏览文件 @
92a98ca7
...
@@ -17,5 +17,16 @@ package paddle;
...
@@ -17,5 +17,16 @@ package paddle;
message
DataFeedDesc
{
message
DataFeedDesc
{
optional
string
name
=
1
;
optional
string
name
=
1
;
optional
int32
batch
=
2
[
default
=
32
];
optional
int32
batch
=
2
[
default
=
32
];
optional
MultiSlotDesc
multi_slot_desc
=
3
;
}
}
message
MultiSlotDesc
{
repeated
Slot
slots
=
1
;
}
message
Slot
{
required
string
name
=
1
;
required
string
type
=
2
;
optional
bool
dense
=
3
[
default
=
false
];
optional
bool
use
=
4
[
default
=
true
];
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录