Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d9c174d1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d9c174d1
编写于
12月 29, 2021
作者:
Y
yaoxuefeng
提交者:
GitHub
12月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add hashtable dynamic mf support (#38493)
add hashtable dynamic mf support
上级
7411dab5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
142 addition
and
0 deletion
+142
-0
paddle/fluid/framework/fleet/heter_ps/hashtable.h
paddle/fluid/framework/fleet/heter_ps/hashtable.h
+20
-0
paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
+88
-0
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
+34
-0
未找到文件。
paddle/fluid/framework/fleet/heter_ps/hashtable.h
浏览文件 @
d9c174d1
...
...
@@ -27,6 +27,8 @@ limitations under the License. */
#include "thrust/pair.h"
// #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
...
...
@@ -53,8 +55,11 @@ class HashTable {
HashTable
&
operator
=
(
const
HashTable
&
)
=
delete
;
void
insert
(
const
KeyType
*
d_keys
,
const
ValType
*
d_vals
,
size_t
len
,
gpuStream_t
stream
);
void
insert
(
const
KeyType
*
d_keys
,
size_t
len
,
char
*
pool
,
size_t
start_index
,
gpuStream_t
stream
);
void
get
(
const
KeyType
*
d_keys
,
ValType
*
d_vals
,
size_t
len
,
gpuStream_t
stream
);
void
get
(
const
KeyType
*
d_keys
,
char
*
d_vals
,
size_t
len
,
gpuStream_t
stream
);
void
show
();
void
dump_to_cpu
(
int
devid
,
cudaStream_t
stream
);
...
...
@@ -62,8 +67,20 @@ class HashTable {
void
update
(
const
KeyType
*
d_keys
,
const
GradType
*
d_grads
,
size_t
len
,
Sgd
sgd
,
gpuStream_t
stream
);
template
<
typename
Sgd
>
void
update
(
const
KeyType
*
d_keys
,
const
char
*
d_grads
,
size_t
len
,
Sgd
sgd
,
gpuStream_t
stream
);
int
size
()
{
return
container_
->
size
();
}
void
set_feature_value_size
(
size_t
pull_feature_value_size
,
size_t
push_grad_value_size
)
{
pull_feature_value_size_
=
pull_feature_value_size
;
push_grad_value_size_
=
push_grad_value_size
;
VLOG
(
3
)
<<
"hashtable set pull value size: "
<<
pull_feature_value_size_
<<
" push value size: "
<<
push_grad_value_size_
;
}
std
::
unique_ptr
<
RWLock
>
rwlock_
{
nullptr
};
private:
...
...
@@ -71,6 +88,9 @@ class HashTable {
int
BLOCK_SIZE_
{
256
};
float
LOAD_FACTOR
{
0.75
f
};
size_t
capacity_
;
size_t
max_mf_dim_
=
8
;
size_t
pull_feature_value_size_
;
size_t
push_grad_value_size_
;
};
}
// end namespace framework
}
// end namespace paddle
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
浏览文件 @
d9c174d1
...
...
@@ -42,6 +42,23 @@ __global__ void insert_kernel(Table* table,
}
}
template
<
typename
Table
>
__global__
void
insert_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
size_t
len
,
char
*
pool
,
int
start_index
)
{
ReplaceOp
<
typename
Table
::
mapped_type
>
op
;
thrust
::
pair
<
typename
Table
::
key_type
,
typename
Table
::
mapped_type
>
kv
;
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
len
)
{
kv
.
first
=
keys
[
i
];
kv
.
second
=
(
Table
::
mapped_type
)(
pool
+
(
start_index
+
i
)
*
80
);
auto
it
=
table
->
insert
(
kv
,
op
);
assert
(
it
!=
table
->
end
()
&&
"error: insert fails: table is full"
);
}
}
template
<
typename
Table
>
__global__
void
search_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
...
...
@@ -56,6 +73,20 @@ __global__ void search_kernel(Table* table,
}
}
template
<
typename
Table
>
__global__
void
dy_mf_search_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
char
*
const
vals
,
size_t
len
,
size_t
pull_feature_value_size
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
len
)
{
auto
it
=
table
->
find
(
keys
[
i
]);
if
(
it
!=
table
->
end
())
{
*
(
FeatureValue
*
)(
vals
+
i
*
pull_feature_value_size
)
=
*
(
it
->
second
);
}
}
}
template
<
typename
Table
,
typename
GradType
,
typename
Sgd
>
__global__
void
update_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
...
...
@@ -70,6 +101,23 @@ __global__ void update_kernel(Table* table,
}
}
template
<
typename
Table
,
typename
Sgd
>
__global__
void
dy_mf_update_kernel
(
Table
*
table
,
const
typename
Table
::
key_type
*
const
keys
,
const
char
*
const
grads
,
size_t
len
,
Sgd
sgd
,
size_t
grad_value_size
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
len
)
{
auto
it
=
table
->
find
(
keys
[
i
]);
if
(
it
!=
table
->
end
())
{
FeaturePushValue
*
cur
=
(
FeaturePushValue
*
)(
grads
+
i
*
grad_value_size
);
sgd
.
dy_mf_update_value
((
it
.
getter
())
->
second
,
*
cur
);
}
else
{
printf
(
"yxf::push miss key: %d"
,
keys
[
i
]);
}
}
}
template
<
typename
KeyType
,
typename
ValType
>
HashTable
<
KeyType
,
ValType
>::
HashTable
(
size_t
capacity
)
{
container_
=
new
TableContainer
<
KeyType
,
ValType
>
(
capacity
);
...
...
@@ -97,6 +145,17 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
d_vals
,
len
);
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
get
(
const
KeyType
*
d_keys
,
char
*
d_vals
,
size_t
len
,
gpuStream_t
stream
)
{
if
(
len
==
0
)
{
return
;
}
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
dy_mf_search_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
d_vals
,
len
,
pull_feature_value_size_
);
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
insert
(
const
KeyType
*
d_keys
,
const
ValType
*
d_vals
,
size_t
len
,
...
...
@@ -109,6 +168,21 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals
,
len
);
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
insert
(
const
KeyType
*
d_keys
,
size_t
len
,
char
*
pool
,
size_t
start_index
,
gpuStream_t
stream
)
{
if
(
len
==
0
)
{
return
;
}
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
if
(
pool
==
NULL
)
{
return
;
}
insert_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
len
,
pool
,
start_index
);
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
dump_to_cpu
(
int
devid
,
cudaStream_t
stream
)
{
container_
->
prefetch
(
cudaCpuDeviceId
,
stream
);
...
...
@@ -166,6 +240,20 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
d_grads
,
len
,
sgd
);
}
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
Sgd
>
void
HashTable
<
KeyType
,
ValType
>::
update
(
const
KeyType
*
d_keys
,
const
char
*
d_grads
,
size_t
len
,
Sgd
sgd
,
gpuStream_t
stream
)
{
if
(
len
==
0
)
{
return
;
}
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
dy_mf_update_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
d_grads
,
len
,
sgd
,
push_grad_value_size_
);
}
}
// end namespace framework
}
// end namespace paddle
#endif
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
浏览文件 @
d9c174d1
...
...
@@ -96,6 +96,40 @@ class Optimizer {
update_mf
(
MF_DIM
,
&
val
.
mf
[
1
],
val
.
mf
[
0
],
grad
.
mf_g
,
grad
.
show
);
}
}
__device__
void
dy_mf_update_value
(
ValType
*
ptr
,
const
GradType
&
grad
)
{
ptr
->
slot
=
grad
.
slot
;
ptr
->
show
+=
grad
.
show
;
ptr
->
clk
+=
grad
.
clk
;
ptr
->
delta_score
+=
optimizer_config
::
nonclk_coeff
*
(
grad
.
show
-
grad
.
clk
)
+
optimizer_config
::
clk_coeff
*
grad
.
clk
;
update_lr
(
ptr
->
lr
,
ptr
->
lr_g2sum
,
grad
.
lr_g
,
grad
.
show
);
// use MF_DIM temporarily
// ptr->mf_dim = grad.mf_dim;
if
(
ptr
->
mf_size
==
0
)
{
if
(
optimizer_config
::
mf_create_thresholds
<=
optimizer_config
::
nonclk_coeff
*
(
ptr
->
show
-
ptr
->
clk
)
+
optimizer_config
::
clk_coeff
*
ptr
->
clk
)
{
// ptr->mf_size = ptr->mf_dim + 1;
ptr
->
mf_size
=
MF_DIM
+
1
;
ptr
->
mf
[
0
]
=
0
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandState
state
;
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
ptr
->
mf
[
i
+
1
]
=
(
curand_uniform
(
&
state
))
*
optimizer_config
::
mf_initial_range
;
}
}
}
else
{
update_mf
(
MF_DIM
,
&
(
ptr
->
mf
[
1
]),
ptr
->
mf
[
0
],
grad
.
mf_g
,
grad
.
show
);
// for local test
}
}
};
}
// end namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录