Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3f619290
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3f619290
编写于
5月 20, 2022
作者:
Y
yaoxuefeng
提交者:
GitHub
5月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge dymf branch (#42714)
merge dymf branch
上级
e726960a
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
991 addition
and
344 deletion
+991
-344
paddle/fluid/framework/fleet/heter_context.h
paddle/fluid/framework/fleet/heter_context.h
+0
-5
paddle/fluid/framework/fleet/heter_ps/feature_value.h
paddle/fluid/framework/fleet/heter_ps/feature_value.h
+31
-16
paddle/fluid/framework/fleet/heter_ps/hashtable.h
paddle/fluid/framework/fleet/heter_ps/hashtable.h
+2
-2
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
+33
-8
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
+26
-2
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+282
-116
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
+97
-0
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
+54
-0
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
+11
-0
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
+4
-1
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
+4
-0
paddle/fluid/framework/fleet/heter_ps/heter_resource.h
paddle/fluid/framework/fleet/heter_ps/heter_resource.h
+2
-0
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
+5
-4
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+285
-187
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
+127
-0
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
+25
-1
paddle/fluid/operators/pull_gpups_sparse_op.h
paddle/fluid/operators/pull_gpups_sparse_op.h
+2
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+1
-1
未找到文件。
paddle/fluid/framework/fleet/heter_context.h
浏览文件 @
3f619290
...
...
@@ -129,11 +129,6 @@ class HeterContext {
for
(
size_t
i
=
0
;
i
<
feature_dim_keys_
.
size
();
i
++
)
{
feature_dim_keys_
[
i
].
resize
(
dim_num
);
value_dim_ptr_
[
i
].
resize
(
dim_num
);
if
(
i
==
0
)
{
for
(
int
j
=
0
;
j
<
dim_num
;
j
++
)
{
feature_dim_keys_
[
i
][
j
].
push_back
(
0
);
}
}
}
device_values_
.
resize
(
device_num
);
device_dim_values_
.
resize
(
device_num
);
...
...
paddle/fluid/framework/fleet/heter_ps/feature_value.h
浏览文件 @
3f619290
...
...
@@ -32,17 +32,33 @@ struct FeatureValue {
float
lr
;
float
lr_g2sum
;
int
mf_size
;
float
mf
[
MF_DIM
+
1
]
;
int
mf_dim
;
uint64_t
cpu_ptr
;
float
mf
[
0
];
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
FeatureValue
&
val
)
{
out
<<
"show: "
<<
val
.
show
<<
" clk: "
<<
val
.
clk
<<
" slot: "
<<
val
.
slot
<<
" lr: "
<<
val
.
lr
<<
" mf_size: "
<<
val
.
mf_size
<<
" mf:"
;
for
(
int
i
=
0
;
i
<
val
.
mf_size
;
++
i
)
{
<<
" lr: "
<<
val
.
lr
<<
" mf_dim: "
<<
val
.
mf_dim
<<
"cpuptr: "
<<
val
.
cpu_ptr
<<
" mf_size: "
<<
val
.
mf_size
<<
" mf:"
;
for
(
int
i
=
0
;
i
<
val
.
mf_dim
+
1
;
++
i
)
{
out
<<
" "
<<
val
.
mf
[
i
];
}
return
out
;
}
__device__
__forceinline__
void
operator
=
(
const
FeatureValue
&
in
)
{
delta_score
=
in
.
delta_score
;
show
=
in
.
show
;
clk
=
in
.
clk
;
slot
=
in
.
slot
;
lr
=
in
.
lr
;
lr_g2sum
=
in
.
lr_g2sum
;
mf_size
=
in
.
mf_size
;
mf_dim
=
in
.
mf_dim
;
cpu_ptr
=
in
.
cpu_ptr
;
for
(
int
i
=
0
;
i
<
mf_dim
+
1
;
i
++
)
{
mf
[
i
]
=
in
.
mf
[
i
];
}
}
};
struct
FeaturePushValue
{
...
...
@@ -50,20 +66,19 @@ struct FeaturePushValue {
float
clk
;
int
slot
;
float
lr_g
;
float
mf_g
[
MF_DIM
];
int
mf_dim
;
float
mf_g
[
0
];
// __device__ __forceinline__ FeaturePushValue
// operator+(const FeaturePushValue& a) const {
// FeaturePushValue out;
// out.slot = a.slot;
// out.show = a.show + show;
// out.clk = a.clk + clk;
// out.lr_g = a.lr_g + lr_g;
// for (int i = 0; i < MF_DIM; ++i) {
// out.mf_g[i] = a.mf_g[i] + mf_g[i];
// }
// return out;
// }
__device__
__forceinline__
void
operator
=
(
const
FeaturePushValue
&
in
)
{
show
=
in
.
show
;
clk
=
in
.
clk
;
slot
=
in
.
slot
;
lr_g
=
in
.
lr_g
;
mf_dim
=
in
.
mf_dim
;
for
(
int
i
=
0
;
i
<
mf_dim
;
i
++
)
{
mf_g
[
i
]
=
in
.
mf_g
[
i
];
}
}
};
}
// end namespace framework
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable.h
浏览文件 @
3f619290
...
...
@@ -118,8 +118,8 @@ class HashTable {
StreamType
stream
);
template
<
typename
StreamType
>
void
insert
(
const
KeyType
*
d_keys
,
size_t
len
,
char
*
pool
,
size_t
start_index
,
StreamType
stream
);
void
insert
(
const
KeyType
*
d_keys
,
size_t
len
,
char
*
pool
,
size_t
feature_value_size
,
size_t
start_index
,
StreamType
stream
);
template
<
typename
StreamType
>
void
get
(
const
KeyType
*
d_keys
,
ValType
*
d_vals
,
size_t
len
,
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
浏览文件 @
3f619290
...
...
@@ -50,7 +50,8 @@ __global__ void insert_kernel(Table* table,
template
<
typename
Table
>
__global__
void
insert_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
size_t
len
,
char
*
pool
,
int
start_index
)
{
size_t
len
,
char
*
pool
,
size_t
feature_value_size
,
int
start_index
)
{
ReplaceOp
<
typename
Table
::
mapped_type
>
op
;
thrust
::
pair
<
typename
Table
::
key_type
,
typename
Table
::
mapped_type
>
kv
;
...
...
@@ -58,7 +59,8 @@ __global__ void insert_kernel(Table* table,
if
(
i
<
len
)
{
kv
.
first
=
keys
[
i
];
kv
.
second
=
(
Table
::
mapped_type
)(
pool
+
(
start_index
+
i
)
*
80
);
uint64_t
offset
=
uint64_t
(
start_index
+
i
)
*
feature_value_size
;
kv
.
second
=
(
Table
::
mapped_type
)(
pool
+
offset
);
auto
it
=
table
->
insert
(
kv
,
op
);
assert
(
it
!=
table
->
end
()
&&
"error: insert fails: table is full"
);
}
...
...
@@ -81,14 +83,16 @@ __global__ void search_kernel(Table* table,
template
<
typename
Table
>
__global__
void
dy_mf_search_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
char
*
const
vals
,
size_t
len
,
char
*
vals
,
size_t
len
,
size_t
pull_feature_value_size
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
len
)
{
auto
it
=
table
->
find
(
keys
[
i
]);
if
(
it
!=
table
->
end
())
{
*
(
FeatureValue
*
)(
vals
+
i
*
pull_feature_value_size
)
=
*
(
it
->
second
);
uint64_t
offset
=
i
*
pull_feature_value_size
;
FeatureValue
&
cur
=
*
(
FeatureValue
*
)(
vals
+
offset
);
FeatureValue
&
input
=
*
(
FeatureValue
*
)(
it
->
second
);
}
}
}
...
...
@@ -121,7 +125,7 @@ __global__ void dy_mf_update_kernel(Table* table,
FeaturePushValue
*
cur
=
(
FeaturePushValue
*
)(
grads
+
i
*
grad_value_size
);
sgd
.
dy_mf_update_value
(
optimizer_config
,
(
it
.
getter
())
->
second
,
*
cur
);
}
else
{
printf
(
"
yxf::
push miss key: %d"
,
keys
[
i
]);
printf
(
"
warning:
push miss key: %d"
,
keys
[
i
]);
}
}
}
...
...
@@ -201,7 +205,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
StreamType
>
void
HashTable
<
KeyType
,
ValType
>::
insert
(
const
KeyType
*
d_keys
,
size_t
len
,
char
*
pool
,
size_t
start_index
,
char
*
pool
,
size_t
feature_value_size
,
size_t
start_index
,
StreamType
stream
)
{
if
(
len
==
0
)
{
return
;
...
...
@@ -210,8 +215,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
return
;
}
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
insert_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
len
,
pool
,
start_index
);
insert_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
len
,
pool
,
feature_value_size
,
start_index
);
}
template
<
typename
KeyType
,
typename
ValType
>
...
...
@@ -319,6 +324,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}
template
class
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
>;
template
class
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
*
>;
template
class
HashTable
<
long
,
int
>;
template
class
HashTable
<
unsigned
long
,
int
>;
template
class
HashTable
<
unsigned
long
,
unsigned
long
>;
...
...
@@ -331,6 +337,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
paddle
::
framework
::
FeatureValue
*
d_vals
,
size_t
len
,
cudaStream_t
stream
);
template
void
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
*
>
::
get
<
cudaStream_t
>
(
const
unsigned
long
*
d_keys
,
char
*
d_vals
,
size_t
len
,
cudaStream_t
stream
);
template
void
HashTable
<
long
,
int
>
::
get
<
cudaStream_t
>
(
const
long
*
d_keys
,
int
*
d_vals
,
size_t
len
,
cudaStream_t
stream
);
...
...
@@ -354,6 +364,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
const
paddle
::
framework
::
FeatureValue
*
d_vals
,
size_t
len
,
cudaStream_t
stream
);
template
void
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
*
>
::
insert
<
cudaStream_t
>
(
const
unsigned
long
*
d_keys
,
size_t
len
,
char
*
pool
,
size_t
feature_value_size
,
size_t
start_index
,
cudaStream_t
stream
);
template
void
HashTable
<
long
,
int
>
::
insert
<
cudaStream_t
>
(
const
long
*
d_keys
,
const
int
*
d_vals
,
size_t
len
,
...
...
@@ -393,6 +408,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
sgd
,
cudaStream_t
stream
);
template
void
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
*
>
::
update
<
Optimizer
<
paddle
::
framework
::
FeatureValue
,
paddle
::
framework
::
FeaturePushValue
>
,
cudaStream_t
>
(
const
unsigned
long
*
d_keys
,
const
char
*
d_grads
,
size_t
len
,
Optimizer
<
paddle
::
framework
::
FeatureValue
,
paddle
::
framework
::
FeaturePushValue
>
sgd
,
cudaStream_t
stream
);
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
// Optimizer<paddle::framework::FeatureValue,
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
浏览文件 @
3f619290
...
...
@@ -15,10 +15,13 @@ limitations under the License. */
#pragma once
#include <thread>
#include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/timer.h"
#include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
...
...
@@ -38,6 +41,9 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
#define TYPEALIGN(ALIGNVAL, LEN) \
(((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
class
HeterComm
{
public:
...
...
@@ -50,9 +56,13 @@ class HeterComm {
int
*
left
,
int
*
right
,
int
gpu_num
);
void
merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
);
// NOLINT
void
dynamic_merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
);
void
pull_sparse
(
int
num
,
KeyType
*
d_keys
,
ValType
*
d_vals
,
size_t
len
);
void
build_ps
(
int
num
,
KeyType
*
h_keys
,
ValType
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
int
stream_num
);
void
build_ps
(
int
num
,
KeyType
*
h_keys
,
char
*
pool
,
size_t
len
,
size_t
feature_value_size
,
size_t
chunk_size
,
int
stream_num
);
void
dump
();
void
show_one_table
(
int
gpu_num
);
int
get_index_by_devid
(
int
devid
);
...
...
@@ -96,6 +106,11 @@ class HeterComm {
nccl_inter_comms_
=
inter_comms
;
node_size_
=
comm_size
;
}
void
set_multi_mf_dim
(
int
multi_mf_dim
,
int
max_mf_dim
)
{
multi_mf_dim_
=
multi_mf_dim
;
max_mf_dim_
=
max_mf_dim
;
}
#endif
bool
need_transfer
(
int
send_id
,
int
receive_id
)
{
...
...
@@ -114,8 +129,8 @@ class HeterComm {
char
*
key_storage
;
char
*
val_storage
;
int
sync
;
in
t
key_bytes_len
;
in
t
val_bytes_len
;
size_
t
key_bytes_len
;
size_
t
val_bytes_len
;
int
dev_num
;
};
...
...
@@ -206,12 +221,18 @@ class HeterComm {
void
destroy_storage
(
int
start_index
,
int
end_index
);
void
walk_to_dest
(
int
start_index
,
int
gpu_num
,
int
*
h_left
,
int
*
h_right
,
KeyType
*
src_key
,
GradType
*
src_val
);
void
walk_to_dest
(
int
start_index
,
int
gpu_num
,
int
*
h_left
,
int
*
h_right
,
KeyType
*
src_key
,
char
*
src_val
,
size_t
val_size
);
void
walk_to_src
(
int
start_index
,
int
gpu_num
,
int
*
h_left
,
int
*
h_right
,
ValType
*
src_val
);
void
walk_to_src
(
int
start_index
,
int
gpu_num
,
int
*
h_left
,
int
*
h_right
,
char
*
src_val
,
size_t
val_size
);
protected:
using
Table
=
HashTable
<
KeyType
,
ValType
>
;
using
PtrTable
=
HashTable
<
KeyType
,
ValType
*>
;
std
::
vector
<
Table
*>
tables_
;
std
::
vector
<
PtrTable
*>
ptr_tables_
;
std
::
shared_ptr
<
HeterPsResource
>
resource_
;
std
::
vector
<
std
::
vector
<
Path
>>
path_
;
float
load_factor_
{
0.75
};
...
...
@@ -221,6 +242,7 @@ class HeterComm {
private:
int
topo_aware_
{
0
};
std
::
vector
<
LocalStorage
>
storage_
;
DynamicGradMerger
merger_
;
int
feanum_
{
1800
*
2048
};
int
multi_node_
{
0
};
int
node_size_
;
...
...
@@ -228,6 +250,8 @@ class HeterComm {
#if defined(PADDLE_WITH_CUDA)
std
::
vector
<
ncclComm_t
>
nccl_inner_comms_
;
std
::
vector
<
ncclComm_t
>
nccl_inter_comms_
;
int
multi_mf_dim_
{
8
};
int
max_mf_dim_
=
8
;
std
::
vector
<
std
::
shared_ptr
<
cub
::
CachingDeviceAllocator
>>
allocators_
;
#endif
};
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
浏览文件 @
3f619290
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <queue>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_XPU_KP
...
...
@@ -22,20 +23,31 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
HeterComm
<
KeyType
,
ValType
,
GradType
>::
HeterComm
(
size_t
capacity
,
std
::
shared_ptr
<
HeterPsResource
>
resource
)
{
resource_
=
resource
;
storage_
.
resize
(
resource_
->
total_device
());
multi_mf_dim_
=
resource
->
multi_mf
();
for
(
int
i
=
0
;
i
<
resource_
->
total_device
();
++
i
)
{
#if defined(PADDLE_WITH_CUDA)
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
allocators_
.
push_back
(
std
::
make_shared
<
cub
::
CachingDeviceAllocator
>
(
8
,
1
,
(
unsigned
int
)
-
1
,
(
size_t
)
-
1
,
false
,
false
));
// NOLINT
#endif
auto
table
=
new
Table
(
capacity
/
load_factor_
);
tables_
.
push_back
(
table
);
if
(
!
multi_mf_dim_
)
{
auto
table
=
new
Table
(
capacity
/
load_factor_
);
tables_
.
push_back
(
table
);
}
else
{
max_mf_dim_
=
resource_
->
max_mf_dim
();
size_t
val_type_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
sizeof
(
float
)
*
(
max_mf_dim_
+
1
));
size_t
grad_type_size
=
TYPEALIGN
(
8
,
sizeof
(
FeaturePushValue
)
+
(
max_mf_dim_
*
sizeof
(
float
)));
auto
ptr_table
=
new
PtrTable
(
capacity
/
load_factor_
);
ptr_table
->
set_feature_value_size
(
val_type_size
,
grad_type_size
);
ptr_tables_
.
push_back
(
ptr_table
);
}
if
(
multi_node_
)
{
storage_
[
i
].
init
(
feanum_
,
resource_
->
dev_id
(
i
));
}
...
...
@@ -238,95 +250,128 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
walk_to_src
(
int
start_index
,
int
num
,
int
*
h_left
,
int
*
h_right
,
ValType
*
src_val
)
{
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
walk_to_dest
(
int
start_index
,
int
gpu_num
,
int
*
h_left
,
int
*
h_right
,
KeyType
*
src_key
,
char
*
src_val
,
size_t
val_size
)
{
int
need_copy_val
=
0
;
if
(
src_val
)
{
need_copy_val
=
1
;
}
std
::
queue
<
CopyTask
>
que
;
for
(
int
i
=
0
;
i
<
gpu_num
;
i
++
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
}
int
size
=
path_
[
start_index
][
i
].
nodes_
.
size
();
auto
&
node
=
path_
[
start_index
][
i
].
nodes_
[
0
];
CopyTask
t
(
&
path_
[
start_index
][
i
],
0
);
que
.
push
(
t
);
cudaMemcpyAsync
(
node
.
key_storage
,
reinterpret_cast
<
char
*>
(
src_key
+
h_left
[
i
]),
node
.
key_bytes_len
,
cudaMemcpyDefault
,
node
.
in_stream
);
if
(
need_copy_val
)
{
cudaMemcpyAsync
(
node
.
val_storage
,
src_val
+
uint64_t
(
h_left
[
i
])
*
uint64_t
(
val_size
),
node
.
val_bytes_len
,
cudaMemcpyDefault
,
node
.
in_stream
);
}
}
while
(
!
que
.
empty
())
{
CopyTask
&
cur_task
=
que
.
front
();
que
.
pop
();
if
(
cur_task
.
path
->
nodes_
[
cur_task
.
step
].
sync
)
{
cudaStreamSynchronize
(
cur_task
.
path
->
nodes_
[
cur_task
.
step
].
in_stream
);
}
if
(
cur_task
.
step
!=
cur_task
.
path
->
nodes_
.
size
()
-
1
)
{
int
cur_step
=
cur_task
.
step
;
CopyTask
c
(
cur_task
.
path
,
cur_step
+
1
);
que
.
push
(
c
);
cudaMemcpyAsync
(
cur_task
.
path
->
nodes_
[
cur_step
+
1
].
key_storage
,
cur_task
.
path
->
nodes_
[
cur_step
].
key_storage
,
cur_task
.
path
->
nodes_
[
cur_step
+
1
].
key_bytes_len
,
cudaMemcpyDefault
,
cur_task
.
path
->
nodes_
[
cur_step
+
1
].
in_stream
);
if
(
need_copy_val
)
{
cudaMemcpyAsync
(
cur_task
.
path
->
nodes_
[
cur_step
+
1
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
+
1
].
val_bytes_len
,
cudaMemcpyDefault
,
cur_task
.
path
->
nodes_
[
cur_step
+
1
].
in_stream
);
}
}
}
}
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
walk_to_src
(
int
start_index
,
int
gpu_num
,
int
*
h_left
,
int
*
h_right
,
char
*
src_val
,
size_t
val_size
)
{
std
::
queue
<
CopyTask
>
que
;
for
(
int
i
=
0
;
i
<
gpu_num
;
i
++
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
}
int
cur_step
=
path_
[
start_index
][
i
].
nodes_
.
size
()
-
1
;
auto
&
node
=
path_
[
start_index
][
i
].
nodes_
[
cur_step
];
auto
src_dev_id
=
resource_
->
dev_id
(
i
);
auto
src_place
=
DevPlace
(
src_dev_id
);
if
(
cur_step
==
0
)
{
auto
dst_dev_id
=
resource_
->
dev_id
(
start_index
);
auto
dst_place
=
DevPlace
(
dst_dev_id
);
memory_copy
(
dst_place
,
reinterpret_cast
<
char
*>
(
src_val
+
h_left
[
i
]),
src_place
,
node
.
val_storage
,
node
.
val_bytes_len
,
node
.
out_stream
);
cudaMemcpyAsync
(
src_val
+
uint64_t
(
h_left
[
i
])
*
val_size
,
node
.
val_storage
,
node
.
val_bytes_len
,
cudaMemcpyDefault
,
node
.
out_stream
);
}
else
{
CopyTask
t
(
&
path_
[
start_index
][
i
],
cur_step
-
1
);
que
.
push
(
t
);
auto
dst_dev_id
=
resource_
->
dev_id
(
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
dev_num
);
auto
dst_place
=
DevPlace
(
dst_dev_id
);
memory_copy
(
dst_place
,
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
val_storage
,
src_place
,
node
.
val_storage
,
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
val_bytes_len
,
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
out_stream
);
cudaMemcpyAsync
(
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
val_storage
,
node
.
val_storage
,
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
val_bytes_len
,
cudaMemcpyDefault
,
path_
[
start_index
][
i
].
nodes_
[
cur_step
-
1
].
out_stream
);
}
}
while
(
!
que
.
empty
())
{
CopyTask
&
cur_task
=
que
.
front
();
que
.
pop
();
int
cur_step
=
cur_task
.
step
;
if
(
cur_task
.
path
->
nodes_
[
cur_step
].
sync
)
{
sync_stream
(
cur_task
.
path
->
nodes_
[
cur_step
].
out_stream
);
cudaStreamSynchronize
(
cur_task
.
path
->
nodes_
[
cur_step
].
out_stream
);
}
auto
src_dev_id
=
resource_
->
dev_id
(
cur_task
.
path
->
nodes_
[
cur_step
].
dev_num
);
auto
src_place
=
DevPlace
(
src_dev_id
);
if
(
cur_step
>
0
)
{
CopyTask
c
(
cur_task
.
path
,
cur_step
-
1
);
que
.
push
(
c
);
auto
dst_dev_id
=
resource_
->
dev_id
(
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
dev_num
);
auto
dst_place
=
DevPlace
(
dst_dev_id
);
memory_copy
(
dst_place
,
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
val_storage
,
src_place
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
val_bytes_len
,
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
out_stream
);
cudaMemcpyAsync
(
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
val_bytes_len
,
cudaMemcpyDefault
,
cur_task
.
path
->
nodes_
[
cur_step
-
1
].
out_stream
);
}
else
if
(
cur_step
==
0
)
{
int
end_index
=
cur_task
.
path
->
nodes_
.
back
().
dev_num
;
auto
dst_dev_id
=
resource_
->
dev_id
(
end_index
);
auto
dst_place
=
DevPlace
(
dst_dev_id
);
memory_copy
(
dst_place
,
reinterpret_cast
<
char
*>
(
src_val
+
h_left
[
end_index
]),
src_place
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_bytes_len
,
cur_task
.
path
->
nodes_
[
cur_step
].
out_stream
);
cudaMemcpyAsync
(
src_val
+
uint64_t
(
h_left
[
end_index
])
*
val_size
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_storage
,
cur_task
.
path
->
nodes_
[
cur_step
].
val_bytes_len
,
cudaMemcpyDefault
,
cur_task
.
path
->
nodes_
[
cur_step
].
out_stream
);
}
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
HeterComm
<
KeyType
,
ValType
,
GradType
>::~
HeterComm
()
{
for
(
auto
&
table
:
tables_
)
{
delete
table
;
table
=
nullptr
;
if
(
!
multi_mf_dim_
)
{
for
(
auto
&
table
:
tables_
)
{
delete
table
;
table
=
nullptr
;
}
}
else
{
for
(
auto
&
table
:
ptr_tables_
)
{
delete
table
;
table
=
nullptr
;
}
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
show_one_table
(
int
num
)
{
tables_
[
num
]
->
show
();
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
show_one_table
(
int
gpu_num
)
{
if
(
!
multi_mf_dim_
)
{
tables_
[
gpu_num
]
->
show
();
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
...
...
@@ -418,59 +463,165 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
build_ps
(
int
num
,
KeyType
*
h_keys
,
char
*
pool
,
size_t
len
,
size_t
feature_value_size
,
size_t
chunk_size
,
int
stream_num
)
{
if
(
len
<=
0
)
{
return
;
}
int
dev_id
=
resource_
->
dev_id
(
num
);
DevPlace
place
=
DevPlace
(
dev_id
);
AnyDeviceGuard
guard
(
dev_id
);
std
::
vector
<
memory
::
allocation
::
AllocationPtr
>
d_key_bufs
;
ppStream
streams
[
stream_num
];
// NOLINT
for
(
int
i
=
0
;
i
<
stream_num
;
++
i
)
{
create_stream
(
&
(
streams
[
i
]));
auto
d_k_buf
=
memory
::
Alloc
(
place
,
chunk_size
*
sizeof
(
KeyType
));
d_key_bufs
.
push_back
(
std
::
move
(
d_k_buf
));
}
int
cur_len
=
0
;
int
cur_stream
=
0
;
while
(
cur_len
<
len
)
{
cur_stream
=
cur_stream
%
stream_num
;
auto
cur_use_stream
=
streams
[
cur_stream
];
#if defined(PADDLE_WITH_XPU_KP)
cur_use_stream
=
0
;
#endif
int
tmp_len
=
cur_len
+
chunk_size
>
len
?
len
-
cur_len
:
chunk_size
;
auto
dst_place
=
place
;
auto
src_place
=
platform
::
CPUPlace
();
memory_copy
(
dst_place
,
reinterpret_cast
<
char
*>
(
d_key_bufs
[
cur_stream
]
->
ptr
()),
src_place
,
h_keys
+
cur_len
,
sizeof
(
KeyType
)
*
tmp_len
,
cur_use_stream
);
ptr_tables_
[
num
]
->
insert
(
reinterpret_cast
<
KeyType
*>
(
d_key_bufs
[
cur_stream
]
->
ptr
()),
tmp_len
,
pool
,
feature_value_size
,
cur_len
,
cur_use_stream
);
cur_stream
+=
1
;
cur_len
+=
tmp_len
;
}
for
(
int
i
=
0
;
i
<
stream_num
;
++
i
)
{
sync_stream
(
streams
[
i
]);
destroy_stream
(
streams
[
i
]);
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
merge_grad
(
int
dev_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
)
{
// NOLINT
int
dev_id
=
resource_
->
dev_id
(
dev_num
);
DevPlace
place
=
DevPlace
(
dev_id
);
AnyDeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
dev_num
,
0
);
size_t
temp_storage_bytes
;
auto
d_merge_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_merge_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_merge_keys
->
ptr
());
auto
d_merge_grads
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
GradType
));
GradType
*
d_merge_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_merge_grads
->
ptr
());
heter_comm_kernel_
->
sort_pairs
(
NULL
,
temp_storage_bytes
,
d_keys
,
d_merge_keys_ptr
,
d_grads
,
d_merge_grads_ptr
,
len
,
0
,
8
*
sizeof
(
KeyType
),
stream
,
false
);
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
heter_comm_kernel_
->
sort_pairs
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_keys
,
d_merge_keys_ptr
,
d_grads
,
d_merge_grads_ptr
,
len
,
0
,
8
*
sizeof
(
KeyType
),
stream
,
false
);
temp_storage_bytes
=
0
;
auto
d_num_runs_out_mem
=
memory
::
Alloc
(
place
,
sizeof
(
int
));
int
*
d_num_runs_out
=
reinterpret_cast
<
int
*>
(
d_num_runs_out_mem
->
ptr
());
heter_comm_kernel_
->
reduce_by_key
(
NULL
,
temp_storage_bytes
,
d_merge_keys_ptr
,
d_keys
,
d_merge_grads_ptr
,
d_grads
,
d_num_runs_out
,
len
,
stream
,
false
);
if
(
d_temp_storage
->
size
()
<
temp_storage_bytes
)
{
d_temp_storage
=
NULL
;
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
}
heter_comm_kernel_
->
reduce_by_key
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_merge_keys_ptr
,
d_keys
,
d_merge_grads_ptr
,
d_grads
,
d_num_runs_out
,
len
,
stream
,
false
);
auto
dst_place
=
platform
::
CPUPlace
();
auto
src_place
=
place
;
memory_copy
(
dst_place
,
&
uniq_len
,
src_place
,
d_num_runs_out
,
sizeof
(
int
),
stream
);
sync_stream
(
stream
);
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
dynamic_merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
)
{
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
size_t
temp_storage_bytes
;
size_t
grad_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeaturePushValue
)
+
(
max_mf_dim_
*
sizeof
(
float
)));
auto
d_merge_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_merge_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_merge_keys
->
ptr
());
auto
d_merge_grads
=
memory
::
Alloc
(
place
,
len
*
grad_value_size
);
GradType
*
d_merge_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_merge_grads
->
ptr
());
auto
d_fea_num_info
=
memory
::
Alloc
(
place
,
sizeof
(
uint32_t
)
*
(
len
*
3
+
1
));
uint32_t
*
d_fea_num_info_ptr
=
reinterpret_cast
<
uint32_t
*>
(
d_fea_num_info
->
ptr
());
uint32_t
*
d_index
=
(
uint32_t
*
)
&
d_fea_num_info_ptr
[
len
];
uint32_t
*
d_idx
=
(
uint32_t
*
)
&
d_index
[
len
];
int
*
d_merged_size
=
(
int
*
)
&
d_idx
[
len
];
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
heter_comm_kernel_
->
fill_idx
(
d_idx
,
len
,
stream
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceRadixSort
::
SortPairs
(
NULL
,
temp_storage_bytes
,
d_keys
,
d_merge_keys_ptr
,
d_idx
,
d_index
,
len
,
0
,
8
*
sizeof
(
KeyType
),
stream
));
void
*
d_buff
=
NULL
;
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceRadixSort
::
SortPairs
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_keys
,
d_merge_keys_ptr
,
d_idx
,
d_index
,
len
,
0
,
8
*
sizeof
(
KeyType
),
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
temp_storage_bytes
=
0
;
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceRunLengthEncode
::
Encode
(
NULL
,
temp_storage_bytes
,
d_merge_keys_ptr
,
d_keys
,
d_fea_num_info_ptr
,
d_merged_size
,
len
,
stream
));
if
(
d_temp_storage
->
size
()
<
temp_storage_bytes
)
{
d_temp_storage
=
NULL
;
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceRunLengthEncode
::
Encode
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_merge_keys_ptr
,
d_keys
,
d_fea_num_info_ptr
,
d_merged_size
,
len
,
stream
));
cudaMemcpyAsync
((
void
*
)
&
uniq_len
,
d_merged_size
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
stream
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
assert
(
d_merged_size
>
0
);
uint32_t
*
d_offset
=
(
uint32_t
*
)
&
d_index
[
len
];
temp_storage_bytes
=
0
;
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceScan
::
ExclusiveSum
(
NULL
,
temp_storage_bytes
,
d_fea_num_info_ptr
,
d_offset
,
uniq_len
,
stream
));
if
(
d_temp_storage
->
size
()
<
temp_storage_bytes
)
{
d_temp_storage
=
NULL
;
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_fea_num_info_ptr
,
d_offset
,
uniq_len
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
heter_comm_kernel_
->
merge_gradient
(
d_offset
,
d_fea_num_info_ptr
,
d_index
,
(
char
*
)
d_grads
,
(
char
*
)
d_merge_grads_ptr
,
uniq_len
,
grad_value_size
,
merger_
,
stream
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpyAsync
(
d_grads
,
d_merge_grads_ptr
,
grad_value_size
*
uniq_len
,
cudaMemcpyDeviceToDevice
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
split_input_to_shard
(
KeyType
*
d_keys
,
int
*
d_idx_ptr
,
size_t
len
,
int
*
left
,
int
*
right
,
...
...
@@ -529,8 +680,6 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
AnyDeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
num
,
0
);
// int grid_size = (len - 1) / block_size_ + 1;
int
h_left
[
total_device
];
// NOLINT
int
h_right
[
total_device
];
// NOLINT
...
...
@@ -562,10 +711,11 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
auto
d_idx
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_idx_ptr
=
reinterpret_cast
<
int
*>
(
d_idx
->
ptr
());
size_t
val_type_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
sizeof
(
float
)
*
(
max_mf_dim_
+
1
));
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_shard_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_shard_keys
->
ptr
());
auto
d_shard_vals
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
ValType
)
);
auto
d_shard_vals
=
memory
::
Alloc
(
place
,
len
*
val_type_size
);
ValType
*
d_shard_vals_ptr
=
reinterpret_cast
<
ValType
*>
(
d_shard_vals
->
ptr
());
split_input_to_shard
(
d_keys
,
d_idx_ptr
,
len
,
d_left_ptr
,
d_right_ptr
,
num
);
...
...
@@ -589,9 +739,8 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
continue
;
}
create_storage
(
num
,
i
,
shard_len
*
sizeof
(
KeyType
),
shard_len
*
sizeof
(
ValType
)
);
shard_len
*
val_type_size
);
}
walk_to_dest
(
num
,
total_device
,
h_left
,
h_right
,
d_shard_keys_ptr
,
NULL
);
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
...
...
@@ -600,14 +749,11 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
auto
&
node
=
path_
[
num
][
i
].
nodes_
.
back
();
sync_stream
(
node
.
in_stream
);
AnyDeviceGuard
guard
(
resource_
->
dev_id
(
i
));
tables_
[
i
]
->
rwlock_
->
RDLock
();
tables_
[
i
]
->
get
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
reinterpret_cast
<
ValType
*>
(
node
.
val_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
resource_
->
remote_stream
(
i
,
num
));
ptr_tables_
[
i
]
->
rwlock_
->
RDLock
();
ptr_tables_
[
i
]
->
get
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
node
.
val_storage
,
h_right
[
i
]
-
h_left
[
i
]
+
1
,
resource_
->
remote_stream
(
i
,
num
));
}
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
...
...
@@ -615,21 +761,18 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
if
(
h_left
[
i
]
==
-
1
)
{
continue
;
}
tables_
[
i
]
->
rwlock_
->
UNLock
();
ptr_
tables_
[
i
]
->
rwlock_
->
UNLock
();
}
walk_to_src
(
num
,
total_device
,
h_left
,
h_right
,
d_shard_vals_ptr
);
walk_to_src
(
num
,
total_device
,
h_left
,
h_right
,
reinterpret_cast
<
char
*>
(
d_shard_vals_ptr
),
val_type_size
);
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
auto
&
node
=
path_
[
num
][
i
].
nodes_
.
front
();
sync_stream
(
node
.
out_stream
);
}
heter_comm_kernel_
->
fill_dvals
(
d_shard_vals_ptr
,
d_vals
,
d_idx_ptr
,
len
,
stream
);
heter_comm_kernel_
->
dy_mf_fill_dvals
(
d_shard_vals_ptr
,
d_vals
,
d_idx_ptr
,
len
,
val_type_size
,
stream
);
sync_stream
(
stream
);
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
...
...
@@ -653,6 +796,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
int
total_device
=
resource_
->
total_device
();
int
dev_id
=
resource_
->
dev_id
(
dev_num
);
size_t
grad_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeaturePushValue
)
+
(
max_mf_dim_
*
sizeof
(
float
)));
DevPlace
place
=
DevPlace
(
dev_id
);
AnyDeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
dev_num
,
0
);
...
...
@@ -691,21 +836,19 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_shard_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_shard_keys
->
ptr
());
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
GradType
));
GradType
*
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
GradType
*
d_shard_grads_ptr
;
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
grad_value_size
);
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
int
uniq_len
=
len
;
merge_grad
(
dev_num
,
d_keys
,
d_grads
,
len
,
uniq_len
);
dynamic_
merge_grad
(
dev_num
,
d_keys
,
d_grads
,
len
,
uniq_len
);
//
int grid_size = (uniq_len - 1) / block_size_ + 1;
int
grid_size
=
(
uniq_len
-
1
)
/
block_size_
+
1
;
split_input_to_shard
(
d_keys
,
d_idx_ptr
,
uniq_len
,
d_left_ptr
,
d_right_ptr
,
dev_num
);
heter_comm_kernel_
->
fill_shard_grads
(
d_shard_keys_ptr
,
d_keys
,
d_shard_grads_ptr
,
d_grads
,
d_idx_ptr
,
uniq_len
,
stream
);
heter_comm_kernel_
->
dy_mf_fill_shard_grads
(
d_shard_keys_ptr
,
d_keys
,
d_shard_grads_ptr
,
d_grads
,
d_idx_ptr
,
uniq_len
,
grad_value_size
,
stream
);
sync_stream
(
stream
);
...
...
@@ -721,12 +864,22 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
}
create_storage
(
dev_num
,
i
,
shard_len
*
sizeof
(
KeyType
),
shard_len
*
sizeof
(
GradType
));
if
(
!
multi_mf_dim_
)
{
create_storage
(
dev_num
,
i
,
shard_len
*
sizeof
(
KeyType
),
shard_len
*
sizeof
(
GradType
));
}
else
{
create_storage
(
dev_num
,
i
,
shard_len
*
sizeof
(
KeyType
),
shard_len
*
grad_value_size
);
}
}
walk_to_dest
(
dev_num
,
total_device
,
h_left
,
h_right
,
d_shard_keys_ptr
,
d_shard_grads_ptr
);
if
(
!
multi_mf_dim_
)
{
walk_to_dest
(
dev_num
,
total_device
,
h_left
,
h_right
,
d_shard_keys_ptr
,
d_shard_grads_ptr
);
}
else
{
walk_to_dest
(
dev_num
,
total_device
,
h_left
,
h_right
,
d_shard_keys_ptr
,
reinterpret_cast
<
char
*>
(
d_shard_grads_ptr
),
grad_value_size
);
}
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
...
...
@@ -736,17 +889,28 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
sync_stream
(
node
.
in_stream
);
AnyDeviceGuard
guard
(
resource_
->
dev_id
(
i
));
tables_
[
i
]
->
rwlock_
->
WRLock
();
tables_
[
i
]
->
update
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
reinterpret_cast
<
GradType
*>
(
node
.
val_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
sgd
,
resource_
->
remote_stream
(
i
,
dev_num
));
if
(
!
multi_mf_dim_
)
{
tables_
[
i
]
->
rwlock_
->
WRLock
();
tables_
[
i
]
->
update
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
reinterpret_cast
<
GradType
*>
(
node
.
val_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
sgd
,
resource_
->
remote_stream
(
i
,
dev_num
));
}
else
{
ptr_tables_
[
i
]
->
rwlock_
->
WRLock
();
ptr_tables_
[
i
]
->
update
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
node
.
val_storage
,
h_right
[
i
]
-
h_left
[
i
]
+
1
,
sgd
,
resource_
->
remote_stream
(
i
,
dev_num
));
}
}
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
sync_stream
(
resource_
->
remote_stream
(
i
,
dev_num
));
if
(
h_left
[
i
]
!=
-
1
)
{
tables_
[
i
]
->
rwlock_
->
UNLock
();
if
(
!
multi_mf_dim_
)
{
tables_
[
i
]
->
rwlock_
->
UNLock
();
}
else
{
ptr_tables_
[
i
]
->
rwlock_
->
UNLock
();
}
}
}
...
...
@@ -1078,11 +1242,13 @@ void HeterComm<KeyType, ValType, GradType>::end_pass() {
tables_
[
index
]
->
dump_to_cpu
(
dev_id
,
stream
);
};
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
threads
.
push_back
(
std
::
thread
(
dump_to_cpu_func
,
i
));
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
total_device
;
++
i
)
{
threads
.
push_back
(
std
::
thread
(
dump_to_cpu_func
,
i
));
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
}
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
浏览文件 @
3f619290
...
...
@@ -117,6 +117,52 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, ValType* d_vals,
}
}
template
<
typename
KeyType
,
typename
GradType
,
typename
T
>
__global__
void
dy_mf_fill_shard_grads_kernel
(
KeyType
*
d_shard_keys
,
KeyType
*
d_keys
,
GradType
*
d_shard_grads
,
GradType
*
d_grads
,
T
*
idx
,
size_t
len
,
size_t
grad_value_size
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
len
)
{
d_shard_keys
[
i
]
=
d_keys
[
idx
[
i
]];
*
(
GradType
*
)((
char
*
)
d_shard_grads
+
i
*
grad_value_size
)
=
*
(
GradType
*
)((
char
*
)
d_grads
+
uint64_t
(
idx
[
i
])
*
grad_value_size
);
}
}
__global__
void
merge_gradients_kernel
(
const
uint32_t
*
offset
,
const
uint32_t
*
fea_num
,
const
uint32_t
*
index
,
const
char
*
input
,
char
*
output
,
int
n
,
size_t
grad_value_size
,
DynamicGradMerger
&
merger_
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
n
)
{
uint32_t
start
=
offset
[
i
];
uint32_t
num
=
fea_num
[
i
];
int
ori_index
=
index
[
start
];
FeaturePushValue
&
out
=
*
(
FeaturePushValue
*
)(
output
+
i
*
grad_value_size
);
FeaturePushValue
&
in
=
*
(
FeaturePushValue
*
)(
input
+
size_t
(
ori_index
)
*
grad_value_size
);
merger_
.
update_one
(
out
,
in
);
for
(
int
j
=
1
;
j
<
num
;
++
j
)
{
ori_index
=
index
[
start
+
j
];
in
=
*
(
FeaturePushValue
*
)(
input
+
size_t
(
ori_index
)
*
grad_value_size
);
merger_
.
merge_one
(
out
,
in
);
}
}
}
template
<
typename
ValType
,
typename
T
>
__global__
void
dy_mf_fill_dvals_kernel
(
ValType
*
d_shard_vals
,
ValType
*
d_vals
,
T
*
idx
,
size_t
len
,
size_t
val_size
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
len
)
{
uint64_t
new_offset
=
uint64_t
(
idx
[
i
])
*
val_size
;
*
(
ValType
*
)((
char
*
)
d_vals
+
new_offset
)
=
*
(
ValType
*
)((
char
*
)
d_shard_vals
+
i
*
val_size
);
}
}
// cuda implemention of heter_comm_kernel.h
template
<
typename
T
,
typename
StreamType
>
void
HeterCommKernel
::
fill_idx
(
T
*
idx
,
long
long
len
,
...
...
@@ -207,8 +253,42 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage,
debug_synchronous
));
}
template
<
typename
KeyType
,
typename
GradType
,
typename
T
,
typename
StreamType
>
void
HeterCommKernel
::
dy_mf_fill_shard_grads
(
KeyType
*
d_shard_keys
,
KeyType
*
d_keys
,
GradType
*
d_shard_grads
,
GradType
*
d_grads
,
T
*
idx
,
long
long
len
,
size_t
grad_value_size
,
const
StreamType
&
stream
)
{
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
size_t
c_len
=
(
size_t
)
len
;
dy_mf_fill_shard_grads_kernel
<<<
grid_size
,
block_size_
,
0
,
stream
>>>
(
d_shard_keys
,
d_keys
,
d_shard_grads
,
d_grads
,
idx
,
c_len
,
grad_value_size
);
}
template
<
typename
StreamType
>
void
HeterCommKernel
::
merge_gradient
(
const
uint32_t
*
offset
,
const
uint32_t
*
fea_num
,
const
uint32_t
*
index
,
const
char
*
input
,
char
*
output
,
int
n
,
size_t
grad_value_size
,
DynamicGradMerger
&
merger_
,
const
StreamType
&
stream
)
{
int
grid_size
=
(
n
-
1
)
/
block_size_
+
1
;
merge_gradients_kernel
<<<
grid_size
,
block_size_
,
0
,
stream
>>>
(
offset
,
fea_num
,
index
,
input
,
output
,
n
,
grad_value_size
,
merger_
);
}
template
<
typename
ValType
,
typename
T
,
typename
StreamType
>
void
HeterCommKernel
::
dy_mf_fill_dvals
(
ValType
*
d_shard_vals
,
ValType
*
d_vals
,
T
*
idx
,
long
long
len
,
size_t
val_size
,
const
StreamType
&
stream
)
{
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
size_t
c_len
=
(
size_t
)
len
;
dy_mf_fill_dvals_kernel
<<<
grid_size
,
block_size_
,
0
,
stream
>>>
(
d_shard_vals
,
d_vals
,
idx
,
c_len
,
val_size
);
}
template
void
HeterCommKernel
::
fill_idx
<
int
,
cudaStream_t
>(
int
*
idx
,
long
long
len
,
const
cudaStream_t
&
stream
);
template
void
HeterCommKernel
::
fill_idx
<
uint32_t
,
cudaStream_t
>(
uint32_t
*
idx
,
long
long
len
,
const
cudaStream_t
&
stream
);
template
void
HeterCommKernel
::
calc_shard_offset
<
int
,
cudaStream_t
>(
int
*
idx
,
int
*
left
,
int
*
right
,
long
long
len
,
int
total_devs
,
...
...
@@ -270,6 +350,23 @@ template void HeterCommKernel::reduce_by_key<
paddle
::
framework
::
FeaturePushValue
*
d_aggregates_out
,
int
*
d_num_runs_out
,
int
num_items
,
cudaStream_t
stream
,
bool
debug_synchronous
);
template
void
HeterCommKernel
::
dy_mf_fill_shard_grads
<
unsigned
long
,
paddle
::
framework
::
FeaturePushValue
,
int
,
cudaStream_t
>(
unsigned
long
*
d_shard_keys
,
unsigned
long
*
d_keys
,
paddle
::
framework
::
FeaturePushValue
*
d_shard_grads
,
paddle
::
framework
::
FeaturePushValue
*
d_grads
,
int
*
idx
,
long
long
len
,
size_t
grad_value_size
,
const
cudaStream_t
&
stream
);
template
void
HeterCommKernel
::
merge_gradient
<
cudaStream_t
>(
const
uint32_t
*
offset
,
const
uint32_t
*
fea_num
,
const
uint32_t
*
index
,
const
char
*
input
,
char
*
output
,
int
n
,
size_t
grad_value_size
,
DynamicGradMerger
&
merger_
,
const
cudaStream_t
&
stream
);
template
void
HeterCommKernel
::
dy_mf_fill_dvals
<
paddle
::
framework
::
FeatureValue
,
int
,
cudaStream_t
>(
paddle
::
framework
::
FeatureValue
*
d_shard_vals
,
paddle
::
framework
::
FeatureValue
*
d_vals
,
int
*
idx
,
long
long
len
,
size_t
val_size
,
const
cudaStream_t
&
stream
);
#endif
}
// namespace framework
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
浏览文件 @
3f619290
...
...
@@ -27,6 +27,42 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
struct
DynamicGradMerger
{
template
<
typename
T
>
CUB_RUNTIME_FUNCTION
__forceinline__
__device__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
T
out
;
out
.
slot
=
a
.
slot
;
out
.
mf_dim
=
a
.
mf_dim
;
out
.
show
=
a
.
show
+
b
.
show
;
out
.
clk
=
a
.
clk
+
b
.
clk
;
out
.
lr_g
=
a
.
lr_g
+
b
.
lr_g
;
return
out
;
}
template
<
typename
T
>
__device__
__forceinline__
void
update_one
(
T
&
output
,
const
T
&
input
)
{
output
.
slot
=
input
.
slot
;
output
.
show
=
input
.
show
;
output
.
clk
=
input
.
clk
;
output
.
mf_dim
=
input
.
mf_dim
;
output
.
lr_g
=
input
.
lr_g
;
for
(
int
i
=
0
;
i
<
output
.
mf_dim
;
++
i
)
{
output
.
mf_g
[
i
]
=
input
.
mf_g
[
i
];
}
}
template
<
typename
T
>
__device__
__forceinline__
void
merge_one
(
T
&
output
,
const
T
&
input
)
{
output
.
show
+=
input
.
show
;
output
.
clk
+=
input
.
clk
;
output
.
lr_g
+=
input
.
lr_g
;
for
(
int
i
=
0
;
i
<
input
.
mf_dim
;
++
i
)
{
output
.
mf_g
[
i
]
+=
input
.
mf_g
[
i
];
}
}
};
class
HeterCommKernel
{
public:
HeterCommKernel
()
{}
...
...
@@ -80,6 +116,24 @@ class HeterCommKernel {
StreamType
stream
=
NULL
,
bool
debug_synchronous
=
false
);
template
<
typename
KeyType
,
typename
GradType
,
typename
T
,
typename
StreamType
>
void
dy_mf_fill_shard_grads
(
KeyType
*
d_shard_keys
,
KeyType
*
d_keys
,
GradType
*
d_shard_grads
,
GradType
*
d_grads
,
T
*
idx
,
long
long
len
,
size_t
grad_value_size
,
const
StreamType
&
stream
);
template
<
typename
StreamType
>
void
merge_gradient
(
const
uint32_t
*
offset
,
const
uint32_t
*
fea_num
,
const
uint32_t
*
index
,
const
char
*
input
,
char
*
output
,
int
n
,
size_t
grad_value_size
,
DynamicGradMerger
&
merger_
,
const
StreamType
&
stream
);
template
<
typename
ValType
,
typename
T
,
typename
StreamType
>
void
dy_mf_fill_dvals
(
ValType
*
d_shard_vals
,
ValType
*
d_vals
,
T
*
idx
,
long
long
len
,
size_t
val_size
,
const
StreamType
&
stream
);
private:
int
block_size_
{
256
};
};
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
浏览文件 @
3f619290
...
...
@@ -44,6 +44,13 @@ void HeterPs::build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
comm_
->
build_ps
(
num
,
h_keys
,
h_vals
,
len
,
chunk_size
,
stream_num
);
}
void
HeterPs
::
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
char
*
pool
,
size_t
len
,
size_t
feature_value_size
,
size_t
chunk_size
,
int
stream_num
)
{
comm_
->
build_ps
(
num
,
h_keys
,
pool
,
len
,
feature_value_size
,
chunk_size
,
stream_num
);
}
int
HeterPs
::
get_index_by_devid
(
int
devid
)
{
return
comm_
->
get_index_by_devid
(
devid
);
}
...
...
@@ -72,6 +79,10 @@ void HeterPs::set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
comm_
->
set_nccl_comm_and_size
(
inner_comms
,
inter_comms
,
comm_size
);
}
void
HeterPs
::
set_multi_mf_dim
(
int
multi_mf_dim
,
int
max_mf_dim
)
{
comm_
->
set_multi_mf_dim
(
multi_mf_dim
,
max_mf_dim
);
}
}
// end namespace framework
}
// end namespace paddle
#endif
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
浏览文件 @
3f619290
...
...
@@ -37,11 +37,14 @@ class HeterPs : public HeterPsBase {
size_t
len
)
override
;
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
FeatureValue
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
int
stream_num
)
override
;
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
char
*
pool
,
size_t
len
,
size_t
feature_value_size
,
size_t
chunk_size
,
int
stream_num
)
override
;
#if defined(PADDLE_WITH_CUDA)
void
set_nccl_comm_and_size
(
const
std
::
vector
<
ncclComm_t
>&
inner_comms
,
const
std
::
vector
<
ncclComm_t
>&
inter_comms
,
int
comm_size
)
override
;
void
set_multi_mf_dim
(
int
multi_mf_dim
,
int
max_mf_dim
)
override
;
#endif
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
override
;
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
浏览文件 @
3f619290
...
...
@@ -35,11 +35,15 @@ class HeterPsBase {
size_t
len
)
=
0
;
virtual
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
FeatureValue
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
int
stream_num
)
=
0
;
virtual
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
char
*
pool
,
size_t
len
,
size_t
feature_value_size
,
size_t
chunk_size
,
int
stream_num
)
=
0
;
virtual
int
get_index_by_devid
(
int
devid
)
=
0
;
#if defined(PADDLE_WITH_CUDA)
virtual
void
set_nccl_comm_and_size
(
const
std
::
vector
<
ncclComm_t
>&
inner_comms
,
const
std
::
vector
<
ncclComm_t
>&
inter_comms
,
int
comm_size
)
=
0
;
virtual
void
set_multi_mf_dim
(
int
multi_mf_dim
,
int
max_mf_dim
)
=
0
;
#endif
virtual
void
end_pass
()
=
0
;
virtual
void
show_one_table
(
int
gpu_num
)
=
0
;
...
...
paddle/fluid/framework/fleet/heter_ps/heter_resource.h
浏览文件 @
3f619290
...
...
@@ -107,6 +107,8 @@ class HeterPsResource {
int
get_index_by_devid
(
int
devid
);
int
dev_id
(
int
num
);
void
set_multi_mf
(
int
multi_mf_dim
,
int
max_mf_dim
);
int
multi_mf
()
{
return
multi_mf_dim_
;
}
int
max_mf_dim
()
{
return
max_mf_dim_
;
}
ppStream
local_stream
(
int
dev_num
,
int
stream_num
);
ppStream
remote_stream
(
int
dev_num
,
int
stream_num
);
...
...
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
浏览文件 @
3f619290
...
...
@@ -125,20 +125,21 @@ class Optimizer {
if
(
optimizer_config
.
mf_create_thresholds
<=
optimizer_config
.
nonclk_coeff
*
(
ptr
->
show
-
ptr
->
clk
)
+
optimizer_config
.
clk_coeff
*
ptr
->
clk
)
{
//
ptr->mf_size = ptr->mf_dim + 1;
ptr
->
mf_size
=
ptr
->
mf_dim
+
1
;
ptr
->
mf_size
=
MF_DIM
+
1
;
//
ptr->mf_size = MF_DIM + 1;
ptr
->
mf
[
0
]
=
0
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandState
state
;
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ptr
->
mf_dim
;
++
i
)
{
ptr
->
mf
[
i
+
1
]
=
(
curand_uniform
(
&
state
))
*
optimizer_config
.
mf_initial_range
;
}
}
}
else
{
update_mf
(
optimizer_config
,
MF_DIM
,
&
(
ptr
->
mf
[
1
]),
ptr
->
mf
[
0
],
grad
.
mf_g
,
update_mf
(
optimizer_config
,
ptr
->
mf_dim
,
&
(
ptr
->
mf
[
1
]),
ptr
->
mf
[
0
],
grad
.
mf_g
,
grad
.
show
);
// for local test
}
}
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
3f619290
...
...
@@ -31,7 +31,6 @@ limitations under the License. */
#include <algorithm>
#include <deque>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/platform/timer.h"
...
...
@@ -112,12 +111,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
else
{
gpu_task
->
init
(
thread_keys_shard_num_
,
device_num
,
multi_mf_dim_
);
}
auto
&
local_keys
=
gpu_task
->
feature_keys_
;
auto
&
local_ptr
=
gpu_task
->
value_ptr_
;
std
::
vector
<
std
::
thread
>
threads
;
// data should be in input channel
if
(
!
multi_mf_dim_
)
{
thread_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
...
...
@@ -141,11 +136,9 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
std
::
string
data_set_name
=
std
::
string
(
typeid
(
*
dataset_
).
name
());
if
(
data_set_name
.
find
(
"SlotRecordDataset"
)
!=
std
::
string
::
npos
)
{
VLOG
(
0
)
<<
"ps_gpu_wrapper use SlotRecordDataset"
;
SlotRecordDataset
*
dataset
=
dynamic_cast
<
SlotRecordDataset
*>
(
dataset_
);
auto
input_channel
=
dataset
->
GetInputChannel
();
VLOG
(
0
)
<<
"yxf::buildtask::inputslotchannle size: "
<<
input_channel
->
Size
();
VLOG
(
0
)
<<
"psgpu wrapperinputslotchannle size: "
<<
input_channel
->
Size
();
const
std
::
deque
<
SlotRecord
>&
vec_data
=
input_channel
->
GetData
();
total_len
=
vec_data
.
size
();
len_per_thread
=
total_len
/
thread_keys_thread_num_
;
...
...
@@ -176,21 +169,12 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
j
<
slot_offset
[
slot_offset_vector_
[
slot_idx
]
+
1
];
j
++
)
{
int
shard_id
=
feasign_v
[
j
]
%
thread_keys_shard_num_
;
int
dim_id
=
slot_index_vec_
[
slot_idx
];
this
->
thread_dim_keys_
[
i
][
shard_id
][
dim_id
].
insert
(
feasign_v
[
j
]);
if
(
feasign_v
[
j
]
!=
0
)
{
this
->
thread_dim_keys_
[
i
][
shard_id
][
dim_id
].
insert
(
feasign_v
[
j
]);
}
}
}
}
/*
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_dim_keys_[i][shard_id][0].insert(feasign);
}
}
*/
};
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
if
(
!
multi_mf_dim_
)
{
...
...
@@ -264,12 +248,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
thread_dim_keys_
[
i
][
shard_num
][
dim_id
].
clear
();
}
};
// for (size_t i = 0; i < thread_keys_.size(); i++) {
// gpu_task->batch_add_keys(thread_keys_[i]);
// for (int j = 0; j < thread_keys_thread_num_; j++) {
// thread_keys_[i][j].clear();
// }
//}
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
++
i
)
{
if
(
!
multi_mf_dim_
)
{
threads
.
push_back
(
std
::
thread
(
merge_ins_func
,
i
));
...
...
@@ -291,20 +269,15 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs task unique cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
VLOG
(
0
)
<<
"GpuPs shard: "
<<
i
<<
" key len: "
<<
local_keys
[
i
].
size
();
local_ptr
[
i
].
resize
(
local_keys
[
i
].
size
());
}
}
else
{
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
VLOG
(
0
)
<<
"GpuPs shard: "
<<
i
<<
"mf dim: "
<<
index_dim_vec_
[
j
]
<<
" key len: "
<<
gpu_task
->
feature_dim_keys_
[
i
][
j
].
size
();
gpu_task
->
value_dim_ptr_
[
i
][
j
].
resize
(
gpu_task
->
feature_dim_keys_
[
i
][
j
].
size
());
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
if
(
i
==
0
&&
j
==
multi_mf_dim_
-
1
)
{
gpu_task
->
feature_dim_keys_
[
i
][
j
].
push_back
(
0
);
}
VLOG
(
0
)
<<
"GpuPs shard: "
<<
i
<<
"mf dim: "
<<
index_dim_vec_
[
j
]
<<
" key len: "
<<
gpu_task
->
feature_dim_keys_
[
i
][
j
].
size
();
gpu_task
->
value_dim_ptr_
[
i
][
j
].
resize
(
gpu_task
->
feature_dim_keys_
[
i
][
j
].
size
());
}
}
}
...
...
@@ -353,85 +326,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
#endif
timeline
.
Start
();
auto
ptl_func
=
[
this
,
&
local_keys
,
&
local_ptr
,
&
fleet_ptr
](
int
i
)
{
size_t
key_size
=
local_keys
[
i
].
size
();
int32_t
status
=
-
1
;
#ifdef PADDLE_WITH_PSLIB
// auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
// reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
// local_keys[i].data(), key_size);
int32_t
cnt
=
0
;
while
(
true
)
{
auto
tt
=
fleet_ptr
->
pslib_ptr_
->
_worker_ptr
->
pull_sparse_ptr
(
i
,
reinterpret_cast
<
char
**>
(
local_ptr
[
i
].
data
()),
this
->
table_id_
,
local_keys
[
i
].
data
(),
key_size
);
bool
flag
=
true
;
tt
.
wait
();
try
{
status
=
tt
.
get
();
}
catch
(
const
std
::
future_error
&
e
)
{
VLOG
(
0
)
<<
"Caught a future_error with code"
<<
e
.
code
()
<<
", Message:"
<<
e
.
what
();
}
if
(
status
!=
0
)
{
VLOG
(
0
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
sleep_seconds_before_fail_exit_
);
flag
=
false
;
cnt
++
;
}
if
(
cnt
>
3
)
{
VLOG
(
0
)
<<
"fleet pull sparse failed, retry 3 times"
;
exit
(
-
1
);
}
if
(
flag
)
{
break
;
}
}
#endif
#ifdef PADDLE_WITH_PSCORE
int32_t
cnt
=
0
;
while
(
true
)
{
auto
tt
=
fleet_ptr
->
worker_ptr_
->
PullSparsePtr
(
reinterpret_cast
<
char
**>
(
local_ptr
[
i
].
data
()),
this
->
table_id_
,
local_keys
[
i
].
data
(),
key_size
);
bool
flag
=
true
;
tt
.
wait
();
try
{
status
=
tt
.
get
();
}
catch
(
const
std
::
future_error
&
e
)
{
VLOG
(
0
)
<<
"Caught a future_error with code"
<<
e
.
code
()
<<
", Message:"
<<
e
.
what
();
}
if
(
status
!=
0
)
{
VLOG
(
0
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
sleep_seconds_before_fail_exit_
);
flag
=
false
;
cnt
++
;
}
if
(
cnt
>
3
)
{
VLOG
(
0
)
<<
"fleet pull sparse failed, retry 3 times"
;
exit
(
-
1
);
}
if
(
flag
)
{
break
;
}
}
#endif
if
(
status
!=
0
)
{
LOG
(
ERROR
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
300
);
exit
(
-
1
);
}
else
{
VLOG
(
3
)
<<
"FleetWrapper Pull sparse to local done with table size: "
<<
local_keys
[
i
].
size
();
}
};
auto
ptl_dynamic_mf_func
=
[
this
,
&
local_dim_keys
,
&
local_dim_ptr
,
&
fleet_ptr
](
int
i
,
int
j
)
{
...
...
@@ -478,21 +372,18 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
}
#endif
};
if
(
!
multi_mf_dim_
)
{
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
ptl_func
,
i
);
}
}
else
{
threads
.
resize
(
thread_keys_shard_num_
*
multi_mf_dim_
);
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
[
i
*
multi_mf_dim_
+
j
]
=
std
::
thread
(
ptl_dynamic_mf_func
,
i
,
j
);
}
threads
.
resize
(
thread_keys_shard_num_
*
multi_mf_dim_
);
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
task_futures
.
emplace_back
(
pull_thread_pool_
[
i
]
->
enqueue
(
ptl_dynamic_mf_func
,
i
,
j
));
}
}
for
(
std
::
thread
&
t
:
thread
s
)
{
t
.
join
();
for
(
auto
&
f
:
task_future
s
)
{
f
.
wait
();
}
task_futures
.
clear
();
timeline
.
Pause
();
VLOG
(
0
)
<<
"pull sparse from CpuPS into GpuPS cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
...
...
@@ -509,19 +400,12 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
char
*>>>
pass_values
;
bool
record_status
=
false
;
#ifdef PADDLE_WITH_PSLIB
uint16_t
pass_id
=
0
;
if
(
multi_node_
)
{
record_status
=
fleet_ptr
->
pslib_ptr_
->
_worker_ptr
->
take_sparse_record
(
table_id_
,
pass_id
,
pass_values
);
}
#endif
auto
&
device_task_keys
=
gpu_task
->
device_task_keys_
;
auto
&
device_task_ptrs
=
gpu_task
->
device_task_ptr_
;
auto
build_dynamic_mf_func
=
[
this
,
device_num
,
&
local_dim_keys
,
&
local_dim_ptr
,
&
device_dim_keys
,
&
device_dim_ptr
,
&
device_dim_mutex
](
int
i
,
int
j
)
{
auto
build_
pull_
dynamic_mf_func
=
[
this
,
device_num
,
&
local_dim_keys
,
&
local_dim_ptr
,
&
device_dim_keys
,
&
device_dim_ptr
,
&
device_dim_mutex
](
int
i
,
int
j
)
{
#ifdef PADDLE_WITH_PSLIB
std
::
vector
<
std
::
vector
<
FeatureKey
>>
task_keys
(
device_num
);
std
::
vector
<
std
::
vector
<
paddle
::
ps
::
DownpourFixedFeatureValue
*>>
task_ptrs
(
...
...
@@ -532,20 +416,16 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
task_ptrs
[
shard
].
push_back
(
local_dim_ptr
[
i
][
j
][
k
]);
}
for
(
int
dev
=
0
;
dev
<
device_num
;
dev
++
)
{
for
(
int
dim
=
0
;
dim
<
multi_mf_dim_
;
dim
++
)
{
device_dim_mutex
[
dev
][
dim
]
->
lock
();
int
len
=
task_keys
[
dev
].
size
();
int
cur
=
device_dim_keys
[
dev
][
dim
].
size
();
device_dim_keys
[
dev
][
dim
].
resize
(
device_dim_keys
[
dev
][
dim
].
size
()
+
len
);
device_dim_ptr
[
dev
][
dim
].
resize
(
device_dim_ptr
[
dev
][
dim
].
size
()
+
len
);
for
(
int
k
=
0
;
k
<
len
;
++
k
)
{
device_dim_keys
[
dev
][
dim
][
cur
+
k
]
=
task_keys
[
dev
][
k
];
device_dim_ptr
[
dev
][
dim
][
cur
+
k
]
=
task_ptrs
[
dev
][
k
];
}
device_dim_mutex
[
dev
][
dim
]
->
unlock
();
device_dim_mutex
[
dev
][
j
]
->
lock
();
int
len
=
task_keys
[
dev
].
size
();
int
cur
=
device_dim_keys
[
dev
][
j
].
size
();
device_dim_keys
[
dev
][
j
].
resize
(
device_dim_keys
[
dev
][
j
].
size
()
+
len
);
device_dim_ptr
[
dev
][
j
].
resize
(
device_dim_ptr
[
dev
][
j
].
size
()
+
len
);
for
(
int
k
=
0
;
k
<
len
;
++
k
)
{
device_dim_keys
[
dev
][
j
][
cur
+
k
]
=
task_keys
[
dev
][
k
];
device_dim_ptr
[
dev
][
j
][
cur
+
k
]
=
task_ptrs
[
dev
][
k
];
}
device_dim_mutex
[
dev
][
j
]
->
unlock
();
}
#endif
};
...
...
@@ -697,7 +577,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
[
i
*
multi_mf_dim_
+
j
]
=
std
::
thread
(
build_dynamic_mf_func
,
i
,
j
);
std
::
thread
(
build_
pull_
dynamic_mf_func
,
i
,
j
);
}
}
for
(
std
::
thread
&
t
:
threads
)
{
...
...
@@ -727,21 +607,17 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
std
::
vector
<
size_t
>
feature_keys_count
(
device_num
);
size_t
size_max
=
0
;
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
feature_keys_count
[
i
]
=
gpu_task
->
device_keys_
[
i
].
size
();
VLOG
(
0
)
<<
i
<<
" card contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
}
else
{
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
feature_keys_count
[
i
]
+=
gpu_task
->
device_dim_ptr_
[
i
][
j
].
size
();
}
VLOG
(
0
)
<<
i
<<
" card with dynamic mf contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
feature_keys_count
[
i
]
+=
gpu_task
->
device_dim_ptr_
[
i
][
j
].
size
();
VLOG
(
1
)
<<
i
<<
" card with dynamic mf dim: "
<<
index_dim_vec_
[
j
]
<<
" dim index: "
<<
j
<<
" contains feasign nums: "
<<
gpu_task
->
device_dim_ptr_
[
i
][
j
].
size
();
}
VLOG
(
1
)
<<
i
<<
" card with dynamic mf contains feasign nums total: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
if
(
HeterPs_
)
{
delete
HeterPs_
;
...
...
@@ -756,17 +632,73 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
#ifdef PADDLE_WITH_CUDA
HeterPs_
->
set_nccl_comm_and_size
(
inner_comms_
,
inter_comms_
,
node_size_
);
#endif
auto
build_func
=
[
this
,
&
gpu_task
,
&
feature_keys_count
](
int
i
)
{
VLOG
(
3
)
<<
"building table: "
<<
i
;
this
->
HeterPs_
->
build_ps
(
i
,
gpu_task
->
device_keys_
[
i
].
data
(),
gpu_task
->
device_values_
[
i
].
data
(),
feature_keys_count
[
i
],
500000
,
2
);
// if (feature_keys_count[i] > 0) {
// HeterPs_->show_one_table(i);
// }
auto
build_dynamic_mf_func
=
[
this
,
&
gpu_task
](
int
i
,
int
j
)
{
this
->
HeterPs_
->
set_multi_mf_dim
(
multi_mf_dim_
,
max_mf_dim_
);
int
mf_dim
=
this
->
index_dim_vec_
[
j
];
VLOG
(
0
)
<<
"building table: "
<<
i
<<
"with mf dim: "
<<
mf_dim
;
size_t
feature_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
((
mf_dim
+
1
)
*
sizeof
(
float
)));
auto
&
device_dim_keys
=
gpu_task
->
device_dim_keys_
[
i
][
j
];
auto
&
device_dim_ptrs
=
gpu_task
->
device_dim_ptr_
[
i
][
j
];
size_t
len
=
device_dim_keys
.
size
();
CHECK
(
len
==
device_dim_ptrs
.
size
());
this
->
mem_pools_
[
i
*
this
->
multi_mf_dim_
+
j
]
=
new
MemoryPool
(
len
,
feature_value_size
);
auto
&
mem_pool
=
this
->
mem_pools_
[
i
*
this
->
multi_mf_dim_
+
j
];
for
(
size_t
k
=
0
;
k
<
len
;
k
++
)
{
FeatureValue
*
val
=
(
FeatureValue
*
)(
mem_pool
->
mem_address
(
k
));
float
*
ptr_val
=
device_dim_ptrs
[
k
]
->
data
();
size_t
dim
=
device_dim_ptrs
[
k
]
->
size
();
#ifdef PADDLE_WITH_PSLIB
val
->
delta_score
=
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
delta_score_index
()];
val
->
show
=
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
show_index
()];
val
->
clk
=
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
click_index
()];
val
->
slot
=
int
(
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
slot_index
()]);
val
->
lr
=
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
embed_w_index
()];
val
->
lr_g2sum
=
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
embed_g2sum_index
()];
val
->
cpu_ptr
=
(
uint64_t
)(
device_dim_ptrs
[
k
]);
ptr_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
mf_dim_index
()]
=
float
(
mf_dim
);
val
->
mf_dim
=
mf_dim
;
#endif
if
(
dim
>
8
)
{
// CpuPS alreay expand as mf_dim
val
->
mf_size
=
mf_dim
+
1
;
for
(
int
x
=
0
;
x
<
val
->
mf_dim
+
1
;
x
++
)
{
val
->
mf
[
x
]
=
ptr_val
[
x
+
8
];
}
}
else
{
val
->
mf_size
=
0
;
for
(
int
x
=
0
;
x
<
val
->
mf_dim
+
1
;
x
++
)
{
val
->
mf
[
x
]
=
0
;
}
}
}
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
this
->
hbm_pools_
[
i
*
this
->
multi_mf_dim_
+
j
]
=
new
HBMMemoryPool
(
mem_pool
);
auto
&
cur_pool
=
this
->
hbm_pools_
[
i
*
this
->
multi_mf_dim_
+
j
];
this
->
HeterPs_
->
build_ps
(
i
,
device_dim_keys
.
data
(),
cur_pool
->
mem
(),
len
,
feature_value_size
,
500000
,
2
);
if
(
device_dim_keys
.
size
()
>
0
)
{
VLOG
(
0
)
<<
"show ptr table: "
<<
i
<<
" table kv size: "
<<
device_dim_keys
.
size
()
<<
"dim: "
<<
mf_dim
<<
" len: "
<<
len
;
this
->
HeterPs_
->
show_one_table
(
i
);
}
delete
mem_pool
;
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
build_func
,
i
);
threads
.
resize
(
device_num
*
multi_mf_dim_
);
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
[
i
+
j
*
device_num
]
=
std
::
thread
(
build_dynamic_mf_func
,
i
,
j
);
}
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
...
...
@@ -788,7 +720,6 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
if
(
is_shuffle
)
{
dataset_
->
LocalShuffle
();
}
std
::
shared_ptr
<
HeterContext
>
gpu_task
=
gpu_task_pool_
.
Get
();
gpu_task
->
Reset
();
data_ready_channel_
->
Put
(
gpu_task
);
...
...
@@ -874,17 +805,86 @@ void PSGPUWrapper::EndPass() {
size_t
keysize_max
=
0
;
// in case of feasign_num = 0, skip dump_to_cpu
for
(
size_t
i
=
0
;
i
<
heter_devices_
.
size
();
i
++
)
{
keysize_max
=
std
::
max
(
keysize_max
,
current_task_
->
device_keys_
[
i
].
size
());
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
keysize_max
=
std
::
max
(
keysize_max
,
current_task_
->
device_dim_keys_
[
i
][
j
].
size
());
}
}
auto
dump_pool_to_cpu_func
=
[
this
](
int
i
,
int
j
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaSetDevice
(
this
->
resource_
->
dev_id
(
i
)));
auto
&
hbm_pool
=
this
->
hbm_pools_
[
i
*
this
->
multi_mf_dim_
+
j
];
auto
&
device_keys
=
this
->
current_task_
->
device_dim_keys_
[
i
][
j
];
size_t
len
=
device_keys
.
size
();
int
mf_dim
=
this
->
index_dim_vec_
[
j
];
VLOG
(
0
)
<<
"dump pool to cpu table: "
<<
i
<<
"with mf dim: "
<<
mf_dim
;
size_t
feature_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
((
mf_dim
+
1
)
*
sizeof
(
float
)));
char
*
test_build_values
=
(
char
*
)
malloc
(
feature_value_size
*
len
);
cudaMemcpy
(
test_build_values
,
hbm_pool
->
mem
(),
feature_value_size
*
len
,
cudaMemcpyDeviceToHost
);
CHECK
(
len
==
hbm_pool
->
capacity
());
#ifdef PADDLE_WITH_PSLIB
uint64_t
unuse_key
=
std
::
numeric_limits
<
uint64_t
>::
max
();
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
if
(
device_keys
[
i
]
==
unuse_key
)
{
continue
;
}
size_t
offset
=
i
*
feature_value_size
;
FeatureValue
*
gpu_val
=
(
FeatureValue
*
)(
test_build_values
+
offset
);
auto
*
downpour_value
=
(
paddle
::
ps
::
DownpourFixedFeatureValue
*
)(
gpu_val
->
cpu_ptr
);
int
downpour_value_size
=
downpour_value
->
size
();
if
(
gpu_val
->
mf_size
>
0
&&
downpour_value_size
==
8
)
{
downpour_value
->
resize
(
gpu_val
->
mf_dim
+
1
+
downpour_value_size
);
}
float
*
cpu_val
=
downpour_value
->
data
();
cpu_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
delta_score_index
()]
=
gpu_val
->
delta_score
;
cpu_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
show_index
()]
=
gpu_val
->
show
;
cpu_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
click_index
()]
=
gpu_val
->
clk
;
cpu_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
embed_w_index
()]
=
gpu_val
->
lr
;
cpu_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
embed_g2sum_index
()]
=
gpu_val
->
lr_g2sum
;
cpu_val
[
paddle
::
ps
::
DownpourCtrDymfAccessor
::
DownpourCtrDymfFeatureValue
::
slot_index
()]
=
gpu_val
->
slot
;
if
(
gpu_val
->
mf_size
>
0
)
{
for
(
int
x
=
0
;
x
<
gpu_val
->
mf_dim
+
1
;
x
++
)
{
cpu_val
[
x
+
8
]
=
gpu_val
->
mf
[
x
];
}
}
}
#endif
free
(
test_build_values
);
};
if
(
multi_mf_dim_
)
{
VLOG
(
0
)
<<
"psgpu wrapper dump pool: multi_mf_dim_: "
<<
multi_mf_dim_
;
size_t
device_num
=
heter_devices_
.
size
();
std
::
vector
<
std
::
thread
>
threads
(
device_num
*
multi_mf_dim_
);
for
(
size_t
i
=
0
;
i
<
device_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
threads
[
i
+
j
*
device_num
]
=
std
::
thread
(
dump_pool_to_cpu_func
,
i
,
j
);
}
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
}
if
(
keysize_max
!=
0
)
{
HeterPs_
->
end_pass
();
}
for
(
size_t
i
=
0
;
i
<
hbm_pools_
.
size
();
i
++
)
{
delete
hbm_pools_
[
i
];
}
gpu_task_pool_
.
Push
(
current_task_
);
current_task_
=
nullptr
;
gpu_free_channel_
->
Put
(
current_task_
);
timer
.
Pause
();
VLOG
(
0
)
<<
"EndPass end, cost time: "
<<
timer
.
ElapsedSec
()
<<
"s"
;
VLOG
(
1
)
<<
"EndPass end, cost time: "
<<
timer
.
ElapsedSec
()
<<
"s"
;
}
void
PSGPUWrapper
::
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
...
...
@@ -936,8 +936,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
pull_gpups_timer
.
Start
();
HeterPs_
->
pull_sparse
(
devid_2_index
,
total_keys
,
total_values_gpu
,
static_cast
<
int
>
(
total_length
));
// PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
// "PullSparseGPU failed in GPUPS."));
pull_gpups_timer
.
Pause
();
VLOG
(
3
)
<<
"Begin Copy result to tensor, total_length["
<<
total_length
...
...
@@ -945,6 +943,98 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
this
->
CopyForPull
(
place
,
gpu_keys
,
values
,
total_values_gpu
,
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
hidden_size
,
total_length
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"GpuPs: PullSparse Only Support CUDAPlace Now."
));
}
all_timer
.
Pause
();
VLOG
(
3
)
<<
"GpuPs PullSparse total costs: "
<<
all_timer
.
ElapsedSec
()
<<
" s, of which GPUPS costs: "
<<
pull_gpups_timer
.
ElapsedSec
()
<<
" s"
;
VLOG
(
3
)
<<
"End PullSparse"
;
}
void
PSGPUWrapper
::
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
const
int
table_id
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
std
::
vector
<
int
>&
slot_dim
,
const
int
hidden_size
)
{
VLOG
(
3
)
<<
"Begine Gpu Ps PullSparse"
;
platform
::
Timer
all_timer
;
platform
::
Timer
pull_gpups_timer
;
all_timer
.
Start
();
size_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
size_t
feature_value_size
=
0
;
feature_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeatureValue
)
+
sizeof
(
float
)
*
(
index_dim_vec_
.
back
()
+
1
));
VLOG
(
0
)
<<
"yxf pull sparse feature_value_size: "
<<
feature_value_size
;
#ifdef PADDLE_WITH_CUDA
VLOG
(
3
)
<<
"Begine Gpu Ps PullSparse"
;
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
feature_value_size
);
FeatureValue
*
total_values_gpu
=
reinterpret_cast
<
FeatureValue
*>
(
buf
->
ptr
());
#endif
#ifdef PADDLE_WITH_XPU_KP
VLOG
(
3
)
<<
"Begine Xpu Ps PullSparse"
;
FeatureValue
*
total_values_gpu
=
nullptr
;
xpu_malloc
(
reinterpret_cast
<
void
**>
(
&
total_values_gpu
),
total_length
*
feature_value_size
);
#endif
if
(
platform
::
is_cpu_place
(
place
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Warning:: CPUPlace is not supported in GpuPs now."
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
VLOG
(
3
)
<<
"Begin copy keys, key_num["
<<
total_length
<<
"]"
;
int
device_id
=
place
.
GetDeviceId
();
int
devid_2_index
=
HeterPs_
->
get_index_by_devid
(
device_id
);
LoDTensor
&
total_keys_tensor
=
keys_tensor
[
devid_2_index
];
uint64_t
*
total_keys
=
reinterpret_cast
<
uint64_t
*>
(
total_keys_tensor
.
mutable_data
<
int64_t
>
(
{
int64_t
(
total_length
),
1
},
place
));
// construct slot_level lod info
auto
slot_lengths_lod
=
slot_lengths
;
for
(
size_t
i
=
1
;
i
<
slot_lengths_lod
.
size
();
i
++
)
{
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_key
=
memory
::
Alloc
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_length
=
memory
::
Alloc
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
uint64_t
**
gpu_keys
=
reinterpret_cast
<
uint64_t
**>
(
buf_key
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
cudaMemcpy
(
gpu_keys
,
keys
.
data
(),
keys
.
size
()
*
sizeof
(
uint64_t
*
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
gpu_len
,
slot_lengths_lod
.
data
(),
slot_lengths
.
size
()
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
auto
buf_dim
=
memory
::
Alloc
(
place
,
slot_dim
.
size
()
*
sizeof
(
int
));
int
*
gpu_dim
=
reinterpret_cast
<
int
*>
(
buf_dim
->
ptr
());
cudaMemcpy
(
gpu_dim
,
slot_dim
.
data
(),
slot_dim
.
size
()
*
sizeof
(
int
),
cudaMemcpyHostToDevice
);
this
->
CopyKeys
(
place
,
gpu_keys
,
total_keys
,
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
static_cast
<
int
>
(
total_length
));
VLOG
(
3
)
<<
"Begin call PullSparseGPU in GPUPS, dev: "
<<
devid_2_index
<<
" len: "
<<
total_length
;
pull_gpups_timer
.
Start
();
HeterPs_
->
pull_sparse
(
devid_2_index
,
total_keys
,
total_values_gpu
,
total_length
);
VLOG
(
3
)
<<
"Begin Copy result to tensor, total_length["
<<
total_length
<<
"]"
;
this
->
CopyForPull
(
place
,
gpu_keys
,
values
,
total_values_gpu
,
gpu_len
,
static_cast
<
int
>
(
slot_lengths
.
size
()),
hidden_size
,
total_length
,
gpu_dim
);
pull_gpups_timer
.
Pause
();
#endif
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU_KP
...
...
@@ -1013,7 +1103,10 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
// #ifdef PADDLE_WITH_CUDA
VLOG
(
3
)
<<
"Begin GPUPS PushSparseGrad"
;
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
sizeof
(
FeaturePushValue
));
size_t
grad_value_size
=
TYPEALIGN
(
8
,
sizeof
(
FeaturePushValue
)
+
(
max_mf_dim_
*
sizeof
(
float
)));
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
grad_value_size
);
VLOG
(
3
)
<<
"Push Sparse Max mf dimention: "
<<
max_mf_dim_
;
FeaturePushValue
*
total_grad_values_gpu
=
reinterpret_cast
<
FeaturePushValue
*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
...
...
@@ -1027,8 +1120,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
uint64_t
*
total_keys
=
reinterpret_cast
<
uint64_t
*>
(
cached_total_keys_tensor
.
data
<
int64_t
>
());
VLOG
(
3
)
<<
"Begin copy grad tensor to gpups struct"
;
this
->
CopyForPush
(
place
,
grad_values
,
total_grad_values_gpu
,
slot_lengths
,
hidden_size
,
total_length
,
batch_size
);
if
(
!
multi_mf_dim_
)
{
this
->
CopyForPush
(
place
,
grad_values
,
total_grad_values_gpu
,
slot_lengths
,
hidden_size
,
total_length
,
batch_size
);
}
else
{
this
->
CopyForPush
(
place
,
grad_values
,
total_grad_values_gpu
,
slot_lengths
,
total_length
,
batch_size
,
grad_value_size
);
}
VLOG
(
3
)
<<
"Begin call PushSparseGPU in GPUPS, dev: "
<<
devid_2_index
<<
" len: "
<<
total_length
;
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
浏览文件 @
3f619290
...
...
@@ -61,6 +61,45 @@ __global__ void PullCopy(float** dest, const FeatureValue* src,
}
}
__global__
void
PullCopy
(
float
**
dest
,
const
FeatureValue
*
src
,
const
int64_t
*
len
,
int
slot_num
,
int
total_len
,
uint64_t
**
keys
,
uint64_t
max_val_size
,
int
*
gpu_dim
)
{
CUDA_KERNEL_LOOP
(
i
,
total_len
)
{
int
low
=
0
;
int
high
=
slot_num
-
1
;
while
(
low
<
high
)
{
int
mid
=
(
low
+
high
)
/
2
;
if
(
i
<
len
[
mid
])
high
=
mid
;
else
low
=
mid
+
1
;
}
int
x
=
low
;
int
y
=
i
-
(
x
?
len
[
x
-
1
]
:
0
);
FeatureValue
*
feature_value_ptr
=
(
FeatureValue
*
)((
char
*
)
src
+
uint64_t
(
i
)
*
uint64_t
(
max_val_size
));
int
mf_dim
=
gpu_dim
[
x
]
-
3
;
if
(
*
(
keys
[
x
]
+
y
)
==
0
)
{
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
))
=
0
;
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
)
+
1
)
=
0
;
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
)
+
2
)
=
0
;
}
else
{
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
))
=
feature_value_ptr
->
show
;
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
)
+
1
)
=
feature_value_ptr
->
clk
;
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
)
+
2
)
=
feature_value_ptr
->
lr
;
}
if
((
feature_value_ptr
)
->
mf_size
==
0
||
*
(
keys
[
x
]
+
y
)
==
0
)
{
for
(
int
j
=
0
;
j
<
mf_dim
;
j
++
)
{
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
)
+
3
+
j
)
=
0
;
}
}
else
{
for
(
int
j
=
0
;
j
<
mf_dim
;
j
++
)
{
*
(
dest
[
x
]
+
y
*
(
mf_dim
+
3
)
+
3
+
j
)
=
feature_value_ptr
->
mf
[
1
+
j
];
}
}
}
}
__global__
void
CopyKeysKernel
(
uint64_t
**
src_keys
,
uint64_t
*
dest_total_keys
,
const
int64_t
*
len
,
int
slot_num
,
int
total_len
)
{
...
...
@@ -105,6 +144,35 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
}
}
__global__
void
PushCopyWithPool
(
FeaturePushValue
*
dest
,
float
**
src
,
int64_t
*
len
,
int
slot_num
,
uint64_t
total_len
,
int
bs
,
int
*
slot_vector
,
int
*
mf_dim_vector
,
size_t
grad_value_size
)
{
CUDA_KERNEL_LOOP
(
i
,
total_len
)
{
int
low
=
0
;
int
high
=
slot_num
-
1
;
while
(
low
<
high
)
{
int
mid
=
(
low
+
high
)
/
2
;
if
(
i
<
len
[
mid
])
high
=
mid
;
else
low
=
mid
+
1
;
}
int
x
=
low
;
int
y
=
i
-
(
x
?
len
[
low
-
1
]
:
0
);
FeaturePushValue
*
cur
=
(
FeaturePushValue
*
)((
char
*
)
dest
+
i
*
grad_value_size
);
cur
->
slot
=
slot_vector
[
x
];
int
mf_dim
=
mf_dim_vector
[
x
];
cur
->
mf_dim
=
mf_dim
;
cur
->
show
=
*
(
src
[
x
]
+
y
*
(
mf_dim
+
3
));
cur
->
clk
=
*
(
src
[
x
]
+
y
*
(
mf_dim
+
3
)
+
1
);
cur
->
lr_g
=
*
(
src
[
x
]
+
y
*
(
mf_dim
+
3
)
+
2
)
*
-
1.
*
bs
;
for
(
int
j
=
0
;
j
<
cur
->
mf_dim
;
j
++
)
{
cur
->
mf_g
[
j
]
=
*
(
src
[
x
]
+
y
*
(
mf_dim
+
3
)
+
3
+
j
)
*
-
1.
*
bs
;
}
}
}
PSGPUWrapper
::~
PSGPUWrapper
()
{
delete
HeterPs_
;
}
void
PSGPUWrapper
::
CopyForPull
(
const
paddle
::
platform
::
Place
&
place
,
...
...
@@ -128,6 +196,26 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
cudaStreamSynchronize
(
stream
);
}
void
PSGPUWrapper
::
CopyForPull
(
const
paddle
::
platform
::
Place
&
place
,
uint64_t
**
gpu_keys
,
const
std
::
vector
<
float
*>&
values
,
const
FeatureValue
*
total_values_gpu
,
const
int64_t
*
gpu_len
,
const
int
slot_num
,
const
int
hidden_size
,
const
int64_t
total_length
,
int
*
gpu_dim
)
{
auto
stream
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
auto
buf_value
=
memory
::
Alloc
(
place
,
values
.
size
()
*
sizeof
(
float
*
));
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_value
->
ptr
());
cudaMemcpy
(
gpu_values
,
values
.
data
(),
values
.
size
()
*
sizeof
(
float
*
),
cudaMemcpyHostToDevice
);
PullCopy
<<<
(
total_length
+
1024
-
1
)
/
1024
,
1024
,
0
,
stream
>>>
(
gpu_values
,
total_values_gpu
,
gpu_len
,
slot_num
,
total_length
,
gpu_keys
,
val_type_size_
,
gpu_dim
);
cudaStreamSynchronize
(
stream
);
}
void
PSGPUWrapper
::
CopyKeys
(
const
paddle
::
platform
::
Place
&
place
,
uint64_t
**
origin_keys
,
uint64_t
*
total_keys
,
const
int64_t
*
gpu_len
,
int
slot_num
,
...
...
@@ -177,6 +265,45 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
cudaStreamSynchronize
(
stream
);
}
void
PSGPUWrapper
::
CopyForPush
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
float
*>&
grad_values
,
FeaturePushValue
*
total_grad_values_gpu
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
uint64_t
total_length
,
const
int
batch_size
,
size_t
grad_value_size
)
{
auto
stream
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
auto
slot_lengths_lod
=
slot_lengths
;
for
(
int
i
=
1
;
i
<
slot_lengths_lod
.
size
();
i
++
)
{
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_grad_value
=
memory
::
Alloc
(
place
,
grad_values
.
size
()
*
sizeof
(
float
*
));
auto
buf_length
=
memory
::
Alloc
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
auto
buf_slot_vector
=
memory
::
Alloc
(
place
,
slot_lengths_lod
.
size
()
*
sizeof
(
int
));
auto
buf_mf_dim_vector
=
memory
::
Alloc
(
place
,
slot_lengths_lod
.
size
()
*
sizeof
(
int
));
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_grad_value
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
int
*
d_slot_vector
=
reinterpret_cast
<
int
*>
(
buf_slot_vector
->
ptr
());
int
*
d_mf_dim_vector
=
reinterpret_cast
<
int
*>
(
buf_mf_dim_vector
->
ptr
());
cudaMemcpy
(
gpu_values
,
grad_values
.
data
(),
grad_values
.
size
()
*
sizeof
(
float
*
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
gpu_len
,
slot_lengths_lod
.
data
(),
slot_lengths
.
size
()
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
d_slot_vector
,
slot_vector_
.
data
(),
slot_lengths_lod
.
size
()
*
sizeof
(
int
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
d_mf_dim_vector
,
slot_mf_dim_vector_
.
data
(),
slot_lengths_lod
.
size
()
*
sizeof
(
int
),
cudaMemcpyHostToDevice
);
PushCopyWithPool
<<<
(
total_length
+
1024
-
1
)
/
1024
,
1024
,
0
,
stream
>>>
(
total_grad_values_gpu
,
gpu_values
,
gpu_len
,
slot_lengths
.
size
(),
total_length
,
batch_size
,
d_slot_vector
,
d_mf_dim_vector
,
grad_value_size
);
cudaStreamSynchronize
(
stream
);
}
void
PSGPUWrapper
::
SetSparseSGD
(
float
nonclk_coeff
,
float
clk_coeff
,
float
min_bound
,
float
max_bound
,
float
learning_rate
,
float
initial_g2sum
,
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
浏览文件 @
3f619290
...
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
...
...
@@ -54,6 +55,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSLIB
#include "afs_api.h"
#endif
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h" // NOLINT
#endif
namespace
paddle
{
namespace
framework
{
...
...
@@ -95,12 +99,21 @@ class PSGPUWrapper {
PSGPUWrapper
()
{
HeterPs_
=
NULL
;
sleep_seconds_before_fail_exit_
=
300
;
pull_thread_pool_
.
resize
(
thread_keys_shard_num_
);
for
(
size_t
i
=
0
;
i
<
pull_thread_pool_
.
size
();
i
++
)
{
pull_thread_pool_
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
hbm_thread_pool_
.
resize
(
thread_keys_shard_num_
);
for
(
size_t
i
=
0
;
i
<
hbm_thread_pool_
.
size
();
i
++
)
{
hbm_thread_pool_
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
}
void
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
const
int
table_id
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
std
::
vector
<
int
>&
slot_dim
,
const
int
hidden_size
);
void
PullSparse
(
const
paddle
::
platform
::
Place
&
place
,
const
int
table_id
,
const
std
::
vector
<
const
uint64_t
*>&
keys
,
const
std
::
vector
<
float
*>&
values
,
...
...
@@ -119,13 +132,23 @@ class PSGPUWrapper {
const
FeatureValue
*
total_values_gpu
,
const
int64_t
*
gpu_len
,
const
int
slot_num
,
const
int
hidden_size
,
const
int64_t
total_length
);
void
CopyForPull
(
const
paddle
::
platform
::
Place
&
place
,
uint64_t
**
gpu_keys
,
const
std
::
vector
<
float
*>&
values
,
const
FeatureValue
*
total_values_gpu
,
const
int64_t
*
gpu_len
,
const
int
slot_num
,
const
int
hidden_size
,
const
int64_t
total_length
,
int
*
gpu_dim
);
void
CopyForPush
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
float
*>&
grad_values
,
FeaturePushValue
*
total_grad_values_gpu
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
int
hidden_size
,
const
int64_t
total_length
,
const
int
batch_size
);
void
CopyForPush
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
const
float
*>&
grad_values
,
FeaturePushValue
*
total_grad_values_gpu
,
const
std
::
vector
<
int64_t
>&
slot_lengths
,
const
uint64_t
total_length
,
const
int
batch_size
,
size_t
grad_value_size
);
void
BuildGPUTask
(
std
::
shared_ptr
<
HeterContext
>
gpu_task
);
void
PreBuildTask
(
std
::
shared_ptr
<
HeterContext
>
gpu_task
);
...
...
@@ -428,6 +451,7 @@ class PSGPUWrapper {
std
::
shared_ptr
<
HeterContext
>
current_task_
=
nullptr
;
std
::
thread
pre_build_threads_
;
bool
running_
=
false
;
std
::
vector
<
std
::
shared_ptr
<
ThreadPool
>>
pull_thread_pool_
;
std
::
vector
<
std
::
shared_ptr
<
ThreadPool
>>
hbm_thread_pool_
;
protected:
...
...
paddle/fluid/operators/pull_gpups_sparse_op.h
浏览文件 @
3f619290
...
...
@@ -26,6 +26,7 @@ template <typename T>
static
void
PullGpuPSSparseFunctor
(
const
framework
::
ExecutionContext
&
ctx
)
{
auto
inputs
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Ids"
);
auto
outputs
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Out"
);
auto
embedding_size_vec
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"size"
);
const
auto
slot_size
=
inputs
.
size
();
std
::
vector
<
const
uint64_t
*>
all_keys
(
slot_size
);
// GpuPSPS only supports float now
...
...
@@ -44,7 +45,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) {
#ifdef PADDLE_WITH_HETERPS
auto
gpu_ps_ptr
=
paddle
::
framework
::
PSGPUWrapper
::
GetInstance
();
gpu_ps_ptr
->
PullSparse
(
ctx
.
GetPlace
(),
0
,
all_keys
,
all_values
,
slot_lengths
,
0
);
embedding_size_vec
,
0
);
#endif
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
3f619290
...
...
@@ -737,7 +737,7 @@ def _pull_gpups_sparse(input,
for i in range(len(inputs))
]
w = helper.create_parameter(
attr=helper.param_attr, shape=[
11
], dtype=dtype, is_bias=False)
attr=helper.param_attr, shape=[
size[0]
], dtype=dtype, is_bias=False)
helper.append_op(
type='pull_gpups_sparse',
inputs={'Ids': inputs,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录