Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3fbb6644
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3fbb6644
编写于
10月 26, 2021
作者:
Y
yaoxuefeng
提交者:
GitHub
10月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add slot record dataset (#36200) (#36710)
上级
beb920cd
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
622 addition
and
46 deletion
+622
-46
paddle/fluid/framework/channel.h
paddle/fluid/framework/channel.h
+18
-2
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+103
-9
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+308
-9
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+147
-19
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+36
-4
paddle/fluid/framework/dataset_factory.cc
paddle/fluid/framework/dataset_factory.cc
+2
-1
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+8
-0
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+0
-2
未找到文件。
paddle/fluid/framework/channel.h
浏览文件 @
3fbb6644
...
...
@@ -157,7 +157,19 @@ class ChannelObject {
p
.
resize
(
finished
);
return
finished
;
}
// read once only
size_t
ReadOnce
(
std
::
vector
<
T
>&
p
,
size_t
size
)
{
// NOLINT
if
(
size
==
0
)
{
return
0
;
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
p
.
resize
(
size
);
size_t
finished
=
Read
(
size
,
&
p
[
0
],
lock
,
true
);
p
.
resize
(
finished
);
Notify
();
return
finished
;
}
size_t
ReadAll
(
std
::
vector
<
T
>&
p
)
{
// NOLINT
p
.
clear
();
size_t
finished
=
0
;
...
...
@@ -241,17 +253,21 @@ class ChannelObject {
return
!
closed_
;
}
size_t
Read
(
size_t
n
,
T
*
p
,
std
::
unique_lock
<
std
::
mutex
>&
lock
)
{
// NOLINT
size_t
Read
(
size_t
n
,
T
*
p
,
std
::
unique_lock
<
std
::
mutex
>&
lock
,
// NOLINT
bool
once
=
false
)
{
// NOLINT
size_t
finished
=
0
;
CHECK
(
n
<=
MaxCapacity
()
-
reading_count_
);
reading_count_
+=
n
;
while
(
finished
<
n
&&
WaitForRead
(
lock
))
{
size_t
m
=
std
::
min
(
n
-
finished
,
data_
.
size
());
size_t
m
=
(
std
::
min
)
(
n
-
finished
,
data_
.
size
());
for
(
size_t
i
=
0
;
i
<
m
;
i
++
)
{
p
[
finished
++
]
=
std
::
move
(
data_
.
front
());
data_
.
pop_front
();
}
reading_count_
-=
m
;
if
(
once
&&
m
>
0
)
{
break
;
}
}
reading_count_
-=
n
-
finished
;
return
finished
;
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
3fbb6644
...
...
@@ -36,6 +36,107 @@ DLManager& global_dlmanager_pool() {
return
manager
;
}
class
BufferedLineFileReader
{
typedef
std
::
function
<
bool
()
>
SampleFunc
;
static
const
int
MAX_FILE_BUFF_SIZE
=
4
*
1024
*
1024
;
class
FILEReader
{
public:
explicit
FILEReader
(
FILE
*
fp
)
:
fp_
(
fp
)
{}
int
read
(
char
*
buf
,
int
len
)
{
return
fread
(
buf
,
sizeof
(
char
),
len
,
fp_
);
}
private:
FILE
*
fp_
;
};
public:
typedef
std
::
function
<
bool
(
const
std
::
string
&
)
>
LineFunc
;
private:
template
<
typename
T
>
int
read_lines
(
T
*
reader
,
LineFunc
func
,
int
skip_lines
)
{
int
lines
=
0
;
size_t
ret
=
0
;
char
*
ptr
=
NULL
;
char
*
eol
=
NULL
;
total_len_
=
0
;
error_line_
=
0
;
SampleFunc
spfunc
=
get_sample_func
();
std
::
string
x
;
while
(
!
is_error
()
&&
(
ret
=
reader
->
read
(
buff_
,
MAX_FILE_BUFF_SIZE
))
>
0
)
{
total_len_
+=
ret
;
ptr
=
buff_
;
eol
=
reinterpret_cast
<
char
*>
(
memchr
(
ptr
,
'\n'
,
ret
));
while
(
eol
!=
NULL
)
{
int
size
=
static_cast
<
int
>
((
eol
-
ptr
)
+
1
);
x
.
append
(
ptr
,
size
-
1
);
++
lines
;
if
(
lines
>
skip_lines
&&
spfunc
())
{
if
(
!
func
(
x
))
{
++
error_line_
;
}
}
x
.
clear
();
ptr
+=
size
;
ret
-=
size
;
eol
=
reinterpret_cast
<
char
*>
(
memchr
(
ptr
,
'\n'
,
ret
));
}
if
(
ret
>
0
)
{
x
.
append
(
ptr
,
ret
);
}
}
if
(
!
is_error
()
&&
!
x
.
empty
())
{
++
lines
;
if
(
lines
>
skip_lines
&&
spfunc
())
{
if
(
!
func
(
x
))
{
++
error_line_
;
}
}
}
return
lines
;
}
public:
BufferedLineFileReader
()
:
random_engine_
(
std
::
random_device
()()),
uniform_distribution_
(
0.0
f
,
1.0
f
)
{
total_len_
=
0
;
sample_line_
=
0
;
buff_
=
reinterpret_cast
<
char
*>
(
calloc
(
MAX_FILE_BUFF_SIZE
+
1
,
sizeof
(
char
)));
}
~
BufferedLineFileReader
()
{
free
(
buff_
);
}
int
read_file
(
FILE
*
fp
,
LineFunc
func
,
int
skip_lines
)
{
FILEReader
reader
(
fp
);
return
read_lines
<
FILEReader
>
(
&
reader
,
func
,
skip_lines
);
}
uint64_t
file_size
(
void
)
{
return
total_len_
;
}
void
set_sample_rate
(
float
r
)
{
sample_rate_
=
r
;
}
size_t
get_sample_line
()
{
return
sample_line_
;
}
bool
is_error
(
void
)
{
return
(
error_line_
>
10
);
}
private:
SampleFunc
get_sample_func
()
{
if
(
std
::
abs
(
sample_rate_
-
1.0
f
)
<
1e-5
f
)
{
return
[
this
](
void
)
{
return
true
;
};
}
return
[
this
](
void
)
{
return
(
uniform_distribution_
(
random_engine_
)
<
sample_rate_
);
};
}
private:
char
*
buff_
=
nullptr
;
uint64_t
total_len_
=
0
;
std
::
default_random_engine
random_engine_
;
std
::
uniform_real_distribution
<
float
>
uniform_distribution_
;
float
sample_rate_
=
1.0
f
;
size_t
sample_line_
=
0
;
size_t
error_line_
=
0
;
};
void
RecordCandidateList
::
ReSize
(
size_t
length
)
{
mutex_
.
lock
();
capacity_
=
length
;
...
...
@@ -301,7 +402,7 @@ int InMemoryDataFeed<T>::Next() {
<<
", thread_id="
<<
thread_id_
;
}
}
else
{
VLOG
(
3
)
<<
"enable heter
NEXT
: "
<<
offset_index_
VLOG
(
3
)
<<
"enable heter
next
: "
<<
offset_index_
<<
" batch_offsets: "
<<
batch_offsets_
.
size
();
if
(
offset_index_
>=
batch_offsets_
.
size
())
{
VLOG
(
3
)
<<
"offset_index: "
<<
offset_index_
...
...
@@ -318,14 +419,7 @@ int InMemoryDataFeed<T>::Next() {
VLOG
(
3
)
<<
"finish reading for heterps, batch size zero, thread_id="
<<
thread_id_
;
}
/*
if (offset_index_ == batch_offsets_.size() - 1) {
std::vector<Record> data;
output_channel_->ReadAll(data);
consume_channel_->Write(std::move(data));
}
*/
VLOG
(
3
)
<<
"#15 enable heter NEXT: "
<<
offset_index_
VLOG
(
3
)
<<
"enable heter next: "
<<
offset_index_
<<
" batch_offsets: "
<<
batch_offsets_
.
size
()
<<
" baych_size: "
<<
this
->
batch_size_
;
}
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
3fbb6644
...
...
@@ -39,8 +39,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_int32
(
record_pool_max_size
);
DECLARE_int32
(
slotpool_thread_num
);
DECLARE_bool
(
enable_slotpool_wait_release
);
DECLARE_bool
(
enable_slotrecord_reset_shrink
);
namespace
paddle
{
namespace
framework
{
class
DataFeedDesc
;
...
...
@@ -69,6 +75,50 @@ namespace framework {
// while (reader->Next()) {
// // trainer do something
// }
template
<
typename
T
>
struct
SlotValues
{
std
::
vector
<
T
>
slot_values
;
std
::
vector
<
uint32_t
>
slot_offsets
;
void
add_values
(
const
T
*
values
,
uint32_t
num
)
{
if
(
slot_offsets
.
empty
())
{
slot_offsets
.
push_back
(
0
);
}
if
(
num
>
0
)
{
slot_values
.
insert
(
slot_values
.
end
(),
values
,
values
+
num
);
}
slot_offsets
.
push_back
(
static_cast
<
uint32_t
>
(
slot_values
.
size
()));
}
T
*
get_values
(
int
idx
,
size_t
*
size
)
{
uint32_t
&
offset
=
slot_offsets
[
idx
];
(
*
size
)
=
slot_offsets
[
idx
+
1
]
-
offset
;
return
&
slot_values
[
offset
];
}
void
add_slot_feasigns
(
const
std
::
vector
<
std
::
vector
<
T
>>&
slot_feasigns
,
uint32_t
fea_num
)
{
slot_values
.
reserve
(
fea_num
);
int
slot_num
=
static_cast
<
int
>
(
slot_feasigns
.
size
());
slot_offsets
.
resize
(
slot_num
+
1
);
for
(
int
i
=
0
;
i
<
slot_num
;
++
i
)
{
auto
&
slot_val
=
slot_feasigns
[
i
];
slot_offsets
[
i
]
=
static_cast
<
uint32_t
>
(
slot_values
.
size
());
uint32_t
num
=
static_cast
<
uint32_t
>
(
slot_val
.
size
());
if
(
num
>
0
)
{
slot_values
.
insert
(
slot_values
.
end
(),
slot_val
.
begin
(),
slot_val
.
end
());
}
}
slot_offsets
[
slot_num
]
=
slot_values
.
size
();
}
void
clear
(
bool
shrink
)
{
slot_offsets
.
clear
();
slot_values
.
clear
();
if
(
shrink
)
{
slot_values
.
shrink_to_fit
();
slot_offsets
.
shrink_to_fit
();
}
}
};
union
FeatureFeasign
{
uint64_t
uint64_feasign_
;
float
float_feasign_
;
...
...
@@ -97,6 +147,38 @@ struct FeatureItem {
uint16_t
slot_
;
};
struct
AllSlotInfo
{
std
::
string
slot
;
std
::
string
type
;
int
used_idx
;
int
slot_value_idx
;
};
struct
UsedSlotInfo
{
int
idx
;
int
slot_value_idx
;
std
::
string
slot
;
std
::
string
type
;
bool
dense
;
std
::
vector
<
int
>
local_shape
;
int
total_dims_without_inductive
;
int
inductive_shape_index
;
};
struct
SlotRecordObject
{
uint64_t
search_id
;
uint32_t
rank
;
uint32_t
cmatch
;
std
::
string
ins_id_
;
SlotValues
<
uint64_t
>
slot_uint64_feasigns_
;
SlotValues
<
float
>
slot_float_feasigns_
;
~
SlotRecordObject
()
{
clear
(
true
);
}
void
reset
(
void
)
{
clear
(
FLAGS_enable_slotrecord_reset_shrink
);
}
void
clear
(
bool
shrink
)
{
slot_uint64_feasigns_
.
clear
(
shrink
);
slot_float_feasigns_
.
clear
(
shrink
);
}
};
using
SlotRecord
=
SlotRecordObject
*
;
// sizeof Record is much less than std::vector<MultiSlotType>
struct
Record
{
std
::
vector
<
FeatureItem
>
uint64_feasigns_
;
...
...
@@ -108,6 +190,179 @@ struct Record {
uint32_t
cmatch
;
};
inline
SlotRecord
make_slotrecord
()
{
static
const
size_t
slot_record_byte_size
=
sizeof
(
SlotRecordObject
);
void
*
p
=
malloc
(
slot_record_byte_size
);
new
(
p
)
SlotRecordObject
;
return
reinterpret_cast
<
SlotRecordObject
*>
(
p
);
}
inline
void
free_slotrecord
(
SlotRecordObject
*
p
)
{
p
->~
SlotRecordObject
();
free
(
p
);
}
template
<
class
T
>
class
SlotObjAllocator
{
public:
explicit
SlotObjAllocator
(
std
::
function
<
void
(
T
*
)
>
deleter
)
:
free_nodes_
(
NULL
),
capacity_
(
0
),
deleter_
(
deleter
)
{}
~
SlotObjAllocator
()
{
clear
();
}
void
clear
()
{
T
*
tmp
=
NULL
;
while
(
free_nodes_
!=
NULL
)
{
tmp
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
void
*>
(
free_nodes_
));
free_nodes_
=
free_nodes_
->
next
;
deleter_
(
tmp
);
--
capacity_
;
}
CHECK_EQ
(
capacity_
,
static_cast
<
size_t
>
(
0
));
}
T
*
acquire
(
void
)
{
T
*
x
=
NULL
;
x
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
void
*>
(
free_nodes_
));
free_nodes_
=
free_nodes_
->
next
;
--
capacity_
;
return
x
;
}
void
release
(
T
*
x
)
{
Node
*
node
=
reinterpret_cast
<
Node
*>
(
reinterpret_cast
<
void
*>
(
x
));
node
->
next
=
free_nodes_
;
free_nodes_
=
node
;
++
capacity_
;
}
size_t
capacity
(
void
)
{
return
capacity_
;
}
private:
struct
alignas
(
T
)
Node
{
union
{
Node
*
next
;
char
data
[
sizeof
(
T
)];
};
};
Node
*
free_nodes_
;
// a list
size_t
capacity_
;
std
::
function
<
void
(
T
*
)
>
deleter_
=
nullptr
;
};
static
const
int
OBJPOOL_BLOCK_SIZE
=
10000
;
class
SlotObjPool
{
public:
SlotObjPool
()
:
max_capacity_
(
FLAGS_record_pool_max_size
),
alloc_
(
free_slotrecord
)
{
ins_chan_
=
MakeChannel
<
SlotRecord
>
();
ins_chan_
->
SetBlockSize
(
OBJPOOL_BLOCK_SIZE
);
for
(
int
i
=
0
;
i
<
FLAGS_slotpool_thread_num
;
++
i
)
{
threads_
.
push_back
(
std
::
thread
([
this
]()
{
run
();
}));
}
disable_pool_
=
false
;
count_
=
0
;
}
~
SlotObjPool
()
{
ins_chan_
->
Close
();
for
(
auto
&
t
:
threads_
)
{
t
.
join
();
}
}
void
disable_pool
(
bool
disable
)
{
disable_pool_
=
disable
;
}
void
set_max_capacity
(
size_t
max_capacity
)
{
max_capacity_
=
max_capacity
;
}
void
get
(
std
::
vector
<
SlotRecord
>*
output
,
int
n
)
{
output
->
resize
(
n
);
return
get
(
&
(
*
output
)[
0
],
n
);
}
void
get
(
SlotRecord
*
output
,
int
n
)
{
int
size
=
0
;
mutex_
.
lock
();
int
left
=
static_cast
<
int
>
(
alloc_
.
capacity
());
if
(
left
>
0
)
{
size
=
(
left
>=
n
)
?
n
:
left
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
output
[
i
]
=
alloc_
.
acquire
();
}
}
mutex_
.
unlock
();
count_
+=
n
;
if
(
size
==
n
)
{
return
;
}
for
(
int
i
=
size
;
i
<
n
;
++
i
)
{
output
[
i
]
=
make_slotrecord
();
}
}
void
put
(
std
::
vector
<
SlotRecord
>*
input
)
{
size_t
size
=
input
->
size
();
if
(
size
==
0
)
{
return
;
}
put
(
&
(
*
input
)[
0
],
size
);
input
->
clear
();
}
void
put
(
SlotRecord
*
input
,
size_t
size
)
{
CHECK
(
ins_chan_
->
WriteMove
(
size
,
input
)
==
size
);
}
void
run
(
void
)
{
std
::
vector
<
SlotRecord
>
input
;
while
(
ins_chan_
->
ReadOnce
(
input
,
OBJPOOL_BLOCK_SIZE
))
{
if
(
input
.
empty
())
{
continue
;
}
// over max capacity
size_t
n
=
input
.
size
();
count_
-=
n
;
if
(
disable_pool_
||
n
+
capacity
()
>
max_capacity_
)
{
for
(
auto
&
t
:
input
)
{
free_slotrecord
(
t
);
}
}
else
{
for
(
auto
&
t
:
input
)
{
t
->
reset
();
}
mutex_
.
lock
();
for
(
auto
&
t
:
input
)
{
alloc_
.
release
(
t
);
}
mutex_
.
unlock
();
}
input
.
clear
();
}
}
void
clear
(
void
)
{
platform
::
Timer
timeline
;
timeline
.
Start
();
mutex_
.
lock
();
alloc_
.
clear
();
mutex_
.
unlock
();
// wait release channel data
if
(
FLAGS_enable_slotpool_wait_release
)
{
while
(
!
ins_chan_
->
Empty
())
{
sleep
(
1
);
}
}
timeline
.
Pause
();
VLOG
(
3
)
<<
"clear slot pool data size="
<<
count_
.
load
()
<<
", span="
<<
timeline
.
ElapsedSec
();
}
size_t
capacity
(
void
)
{
mutex_
.
lock
();
size_t
total
=
alloc_
.
capacity
();
mutex_
.
unlock
();
return
total
;
}
private:
size_t
max_capacity_
;
Channel
<
SlotRecord
>
ins_chan_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
mutex
mutex_
;
SlotObjAllocator
<
SlotRecordObject
>
alloc_
;
bool
disable_pool_
;
std
::
atomic
<
long
>
count_
;
// NOLINT
};
inline
SlotObjPool
&
SlotRecordPool
()
{
static
SlotObjPool
pool
;
return
pool
;
}
struct
PvInstanceObject
{
std
::
vector
<
Record
*>
ads
;
void
merge_instance
(
Record
*
ins
)
{
ads
.
push_back
(
ins
);
}
...
...
@@ -129,7 +384,21 @@ class CustomParser {
CustomParser
()
{}
virtual
~
CustomParser
()
{}
virtual
void
Init
(
const
std
::
vector
<
SlotConf
>&
slots
)
=
0
;
virtual
bool
Init
(
const
std
::
vector
<
AllSlotInfo
>&
slots
)
=
0
;
virtual
void
ParseOneInstance
(
const
char
*
str
,
Record
*
instance
)
=
0
;
virtual
bool
ParseOneInstance
(
const
std
::
string
&
line
,
std
::
function
<
void
(
std
::
vector
<
SlotRecord
>&
,
int
)
>
GetInsFunc
)
{
// NOLINT
return
true
;
}
virtual
bool
ParseFileInstance
(
std
::
function
<
int
(
char
*
buf
,
int
len
)
>
ReadBuffFunc
,
std
::
function
<
void
(
std
::
vector
<
SlotRecord
>&
,
int
,
int
)
>
PullRecordsFunc
,
// NOLINT
int
&
lines
)
{
// NOLINT
return
false
;
}
};
typedef
paddle
::
framework
::
CustomParser
*
(
*
CreateParserObjectFunc
)();
...
...
@@ -194,6 +463,34 @@ class DLManager {
return
nullptr
;
}
paddle
::
framework
::
CustomParser
*
Load
(
const
std
::
string
&
name
,
const
std
::
vector
<
AllSlotInfo
>&
conf
)
{
#ifdef _LINUX
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
DLHandle
handle
;
std
::
map
<
std
::
string
,
DLHandle
>::
iterator
it
=
handle_map_
.
find
(
name
);
if
(
it
!=
handle_map_
.
end
())
{
return
it
->
second
.
parser
;
}
handle
.
module
=
dlopen
(
name
.
c_str
(),
RTLD_NOW
);
if
(
handle
.
module
==
nullptr
)
{
VLOG
(
0
)
<<
"Create so of "
<<
name
<<
" fail"
;
exit
(
-
1
);
return
nullptr
;
}
CreateParserObjectFunc
create_parser_func
=
(
CreateParserObjectFunc
)
dlsym
(
handle
.
module
,
"CreateParserObject"
);
handle
.
parser
=
create_parser_func
();
handle
.
parser
->
Init
(
conf
);
handle_map_
.
insert
({
name
,
handle
});
return
handle
.
parser
;
#endif
VLOG
(
0
)
<<
"Not implement in windows"
;
return
nullptr
;
}
paddle
::
framework
::
CustomParser
*
ReLoad
(
const
std
::
string
&
name
,
const
std
::
vector
<
SlotConf
>&
conf
)
{
Close
(
name
);
...
...
@@ -415,6 +712,11 @@ class InMemoryDataFeed : public DataFeed {
virtual
void
SetCurrentPhase
(
int
current_phase
);
virtual
void
LoadIntoMemory
();
virtual
void
LoadIntoMemoryFromSo
();
virtual
void
SetRecord
(
T
*
records
)
{
records_
=
records
;
}
int
GetDefaultBatchSize
()
{
return
default_batch_size_
;
}
void
AddBatchOffset
(
const
std
::
pair
<
int
,
int
>&
offset
)
{
batch_offsets_
.
push_back
(
offset
);
}
protected:
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
...
...
@@ -424,6 +726,11 @@ class InMemoryDataFeed : public DataFeed {
virtual
void
PutToFeedVec
(
const
std
::
vector
<
T
>&
ins_vec
)
=
0
;
virtual
void
PutToFeedVec
(
const
T
*
ins_vec
,
int
num
)
=
0
;
std
::
vector
<
std
::
vector
<
float
>>
batch_float_feasigns_
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
batch_uint64_feasigns_
;
std
::
vector
<
std
::
vector
<
size_t
>>
offset_
;
std
::
vector
<
bool
>
visit_
;
int
thread_id_
;
int
thread_num_
;
bool
parse_ins_id_
;
...
...
@@ -783,11 +1090,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
MultiSlotInMemoryDataFeed
()
{}
virtual
~
MultiSlotInMemoryDataFeed
()
{}
virtual
void
Init
(
const
DataFeedDesc
&
data_feed_desc
);
void
SetRecord
(
Record
*
records
)
{
records_
=
records
;
}
int
GetDefaultBatchSize
()
{
return
default_batch_size_
;
}
void
AddBatchOffset
(
const
std
::
pair
<
int
,
int
>&
offset
)
{
batch_offsets_
.
push_back
(
offset
);
}
// void SetRecord(Record* records) { records_ = records; }
protected:
virtual
bool
ParseOneInstance
(
Record
*
instance
);
...
...
@@ -798,10 +1101,6 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual
void
GetMsgFromLogKey
(
const
std
::
string
&
log_key
,
uint64_t
*
search_id
,
uint32_t
*
cmatch
,
uint32_t
*
rank
);
virtual
void
PutToFeedVec
(
const
Record
*
ins_vec
,
int
num
);
std
::
vector
<
std
::
vector
<
float
>>
batch_float_feasigns_
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
batch_uint64_feasigns_
;
std
::
vector
<
std
::
vector
<
size_t
>>
offset_
;
std
::
vector
<
bool
>
visit_
;
};
class
PaddleBoxDataFeed
:
public
MultiSlotInMemoryDataFeed
{
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
3fbb6644
...
...
@@ -351,10 +351,8 @@ static int compute_thread_batch_nccl(
return
thread_avg_batch_num
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetHeterPs
(
bool
enable_heterps
)
{
void
MultiSlotDataset
::
PrepareTrain
()
{
#ifdef PADDLE_WITH_GLOO
enable_heterps_
=
enable_heterps
;
if
(
enable_heterps_
)
{
if
(
input_records_
.
size
()
==
0
&&
input_channel_
!=
nullptr
&&
input_channel_
->
Size
()
!=
0
)
{
...
...
@@ -541,22 +539,21 @@ void DatasetImpl<T>::LocalShuffle() {
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
GlobalShuffle
(
int
thread_num
)
{
void
MultiSlotDataset
::
GlobalShuffle
(
int
thread_num
)
{
#ifdef PADDLE_WITH_PSLIB
VLOG
(
3
)
<<
"
DatasetImpl<T>
::GlobalShuffle() begin"
;
VLOG
(
3
)
<<
"
MultiSlotDataset
::GlobalShuffle() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
if
(
!
input_channel_
||
input_channel_
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"
DatasetImpl<T>
::GlobalShuffle() end, no data to shuffle"
;
VLOG
(
3
)
<<
"
MultiSlotDataset
::GlobalShuffle() end, no data to shuffle"
;
return
;
}
// local shuffle
input_channel_
->
Close
();
std
::
vector
<
T
>
data
;
std
::
vector
<
Record
>
data
;
input_channel_
->
ReadAll
(
data
);
std
::
shuffle
(
data
.
begin
(),
data
.
end
(),
fleet_ptr
->
LocalRandomEngine
());
input_channel_
->
Open
();
...
...
@@ -566,10 +563,10 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
input_channel_
->
Close
();
input_channel_
->
SetBlockSize
(
fleet_send_batch_size_
);
VLOG
(
3
)
<<
"
DatasetImpl<T>
::GlobalShuffle() input_channel_ size "
VLOG
(
3
)
<<
"
MultiSlotDataset
::GlobalShuffle() input_channel_ size "
<<
input_channel_
->
Size
();
auto
get_client_id
=
[
this
,
fleet_ptr
](
const
T
&
data
)
->
size_t
{
auto
get_client_id
=
[
this
,
fleet_ptr
](
const
Record
&
data
)
->
size_t
{
if
(
!
this
->
merge_by_insid_
)
{
return
fleet_ptr
->
LocalRandomEngine
()()
%
this
->
trainer_num_
;
}
else
{
...
...
@@ -580,7 +577,7 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
auto
global_shuffle_func
=
[
this
,
get_client_id
]()
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
vector
<
T
>
data
;
std
::
vector
<
Record
>
data
;
while
(
this
->
input_channel_
->
Read
(
data
))
{
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
ars
(
this
->
trainer_num_
);
for
(
auto
&
t
:
data
)
{
...
...
@@ -835,9 +832,6 @@ void DatasetImpl<T>::CreateReaders() {
channel_idx
=
0
;
}
}
if
(
enable_heterps_
)
{
SetHeterPs
(
true
);
}
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
}
...
...
@@ -923,8 +917,7 @@ int64_t DatasetImpl<T>::GetShuffleDataSize() {
return
sum
;
}
template
<
typename
T
>
int
DatasetImpl
<
T
>::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
int
MultiSlotDataset
::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
#ifdef _LINUX
VLOG
(
3
)
<<
"ReceiveFromClient msg_type="
<<
msg_type
...
...
@@ -937,9 +930,9 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
if
(
ar
.
Cursor
()
==
ar
.
Finish
())
{
return
0
;
}
std
::
vector
<
T
>
data
;
std
::
vector
<
Record
>
data
;
while
(
ar
.
Cursor
()
<
ar
.
Finish
())
{
data
.
push_back
(
ar
.
Get
<
T
>
());
data
.
push_back
(
ar
.
Get
<
Record
>
());
}
CHECK
(
ar
.
Cursor
()
==
ar
.
Finish
());
...
...
@@ -966,6 +959,20 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
// explicit instantiation
template
class
DatasetImpl
<
Record
>;
void
MultiSlotDataset
::
DynamicAdjustReadersNum
(
int
thread_num
)
{
if
(
thread_num_
==
thread_num
)
{
VLOG
(
3
)
<<
"DatasetImpl<T>::DynamicAdjustReadersNum thread_num_="
<<
thread_num_
<<
", thread_num_=thread_num, no need to adjust"
;
return
;
}
VLOG
(
3
)
<<
"adjust readers num from "
<<
thread_num_
<<
" to "
<<
thread_num
;
thread_num_
=
thread_num
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
().
swap
(
readers_
);
CreateReaders
();
VLOG
(
3
)
<<
"adjust readers num done"
;
PrepareTrain
();
}
void
MultiSlotDataset
::
PostprocessInstance
()
{
// divide pv instance, and merge to input_channel_
if
(
enable_pv_merge_
)
{
...
...
@@ -1503,5 +1510,126 @@ void MultiSlotDataset::SlotsShuffle(
<<
", cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
template
class
DatasetImpl
<
SlotRecord
>;
void
SlotRecordDataset
::
CreateChannel
()
{
if
(
input_channel_
==
nullptr
)
{
input_channel_
=
paddle
::
framework
::
MakeChannel
<
SlotRecord
>
();
}
}
void
SlotRecordDataset
::
CreateReaders
()
{
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
VLOG
(
3
)
<<
"thread num in Dataset: "
<<
thread_num_
;
VLOG
(
3
)
<<
"Filelist size in Dataset: "
<<
filelist_
.
size
();
VLOG
(
3
)
<<
"channel num in Dataset: "
<<
channel_num_
;
CHECK
(
thread_num_
>
0
)
<<
"thread num should > 0"
;
CHECK
(
channel_num_
>
0
)
<<
"channel num should > 0"
;
CHECK
(
channel_num_
<=
thread_num_
)
<<
"channel num should <= thread num"
;
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
if
(
readers_
.
size
()
!=
0
)
{
VLOG
(
3
)
<<
"readers_.size() = "
<<
readers_
.
size
()
<<
", will not create again"
;
return
;
}
VLOG
(
3
)
<<
"data feed class name: "
<<
data_feed_desc_
.
name
();
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
[
i
]
->
Init
(
data_feed_desc_
);
readers_
[
i
]
->
SetThreadId
(
i
);
readers_
[
i
]
->
SetThreadNum
(
thread_num_
);
readers_
[
i
]
->
SetFileListMutex
(
&
mutex_for_pick_file_
);
readers_
[
i
]
->
SetFileListIndex
(
&
file_idx_
);
readers_
[
i
]
->
SetFeaNumMutex
(
&
mutex_for_fea_num_
);
readers_
[
i
]
->
SetFeaNum
(
&
total_fea_num_
);
readers_
[
i
]
->
SetFileList
(
filelist_
);
readers_
[
i
]
->
SetParseInsId
(
parse_ins_id_
);
readers_
[
i
]
->
SetParseContent
(
parse_content_
);
readers_
[
i
]
->
SetParseLogKey
(
parse_logkey_
);
readers_
[
i
]
->
SetEnablePvMerge
(
enable_pv_merge_
);
readers_
[
i
]
->
SetCurrentPhase
(
current_phase_
);
if
(
input_channel_
!=
nullptr
)
{
readers_
[
i
]
->
SetInputChannel
(
input_channel_
.
get
());
}
}
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
}
void
SlotRecordDataset
::
ReleaseMemory
()
{
VLOG
(
3
)
<<
"SlotRecordDataset::ReleaseMemory() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
if
(
input_channel_
)
{
input_channel_
->
Clear
();
input_channel_
=
nullptr
;
}
if
(
enable_heterps_
)
{
VLOG
(
3
)
<<
"put pool records size: "
<<
input_records_
.
size
();
SlotRecordPool
().
put
(
&
input_records_
);
input_records_
.
clear
();
input_records_
.
shrink_to_fit
();
VLOG
(
3
)
<<
"release heterps input records records size: "
<<
input_records_
.
size
();
}
readers_
.
clear
();
readers_
.
shrink_to_fit
();
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
().
swap
(
readers_
);
VLOG
(
3
)
<<
"SlotRecordDataset::ReleaseMemory() end"
;
VLOG
(
3
)
<<
"total_feasign_num_("
<<
STAT_GET
(
STAT_total_feasign_num_in_mem
)
<<
") - current_fea_num_("
<<
total_fea_num_
<<
") = ("
<<
STAT_GET
(
STAT_total_feasign_num_in_mem
)
-
total_fea_num_
<<
")"
<<
" object pool size="
<<
SlotRecordPool
().
capacity
();
// For Debug
STAT_SUB
(
STAT_total_feasign_num_in_mem
,
total_fea_num_
);
}
void
SlotRecordDataset
::
GlobalShuffle
(
int
thread_num
)
{
// TODO(yaoxuefeng)
return
;
}
void
SlotRecordDataset
::
DynamicAdjustChannelNum
(
int
channel_num
,
bool
discard_remaining_ins
)
{
if
(
channel_num_
==
channel_num
)
{
VLOG
(
3
)
<<
"DatasetImpl<T>::DynamicAdjustChannelNum channel_num_="
<<
channel_num_
<<
", channel_num_=channel_num, no need to adjust"
;
return
;
}
VLOG
(
3
)
<<
"adjust channel num from "
<<
channel_num_
<<
" to "
<<
channel_num
;
channel_num_
=
channel_num
;
if
(
static_cast
<
int
>
(
input_channel_
->
Size
())
>=
channel_num
)
{
input_channel_
->
SetBlockSize
(
input_channel_
->
Size
()
/
channel_num
+
(
discard_remaining_ins
?
0
:
1
));
}
VLOG
(
3
)
<<
"adjust channel num done"
;
}
void
SlotRecordDataset
::
PrepareTrain
()
{
#ifdef PADDLE_WITH_GLOO
return
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"dataset set heterps need compile with GLOO"
));
#endif
return
;
}
void
SlotRecordDataset
::
DynamicAdjustReadersNum
(
int
thread_num
)
{
if
(
thread_num_
==
thread_num
)
{
VLOG
(
3
)
<<
"DatasetImpl<T>::DynamicAdjustReadersNum thread_num_="
<<
thread_num_
<<
", thread_num_=thread_num, no need to adjust"
;
return
;
}
VLOG
(
3
)
<<
"adjust readers num from "
<<
thread_num_
<<
" to "
<<
thread_num
;
thread_num_
=
thread_num
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
().
swap
(
readers_
);
CreateReaders
();
VLOG
(
3
)
<<
"adjust readers num done"
;
PrepareTrain
();
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/data_set.h
浏览文件 @
3fbb6644
...
...
@@ -149,7 +149,6 @@ class Dataset {
virtual
void
DynamicAdjustReadersNum
(
int
thread_num
)
=
0
;
// set fleet send sleep seconds
virtual
void
SetFleetSendSleepSeconds
(
int
seconds
)
=
0
;
virtual
void
SetHeterPs
(
bool
enable_heterps
)
=
0
;
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
...
...
@@ -207,7 +206,7 @@ class DatasetImpl : public Dataset {
virtual
void
WaitPreLoadDone
();
virtual
void
ReleaseMemory
();
virtual
void
LocalShuffle
();
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
)
;
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
)
{}
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
{}
virtual
const
std
::
vector
<
T
>&
GetSlotsOriginalData
()
{
return
slots_shuffle_original_data_
;
...
...
@@ -233,7 +232,11 @@ class DatasetImpl : public Dataset {
bool
discard_remaining_ins
=
false
);
virtual
void
DynamicAdjustReadersNum
(
int
thread_num
);
virtual
void
SetFleetSendSleepSeconds
(
int
seconds
);
virtual
void
SetHeterPs
(
bool
enable_heterps
);
/* for enable_heterps_
virtual void EnableHeterps(bool enable_heterps) {
enable_heterps_ = enable_heterps;
}
*/
std
::
vector
<
paddle
::
framework
::
Channel
<
T
>>&
GetMultiOutputChannel
()
{
return
multi_output_channel_
;
...
...
@@ -251,7 +254,10 @@ class DatasetImpl : public Dataset {
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
);
const
std
::
string
&
msg
)
{
// TODO(yaoxuefeng) for SlotRecordDataset
return
-
1
;
}
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
preload_readers_
;
paddle
::
framework
::
Channel
<
T
>
input_channel_
;
...
...
@@ -327,6 +333,32 @@ class MultiSlotDataset : public DatasetImpl<Record> {
const
std
::
unordered_set
<
uint16_t
>&
slots_to_replace
,
std
::
vector
<
Record
>*
result
);
virtual
~
MultiSlotDataset
()
{}
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
);
virtual
void
DynamicAdjustReadersNum
(
int
thread_num
);
virtual
void
PrepareTrain
();
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
);
};
class
SlotRecordDataset
:
public
DatasetImpl
<
SlotRecord
>
{
public:
SlotRecordDataset
()
{
SlotRecordPool
();
}
virtual
~
SlotRecordDataset
()
{}
// create input channel
virtual
void
CreateChannel
();
// create readers
virtual
void
CreateReaders
();
// release memory
virtual
void
ReleaseMemory
();
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
);
virtual
void
DynamicAdjustChannelNum
(
int
channel_num
,
bool
discard_remaining_ins
);
virtual
void
PrepareTrain
();
virtual
void
DynamicAdjustReadersNum
(
int
thread_num
);
protected:
bool
enable_heterps_
=
true
;
};
}
// end namespace framework
...
...
paddle/fluid/framework/dataset_factory.cc
浏览文件 @
3fbb6644
...
...
@@ -53,7 +53,7 @@ std::unique_ptr<Dataset> DatasetFactory::CreateDataset(
std
::
string
dataset_class
)
{
if
(
g_dataset_map
.
count
(
dataset_class
)
<
1
)
{
LOG
(
WARNING
)
<<
"Your Dataset "
<<
dataset_class
<<
"is not supported currently"
;
<<
"
is not supported currently"
;
LOG
(
WARNING
)
<<
"Supported Dataset: "
<<
DatasetTypeList
();
exit
(
-
1
);
}
...
...
@@ -61,5 +61,6 @@ std::unique_ptr<Dataset> DatasetFactory::CreateDataset(
}
REGISTER_DATASET_CLASS
(
MultiSlotDataset
);
REGISTER_DATASET_CLASS
(
SlotRecordDataset
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/platform/flags.cc
浏览文件 @
3fbb6644
...
...
@@ -673,3 +673,11 @@ PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120,
PADDLE_DEFINE_EXPORTED_bool
(
apply_pass_to_program
,
false
,
"It controls whether to apply IR pass to program when using Fleet APIs"
);
DEFINE_int32
(
record_pool_max_size
,
2000000
,
"SlotRecordDataset slot record pool max size"
);
DEFINE_int32
(
slotpool_thread_num
,
1
,
"SlotRecordDataset slot pool thread num"
);
DEFINE_bool
(
enable_slotpool_wait_release
,
false
,
"enable slotrecord obejct wait release, default false"
);
DEFINE_bool
(
enable_slotrecord_reset_shrink
,
false
,
"enable slotrecord obejct reset shrink memory, default false"
);
\ No newline at end of file
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
3fbb6644
...
...
@@ -309,8 +309,6 @@ void BindDataset(py::module *m) {
&
framework
::
Dataset
::
SetFleetSendSleepSeconds
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"enable_pv_merge"
,
&
framework
::
Dataset
::
EnablePvMerge
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"set_heter_ps"
,
&
framework
::
Dataset
::
SetHeterPs
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
class_
<
IterableDatasetWrapper
>
(
*
m
,
"IterableDatasetWrapper"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录