Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c0001a24
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c0001a24
编写于
5月 23, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
5月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Acc name (#42906)
add dymf support of gpups
上级
3b488bae
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
114 addition
and
82 deletion
+114
-82
paddle/fluid/framework/fleet/heter_context.h
paddle/fluid/framework/fleet/heter_context.h
+0
-18
paddle/fluid/framework/fleet/heter_ps/feature_value.h
paddle/fluid/framework/fleet/heter_ps/feature_value.h
+14
-0
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
+14
-1
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+38
-5
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
+1
-0
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+36
-58
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
+11
-0
未找到文件。
paddle/fluid/framework/fleet/heter_context.h
浏览文件 @
c0001a24
...
...
@@ -95,24 +95,6 @@ class HeterContext {
}
void
SetShardNum
(
uint32_t
shard_num
)
{
shard_num_
=
shard_num
;
}
uint32_t
ShardNum
()
{
return
shard_num_
;
}
void
init
(
int
shard_num
,
int
device_num
)
{
shard_num_
=
shard_num
;
feature_keys_
.
resize
(
shard_num_
);
value_ptr_
.
resize
(
shard_num_
);
device_task_ptr_
.
resize
(
shard_num_
);
device_task_keys_
.
resize
(
shard_num_
);
for
(
size_t
i
=
0
;
i
<
device_task_ptr_
.
size
();
i
++
)
{
device_task_ptr_
[
i
].
resize
(
device_num
);
device_task_keys_
[
i
].
resize
(
device_num
);
}
device_values_
.
resize
(
device_num
);
device_keys_
.
resize
(
device_num
);
mutex_
.
resize
(
device_num
);
for
(
size_t
i
=
0
;
i
<
mutex_
.
size
();
++
i
)
{
mutex_
[
i
]
=
new
std
::
mutex
();
}
}
void
init
(
int
shard_num
,
int
device_num
,
int
dim_num
)
{
shard_num_
=
shard_num
;
...
...
paddle/fluid/framework/fleet/heter_ps/feature_value.h
浏览文件 @
c0001a24
...
...
@@ -69,6 +69,20 @@ struct FeaturePushValue {
int
mf_dim
;
float
mf_g
[
0
];
__device__
__forceinline__
FeaturePushValue
operator
+
(
const
FeaturePushValue
&
a
)
const
{
FeaturePushValue
out
;
out
.
slot
=
a
.
slot
;
out
.
mf_dim
=
a
.
mf_dim
;
out
.
show
=
a
.
show
+
show
;
out
.
clk
=
a
.
clk
+
clk
;
out
.
lr_g
=
a
.
lr_g
+
lr_g
;
// out.mf_g = a.mf_g;
for
(
int
i
=
0
;
i
<
out
.
mf_dim
;
++
i
)
{
out
.
mf_g
[
i
]
=
a
.
mf_g
[
i
]
+
mf_g
[
i
];
}
return
out
;
}
__device__
__forceinline__
void
operator
=
(
const
FeaturePushValue
&
in
)
{
show
=
in
.
show
;
clk
=
in
.
clk
;
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
浏览文件 @
c0001a24
...
...
@@ -86,13 +86,26 @@ __global__ void dy_mf_search_kernel(Table* table,
char
*
vals
,
size_t
len
,
size_t
pull_feature_value_size
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// return;
if
(
i
<
len
)
{
auto
it
=
table
->
find
(
keys
[
i
]);
if
(
it
!=
table
->
end
())
{
uint64_t
offset
=
i
*
pull_feature_value_size
;
FeatureValue
&
cur
=
*
(
FeatureValue
*
)(
vals
+
offset
);
FeatureValue
*
cur
=
(
FeatureValue
*
)(
vals
+
offset
);
FeatureValue
&
input
=
*
(
FeatureValue
*
)(
it
->
second
);
cur
->
slot
=
input
.
slot
;
cur
->
show
=
input
.
show
;
cur
->
clk
=
input
.
clk
;
cur
->
mf_dim
=
input
.
mf_dim
;
cur
->
lr
=
input
.
lr
;
cur
->
mf_size
=
input
.
mf_size
;
cur
->
cpu_ptr
=
input
.
cpu_ptr
;
cur
->
delta_score
=
input
.
delta_score
;
cur
->
lr_g2sum
=
input
.
lr_g2sum
;
for
(
int
j
=
0
;
j
<
cur
->
mf_dim
+
1
;
++
j
)
{
cur
->
mf
[
j
]
=
input
.
mf
[
j
];
}
}
}
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
浏览文件 @
c0001a24
...
...
@@ -26,6 +26,7 @@ namespace framework {
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
HeterComm
<
KeyType
,
ValType
,
GradType
>::
HeterComm
(
size_t
capacity
,
std
::
shared_ptr
<
HeterPsResource
>
resource
)
{
VLOG
(
1
)
<<
"Construct new HeterComm"
;
resource_
=
resource
;
storage_
.
resize
(
resource_
->
total_device
());
multi_mf_dim_
=
resource
->
multi_mf
();
...
...
@@ -364,6 +365,10 @@ HeterComm<KeyType, ValType, GradType>::~HeterComm() {
delete
table
;
table
=
nullptr
;
}
for
(
auto
&
table
:
tables_
)
{
delete
table
;
table
=
nullptr
;
}
}
}
...
...
@@ -473,17 +478,23 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
return
;
}
int
dev_id
=
resource_
->
dev_id
(
num
);
DevPlace
place
=
DevPlace
(
dev_id
);
AnyDeviceGuard
guard
(
dev_id
);
// use hbm pool
std
::
vector
<
memory
::
allocation
::
AllocationPtr
>
d_key_bufs
;
ppStream
streams
[
stream_num
];
// NOLINT
for
(
int
i
=
0
;
i
<
stream_num
;
++
i
)
{
create_stream
(
&
(
streams
[
i
]));
auto
d_k_buf
=
memory
::
Alloc
(
place
,
chunk_size
*
sizeof
(
KeyType
));
d_key_bufs
.
push_back
(
std
::
move
(
d_k_buf
));
}
int
cur_len
=
0
;
int
cur_stream
=
0
;
while
(
cur_len
<
len
)
{
cur_stream
=
cur_stream
%
stream_num
;
auto
cur_use_stream
=
streams
[
cur_stream
];
...
...
@@ -491,8 +502,10 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
cur_use_stream
=
0
;
#endif
int
tmp_len
=
cur_len
+
chunk_size
>
len
?
len
-
cur_len
:
chunk_size
;
auto
dst_place
=
place
;
auto
src_place
=
platform
::
CPUPlace
();
memory_copy
(
dst_place
,
reinterpret_cast
<
char
*>
(
d_key_bufs
[
cur_stream
]
->
ptr
()),
src_place
,
h_keys
+
cur_len
,
sizeof
(
KeyType
)
*
tmp_len
,
cur_use_stream
);
...
...
@@ -557,14 +570,20 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
size_t
temp_storage_bytes
;
// VLOG(1) << "hetercomm merge_grad: max_mf_dim: " << max_mf_dim_;
size_t
grad_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeaturePushValue
)
+
(
max_mf_dim_
*
sizeof
(
float
)));
auto
d_merge_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_merge_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_merge_keys
->
ptr
());
auto
d_merge_grads
=
memory
::
Alloc
(
place
,
len
*
grad_value_size
);
GradType
*
d_merge_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_merge_grads
->
ptr
());
auto
d_fea_num_info
=
memory
::
Alloc
(
place
,
sizeof
(
uint32_t
)
*
(
len
*
3
+
1
));
uint32_t
*
d_fea_num_info_ptr
=
reinterpret_cast
<
uint32_t
*>
(
d_fea_num_info
->
ptr
());
...
...
@@ -836,9 +855,16 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_shard_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_shard_keys
->
ptr
());
GradType
*
d_shard_grads_ptr
;
if
(
!
multi_mf_dim_
)
{
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
GradType
));
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
}
else
{
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
grad_value_size
);
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
}
int
uniq_len
=
len
;
dynamic_merge_grad
(
dev_num
,
d_keys
,
d_grads
,
len
,
uniq_len
);
...
...
@@ -846,9 +872,16 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
split_input_to_shard
(
d_keys
,
d_idx_ptr
,
uniq_len
,
d_left_ptr
,
d_right_ptr
,
dev_num
);
if
(
!
multi_mf_dim_
)
{
heter_comm_kernel_
->
fill_shard_grads
(
d_shard_keys_ptr
,
d_keys
,
d_shard_grads_ptr
,
d_grads
,
d_idx_ptr
,
uniq_len
,
stream
);
}
else
{
heter_comm_kernel_
->
dy_mf_fill_shard_grads
(
d_shard_keys_ptr
,
d_keys
,
d_shard_grads_ptr
,
d_grads
,
d_idx_ptr
,
uniq_len
,
grad_value_size
,
stream
);
d_shard_keys_ptr
,
d_keys
,
d_shard_grads_ptr
,
d_grads
,
d_idx_ptr
,
uniq_len
,
grad_value_size
,
stream
);
}
sync_stream
(
stream
);
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
浏览文件 @
c0001a24
...
...
@@ -136,6 +136,7 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
size_t
grad_value_size
,
DynamicGradMerger
&
merger_
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
n
)
{
uint32_t
start
=
offset
[
i
];
uint32_t
num
=
fea_num
[
i
];
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
c0001a24
...
...
@@ -106,19 +106,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
platform
::
Timer
timeline
;
timeline
.
Start
();
int
device_num
=
heter_devices_
.
size
();
if
(
!
multi_mf_dim_
)
{
gpu_task
->
init
(
thread_keys_shard_num_
,
device_num
);
}
else
{
gpu_task
->
init
(
thread_keys_shard_num_
,
device_num
,
multi_mf_dim_
);
}
std
::
vector
<
std
::
thread
>
threads
;
if
(
!
multi_mf_dim_
)
{
thread_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
thread_keys_
[
i
].
resize
(
thread_keys_shard_num_
);
}
}
else
{
// data should be in input channel
thread_dim_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
thread_dim_keys_
[
i
].
resize
(
thread_keys_shard_num_
);
...
...
@@ -126,7 +119,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
thread_dim_keys_
[
i
][
j
].
resize
(
multi_mf_dim_
);
}
}
}
size_t
total_len
=
0
;
size_t
len_per_thread
=
0
;
...
...
@@ -144,18 +136,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
len_per_thread
=
total_len
/
thread_keys_thread_num_
;
remain
=
total_len
%
thread_keys_thread_num_
;
VLOG
(
0
)
<<
"total len: "
<<
total_len
;
auto
gen_func
=
[
this
](
const
std
::
deque
<
SlotRecord
>&
total_data
,
int
begin_index
,
int
end_index
,
int
i
)
{
for
(
auto
iter
=
total_data
.
begin
()
+
begin_index
;
iter
!=
total_data
.
begin
()
+
end_index
;
iter
++
)
{
const
auto
&
ins
=
*
iter
;
const
auto
&
feasign_v
=
ins
->
slot_uint64_feasigns_
.
slot_values
;
for
(
const
auto
feasign
:
feasign_v
)
{
int
shard_id
=
feasign
%
thread_keys_shard_num_
;
this
->
thread_keys_
[
i
][
shard_id
].
insert
(
feasign
);
}
}
};
auto
gen_dynamic_mf_func
=
[
this
](
const
std
::
deque
<
SlotRecord
>&
total_data
,
int
begin_index
,
int
end_index
,
int
i
)
{
for
(
auto
iter
=
total_data
.
begin
()
+
begin_index
;
...
...
@@ -177,17 +157,10 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
};
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
if
(
!
multi_mf_dim_
)
{
VLOG
(
0
)
<<
"yxf::psgpu wrapper genfunc"
;
threads
.
push_back
(
std
::
thread
(
gen_func
,
std
::
ref
(
vec_data
),
begin
,
begin
+
len_per_thread
+
(
i
<
remain
?
1
:
0
),
i
));
}
else
{
VLOG
(
0
)
<<
"yxf::psgpu wrapper genfunc with dynamic mf"
;
threads
.
push_back
(
std
::
thread
(
gen_dynamic_mf_func
,
std
::
ref
(
vec_data
),
begin
,
begin
+
len_per_thread
+
(
i
<
remain
?
1
:
0
),
i
));
}
begin
+=
len_per_thread
+
(
i
<
remain
?
1
:
0
);
}
for
(
std
::
thread
&
t
:
threads
)
{
...
...
@@ -235,12 +208,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
threads
.
clear
();
// merge thread_keys to shard_keys
auto
merge_ins_func
=
[
this
,
gpu_task
](
int
shard_num
)
{
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
++
i
)
{
gpu_task
->
batch_add_keys
(
shard_num
,
thread_keys_
[
i
][
shard_num
]);
thread_keys_
[
i
][
shard_num
].
clear
();
}
};
auto
merge_ins_dynamic_mf_func
=
[
this
,
gpu_task
](
int
shard_num
,
int
dim_id
)
{
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
++
i
)
{
gpu_task
->
batch_add_keys
(
shard_num
,
dim_id
,
...
...
@@ -249,14 +216,10 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
};
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
++
i
)
{
if
(
!
multi_mf_dim_
)
{
threads
.
push_back
(
std
::
thread
(
merge_ins_func
,
i
));
}
else
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
.
push_back
(
std
::
thread
(
merge_ins_dynamic_mf_func
,
i
,
j
));
}
}
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
...
...
@@ -297,12 +260,12 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
auto
&
device_dim_keys
=
gpu_task
->
device_dim_keys_
;
auto
&
device_dim_ptr
=
gpu_task
->
device_dim_ptr_
;
auto
&
device_dim_mutex
=
gpu_task
->
dim_mutex_
;
if
(
multi_mf_dim_
)
{
for
(
size_t
dev
=
0
;
dev
<
device_dim_keys
.
size
();
dev
++
)
{
device_dim_keys
[
dev
].
resize
(
multi_mf_dim_
);
device_dim_ptr
[
dev
].
resize
(
multi_mf_dim_
);
}
}
// auto& device_mutex = gpu_task->mutex_;
std
::
vector
<
std
::
thread
>
threads
(
thread_keys_shard_num_
);
...
...
@@ -415,6 +378,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
task_keys
[
shard
].
push_back
(
local_dim_keys
[
i
][
j
][
k
]);
task_ptrs
[
shard
].
push_back
(
local_dim_ptr
[
i
][
j
][
k
]);
}
// allocate local keys to devices
for
(
int
dev
=
0
;
dev
<
device_num
;
dev
++
)
{
device_dim_mutex
[
dev
][
j
]
->
lock
();
int
len
=
task_keys
[
dev
].
size
();
...
...
@@ -619,6 +583,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
if
(
HeterPs_
)
{
delete
HeterPs_
;
HeterPs_
=
nullptr
;
...
...
@@ -665,6 +630,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
embed_g2sum_index
()];
val
->
cpu_ptr
=
(
uint64_t
)(
device_dim_ptrs
[
k
]);
// TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
mf_dim_index
()]
=
float
(
mf_dim
);
val
->
mf_dim
=
mf_dim
;
...
...
@@ -681,11 +648,15 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
}
}
}
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
this
->
hbm_pools_
[
i
*
this
->
multi_mf_dim_
+
j
]
=
new
HBMMemoryPool
(
mem_pool
);
auto
&
cur_pool
=
this
->
hbm_pools_
[
i
*
this
->
multi_mf_dim_
+
j
];
this
->
HeterPs_
->
build_ps
(
i
,
device_dim_keys
.
data
(),
cur_pool
->
mem
(),
len
,
feature_value_size
,
500000
,
2
);
if
(
device_dim_keys
.
size
()
>
0
)
{
VLOG
(
0
)
<<
"show ptr table: "
<<
i
<<
" table kv size: "
<<
device_dim_keys
.
size
()
...
...
@@ -700,6 +671,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
threads
[
i
+
j
*
device_num
]
=
std
::
thread
(
build_dynamic_mf_func
,
i
,
j
);
}
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
...
...
@@ -723,7 +695,9 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
InitSlotInfo
();
std
::
shared_ptr
<
HeterContext
>
gpu_task
=
gpu_task_pool_
.
Get
();
gpu_task
->
Reset
();
data_ready_channel_
->
Put
(
gpu_task
);
VLOG
(
3
)
<<
"End LoadIntoMemory(), dataset["
<<
dataset_
<<
"]"
;
}
...
...
@@ -805,6 +779,7 @@ void PSGPUWrapper::EndPass() {
timer
.
Start
();
size_t
keysize_max
=
0
;
// in case of feasign_num = 0, skip dump_to_cpu
for
(
size_t
i
=
0
;
i
<
heter_devices_
.
size
();
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
keysize_max
=
...
...
@@ -821,9 +796,11 @@ void PSGPUWrapper::EndPass() {
VLOG
(
0
)
<<
"dump pool to cpu table: "
<<
i
<<
"with mf dim: "
<<
mf_dim
;
size_t
feature_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
((
mf_dim
+
1
)
*
sizeof
(
float
)));
char
*
test_build_values
=
(
char
*
)
malloc
(
feature_value_size
*
len
);
cudaMemcpy
(
test_build_values
,
hbm_pool
->
mem
(),
feature_value_size
*
len
,
cudaMemcpyDeviceToHost
);
CHECK
(
len
==
hbm_pool
->
capacity
());
#ifdef PADDLE_WITH_PSLIB
uint64_t
unuse_key
=
std
::
numeric_limits
<
uint64_t
>::
max
();
...
...
@@ -972,7 +949,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
feature_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
sizeof
(
float
)
*
(
index_dim_vec_
.
back
()
+
1
));
VLOG
(
0
)
<<
"yxf pull sparse feature_value_size: "
<<
feature_value_size
;
#ifdef PADDLE_WITH_CUDA
VLOG
(
3
)
<<
"Begine Gpu Ps PullSparse"
;
...
...
@@ -1159,6 +1135,8 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
"GPUPS: PushSparseGrad Only Support CUDAPlace Now."
));
}
all_timer
.
Pause
();
time_3
+=
all_timer
.
ElapsedSec
();
time_4
+=
push_gpups_timer
.
ElapsedSec
();
VLOG
(
3
)
<<
"PushSparseGrad total cost: "
<<
all_timer
.
ElapsedSec
()
<<
" s, of which GPUPS cost: "
<<
push_gpups_timer
.
ElapsedSec
()
<<
" s"
;
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
浏览文件 @
c0001a24
...
...
@@ -333,6 +333,11 @@ class PSGPUWrapper {
void
SetSlotOffsetVector
(
const
std
::
vector
<
int
>&
slot_offset_vector
)
{
slot_offset_vector_
=
slot_offset_vector
;
std
::
cout
<<
"yxf set: "
;
for
(
auto
s
:
slot_offset_vector_
)
{
std
::
cout
<<
s
<<
" | "
;
}
std
::
cout
<<
" end "
<<
std
::
endl
;
}
#ifdef PADDLE_WITH_CUDA
...
...
@@ -431,6 +436,12 @@ class PSGPUWrapper {
int
max_mf_dim_
{
0
};
size_t
val_type_size_
{
0
};
size_t
grad_type_size_
{
0
};
double
time_1
=
0.0
;
double
time_2
=
0.0
;
double
time_3
=
0.0
;
double
time_4
=
0.0
;
int
multi_node_
{
0
};
int
node_size_
;
uint64_t
table_id_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录