Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
59888bba
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
59888bba
编写于
1月 04, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
heter context support dynamic mf dim (#38487)
heter context support dynamic mf dim
上级
08b7f17d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
337 addition
and
34 deletion
+337
-34
paddle/fluid/framework/fleet/heter_context.h
paddle/fluid/framework/fleet/heter_context.h
+123
-14
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+212
-20
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
+2
-0
未找到文件。
paddle/fluid/framework/fleet/heter_context.h
浏览文件 @
59888bba
...
...
@@ -39,22 +39,45 @@ namespace framework {
class
HeterContext
{
public:
~
HeterContext
()
{
for
(
size_t
i
=
0
;
i
<
mutex_
.
size
();
++
i
)
{
delete
mutex_
[
i
];
if
(
!
multi_mf_dim_
)
{
for
(
size_t
i
=
0
;
i
<
mutex_
.
size
();
++
i
)
{
delete
mutex_
[
i
];
}
mutex_
.
clear
();
}
else
{
for
(
size_t
i
=
0
;
i
<
dim_mutex_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
dim_mutex_
[
i
].
size
();
j
++
)
{
delete
dim_mutex_
[
i
][
j
];
}
dim_mutex_
[
i
].
clear
();
}
}
mutex_
.
clear
();
}
Scope
*
scope_
{
nullptr
};
std
::
vector
<
std
::
vector
<
FeatureKey
>>
feature_keys_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
FeatureKey
>>>
feature_dim_keys_
;
#ifdef PADDLE_WITH_PSLIB
std
::
vector
<
std
::
vector
<
paddle
::
ps
::
DownpourFixedFeatureValue
*>>
value_ptr_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
paddle
::
ps
::
DownpourFixedFeatureValue
*>>>
value_dim_ptr_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
paddle
::
ps
::
DownpourFixedFeatureValue
*>>>
device_dim_ptr_
;
#endif
#ifdef PADDLE_WITH_PSCORE
std
::
vector
<
std
::
vector
<
paddle
::
distributed
::
VALUE
*>>
value_ptr_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
paddle
::
distributed
::
VALUE
*>>>
value_dim_ptr_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
paddle
::
distributed
::
VALUE
*>>>
device_dim_ptr_
;
#endif
std
::
vector
<
std
::
vector
<
FeatureValue
>>
device_values_
;
std
::
vector
<
std
::
vector
<
FeatureKey
>>
device_keys_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
FeatureKey
>>>
device_dim_keys_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
FeatureValue
>>>
device_dim_values_
;
std
::
vector
<
std
::
mutex
*>
mutex_
;
std
::
vector
<
std
::
vector
<
std
::
mutex
*>>
dim_mutex_
;
int
multi_mf_dim_
=
0
;
uint32_t
shard_num_
=
37
;
uint64_t
size
()
{
...
...
@@ -79,18 +102,78 @@ class HeterContext {
}
}
void
Reset
()
{
for
(
size_t
i
=
0
;
i
<
feature_keys_
.
size
();
++
i
)
{
feature_keys_
[
i
].
clear
();
void
init
(
int
shard_num
,
int
device_num
,
int
dim_num
)
{
shard_num_
=
shard_num
;
feature_keys_
.
resize
(
shard_num_
);
feature_dim_keys_
.
resize
(
shard_num_
);
value_ptr_
.
resize
(
shard_num_
);
value_dim_ptr_
.
resize
(
shard_num_
);
for
(
size_t
i
=
0
;
i
<
feature_dim_keys_
.
size
();
i
++
)
{
feature_dim_keys_
[
i
].
resize
(
dim_num
);
value_dim_ptr_
[
i
].
resize
(
dim_num
);
if
(
i
==
0
)
{
for
(
int
j
=
0
;
j
<
dim_num
;
j
++
)
{
feature_dim_keys_
[
i
][
j
].
push_back
(
0
);
}
}
}
for
(
size_t
i
=
0
;
i
<
value_ptr_
.
size
();
++
i
)
{
value_ptr_
[
i
].
clear
();
device_values_
.
resize
(
device_num
);
device_dim_values_
.
resize
(
device_num
);
device_keys_
.
resize
(
device_num
);
device_dim_keys_
.
resize
(
device_num
);
device_dim_ptr_
.
resize
(
device_num
);
mutex_
.
resize
(
device_num
);
dim_mutex_
.
resize
(
device_num
);
for
(
size_t
i
=
0
;
i
<
mutex_
.
size
();
++
i
)
{
mutex_
[
i
]
=
new
std
::
mutex
();
}
for
(
size_t
i
=
0
;
i
<
device_values_
.
size
();
++
i
)
{
device_values_
[
i
].
clear
();
for
(
size_t
i
=
0
;
i
<
dim_mutex_
.
size
();
++
i
)
{
dim_mutex_
[
i
].
resize
(
dim_num
);
for
(
int
j
=
0
;
j
<
dim_num
;
j
++
)
{
dim_mutex_
[
i
][
j
]
=
new
std
::
mutex
();
}
}
for
(
size_t
i
=
0
;
i
<
device_keys_
.
size
();
++
i
)
{
device_keys_
[
i
].
clear
();
multi_mf_dim_
=
dim_num
;
}
void
Reset
()
{
if
(
!
multi_mf_dim_
)
{
for
(
size_t
i
=
0
;
i
<
feature_keys_
.
size
();
++
i
)
{
feature_keys_
[
i
].
clear
();
}
for
(
size_t
i
=
0
;
i
<
value_ptr_
.
size
();
++
i
)
{
value_ptr_
[
i
].
clear
();
}
for
(
size_t
i
=
0
;
i
<
device_values_
.
size
();
++
i
)
{
device_values_
[
i
].
clear
();
}
for
(
size_t
i
=
0
;
i
<
device_keys_
.
size
();
++
i
)
{
device_keys_
[
i
].
clear
();
}
}
else
{
VLOG
(
3
)
<<
"Reset gpu task with dynamic mf dimention"
;
for
(
size_t
i
=
0
;
i
<
feature_dim_keys_
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
feature_dim_keys_
[
i
].
size
();
j
++
)
{
feature_dim_keys_
[
i
][
j
].
clear
();
}
}
for
(
size_t
i
=
0
;
i
<
value_dim_ptr_
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
value_dim_ptr_
[
i
].
size
();
j
++
)
{
value_dim_ptr_
[
i
][
j
].
clear
();
}
}
for
(
size_t
i
=
0
;
i
<
device_dim_keys_
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
device_dim_keys_
[
i
].
size
();
j
++
)
{
device_dim_keys_
[
i
][
j
].
clear
();
}
}
for
(
size_t
i
=
0
;
i
<
device_dim_ptr_
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
device_dim_ptr_
[
i
].
size
();
j
++
)
{
device_dim_ptr_
[
i
][
j
].
clear
();
}
}
}
}
void
batch_add_keys
(
...
...
@@ -115,6 +198,15 @@ class HeterContext {
feature_keys_
[
shard_num
].
begin
()
+
idx
);
}
void
batch_add_keys
(
int
shard_num
,
int
dim_id
,
const
robin_hood
::
unordered_set
<
uint64_t
>&
shard_keys
)
{
int
idx
=
feature_dim_keys_
[
shard_num
][
dim_id
].
size
();
feature_dim_keys_
[
shard_num
][
dim_id
].
resize
(
feature_dim_keys_
[
shard_num
][
dim_id
].
size
()
+
shard_keys
.
size
());
std
::
copy
(
shard_keys
.
begin
(),
shard_keys
.
end
(),
feature_dim_keys_
[
shard_num
][
dim_id
].
begin
()
+
idx
);
}
void
UniqueKeys
()
{
std
::
vector
<
std
::
thread
>
threads
;
auto
unique_func
=
[
this
](
int
i
)
{
...
...
@@ -124,9 +216,26 @@ class HeterContext {
it
=
std
::
unique
(
cur_keys
.
begin
(),
cur_keys
.
end
());
cur_keys
.
resize
(
std
::
distance
(
cur_keys
.
begin
(),
it
));
};
for
(
uint32_t
i
=
0
;
i
<
shard_num_
;
i
++
)
{
threads
.
push_back
(
std
::
thread
(
unique_func
,
i
));
auto
unique_dynamic_mf_func
=
[
this
](
int
i
,
int
j
)
{
auto
&
cur_keys
=
feature_dim_keys_
[
i
][
j
];
std
::
sort
(
cur_keys
.
begin
(),
cur_keys
.
end
());
std
::
vector
<
FeatureKey
>::
iterator
it
;
it
=
std
::
unique
(
cur_keys
.
begin
(),
cur_keys
.
end
());
cur_keys
.
resize
(
std
::
distance
(
cur_keys
.
begin
(),
it
));
};
if
(
!
multi_mf_dim_
)
{
for
(
uint32_t
i
=
0
;
i
<
shard_num_
;
i
++
)
{
threads
.
push_back
(
std
::
thread
(
unique_func
,
i
));
}
}
else
{
for
(
uint32_t
i
=
0
;
i
<
shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
.
push_back
(
std
::
thread
(
unique_dynamic_mf_func
,
i
,
j
));
}
}
VLOG
(
3
)
<<
"heter_context unique keys with dynamic mf dimention"
;
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
59888bba
...
...
@@ -45,16 +45,30 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
platform
::
Timer
timeline
;
timeline
.
Start
();
int
device_num
=
heter_devices_
.
size
();
gpu_task
->
init
(
thread_keys_shard_num_
,
device_num
);
if
(
!
multi_mf_dim_
)
{
gpu_task
->
init
(
thread_keys_shard_num_
,
device_num
);
}
else
{
gpu_task
->
init
(
thread_keys_shard_num_
,
device_num
,
multi_mf_dim_
);
}
auto
&
local_keys
=
gpu_task
->
feature_keys_
;
auto
&
local_ptr
=
gpu_task
->
value_ptr_
;
std
::
vector
<
std
::
thread
>
threads
;
// data should be in input channel
thread_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
thread_keys_
[
i
].
resize
(
thread_keys_shard_num_
);
if
(
!
multi_mf_dim_
)
{
thread_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
thread_keys_
[
i
].
resize
(
thread_keys_shard_num_
);
}
}
else
{
thread_dim_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
thread_dim_keys_
[
i
].
resize
(
thread_keys_shard_num_
);
for
(
int
j
=
0
;
j
<
thread_keys_shard_num_
;
j
++
)
{
thread_dim_keys_
[
i
][
j
].
resize
(
multi_mf_dim_
);
}
}
}
size_t
total_len
=
0
;
...
...
@@ -87,10 +101,47 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
}
};
auto
gen_dynamic_mf_func
=
[
this
](
const
std
::
deque
<
SlotRecord
>&
total_data
,
int
begin_index
,
int
end_index
,
int
i
)
{
for
(
auto
iter
=
total_data
.
begin
()
+
begin_index
;
iter
!=
total_data
.
begin
()
+
end_index
;
iter
++
)
{
const
auto
&
ins
=
*
iter
;
const
auto
&
feasign_v
=
ins
->
slot_uint64_feasigns_
.
slot_values
;
const
auto
&
slot_offset
=
ins
->
slot_uint64_feasigns_
.
slot_offsets
;
for
(
size_t
slot_idx
=
0
;
slot_idx
<
slot_offset_vector_
.
size
();
slot_idx
++
)
{
for
(
size_t
j
=
slot_offset
[
slot_offset_vector_
[
slot_idx
]];
j
<
slot_offset
[
slot_offset_vector_
[
slot_idx
]
+
1
];
j
++
)
{
int
shard_id
=
feasign_v
[
j
]
%
thread_keys_shard_num_
;
int
dim_id
=
slot_index_vec_
[
slot_idx
];
this
->
thread_dim_keys_
[
i
][
shard_id
][
dim_id
].
insert
(
feasign_v
[
j
]);
}
}
}
/*
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_dim_keys_[i][shard_id][0].insert(feasign);
}
}
*/
};
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
threads
.
push_back
(
std
::
thread
(
gen_func
,
std
::
ref
(
vec_data
),
begin
,
begin
+
len_per_thread
+
(
i
<
remain
?
1
:
0
),
i
));
if
(
!
multi_mf_dim_
)
{
VLOG
(
0
)
<<
"yxf::psgpu wrapper genfunc"
;
threads
.
push_back
(
std
::
thread
(
gen_func
,
std
::
ref
(
vec_data
),
begin
,
begin
+
len_per_thread
+
(
i
<
remain
?
1
:
0
),
i
));
}
else
{
VLOG
(
0
)
<<
"yxf::psgpu wrapper genfunc with dynamic mf"
;
threads
.
push_back
(
std
::
thread
(
gen_dynamic_mf_func
,
std
::
ref
(
vec_data
),
begin
,
begin
+
len_per_thread
+
(
i
<
remain
?
1
:
0
),
i
));
}
begin
+=
len_per_thread
+
(
i
<
remain
?
1
:
0
);
}
for
(
std
::
thread
&
t
:
threads
)
{
...
...
@@ -144,7 +195,13 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
thread_keys_
[
i
][
shard_num
].
clear
();
}
};
auto
merge_ins_dynamic_mf_func
=
[
this
,
gpu_task
](
int
shard_num
,
int
dim_id
)
{
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
++
i
)
{
gpu_task
->
batch_add_keys
(
shard_num
,
dim_id
,
thread_dim_keys_
[
i
][
shard_num
][
dim_id
]);
thread_dim_keys_
[
i
][
shard_num
][
dim_id
].
clear
();
}
};
// for (size_t i = 0; i < thread_keys_.size(); i++) {
// gpu_task->batch_add_keys(thread_keys_[i]);
// for (int j = 0; j < thread_keys_thread_num_; j++) {
...
...
@@ -152,7 +209,13 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
// }
//}
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
++
i
)
{
threads
.
push_back
(
std
::
thread
(
merge_ins_func
,
i
));
if
(
!
multi_mf_dim_
)
{
threads
.
push_back
(
std
::
thread
(
merge_ins_func
,
i
));
}
else
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
.
push_back
(
std
::
thread
(
merge_ins_dynamic_mf_func
,
i
,
j
));
}
}
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
...
...
@@ -167,9 +230,20 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
VLOG
(
1
)
<<
"GpuPs task unique cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
VLOG
(
3
)
<<
"GpuPs shard: "
<<
i
<<
" key len: "
<<
local_keys
[
i
].
size
();
local_ptr
[
i
].
resize
(
local_keys
[
i
].
size
());
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
VLOG
(
0
)
<<
"GpuPs shard: "
<<
i
<<
" key len: "
<<
local_keys
[
i
].
size
();
local_ptr
[
i
].
resize
(
local_keys
[
i
].
size
());
}
}
else
{
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
VLOG
(
0
)
<<
"GpuPs shard: "
<<
i
<<
"mf dim: "
<<
index_dim_vec_
[
j
]
<<
" key len: "
<<
gpu_task
->
feature_dim_keys_
[
i
][
j
].
size
();
gpu_task
->
value_dim_ptr_
[
i
][
j
].
resize
(
gpu_task
->
feature_dim_keys_
[
i
][
j
].
size
());
}
}
}
}
...
...
@@ -179,8 +253,20 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
auto
&
local_keys
=
gpu_task
->
feature_keys_
;
auto
&
local_ptr
=
gpu_task
->
value_ptr_
;
auto
&
local_dim_keys
=
gpu_task
->
feature_dim_keys_
;
auto
&
local_dim_ptr
=
gpu_task
->
value_dim_ptr_
;
auto
&
device_keys
=
gpu_task
->
device_keys_
;
auto
&
device_vals
=
gpu_task
->
device_values_
;
auto
&
device_dim_keys
=
gpu_task
->
device_dim_keys_
;
auto
&
device_dim_ptr
=
gpu_task
->
device_dim_ptr_
;
auto
&
device_dim_mutex
=
gpu_task
->
dim_mutex_
;
if
(
multi_mf_dim_
)
{
for
(
size_t
dev
=
0
;
dev
<
device_dim_keys
.
size
();
dev
++
)
{
device_dim_keys
[
dev
].
resize
(
multi_mf_dim_
);
device_dim_ptr
[
dev
].
resize
(
multi_mf_dim_
);
}
}
auto
&
device_mutex
=
gpu_task
->
mutex_
;
std
::
vector
<
std
::
thread
>
threads
(
thread_keys_shard_num_
);
...
...
@@ -283,8 +369,63 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
<<
local_keys
[
i
].
size
();
}
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
ptl_func
,
i
);
auto
ptl_dynamic_mf_func
=
[
this
,
&
local_dim_keys
,
&
local_dim_ptr
,
&
fleet_ptr
](
int
i
,
int
j
)
{
#ifdef PADDLE_WITH_PSLIB
size_t
key_size
=
local_dim_keys
[
i
][
j
].
size
();
int32_t
status
=
-
1
;
int32_t
cnt
=
0
;
while
(
true
)
{
auto
tt
=
fleet_ptr
->
pslib_ptr_
->
_worker_ptr
->
pull_sparse_ptr
(
reinterpret_cast
<
char
**>
(
local_dim_ptr
[
i
][
j
].
data
()),
this
->
table_id_
,
local_dim_keys
[
i
][
j
].
data
(),
key_size
);
bool
flag
=
true
;
tt
.
wait
();
try
{
status
=
tt
.
get
();
}
catch
(
const
std
::
future_error
&
e
)
{
VLOG
(
0
)
<<
"Caught a future_error with code"
<<
e
.
code
()
<<
", Message:"
<<
e
.
what
();
}
if
(
status
!=
0
)
{
VLOG
(
0
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
sleep_seconds_before_fail_exit_
);
flag
=
false
;
cnt
++
;
}
if
(
cnt
>
3
)
{
VLOG
(
0
)
<<
"fleet pull sparse failed, retry 3 times"
;
exit
(
-
1
);
}
if
(
flag
)
{
break
;
}
}
if
(
status
!=
0
)
{
LOG
(
ERROR
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
300
);
exit
(
-
1
);
}
else
{
VLOG
(
0
)
<<
"FleetWrapper Pull sparse to local done with table size: "
<<
local_dim_keys
[
i
][
j
].
size
();
}
#endif
};
if
(
!
multi_mf_dim_
)
{
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
ptl_func
,
i
);
}
}
else
{
threads
.
resize
(
thread_keys_shard_num_
*
multi_mf_dim_
);
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
[
i
*
multi_mf_dim_
+
j
]
=
std
::
thread
(
ptl_dynamic_mf_func
,
i
,
j
);
}
}
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
...
...
@@ -312,6 +453,37 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
table_id_
,
pass_id
,
pass_values
);
}
#endif
auto
build_dynamic_mf_func
=
[
this
,
device_num
,
&
local_dim_keys
,
&
local_dim_ptr
,
&
device_dim_keys
,
&
device_dim_ptr
,
&
device_dim_mutex
](
int
i
,
int
j
)
{
#ifdef PADDLE_WITH_PSLIB
std
::
vector
<
std
::
vector
<
FeatureKey
>>
task_keys
(
device_num
);
std
::
vector
<
std
::
vector
<
paddle
::
ps
::
DownpourFixedFeatureValue
*>>
task_ptrs
(
device_num
);
for
(
size_t
k
=
0
;
k
<
local_dim_keys
[
i
][
j
].
size
();
k
++
)
{
int
shard
=
local_dim_keys
[
i
][
j
][
k
]
%
device_num
;
task_keys
[
shard
].
push_back
(
local_dim_keys
[
i
][
j
][
k
]);
task_ptrs
[
shard
].
push_back
(
local_dim_ptr
[
i
][
j
][
k
]);
}
for
(
int
dev
=
0
;
dev
<
device_num
;
dev
++
)
{
for
(
int
dim
=
0
;
dim
<
multi_mf_dim_
;
dim
++
)
{
device_dim_mutex
[
dev
][
dim
]
->
lock
();
int
len
=
task_keys
[
dev
].
size
();
int
cur
=
device_dim_keys
[
dev
][
dim
].
size
();
device_dim_keys
[
dev
][
dim
].
resize
(
device_dim_keys
[
dev
][
dim
].
size
()
+
len
);
device_dim_ptr
[
dev
][
dim
].
resize
(
device_dim_ptr
[
dev
][
dim
].
size
()
+
len
);
for
(
int
k
=
0
;
k
<
len
;
++
k
)
{
device_dim_keys
[
dev
][
dim
][
cur
+
k
]
=
task_keys
[
dev
][
k
];
device_dim_ptr
[
dev
][
dim
][
cur
+
k
]
=
task_ptrs
[
dev
][
k
];
}
device_dim_mutex
[
dev
][
dim
]
->
unlock
();
}
}
#endif
};
auto
build_func
=
[
device_num
,
record_status
,
&
pass_values
,
&
local_keys
,
&
local_ptr
,
&
device_keys
,
&
device_vals
,
&
device_mutex
](
int
i
)
{
...
...
@@ -415,8 +587,17 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
}
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
build_func
,
i
);
if
(
!
multi_mf_dim_
)
{
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
build_func
,
i
);
}
}
else
{
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
[
i
*
multi_mf_dim_
+
j
]
=
std
::
thread
(
build_dynamic_mf_func
,
i
,
j
);
}
}
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
...
...
@@ -433,10 +614,21 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
std
::
vector
<
size_t
>
feature_keys_count
(
device_num
);
size_t
size_max
=
0
;
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
feature_keys_count
[
i
]
=
gpu_task
->
device_keys_
[
i
].
size
();
VLOG
(
1
)
<<
i
<<
" card contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
feature_keys_count
[
i
]
=
gpu_task
->
device_keys_
[
i
].
size
();
VLOG
(
1
)
<<
i
<<
" card contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
}
else
{
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
feature_keys_count
[
i
]
+=
gpu_task
->
device_dim_ptr_
[
i
][
j
].
size
();
}
VLOG
(
1
)
<<
i
<<
" card with dynamic mf contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
}
if
(
HeterPs_
)
{
delete
HeterPs_
;
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
浏览文件 @
59888bba
...
...
@@ -335,6 +335,8 @@ class PSGPUWrapper {
std
::
unordered_set
<
std
::
string
>
gpu_ps_config_keys_
;
HeterObjectPool
<
HeterContext
>
gpu_task_pool_
;
std
::
vector
<
std
::
vector
<
robin_hood
::
unordered_set
<
uint64_t
>>>
thread_keys_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
robin_hood
::
unordered_set
<
uint64_t
>>>>
thread_dim_keys_
;
int
thread_keys_thread_num_
=
37
;
int
thread_keys_shard_num_
=
37
;
uint64_t
max_fea_num_per_pass_
=
5000000000
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录