Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e6b87b31
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e6b87b31
编写于
5月 25, 2020
作者:
H
hutuxian
提交者:
GitHub
5月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support AucRunner in PaddleBox (#22884)
* Support AucRunner in PaddleBox * update some code style
上级
c417f991
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
413 addition
and
85 deletion
+413
-85
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+31
-27
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+70
-11
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+24
-15
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+14
-7
paddle/fluid/framework/fleet/box_wrapper.cc
paddle/fluid/framework/fleet/box_wrapper.cc
+107
-0
paddle/fluid/framework/fleet/box_wrapper.h
paddle/fluid/framework/fleet/box_wrapper.h
+134
-21
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+2
-2
paddle/fluid/pybind/box_helper_py.cc
paddle/fluid/pybind/box_helper_py.cc
+5
-1
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+2
-0
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+21
-0
python/paddle/fluid/tests/unittests/test_boxps.py
python/paddle/fluid/tests/unittests/test_boxps.py
+1
-0
python/paddle/fluid/tests/unittests/test_dataset.py
python/paddle/fluid/tests/unittests/test_dataset.py
+2
-1
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
e6b87b31
...
@@ -41,44 +41,44 @@ namespace paddle {
...
@@ -41,44 +41,44 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
void
RecordCandidateList
::
ReSize
(
size_t
length
)
{
void
RecordCandidateList
::
ReSize
(
size_t
length
)
{
_mutex
.
lock
();
mutex_
.
lock
();
_capacity
=
length
;
capacity_
=
length
;
CHECK
(
_capacity
>
0
);
// NOLINT
CHECK
(
capacity_
>
0
);
// NOLINT
_candidate_list
.
clear
();
candidate_list_
.
clear
();
_candidate_list
.
resize
(
_capacity
);
candidate_list_
.
resize
(
capacity_
);
_full
=
false
;
full_
=
false
;
_cur_size
=
0
;
cur_size_
=
0
;
_total_size
=
0
;
total_size_
=
0
;
_mutex
.
unlock
();
mutex_
.
unlock
();
}
}
void
RecordCandidateList
::
ReInit
()
{
void
RecordCandidateList
::
ReInit
()
{
_mutex
.
lock
();
mutex_
.
lock
();
_full
=
false
;
full_
=
false
;
_cur_size
=
0
;
cur_size_
=
0
;
_total_size
=
0
;
total_size_
=
0
;
_mutex
.
unlock
();
mutex_
.
unlock
();
}
}
void
RecordCandidateList
::
AddAndGet
(
const
Record
&
record
,
void
RecordCandidateList
::
AddAndGet
(
const
Record
&
record
,
RecordCandidate
*
result
)
{
RecordCandidate
*
result
)
{
_mutex
.
lock
();
mutex_
.
lock
();
size_t
index
=
0
;
size_t
index
=
0
;
++
_total_size
;
++
total_size_
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
if
(
!
_full
)
{
if
(
!
full_
)
{
_candidate_list
[
_cur_size
++
]
=
record
;
candidate_list_
[
cur_size_
++
]
=
record
;
_full
=
(
_cur_size
==
_capacity
);
full_
=
(
cur_size_
==
capacity_
);
}
else
{
}
else
{
CHECK
(
_cur_size
==
_capacity
);
CHECK
(
cur_size_
==
capacity_
);
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
_total_size
;
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
total_size_
;
if
(
index
<
_capacity
)
{
if
(
index
<
capacity_
)
{
_candidate_list
[
index
]
=
record
;
candidate_list_
[
index
]
=
record
;
}
}
}
}
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
_cur_size
;
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
cur_size_
;
*
result
=
_candidate_list
[
index
];
*
result
=
candidate_list_
[
index
];
_mutex
.
unlock
();
mutex_
.
unlock
();
}
}
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
...
@@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
...
@@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
int
PaddleBoxDataFeed
::
GetCurrentPhase
()
{
int
PaddleBoxDataFeed
::
GetCurrentPhase
()
{
#ifdef PADDLE_WITH_BOX_PS
#ifdef PADDLE_WITH_BOX_PS
auto
box_ptr
=
paddle
::
framework
::
BoxWrapper
::
GetInstance
();
auto
box_ptr
=
paddle
::
framework
::
BoxWrapper
::
GetInstance
();
return
box_ptr
->
PassFlag
();
// join: 1, update: 0
if
(
box_ptr
->
Mode
()
==
1
)
{
// For AucRunner
return
1
;
}
else
{
return
box_ptr
->
Phase
();
}
#else
#else
LOG
(
WARNING
)
<<
"It should be complied with BOX_PS..."
;
LOG
(
WARNING
)
<<
"It should be complied with BOX_PS..."
;
return
current_phase_
;
return
current_phase_
;
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
e6b87b31
...
@@ -27,6 +27,7 @@ limitations under the License. */
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
...
@@ -34,6 +35,7 @@ limitations under the License. */
...
@@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
...
@@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
...
@@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
struct
RecordCandidate
{
struct
RecordCandidate
{
std
::
string
ins_id_
;
std
::
string
ins_id_
;
std
::
unordered_multimap
<
uint16_t
,
FeatureKey
>
feas
;
std
::
unordered_multimap
<
uint16_t
,
FeatureKey
>
feas_
;
size_t
shadow_index_
=
-
1
;
// Optimization for Reservoir Sample
RecordCandidate
()
{}
RecordCandidate
(
const
Record
&
rec
,
const
std
::
unordered_set
<
uint16_t
>&
slot_index_to_replace
)
{
for
(
const
auto
&
fea
:
rec
.
uint64_feasigns_
)
{
if
(
slot_index_to_replace
.
find
(
fea
.
slot
())
!=
slot_index_to_replace
.
end
())
{
feas_
.
insert
({
fea
.
slot
(),
fea
.
sign
()});
}
}
}
RecordCandidate
&
operator
=
(
const
Record
&
rec
)
{
RecordCandidate
&
operator
=
(
const
Record
&
rec
)
{
feas
.
clear
();
feas
_
.
clear
();
ins_id_
=
rec
.
ins_id_
;
ins_id_
=
rec
.
ins_id_
;
for
(
auto
&
fea
:
rec
.
uint64_feasigns_
)
{
for
(
auto
&
fea
:
rec
.
uint64_feasigns_
)
{
feas
.
insert
({
fea
.
slot
(),
fea
.
sign
()});
feas
_
.
insert
({
fea
.
slot
(),
fea
.
sign
()});
}
}
return
*
this
;
return
*
this
;
}
}
...
@@ -499,22 +513,67 @@ struct RecordCandidate {
...
@@ -499,22 +513,67 @@ struct RecordCandidate {
class
RecordCandidateList
{
class
RecordCandidateList
{
public:
public:
RecordCandidateList
()
=
default
;
RecordCandidateList
()
=
default
;
RecordCandidateList
(
const
RecordCandidateList
&
)
=
delete
;
RecordCandidateList
(
const
RecordCandidateList
&
)
{}
RecordCandidateList
&
operator
=
(
const
RecordCandidateList
&
)
=
delete
;
size_t
Size
()
{
return
cur_size_
;
}
void
ReSize
(
size_t
length
);
void
ReSize
(
size_t
length
);
void
ReInit
();
void
ReInit
();
void
ReInitPass
()
{
for
(
size_t
i
=
0
;
i
<
cur_size_
;
++
i
)
{
if
(
candidate_list_
[
i
].
shadow_index_
!=
i
)
{
candidate_list_
[
i
].
ins_id_
=
candidate_list_
[
candidate_list_
[
i
].
shadow_index_
].
ins_id_
;
candidate_list_
[
i
].
feas_
.
swap
(
candidate_list_
[
candidate_list_
[
i
].
shadow_index_
].
feas_
);
candidate_list_
[
i
].
shadow_index_
=
i
;
}
}
candidate_list_
.
resize
(
cur_size_
);
}
void
AddAndGet
(
const
Record
&
record
,
RecordCandidate
*
result
);
void
AddAndGet
(
const
Record
&
record
,
RecordCandidate
*
result
);
void
AddAndGet
(
const
Record
&
record
,
size_t
&
index_result
)
{
// NOLINT
// std::unique_lock<std::mutex> lock(mutex_);
size_t
index
=
0
;
++
total_size_
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
if
(
!
full_
)
{
candidate_list_
.
emplace_back
(
record
,
slot_index_to_replace_
);
candidate_list_
.
back
().
shadow_index_
=
cur_size_
;
++
cur_size_
;
full_
=
(
cur_size_
==
capacity_
);
}
else
{
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
total_size_
;
if
(
index
<
capacity_
)
{
candidate_list_
.
emplace_back
(
record
,
slot_index_to_replace_
);
candidate_list_
[
index
].
shadow_index_
=
candidate_list_
.
size
()
-
1
;
}
}
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
cur_size_
;
index_result
=
candidate_list_
[
index
].
shadow_index_
;
}
const
RecordCandidate
&
Get
(
size_t
index
)
const
{
PADDLE_ENFORCE_LT
(
index
,
candidate_list_
.
size
(),
platform
::
errors
::
OutOfRange
(
"Your index [%lu] exceeds the number of "
"elements in candidate_list[%lu]."
,
index
,
candidate_list_
.
size
()));
return
candidate_list_
[
index
];
}
void
SetSlotIndexToReplace
(
const
std
::
unordered_set
<
uint16_t
>&
slot_index_to_replace
)
{
slot_index_to_replace_
=
slot_index_to_replace
;
}
private:
private:
size_t
_capacity
=
0
;
size_t
capacity_
=
0
;
std
::
mutex
_mutex
;
std
::
mutex
mutex_
;
bool
_full
=
false
;
bool
full_
=
false
;
size_t
_cur_size
=
0
;
size_t
cur_size_
=
0
;
size_t
_total_size
=
0
;
size_t
total_size_
=
0
;
std
::
vector
<
RecordCandidate
>
_candidate_list
;
std
::
vector
<
RecordCandidate
>
candidate_list_
;
std
::
unordered_set
<
uint16_t
>
slot_index_to_replace_
;
};
};
template
<
class
AR
>
template
<
class
AR
>
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
e6b87b31
...
@@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() {
...
@@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() {
VLOG
(
3
)
<<
"MultiSlotDataset::MergeByInsId end"
;
VLOG
(
3
)
<<
"MultiSlotDataset::MergeByInsId end"
;
}
}
void
MultiSlotDataset
::
GetRandomData
(
const
std
::
set
<
uint16_t
>&
slots_to_replace
,
void
MultiSlotDataset
::
GetRandomData
(
std
::
vector
<
Record
>*
result
)
{
const
std
::
unordered_set
<
uint16_t
>&
slots_to_replace
,
std
::
vector
<
Record
>*
result
)
{
int
debug_erase_cnt
=
0
;
int
debug_erase_cnt
=
0
;
int
debug_push_cnt
=
0
;
int
debug_push_cnt
=
0
;
auto
multi_slot_desc
=
data_feed_desc_
.
multi_slot_desc
();
auto
multi_slot_desc
=
data_feed_desc_
.
multi_slot_desc
();
slots_shuffle_rclist_
.
ReInit
();
slots_shuffle_rclist_
.
ReInit
();
for
(
const
auto
&
rec
:
slots_shuffle_original_data_
)
{
const
auto
&
slots_shuffle_original_data
=
GetSlotsOriginalData
();
for
(
const
auto
&
rec
:
slots_shuffle_original_data
)
{
RecordCandidate
rand_rec
;
RecordCandidate
rand_rec
;
Record
new_rec
=
rec
;
Record
new_rec
=
rec
;
slots_shuffle_rclist_
.
AddAndGet
(
rec
,
&
rand_rec
);
slots_shuffle_rclist_
.
AddAndGet
(
rec
,
&
rand_rec
);
...
@@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
...
@@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
}
}
}
}
for
(
auto
slot
:
slots_to_replace
)
{
for
(
auto
slot
:
slots_to_replace
)
{
auto
range
=
rand_rec
.
feas
.
equal_range
(
slot
);
auto
range
=
rand_rec
.
feas
_
.
equal_range
(
slot
);
for
(
auto
it
=
range
.
first
;
it
!=
range
.
second
;
++
it
)
{
for
(
auto
it
=
range
.
first
;
it
!=
range
.
second
;
++
it
)
{
new_rec
.
uint64_feasigns_
.
push_back
({
it
->
second
,
it
->
first
});
new_rec
.
uint64_feasigns_
.
push_back
({
it
->
second
,
it
->
first
});
debug_push_cnt
+=
1
;
debug_push_cnt
+=
1
;
...
@@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
...
@@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
<<
" repush feasign num: "
<<
debug_push_cnt
;
<<
" repush feasign num: "
<<
debug_push_cnt
;
}
}
// slots shuffle to input_channel_ with needed-shuffle slots
void
MultiSlotDataset
::
PreprocessChannel
(
void
MultiSlotDataset
::
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
,
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
{
std
::
unordered_set
<
uint16_t
>&
index_slots
)
{
// NOLINT
int
out_channel_size
=
0
;
int
out_channel_size
=
0
;
if
(
cur_channel_
==
0
)
{
if
(
cur_channel_
==
0
)
{
for
(
size_t
i
=
0
;
i
<
multi_output_channel_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
multi_output_channel_
.
size
();
++
i
)
{
...
@@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle(
...
@@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle(
VLOG
(
2
)
<<
"DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
VLOG
(
2
)
<<
"DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<<
input_channel_
->
Size
()
<<
input_channel_
->
Size
()
<<
" output channel size: "
<<
out_channel_size
;
<<
" output channel size: "
<<
out_channel_size
;
if
(
!
slots_shuffle_fea_eval_
)
{
VLOG
(
3
)
<<
"DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle"
;
return
;
}
if
((
!
input_channel_
||
input_channel_
->
Size
()
==
0
)
&&
if
((
!
input_channel_
||
input_channel_
->
Size
()
==
0
)
&&
slots_shuffle_original_data_
.
size
()
==
0
&&
out_channel_size
==
0
)
{
slots_shuffle_original_data_
.
size
()
==
0
&&
out_channel_size
==
0
)
{
VLOG
(
3
)
<<
"DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle"
;
VLOG
(
3
)
<<
"DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle"
;
return
;
return
;
}
}
platform
::
Timer
timeline
;
timeline
.
Start
();
auto
multi_slot_desc
=
data_feed_desc_
.
multi_slot_desc
();
auto
multi_slot_desc
=
data_feed_desc_
.
multi_slot_desc
();
std
::
set
<
uint16_t
>
index_slots
;
for
(
int
i
=
0
;
i
<
multi_slot_desc
.
slots_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
multi_slot_desc
.
slots_size
();
++
i
)
{
std
::
string
cur_slot
=
multi_slot_desc
.
slots
(
i
).
name
();
std
::
string
cur_slot
=
multi_slot_desc
.
slots
(
i
).
name
();
if
(
slots_to_replace
.
find
(
cur_slot
)
!=
slots_to_replace
.
end
())
{
if
(
slots_to_replace
.
find
(
cur_slot
)
!=
slots_to_replace
.
end
())
{
...
@@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle(
...
@@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle(
}
}
CHECK
(
input_channel_
->
Size
()
==
0
)
CHECK
(
input_channel_
->
Size
()
==
0
)
<<
"input channel should be empty before slots shuffle"
;
<<
"input channel should be empty before slots shuffle"
;
}
// slots shuffle to input_channel_ with needed-shuffle slots
void
MultiSlotDataset
::
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
{
PADDLE_ENFORCE_EQ
(
slots_shuffle_fea_eval_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"fea eval mode off, need to set on for slots shuffle"
));
platform
::
Timer
timeline
;
timeline
.
Start
();
std
::
unordered_set
<
uint16_t
>
index_slots
;
PreprocessChannel
(
slots_to_replace
,
index_slots
);
std
::
vector
<
Record
>
random_data
;
std
::
vector
<
Record
>
random_data
;
random_data
.
clear
();
random_data
.
clear
();
// get slots shuffled random_data
// get slots shuffled random_data
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
e6b87b31
...
@@ -67,6 +67,7 @@ class Dataset {
...
@@ -67,6 +67,7 @@ class Dataset {
virtual
void
SetParseContent
(
bool
parse_content
)
=
0
;
virtual
void
SetParseContent
(
bool
parse_content
)
=
0
;
virtual
void
SetParseLogKey
(
bool
parse_logkey
)
=
0
;
virtual
void
SetParseLogKey
(
bool
parse_logkey
)
=
0
;
virtual
void
SetEnablePvMerge
(
bool
enable_pv_merge
)
=
0
;
virtual
void
SetEnablePvMerge
(
bool
enable_pv_merge
)
=
0
;
virtual
bool
EnablePvMerge
()
=
0
;
virtual
void
SetMergeBySid
(
bool
is_merge
)
=
0
;
virtual
void
SetMergeBySid
(
bool
is_merge
)
=
0
;
// set merge by ins id
// set merge by ins id
virtual
void
SetMergeByInsId
(
int
merge_size
)
=
0
;
virtual
void
SetMergeByInsId
(
int
merge_size
)
=
0
;
...
@@ -108,10 +109,7 @@ class Dataset {
...
@@ -108,10 +109,7 @@ class Dataset {
virtual
void
LocalShuffle
()
=
0
;
virtual
void
LocalShuffle
()
=
0
;
// global shuffle data
// global shuffle data
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
)
=
0
;
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
)
=
0
;
// for slots shuffle
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
=
0
;
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
=
0
;
virtual
void
GetRandomData
(
const
std
::
set
<
uint16_t
>&
slots_to_replace
,
std
::
vector
<
Record
>*
result
)
=
0
;
// create readers
// create readers
virtual
void
CreateReaders
()
=
0
;
virtual
void
CreateReaders
()
=
0
;
// destroy readers
// destroy readers
...
@@ -183,6 +181,9 @@ class DatasetImpl : public Dataset {
...
@@ -183,6 +181,9 @@ class DatasetImpl : public Dataset {
virtual
int
GetThreadNum
()
{
return
thread_num_
;
}
virtual
int
GetThreadNum
()
{
return
thread_num_
;
}
virtual
int
GetTrainerNum
()
{
return
trainer_num_
;
}
virtual
int
GetTrainerNum
()
{
return
trainer_num_
;
}
virtual
Channel
<
T
>
GetInputChannel
()
{
return
input_channel_
;
}
virtual
Channel
<
T
>
GetInputChannel
()
{
return
input_channel_
;
}
virtual
void
SetInputChannel
(
const
Channel
<
T
>&
input_channel
)
{
input_channel_
=
input_channel
;
}
virtual
int64_t
GetFleetSendBatchSize
()
{
return
fleet_send_batch_size_
;
}
virtual
int64_t
GetFleetSendBatchSize
()
{
return
fleet_send_batch_size_
;
}
virtual
std
::
pair
<
std
::
string
,
std
::
string
>
GetHdfsConfig
()
{
virtual
std
::
pair
<
std
::
string
,
std
::
string
>
GetHdfsConfig
()
{
return
std
::
make_pair
(
fs_name_
,
fs_ugi_
);
return
std
::
make_pair
(
fs_name_
,
fs_ugi_
);
...
@@ -192,6 +193,7 @@ class DatasetImpl : public Dataset {
...
@@ -192,6 +193,7 @@ class DatasetImpl : public Dataset {
return
data_feed_desc_
;
return
data_feed_desc_
;
}
}
virtual
int
GetChannelNum
()
{
return
channel_num_
;
}
virtual
int
GetChannelNum
()
{
return
channel_num_
;
}
virtual
bool
EnablePvMerge
()
{
return
enable_pv_merge_
;
}
virtual
std
::
vector
<
paddle
::
framework
::
DataFeed
*>
GetReaders
();
virtual
std
::
vector
<
paddle
::
framework
::
DataFeed
*>
GetReaders
();
virtual
void
CreateChannel
();
virtual
void
CreateChannel
();
virtual
void
RegisterClientToClientMsgHandler
();
virtual
void
RegisterClientToClientMsgHandler
();
...
@@ -202,8 +204,9 @@ class DatasetImpl : public Dataset {
...
@@ -202,8 +204,9 @@ class DatasetImpl : public Dataset {
virtual
void
LocalShuffle
();
virtual
void
LocalShuffle
();
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
);
virtual
void
GlobalShuffle
(
int
thread_num
=
-
1
);
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
{}
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
{}
virtual
void
GetRandomData
(
const
std
::
set
<
uint16_t
>&
slots_to_replace
,
virtual
const
std
::
vector
<
T
>&
GetSlotsOriginalData
()
{
std
::
vector
<
Record
>*
result
)
{}
return
slots_shuffle_original_data_
;
}
virtual
void
CreateReaders
();
virtual
void
CreateReaders
();
virtual
void
DestroyReaders
();
virtual
void
DestroyReaders
();
virtual
int64_t
GetMemoryDataSize
();
virtual
int64_t
GetMemoryDataSize
();
...
@@ -293,9 +296,13 @@ class MultiSlotDataset : public DatasetImpl<Record> {
...
@@ -293,9 +296,13 @@ class MultiSlotDataset : public DatasetImpl<Record> {
}
}
std
::
vector
<
std
::
unordered_set
<
uint64_t
>>
().
swap
(
local_tables_
);
std
::
vector
<
std
::
unordered_set
<
uint64_t
>>
().
swap
(
local_tables_
);
}
}
virtual
void
PreprocessChannel
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
,
std
::
unordered_set
<
uint16_t
>&
index_slot
);
// NOLINT
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
);
virtual
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
);
virtual
void
GetRandomData
(
const
std
::
set
<
uint16_t
>&
slots_to_replace
,
virtual
void
GetRandomData
(
std
::
vector
<
Record
>*
result
);
const
std
::
unordered_set
<
uint16_t
>&
slots_to_replace
,
std
::
vector
<
Record
>*
result
);
virtual
~
MultiSlotDataset
()
{}
virtual
~
MultiSlotDataset
()
{}
};
};
...
...
paddle/fluid/framework/fleet/box_wrapper.cc
浏览文件 @
e6b87b31
...
@@ -255,6 +255,113 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
...
@@ -255,6 +255,113 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
<<
" s"
;
<<
" s"
;
VLOG
(
3
)
<<
"End PushSparseGrad"
;
VLOG
(
3
)
<<
"End PushSparseGrad"
;
}
}
void
BoxWrapper
::
GetRandomReplace
(
const
std
::
vector
<
Record
>&
pass_data
)
{
VLOG
(
0
)
<<
"Begin GetRandomReplace"
;
size_t
ins_num
=
pass_data
.
size
();
replace_idx_
.
resize
(
ins_num
);
for
(
auto
&
cand_list
:
random_ins_pool_list
)
{
cand_list
.
ReInitPass
();
}
std
::
vector
<
std
::
thread
>
threads
;
for
(
int
tid
=
0
;
tid
<
auc_runner_thread_num_
;
++
tid
)
{
threads
.
push_back
(
std
::
thread
([
this
,
&
pass_data
,
tid
,
ins_num
]()
{
int
start
=
tid
*
ins_num
/
auc_runner_thread_num_
;
int
end
=
(
tid
+
1
)
*
ins_num
/
auc_runner_thread_num_
;
VLOG
(
3
)
<<
"GetRandomReplace begin for thread["
<<
tid
<<
"], and process ["
<<
start
<<
", "
<<
end
<<
"), total ins: "
<<
ins_num
;
auto
&
random_pool
=
random_ins_pool_list
[
tid
];
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
const
auto
&
ins
=
pass_data
[
i
];
random_pool
.
AddAndGet
(
ins
,
replace_idx_
[
i
]);
}
}));
}
for
(
int
tid
=
0
;
tid
<
auc_runner_thread_num_
;
++
tid
)
{
threads
[
tid
].
join
();
}
pass_done_semi_
->
Put
(
1
);
VLOG
(
0
)
<<
"End GetRandomReplace"
;
}
void
BoxWrapper
::
GetRandomData
(
const
std
::
vector
<
Record
>&
pass_data
,
const
std
::
unordered_set
<
uint16_t
>&
slots_to_replace
,
std
::
vector
<
Record
>*
result
)
{
VLOG
(
0
)
<<
"Begin GetRandomData"
;
std
::
vector
<
std
::
thread
>
threads
;
for
(
int
tid
=
0
;
tid
<
auc_runner_thread_num_
;
++
tid
)
{
threads
.
push_back
(
std
::
thread
([
this
,
&
pass_data
,
tid
,
&
slots_to_replace
,
result
]()
{
int
debug_erase_cnt
=
0
;
int
debug_push_cnt
=
0
;
size_t
ins_num
=
pass_data
.
size
();
int
start
=
tid
*
ins_num
/
auc_runner_thread_num_
;
int
end
=
(
tid
+
1
)
*
ins_num
/
auc_runner_thread_num_
;
VLOG
(
3
)
<<
"GetRandomData begin for thread["
<<
tid
<<
"], and process ["
<<
start
<<
", "
<<
end
<<
"), total ins: "
<<
ins_num
;
const
auto
&
random_pool
=
random_ins_pool_list
[
tid
];
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
const
auto
&
ins
=
pass_data
[
i
];
const
RecordCandidate
&
rand_rec
=
random_pool
.
Get
(
replace_idx_
[
i
]);
Record
new_rec
=
ins
;
for
(
auto
it
=
new_rec
.
uint64_feasigns_
.
begin
();
it
!=
new_rec
.
uint64_feasigns_
.
end
();)
{
if
(
slots_to_replace
.
find
(
it
->
slot
())
!=
slots_to_replace
.
end
())
{
it
=
new_rec
.
uint64_feasigns_
.
erase
(
it
);
debug_erase_cnt
+=
1
;
}
else
{
++
it
;
}
}
for
(
auto
slot
:
slots_to_replace
)
{
auto
range
=
rand_rec
.
feas_
.
equal_range
(
slot
);
for
(
auto
it
=
range
.
first
;
it
!=
range
.
second
;
++
it
)
{
new_rec
.
uint64_feasigns_
.
push_back
({
it
->
second
,
it
->
first
});
debug_push_cnt
+=
1
;
}
}
(
*
result
)[
i
]
=
std
::
move
(
new_rec
);
}
VLOG
(
3
)
<<
"thread["
<<
tid
<<
"]: erase feasign num: "
<<
debug_erase_cnt
<<
" repush feasign num: "
<<
debug_push_cnt
;
}));
}
for
(
int
tid
=
0
;
tid
<
auc_runner_thread_num_
;
++
tid
)
{
threads
[
tid
].
join
();
}
VLOG
(
0
)
<<
"End GetRandomData"
;
}
void
BoxWrapper
::
AddReplaceFeasign
(
boxps
::
PSAgentBase
*
p_agent
,
int
feed_pass_thread_num
)
{
VLOG
(
0
)
<<
"Enter AddReplaceFeasign Function"
;
int
semi
;
pass_done_semi_
->
Get
(
semi
);
VLOG
(
0
)
<<
"Last Pass had updated random pool done. Begin AddReplaceFeasign"
;
std
::
vector
<
std
::
thread
>
threads
;
for
(
int
tid
=
0
;
tid
<
feed_pass_thread_num
;
++
tid
)
{
threads
.
push_back
(
std
::
thread
([
this
,
tid
,
p_agent
,
feed_pass_thread_num
]()
{
VLOG
(
3
)
<<
"AddReplaceFeasign begin for thread["
<<
tid
<<
"]"
;
for
(
size_t
pool_id
=
tid
;
pool_id
<
random_ins_pool_list
.
size
();
pool_id
+=
feed_pass_thread_num
)
{
auto
&
random_pool
=
random_ins_pool_list
[
pool_id
];
for
(
size_t
i
=
0
;
i
<
random_pool
.
Size
();
++
i
)
{
auto
&
ins_candidate
=
random_pool
.
Get
(
i
);
for
(
const
auto
&
pair
:
ins_candidate
.
feas_
)
{
p_agent
->
AddKey
(
pair
.
second
.
uint64_feasign_
,
tid
);
}
}
}
}));
}
for
(
int
tid
=
0
;
tid
<
feed_pass_thread_num
;
++
tid
)
{
threads
[
tid
].
join
();
}
VLOG
(
0
)
<<
"End AddReplaceFeasign"
;
}
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
#endif
#endif
paddle/fluid/framework/fleet/box_wrapper.h
浏览文件 @
e6b87b31
...
@@ -31,10 +31,12 @@ limitations under the License. */
...
@@ -31,10 +31,12 @@ limitations under the License. */
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <mutex> // NOLINT
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <string>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -469,16 +471,16 @@ class BoxWrapper {
...
@@ -469,16 +471,16 @@ class BoxWrapper {
public:
public:
MetricMsg
()
{}
MetricMsg
()
{}
MetricMsg
(
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
MetricMsg
(
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
int
is_join
,
int
bucket_size
=
1000000
)
int
metric_phase
,
int
bucket_size
=
1000000
)
:
label_varname_
(
label_varname
),
:
label_varname_
(
label_varname
),
pred_varname_
(
pred_varname
),
pred_varname_
(
pred_varname
),
is_join_
(
is_join
)
{
metric_phase_
(
metric_phase
)
{
calculator
=
new
BasicAucCalculator
();
calculator
=
new
BasicAucCalculator
();
calculator
->
init
(
bucket_size
);
calculator
->
init
(
bucket_size
);
}
}
virtual
~
MetricMsg
()
{}
virtual
~
MetricMsg
()
{}
int
IsJoin
()
const
{
return
is_join
_
;
}
int
MetricPhase
()
const
{
return
metric_phase
_
;
}
BasicAucCalculator
*
GetCalculator
()
{
return
calculator
;
}
BasicAucCalculator
*
GetCalculator
()
{
return
calculator
;
}
virtual
void
add_data
(
const
Scope
*
exe_scope
)
{
virtual
void
add_data
(
const
Scope
*
exe_scope
)
{
std
::
vector
<
int64_t
>
label_data
;
std
::
vector
<
int64_t
>
label_data
;
...
@@ -514,20 +516,20 @@ class BoxWrapper {
...
@@ -514,20 +516,20 @@ class BoxWrapper {
protected:
protected:
std
::
string
label_varname_
;
std
::
string
label_varname_
;
std
::
string
pred_varname_
;
std
::
string
pred_varname_
;
int
is_join
_
;
int
metric_phase
_
;
BasicAucCalculator
*
calculator
;
BasicAucCalculator
*
calculator
;
};
};
class
MultiTaskMetricMsg
:
public
MetricMsg
{
class
MultiTaskMetricMsg
:
public
MetricMsg
{
public:
public:
MultiTaskMetricMsg
(
const
std
::
string
&
label_varname
,
MultiTaskMetricMsg
(
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname_list
,
int
is_join
,
const
std
::
string
&
pred_varname_list
,
int
metric_phase
,
const
std
::
string
&
cmatch_rank_group
,
const
std
::
string
&
cmatch_rank_group
,
const
std
::
string
&
cmatch_rank_varname
,
const
std
::
string
&
cmatch_rank_varname
,
int
bucket_size
=
1000000
)
{
int
bucket_size
=
1000000
)
{
label_varname_
=
label_varname
;
label_varname_
=
label_varname
;
cmatch_rank_varname_
=
cmatch_rank_varname
;
cmatch_rank_varname_
=
cmatch_rank_varname
;
is_join_
=
is_join
;
metric_phase_
=
metric_phase
;
calculator
=
new
BasicAucCalculator
();
calculator
=
new
BasicAucCalculator
();
calculator
->
init
(
bucket_size
);
calculator
->
init
(
bucket_size
);
for
(
auto
&
cmatch_rank
:
string
::
split_string
(
cmatch_rank_group
))
{
for
(
auto
&
cmatch_rank
:
string
::
split_string
(
cmatch_rank_group
))
{
...
@@ -594,14 +596,14 @@ class BoxWrapper {
...
@@ -594,14 +596,14 @@ class BoxWrapper {
class
CmatchRankMetricMsg
:
public
MetricMsg
{
class
CmatchRankMetricMsg
:
public
MetricMsg
{
public:
public:
CmatchRankMetricMsg
(
const
std
::
string
&
label_varname
,
CmatchRankMetricMsg
(
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
int
is_join
,
const
std
::
string
&
pred_varname
,
int
metric_phase
,
const
std
::
string
&
cmatch_rank_group
,
const
std
::
string
&
cmatch_rank_group
,
const
std
::
string
&
cmatch_rank_varname
,
const
std
::
string
&
cmatch_rank_varname
,
int
bucket_size
=
1000000
)
{
int
bucket_size
=
1000000
)
{
label_varname_
=
label_varname
;
label_varname_
=
label_varname
;
pred_varname_
=
pred_varname
;
pred_varname_
=
pred_varname
;
cmatch_rank_varname_
=
cmatch_rank_varname
;
cmatch_rank_varname_
=
cmatch_rank_varname
;
is_join_
=
is_join
;
metric_phase_
=
metric_phase
;
calculator
=
new
BasicAucCalculator
();
calculator
=
new
BasicAucCalculator
();
calculator
->
init
(
bucket_size
);
calculator
->
init
(
bucket_size
);
for
(
auto
&
cmatch_rank
:
string
::
split_string
(
cmatch_rank_group
))
{
for
(
auto
&
cmatch_rank
:
string
::
split_string
(
cmatch_rank_group
))
{
...
@@ -653,12 +655,12 @@ class BoxWrapper {
...
@@ -653,12 +655,12 @@ class BoxWrapper {
class
MaskMetricMsg
:
public
MetricMsg
{
class
MaskMetricMsg
:
public
MetricMsg
{
public:
public:
MaskMetricMsg
(
const
std
::
string
&
label_varname
,
MaskMetricMsg
(
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
int
is_join
,
const
std
::
string
&
pred_varname
,
int
metric_phase
,
const
std
::
string
&
mask_varname
,
int
bucket_size
=
1000000
)
{
const
std
::
string
&
mask_varname
,
int
bucket_size
=
1000000
)
{
label_varname_
=
label_varname
;
label_varname_
=
label_varname
;
pred_varname_
=
pred_varname
;
pred_varname_
=
pred_varname
;
mask_varname_
=
mask_varname
;
mask_varname_
=
mask_varname
;
is_join_
=
is_join
;
metric_phase_
=
metric_phase
;
calculator
=
new
BasicAucCalculator
();
calculator
=
new
BasicAucCalculator
();
calculator
->
init
(
bucket_size
);
calculator
->
init
(
bucket_size
);
}
}
...
@@ -682,36 +684,59 @@ class BoxWrapper {
...
@@ -682,36 +684,59 @@ class BoxWrapper {
protected:
protected:
std
::
string
mask_varname_
;
std
::
string
mask_varname_
;
};
};
const
std
::
vector
<
std
::
string
>&
GetMetricNameList
()
const
{
const
std
::
vector
<
std
::
string
>
GetMetricNameList
(
return
metric_name_list_
;
int
metric_phase
=
-
1
)
const
{
VLOG
(
0
)
<<
"Want to Get metric phase: "
<<
metric_phase
;
if
(
metric_phase
==
-
1
)
{
return
metric_name_list_
;
}
else
{
std
::
vector
<
std
::
string
>
ret
;
for
(
const
auto
&
name
:
metric_name_list_
)
{
const
auto
iter
=
metric_lists_
.
find
(
name
);
PADDLE_ENFORCE_NE
(
iter
,
metric_lists_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"The metric name you provided is not registered."
));
if
(
iter
->
second
->
MetricPhase
()
==
metric_phase
)
{
VLOG
(
0
)
<<
name
<<
"'s phase is "
<<
iter
->
second
->
MetricPhase
()
<<
", we want"
;
ret
.
push_back
(
name
);
}
else
{
VLOG
(
0
)
<<
name
<<
"'s phase is "
<<
iter
->
second
->
MetricPhase
()
<<
", not we want"
;
}
}
return
ret
;
}
}
}
int
P
assFlag
()
const
{
return
pass_flag
_
;
}
int
P
hase
()
const
{
return
phase
_
;
}
void
FlipP
assFlag
()
{
pass_flag_
=
1
-
pass_flag
_
;
}
void
FlipP
hase
()
{
phase_
=
(
phase_
+
1
)
%
phase_num
_
;
}
std
::
map
<
std
::
string
,
MetricMsg
*>&
GetMetricList
()
{
return
metric_lists_
;
}
std
::
map
<
std
::
string
,
MetricMsg
*>&
GetMetricList
()
{
return
metric_lists_
;
}
void
InitMetric
(
const
std
::
string
&
method
,
const
std
::
string
&
name
,
void
InitMetric
(
const
std
::
string
&
method
,
const
std
::
string
&
name
,
const
std
::
string
&
label_varname
,
const
std
::
string
&
label_varname
,
const
std
::
string
&
pred_varname
,
const
std
::
string
&
pred_varname
,
const
std
::
string
&
cmatch_rank_varname
,
const
std
::
string
&
cmatch_rank_varname
,
const
std
::
string
&
mask_varname
,
bool
is_join
,
const
std
::
string
&
mask_varname
,
int
metric_phase
,
const
std
::
string
&
cmatch_rank_group
,
const
std
::
string
&
cmatch_rank_group
,
int
bucket_size
=
1000000
)
{
int
bucket_size
=
1000000
)
{
if
(
method
==
"AucCalculator"
)
{
if
(
method
==
"AucCalculator"
)
{
metric_lists_
.
emplace
(
name
,
new
MetricMsg
(
label_varname
,
pred_varname
,
metric_lists_
.
emplace
(
name
,
new
MetricMsg
(
label_varname
,
pred_varname
,
is_join
?
1
:
0
,
bucket_size
));
metric_phase
,
bucket_size
));
}
else
if
(
method
==
"MultiTaskAucCalculator"
)
{
}
else
if
(
method
==
"MultiTaskAucCalculator"
)
{
metric_lists_
.
emplace
(
metric_lists_
.
emplace
(
name
,
new
MultiTaskMetricMsg
(
label_varname
,
pred_varname
,
name
,
new
MultiTaskMetricMsg
(
label_varname
,
pred_varname
,
is_join
?
1
:
0
,
cmatch_rank_group
,
metric_phase
,
cmatch_rank_group
,
cmatch_rank_varname
,
bucket_size
));
cmatch_rank_varname
,
bucket_size
));
}
else
if
(
method
==
"CmatchRankAucCalculator"
)
{
}
else
if
(
method
==
"CmatchRankAucCalculator"
)
{
metric_lists_
.
emplace
(
metric_lists_
.
emplace
(
name
,
new
CmatchRankMetricMsg
(
label_varname
,
pred_varname
,
name
,
new
CmatchRankMetricMsg
(
label_varname
,
pred_varname
,
is_join
?
1
:
0
,
cmatch_rank_group
,
metric_phase
,
cmatch_rank_group
,
cmatch_rank_varname
,
bucket_size
));
cmatch_rank_varname
,
bucket_size
));
}
else
if
(
method
==
"MaskAucCalculator"
)
{
}
else
if
(
method
==
"MaskAucCalculator"
)
{
metric_lists_
.
emplace
(
metric_lists_
.
emplace
(
name
,
new
MaskMetricMsg
(
label_varname
,
pred_varname
,
is_join
?
1
:
0
,
name
,
new
MaskMetricMsg
(
label_varname
,
pred_varname
,
metric_phase
,
mask_varname
,
bucket_size
));
mask_varname
,
bucket_size
));
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
@@ -753,7 +778,8 @@ class BoxWrapper {
...
@@ -753,7 +778,8 @@ class BoxWrapper {
std
::
unordered_set
<
std
::
string
>
slot_name_omited_in_feedpass_
;
std
::
unordered_set
<
std
::
string
>
slot_name_omited_in_feedpass_
;
// Metric Related
// Metric Related
int
pass_flag_
=
1
;
// join: 1, update: 0
int
phase_
=
1
;
int
phase_num_
=
2
;
std
::
map
<
std
::
string
,
MetricMsg
*>
metric_lists_
;
std
::
map
<
std
::
string
,
MetricMsg
*>
metric_lists_
;
std
::
vector
<
std
::
string
>
metric_name_list_
;
std
::
vector
<
std
::
string
>
metric_name_list_
;
std
::
vector
<
int
>
slot_vector_
;
std
::
vector
<
int
>
slot_vector_
;
...
@@ -762,6 +788,57 @@ class BoxWrapper {
...
@@ -762,6 +788,57 @@ class BoxWrapper {
public:
public:
static
AfsManager
*
afs_manager
;
static
AfsManager
*
afs_manager
;
// Auc Runner
public:
void
InitializeAucRunner
(
std
::
vector
<
std
::
vector
<
std
::
string
>>
slot_eval
,
int
thread_num
,
int
pool_size
,
std
::
vector
<
std
::
string
>
slot_list
)
{
mode_
=
1
;
phase_num_
=
static_cast
<
int
>
(
slot_eval
.
size
());
phase_
=
phase_num_
-
1
;
auc_runner_thread_num_
=
thread_num
;
pass_done_semi_
=
paddle
::
framework
::
MakeChannel
<
int
>
();
pass_done_semi_
->
Put
(
1
);
// Note: At most 1 pipeline in AucRunner
random_ins_pool_list
.
resize
(
thread_num
);
std
::
unordered_set
<
std
::
string
>
slot_set
;
for
(
size_t
i
=
0
;
i
<
slot_eval
.
size
();
++
i
)
{
for
(
const
auto
&
slot
:
slot_eval
[
i
])
{
slot_set
.
insert
(
slot
);
}
}
for
(
size_t
i
=
0
;
i
<
slot_list
.
size
();
++
i
)
{
if
(
slot_set
.
find
(
slot_list
[
i
])
!=
slot_set
.
end
())
{
slot_index_to_replace_
.
insert
(
static_cast
<
int16_t
>
(
i
));
}
}
for
(
int
i
=
0
;
i
<
auc_runner_thread_num_
;
++
i
)
{
random_ins_pool_list
[
i
].
SetSlotIndexToReplace
(
slot_index_to_replace_
);
}
VLOG
(
0
)
<<
"AucRunner configuration: thread number["
<<
thread_num
<<
"], pool size["
<<
pool_size
<<
"], runner_group["
<<
phase_num_
<<
"]"
;
VLOG
(
0
)
<<
"Slots that need to be evaluated:"
;
for
(
auto
e
:
slot_index_to_replace_
)
{
VLOG
(
0
)
<<
e
<<
": "
<<
slot_list
[
e
];
}
}
void
GetRandomReplace
(
const
std
::
vector
<
Record
>&
pass_data
);
void
AddReplaceFeasign
(
boxps
::
PSAgentBase
*
p_agent
,
int
feed_pass_thread_num
);
void
GetRandomData
(
const
std
::
vector
<
Record
>&
pass_data
,
const
std
::
unordered_set
<
uint16_t
>&
slots_to_replace
,
std
::
vector
<
Record
>*
result
);
int
Mode
()
const
{
return
mode_
;
}
private:
int
mode_
=
0
;
// 0 means train/test 1 means auc_runner
int
auc_runner_thread_num_
=
1
;
bool
init_done_
=
false
;
paddle
::
framework
::
Channel
<
int
>
pass_done_semi_
;
std
::
unordered_set
<
uint16_t
>
slot_index_to_replace_
;
std
::
vector
<
RecordCandidateList
>
random_ins_pool_list
;
std
::
vector
<
size_t
>
replace_idx_
;
};
};
#endif
#endif
...
@@ -810,7 +887,38 @@ class BoxHelper {
...
@@ -810,7 +887,38 @@ class BoxHelper {
VLOG
(
3
)
<<
"After PreLoadIntoMemory()"
;
VLOG
(
3
)
<<
"After PreLoadIntoMemory()"
;
}
}
void
WaitFeedPassDone
()
{
feed_data_thread_
->
join
();
}
void
WaitFeedPassDone
()
{
feed_data_thread_
->
join
();
}
void
SlotsShuffle
(
const
std
::
set
<
std
::
string
>&
slots_to_replace
)
{
#ifdef PADDLE_WITH_BOX_PS
auto
box_ptr
=
BoxWrapper
::
GetInstance
();
PADDLE_ENFORCE_EQ
(
box_ptr
->
Mode
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"Should call InitForAucRunner first."
));
box_ptr
->
FlipPhase
();
std
::
unordered_set
<
uint16_t
>
index_slots
;
dynamic_cast
<
MultiSlotDataset
*>
(
dataset_
)
->
PreprocessChannel
(
slots_to_replace
,
index_slots
);
const
std
::
vector
<
Record
>&
pass_data
=
dynamic_cast
<
MultiSlotDataset
*>
(
dataset_
)
->
GetSlotsOriginalData
();
if
(
!
get_random_replace_done_
)
{
box_ptr
->
GetRandomReplace
(
pass_data
);
get_random_replace_done_
=
true
;
}
std
::
vector
<
Record
>
random_data
;
random_data
.
resize
(
pass_data
.
size
());
box_ptr
->
GetRandomData
(
pass_data
,
index_slots
,
&
random_data
);
auto
new_input_channel
=
paddle
::
framework
::
MakeChannel
<
Record
>
();
new_input_channel
->
Open
();
new_input_channel
->
Write
(
std
::
move
(
random_data
));
new_input_channel
->
Close
();
dynamic_cast
<
MultiSlotDataset
*>
(
dataset_
)
->
SetInputChannel
(
new_input_channel
);
if
(
dataset_
->
EnablePvMerge
())
{
dataset_
->
PreprocessInstance
();
}
#endif
}
#ifdef PADDLE_WITH_BOX_PS
#ifdef PADDLE_WITH_BOX_PS
// notify boxps to feed this pass feasigns from SSD to memory
// notify boxps to feed this pass feasigns from SSD to memory
static
void
FeedPassThread
(
const
std
::
deque
<
Record
>&
t
,
int
begin_index
,
static
void
FeedPassThread
(
const
std
::
deque
<
Record
>&
t
,
int
begin_index
,
...
@@ -881,6 +989,10 @@ class BoxHelper {
...
@@ -881,6 +989,10 @@ class BoxHelper {
for
(
size_t
i
=
0
;
i
<
tnum
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tnum
;
++
i
)
{
threads
[
i
].
join
();
threads
[
i
].
join
();
}
}
if
(
box_ptr
->
Mode
()
==
1
)
{
box_ptr
->
AddReplaceFeasign
(
p_agent
,
tnum
);
}
VLOG
(
3
)
<<
"Begin call EndFeedPass in BoxPS"
;
VLOG
(
3
)
<<
"Begin call EndFeedPass in BoxPS"
;
box_ptr
->
EndFeedPass
(
p_agent
);
box_ptr
->
EndFeedPass
(
p_agent
);
#endif
#endif
...
@@ -892,6 +1004,7 @@ class BoxHelper {
...
@@ -892,6 +1004,7 @@ class BoxHelper {
int
year_
;
int
year_
;
int
month_
;
int
month_
;
int
day_
;
int
day_
;
bool
get_random_replace_done_
=
false
;
};
};
}
// end namespace framework
}
// end namespace framework
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
e6b87b31
...
@@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() {
...
@@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() {
auto
&
metric_list
=
box_ptr
->
GetMetricList
();
auto
&
metric_list
=
box_ptr
->
GetMetricList
();
for
(
auto
iter
=
metric_list
.
begin
();
iter
!=
metric_list
.
end
();
iter
++
)
{
for
(
auto
iter
=
metric_list
.
begin
();
iter
!=
metric_list
.
end
();
iter
++
)
{
auto
*
metric_msg
=
iter
->
second
;
auto
*
metric_msg
=
iter
->
second
;
if
(
metric_msg
->
IsJoin
()
!=
box_ptr
->
PassFlag
())
{
if
(
box_ptr
->
Phase
()
!=
metric_msg
->
MetricPhase
())
{
continue
;
continue
;
}
}
metric_msg
->
add_data
(
exe_scope
);
metric_msg
->
add_data
(
exe_scope
);
...
@@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() {
auto
&
metric_list
=
box_ptr
->
GetMetricList
();
auto
&
metric_list
=
box_ptr
->
GetMetricList
();
for
(
auto
iter
=
metric_list
.
begin
();
iter
!=
metric_list
.
end
();
iter
++
)
{
for
(
auto
iter
=
metric_list
.
begin
();
iter
!=
metric_list
.
end
();
iter
++
)
{
auto
*
metric_msg
=
iter
->
second
;
auto
*
metric_msg
=
iter
->
second
;
if
(
metric_msg
->
IsJoin
()
!=
box_ptr
->
PassFlag
())
{
if
(
box_ptr
->
Phase
()
!=
metric_msg
->
MetricPhase
())
{
continue
;
continue
;
}
}
metric_msg
->
add_data
(
exe_scope
);
metric_msg
->
add_data
(
exe_scope
);
...
...
paddle/fluid/pybind/box_helper_py.cc
浏览文件 @
e6b87b31
...
@@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) {
...
@@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) {
.
def
(
"preload_into_memory"
,
&
framework
::
BoxHelper
::
PreLoadIntoMemory
,
.
def
(
"preload_into_memory"
,
&
framework
::
BoxHelper
::
PreLoadIntoMemory
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"load_into_memory"
,
&
framework
::
BoxHelper
::
LoadIntoMemory
,
.
def
(
"load_into_memory"
,
&
framework
::
BoxHelper
::
LoadIntoMemory
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"slots_shuffle"
,
&
framework
::
BoxHelper
::
SlotsShuffle
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
// end BoxHelper
}
// end BoxHelper
...
@@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) {
...
@@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) {
.
def
(
"initialize_gpu_and_load_model"
,
.
def
(
"initialize_gpu_and_load_model"
,
&
framework
::
BoxWrapper
::
InitializeGPUAndLoadModel
,
&
framework
::
BoxWrapper
::
InitializeGPUAndLoadModel
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"initialize_auc_runner"
,
&
framework
::
BoxWrapper
::
InitializeAucRunner
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"init_metric"
,
&
framework
::
BoxWrapper
::
InitMetric
,
.
def
(
"init_metric"
,
&
framework
::
BoxWrapper
::
InitMetric
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"get_metric_msg"
,
&
framework
::
BoxWrapper
::
GetMetricMsg
,
.
def
(
"get_metric_msg"
,
&
framework
::
BoxWrapper
::
GetMetricMsg
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"get_metric_name_list"
,
&
framework
::
BoxWrapper
::
GetMetricNameList
,
.
def
(
"get_metric_name_list"
,
&
framework
::
BoxWrapper
::
GetMetricNameList
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"flip_p
ass_flag"
,
&
framework
::
BoxWrapper
::
FlipPassFlag
,
.
def
(
"flip_p
hase"
,
&
framework
::
BoxWrapper
::
FlipPhase
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"init_afs_api"
,
&
framework
::
BoxWrapper
::
InitAfsAPI
,
.
def
(
"init_afs_api"
,
&
framework
::
BoxWrapper
::
InitAfsAPI
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
e6b87b31
...
@@ -291,6 +291,8 @@ void BindDataset(py::module *m) {
...
@@ -291,6 +291,8 @@ void BindDataset(py::module *m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"set_fleet_send_sleep_seconds"
,
.
def
(
"set_fleet_send_sleep_seconds"
,
&
framework
::
Dataset
::
SetFleetSendSleepSeconds
,
&
framework
::
Dataset
::
SetFleetSendSleepSeconds
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"enable_pv_merge"
,
&
framework
::
Dataset
::
EnablePvMerge
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
class_
<
IterableDatasetWrapper
>
(
*
m
,
"IterableDatasetWrapper"
)
py
::
class_
<
IterableDatasetWrapper
>
(
*
m
,
"IterableDatasetWrapper"
)
...
...
python/paddle/fluid/dataset.py
浏览文件 @
e6b87b31
...
@@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset):
...
@@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset):
def
_dynamic_adjust_after_train
(
self
):
def
_dynamic_adjust_after_train
(
self
):
pass
pass
def
slots_shuffle
(
self
,
slots
):
"""
Slots Shuffle
Slots Shuffle is a shuffle method in slots level, which is usually used
in sparse feature with large scale of instances. To compare the metric, i.e.
auc while doing slots shuffle on one or several slots with baseline to
evaluate the importance level of slots(features).
Args:
slots(list[string]): the set of slots(string) to do slots shuffle.
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
#suppose there is a slot 0
dataset.slots_shuffle(['0'])
"""
slots_set
=
set
(
slots
)
self
.
boxps
.
slots_shuffle
(
slots_set
)
python/paddle/fluid/tests/unittests/test_boxps.py
浏览文件 @
e6b87b31
...
@@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase):
...
@@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase):
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
datasets
[
0
].
load_into_memory
()
datasets
[
0
].
load_into_memory
()
datasets
[
0
].
begin_pass
()
datasets
[
0
].
begin_pass
()
datasets
[
0
].
slots_shuffle
([])
datasets
[
1
].
preload_into_memory
()
datasets
[
1
].
preload_into_memory
()
exe
.
train_from_dataset
(
exe
.
train_from_dataset
(
program
=
fluid
.
default_main_program
(),
program
=
fluid
.
default_main_program
(),
...
...
python/paddle/fluid/tests/unittests/test_dataset.py
浏览文件 @
e6b87b31
...
@@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase):
...
@@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase):
dataset
.
set_trainer_num
(
4
)
dataset
.
set_trainer_num
(
4
)
dataset
.
set_hdfs_config
(
"my_fs_name"
,
"my_fs_ugi"
)
dataset
.
set_hdfs_config
(
"my_fs_name"
,
"my_fs_ugi"
)
dataset
.
set_download_cmd
(
"./read_from_afs my_fs_name my_fs_ugi"
)
dataset
.
set_download_cmd
(
"./read_from_afs my_fs_name my_fs_ugi"
)
dataset
.
enable_pv_merge
()
thread_num
=
dataset
.
get_thread_num
()
thread_num
=
dataset
.
get_thread_num
()
self
.
assertEqual
(
thread_num
,
12
)
self
.
assertEqual
(
thread_num
,
12
)
...
@@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase):
...
@@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase):
dataset
.
set_pipe_command
(
"cat"
)
dataset
.
set_pipe_command
(
"cat"
)
dataset
.
set_use_var
(
slots_vars
)
dataset
.
set_use_var
(
slots_vars
)
dataset
.
load_into_memory
()
dataset
.
load_into_memory
()
dataset
.
set_fea_eval
(
1
0000
,
True
)
dataset
.
set_fea_eval
(
1
,
True
)
dataset
.
slots_shuffle
([
"slot1"
])
dataset
.
slots_shuffle
([
"slot1"
])
dataset
.
local_shuffle
()
dataset
.
local_shuffle
()
dataset
.
set_generate_unique_feasigns
(
True
,
15
)
dataset
.
set_generate_unique_feasigns
(
True
,
15
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录