Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0ce42fb0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0ce42fb0
编写于
5月 10, 2022
作者:
Z
zmxdream
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge develop. test=develop (#42624)
上级
21b35167
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
212 addition
and
220 deletion
+212
-220
paddle/fluid/framework/fleet/heter_ps/hashtable.h
paddle/fluid/framework/fleet/heter_ps/hashtable.h
+3
-6
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
+27
-5
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps
+9
-24
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
+0
-2
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+0
-2
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
+8
-0
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
+0
-2
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
+2
-6
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
+40
-33
paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
+57
-38
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
+9
-23
paddle/fluid/framework/fleet/ps_gpu_wrapper.kps
paddle/fluid/framework/fleet/ps_gpu_wrapper.kps
+5
-13
paddle/fluid/framework/ps_gpu_trainer.cc
paddle/fluid/framework/ps_gpu_trainer.cc
+51
-65
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+1
-1
未找到文件。
paddle/fluid/framework/fleet/heter_ps/hashtable.h
浏览文件 @
0ce42fb0
...
@@ -41,9 +41,7 @@ limitations under the License. */
...
@@ -41,9 +41,7 @@ limitations under the License. */
#include "xpu/kernel/simd.h"
#include "xpu/kernel/simd.h"
#endif
#endif
#if defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -132,10 +130,8 @@ class HashTable {
...
@@ -132,10 +130,8 @@ class HashTable {
void
show
();
void
show
();
#if defined(PADDLE_WITH_XPU_KP)
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
);
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
);
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
);
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
);
#endif
template
<
typename
StreamType
>
template
<
typename
StreamType
>
void
dump_to_cpu
(
int
devid
,
StreamType
stream
);
void
dump_to_cpu
(
int
devid
,
StreamType
stream
);
...
@@ -178,9 +174,10 @@ class HashTable {
...
@@ -178,9 +174,10 @@ class HashTable {
TableContainer
<
KeyType
,
ValType
>*
container_
;
TableContainer
<
KeyType
,
ValType
>*
container_
;
#elif defined(PADDLE_WITH_XPU_KP)
#elif defined(PADDLE_WITH_XPU_KP)
XPUCacheArray
<
KeyType
,
ValType
>*
container_
;
XPUCacheArray
<
KeyType
,
ValType
>*
container_
;
OptimizerConfig
*
xpu_optimizer_config_
;
OptimizerConfig
cpu_optimizer_config_
;
#endif
#endif
OptimizerConfig
*
device_optimizer_config_
;
OptimizerConfig
host_optimizer_config_
;
int
BLOCK_SIZE_
{
256
};
int
BLOCK_SIZE_
{
256
};
float
LOAD_FACTOR
{
0.75
f
};
float
LOAD_FACTOR
{
0.75
f
};
size_t
capacity_
;
size_t
capacity_
;
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
浏览文件 @
0ce42fb0
...
@@ -95,6 +95,7 @@ __global__ void dy_mf_search_kernel(Table* table,
...
@@ -95,6 +95,7 @@ __global__ void dy_mf_search_kernel(Table* table,
template
<
typename
Table
,
typename
GradType
,
typename
Sgd
>
template
<
typename
Table
,
typename
GradType
,
typename
Sgd
>
__global__
void
update_kernel
(
Table
*
table
,
__global__
void
update_kernel
(
Table
*
table
,
const
OptimizerConfig
&
optimizer_config
,
const
typename
Table
::
key_type
*
const
keys
,
const
typename
Table
::
key_type
*
const
keys
,
const
GradType
*
const
grads
,
size_t
len
,
const
GradType
*
const
grads
,
size_t
len
,
Sgd
sgd
)
{
Sgd
sgd
)
{
...
@@ -102,13 +103,14 @@ __global__ void update_kernel(Table* table,
...
@@ -102,13 +103,14 @@ __global__ void update_kernel(Table* table,
if
(
i
<
len
)
{
if
(
i
<
len
)
{
auto
it
=
table
->
find
(
keys
[
i
]);
auto
it
=
table
->
find
(
keys
[
i
]);
if
(
it
!=
table
->
end
())
{
if
(
it
!=
table
->
end
())
{
sgd
.
update_value
((
it
.
getter
())
->
second
,
grads
[
i
]);
sgd
.
update_value
(
optimizer_config
,
(
it
.
getter
())
->
second
,
grads
[
i
]);
}
}
}
}
}
}
template
<
typename
Table
,
typename
Sgd
>
template
<
typename
Table
,
typename
Sgd
>
__global__
void
dy_mf_update_kernel
(
Table
*
table
,
__global__
void
dy_mf_update_kernel
(
Table
*
table
,
const
OptimizerConfig
&
optimizer_config
,
const
typename
Table
::
key_type
*
const
keys
,
const
typename
Table
::
key_type
*
const
keys
,
const
char
*
const
grads
,
size_t
len
,
const
char
*
const
grads
,
size_t
len
,
Sgd
sgd
,
size_t
grad_value_size
)
{
Sgd
sgd
,
size_t
grad_value_size
)
{
...
@@ -117,7 +119,7 @@ __global__ void dy_mf_update_kernel(Table* table,
...
@@ -117,7 +119,7 @@ __global__ void dy_mf_update_kernel(Table* table,
auto
it
=
table
->
find
(
keys
[
i
]);
auto
it
=
table
->
find
(
keys
[
i
]);
if
(
it
!=
table
->
end
())
{
if
(
it
!=
table
->
end
())
{
FeaturePushValue
*
cur
=
(
FeaturePushValue
*
)(
grads
+
i
*
grad_value_size
);
FeaturePushValue
*
cur
=
(
FeaturePushValue
*
)(
grads
+
i
*
grad_value_size
);
sgd
.
dy_mf_update_value
((
it
.
getter
())
->
second
,
*
cur
);
sgd
.
dy_mf_update_value
(
optimizer_config
,
(
it
.
getter
())
->
second
,
*
cur
);
}
else
{
}
else
{
printf
(
"yxf::push miss key: %d"
,
keys
[
i
]);
printf
(
"yxf::push miss key: %d"
,
keys
[
i
]);
}
}
...
@@ -127,6 +129,9 @@ __global__ void dy_mf_update_kernel(Table* table,
...
@@ -127,6 +129,9 @@ __global__ void dy_mf_update_kernel(Table* table,
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
KeyType
,
typename
ValType
>
HashTable
<
KeyType
,
ValType
>::
HashTable
(
size_t
capacity
)
{
HashTable
<
KeyType
,
ValType
>::
HashTable
(
size_t
capacity
)
{
container_
=
new
TableContainer
<
KeyType
,
ValType
>
(
capacity
);
container_
=
new
TableContainer
<
KeyType
,
ValType
>
(
capacity
);
cudaMalloc
((
void
**
)
&
device_optimizer_config_
,
sizeof
(
OptimizerConfig
));
cudaMemcpy
((
void
*
)
device_optimizer_config_
,
&
host_optimizer_config_
,
sizeof
(
OptimizerConfig
),
cudaMemcpyHostToDevice
);
rwlock_
.
reset
(
new
phi
::
RWLock
);
rwlock_
.
reset
(
new
phi
::
RWLock
);
}
}
...
@@ -135,6 +140,22 @@ HashTable<KeyType, ValType>::~HashTable() {
...
@@ -135,6 +140,22 @@ HashTable<KeyType, ValType>::~HashTable() {
delete
container_
;
delete
container_
;
}
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
host_optimizer_config_
.
set_sparse_sgd
(
optimizer_config
);
cudaMemcpy
((
void
*
)
device_optimizer_config_
,
&
host_optimizer_config_
,
sizeof
(
OptimizerConfig
),
cudaMemcpyHostToDevice
);
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
host_optimizer_config_
.
set_embedx_sgd
(
optimizer_config
);
cudaMemcpy
((
void
*
)
device_optimizer_config_
,
&
host_optimizer_config_
,
sizeof
(
OptimizerConfig
),
cudaMemcpyHostToDevice
);
}
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
show
()
{
void
HashTable
<
KeyType
,
ValType
>::
show
()
{
container_
->
print
();
container_
->
print
();
...
@@ -279,8 +300,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
...
@@ -279,8 +300,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
return
;
return
;
}
}
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
update_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
update_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
d_grads
,
len
,
sgd
);
container_
,
*
device_optimizer_config_
,
d_keys
,
d_grads
,
len
,
sgd
);
}
}
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
KeyType
,
typename
ValType
>
...
@@ -293,7 +314,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
...
@@ -293,7 +314,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}
}
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
const
int
grid_size
=
(
len
-
1
)
/
BLOCK_SIZE_
+
1
;
dy_mf_update_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
dy_mf_update_kernel
<<<
grid_size
,
BLOCK_SIZE_
,
0
,
stream
>>>
(
container_
,
d_keys
,
d_grads
,
len
,
sgd
,
push_grad_value_size_
);
container_
,
*
device_optimizer_config_
,
d_keys
,
d_grads
,
len
,
sgd
,
push_grad_value_size_
);
}
}
template
class
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
>;
template
class
HashTable
<
unsigned
long
,
paddle
::
framework
::
FeatureValue
>;
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps
浏览文件 @
0ce42fb0
...
@@ -163,7 +163,7 @@ __global__ void search_kernel(Table& table, const KeyType* const keys,
...
@@ -163,7 +163,7 @@ __global__ void search_kernel(Table& table, const KeyType* const keys,
}
}
template <typename KeyType, typename ValType, typename Table, typename GradType>
template <typename KeyType, typename ValType, typename Table, typename GradType>
__global__ void update_kernel(
OptimizerConfig& optimizer_config, Table& table
,
__global__ void update_kernel(
Table& table, OptimizerConfig& optimizer_config
,
const KeyType* const keys,
const KeyType* const keys,
const GradType* const grads, long long len) {
const GradType* const grads, long long len) {
int cid = core_id();
int cid = core_id();
...
@@ -202,12 +202,9 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) {
...
@@ -202,12 +202,9 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) {
sizeof(XPUCacheArray<KeyType, ValType>));
sizeof(XPUCacheArray<KeyType, ValType>));
xpu_memcpy((void*)container_, &tmp_container,
xpu_memcpy((void*)container_, &tmp_container,
sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);
sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);
xpu_malloc(reinterpret_cast<void**>(&device_optimizer_config_),
OptimizerConfig tmp_opt_config;
xpu_malloc(reinterpret_cast<void**>(&xpu_optimizer_config_),
sizeof(OptimizerConfig));
sizeof(OptimizerConfig));
xpu_memcpy((void*)device_optimizer_config_, &host_optimizer_config_,
xpu_memcpy((void*)xpu_optimizer_config_, &tmp_opt_config,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
rwlock_.reset(new phi::RWLock);
rwlock_.reset(new phi::RWLock);
...
@@ -216,7 +213,7 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) {
...
@@ -216,7 +213,7 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) {
template <typename KeyType, typename ValType>
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
HashTable<KeyType, ValType>::~HashTable() {
xpu_free((void*)container_);
xpu_free((void*)container_);
xpu_free((void*)
xpu
_optimizer_config_);
xpu_free((void*)
device
_optimizer_config_);
}
}
template <typename KeyType, typename ValType>
template <typename KeyType, typename ValType>
...
@@ -227,28 +224,16 @@ void HashTable<KeyType, ValType>::show() {
...
@@ -227,28 +224,16 @@ void HashTable<KeyType, ValType>::show() {
template <typename KeyType, typename ValType>
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
void HashTable<KeyType, ValType>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) {
const OptimizerConfig& optimizer_config) {
cpu_optimizer_config_.nonclk_coeff = optimizer_config.nonclk_coeff;
host_optimizer_config_.set_sparse_sgd(optimizer_config);
cpu_optimizer_config_.clk_coeff = optimizer_config.clk_coeff;
xpu_memcpy((void*)device_optimizer_config_, &host_optimizer_config_,
cpu_optimizer_config_.min_bound = optimizer_config.min_bound;
cpu_optimizer_config_.max_bound = optimizer_config.max_bound;
cpu_optimizer_config_.learning_rate = optimizer_config.learning_rate;
cpu_optimizer_config_.initial_g2sum = optimizer_config.initial_g2sum;
cpu_optimizer_config_.initial_range = optimizer_config.initial_range;
xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}
}
template <typename KeyType, typename ValType>
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
void HashTable<KeyType, ValType>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) {
const OptimizerConfig& optimizer_config) {
cpu_optimizer_config_.mf_create_thresholds =
host_optimizer_config_.set_embedx_sgd(optimizer_config);
optimizer_config.mf_create_thresholds;
xpu_memcpy((void*)device_optimizer_config_, &host_optimizer_config_,
cpu_optimizer_config_.mf_learning_rate = optimizer_config.mf_learning_rate;
cpu_optimizer_config_.mf_initial_g2sum = optimizer_config.mf_initial_g2sum;
cpu_optimizer_config_.mf_initial_range = optimizer_config.mf_initial_range;
cpu_optimizer_config_.mf_min_bound = optimizer_config.mf_min_bound;
cpu_optimizer_config_.mf_max_bound = optimizer_config.mf_max_bound;
xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}
}
...
@@ -306,7 +291,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
...
@@ -306,7 +291,7 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
long long c_len = (long long)len;
long long c_len = (long long)len;
update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>,
update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>,
GradType><<<4, 64, stream>>>(
GradType><<<4, 64, stream>>>(
*
xpu_optimizer_config_, *container
_, d_keys, d_grads, c_len);
*
container_, *device_optimizer_config
_, d_keys, d_grads, c_len);
}
}
template <typename KeyType, typename ValType>
template <typename KeyType, typename ValType>
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
浏览文件 @
0ce42fb0
...
@@ -65,10 +65,8 @@ class HeterComm {
...
@@ -65,10 +65,8 @@ class HeterComm {
void
push_sparse
(
int
num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
);
void
push_sparse
(
int
num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
);
#endif
#endif
#if defined(PADDLE_WITH_XPU_KP)
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
);
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
);
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
);
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
);
#endif
int
log2i
(
int
x
);
int
log2i
(
int
x
);
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
浏览文件 @
0ce42fb0
...
@@ -342,7 +342,6 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
...
@@ -342,7 +342,6 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
return
resource_
->
get_index_by_devid
(
devid
);
return
resource_
->
get_index_by_devid
(
devid
);
}
}
#if defined(PADDLE_WITH_XPU_KP)
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
set_sparse_sgd
(
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
const
OptimizerConfig
&
optimizer_config
)
{
...
@@ -358,7 +357,6 @@ void HeterComm<KeyType, ValType, GradType>::set_embedx_sgd(
...
@@ -358,7 +357,6 @@ void HeterComm<KeyType, ValType, GradType>::set_embedx_sgd(
table
->
set_embedx_sgd
(
optimizer_config
);
table
->
set_embedx_sgd
(
optimizer_config
);
}
}
}
}
#endif
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
build_ps
(
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
build_ps
(
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
浏览文件 @
0ce42fb0
...
@@ -48,6 +48,14 @@ int HeterPs::get_index_by_devid(int devid) {
...
@@ -48,6 +48,14 @@ int HeterPs::get_index_by_devid(int devid) {
return
comm_
->
get_index_by_devid
(
devid
);
return
comm_
->
get_index_by_devid
(
devid
);
}
}
void
HeterPs
::
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
comm_
->
set_sparse_sgd
(
optimizer_config
);
}
void
HeterPs
::
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
comm_
->
set_embedx_sgd
(
optimizer_config
);
}
void
HeterPs
::
end_pass
()
{
comm_
->
end_pass
();
}
void
HeterPs
::
end_pass
()
{
comm_
->
end_pass
();
}
void
HeterPs
::
show_one_table
(
int
gpu_num
)
{
comm_
->
show_one_table
(
gpu_num
);
}
void
HeterPs
::
show_one_table
(
int
gpu_num
)
{
comm_
->
show_one_table
(
gpu_num
);
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
浏览文件 @
0ce42fb0
...
@@ -44,10 +44,8 @@ class HeterPs : public HeterPsBase {
...
@@ -44,10 +44,8 @@ class HeterPs : public HeterPsBase {
int
comm_size
)
override
;
int
comm_size
)
override
;
#endif
#endif
#if defined(PADDLE_WITH_XPU_KP)
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
override
;
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
override
;
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
override
;
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
override
;
#endif
void
end_pass
()
override
;
void
end_pass
()
override
;
int
get_index_by_devid
(
int
devid
)
override
;
int
get_index_by_devid
(
int
devid
)
override
;
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
浏览文件 @
0ce42fb0
...
@@ -16,9 +16,7 @@ limitations under the License. */
...
@@ -16,9 +16,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#if defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#endif
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
...
@@ -48,10 +46,8 @@ class HeterPsBase {
...
@@ -48,10 +46,8 @@ class HeterPsBase {
virtual
void
push_sparse
(
int
num
,
FeatureKey
*
d_keys
,
virtual
void
push_sparse
(
int
num
,
FeatureKey
*
d_keys
,
FeaturePushValue
*
d_grads
,
size_t
len
)
=
0
;
FeaturePushValue
*
d_grads
,
size_t
len
)
=
0
;
#if defined(PADDLE_WITH_XPU_KP)
virtual
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
=
0
;
virtual
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{}
virtual
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
=
0
;
virtual
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{}
#endif
static
HeterPsBase
*
get_instance
(
size_t
capacity
,
static
HeterPsBase
*
get_instance
(
size_t
capacity
,
std
::
shared_ptr
<
HeterPsResource
>
resource
);
std
::
shared_ptr
<
HeterPsResource
>
resource
);
...
...
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
浏览文件 @
0ce42fb0
...
@@ -35,58 +35,64 @@ class Optimizer {
...
@@ -35,58 +35,64 @@ class Optimizer {
void
initialize
()
{}
void
initialize
()
{}
__device__
void
update_lr
(
float
&
w
,
float
&
g2sum
,
float
g
,
// NOLINT
__device__
void
update_lr
(
const
OptimizerConfig
&
optimizer_config
,
float
&
w
,
// NOLINT
float
&
g2sum
,
float
g
,
// NOLINT
float
scale
)
{
float
scale
)
{
double
add_g2sum
=
0
;
double
add_g2sum
=
0
;
double
ratio
=
optimizer_config
::
learning_rate
*
double
ratio
=
optimizer_config
.
learning_rate
*
sqrt
(
optimizer_config
::
initial_g2sum
/
sqrt
(
optimizer_config
.
initial_g2sum
/
(
optimizer_config
::
initial_g2sum
+
g2sum
));
(
optimizer_config
.
initial_g2sum
+
g2sum
));
double
scaled_grad
=
g
/
scale
;
double
scaled_grad
=
g
/
scale
;
w
+=
scaled_grad
*
ratio
;
w
+=
scaled_grad
*
ratio
;
if
(
w
<
optimizer_config
::
min_bound
)
w
=
optimizer_config
::
min_bound
;
if
(
w
<
optimizer_config
.
min_bound
)
w
=
optimizer_config
.
min_bound
;
if
(
w
>
optimizer_config
::
max_bound
)
w
=
optimizer_config
::
max_bound
;
if
(
w
>
optimizer_config
.
max_bound
)
w
=
optimizer_config
.
max_bound
;
add_g2sum
+=
scaled_grad
*
scaled_grad
;
add_g2sum
+=
scaled_grad
*
scaled_grad
;
g2sum
+=
add_g2sum
;
g2sum
+=
add_g2sum
;
}
}
__device__
void
update_mf
(
int
n
,
float
*
w
,
float
&
g2sum
,
// NOLINT
__device__
void
update_mf
(
const
OptimizerConfig
&
optimizer_config
,
int
n
,
float
*
w
,
float
&
g2sum
,
// NOLINT
const
float
*
g
,
float
scale
)
{
const
float
*
g
,
float
scale
)
{
double
add_g2sum
=
0
;
double
add_g2sum
=
0
;
double
ratio
=
optimizer_config
::
mf_learning_rate
*
double
ratio
=
optimizer_config
.
mf_learning_rate
*
sqrt
(
optimizer_config
::
mf_initial_g2sum
/
sqrt
(
optimizer_config
.
mf_initial_g2sum
/
(
optimizer_config
::
mf_initial_g2sum
+
g2sum
));
(
optimizer_config
.
mf_initial_g2sum
+
g2sum
));
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
double
scaled_grad
=
g
[
i
]
/
scale
;
double
scaled_grad
=
g
[
i
]
/
scale
;
w
[
i
]
+=
scaled_grad
*
ratio
;
w
[
i
]
+=
scaled_grad
*
ratio
;
if
(
w
[
i
]
<
optimizer_config
::
mf_min_bound
)
if
(
w
[
i
]
<
optimizer_config
.
mf_min_bound
)
w
[
i
]
=
optimizer_config
::
mf_min_bound
;
w
[
i
]
=
optimizer_config
.
mf_min_bound
;
if
(
w
[
i
]
>
optimizer_config
::
mf_max_bound
)
if
(
w
[
i
]
>
optimizer_config
.
mf_max_bound
)
w
[
i
]
=
optimizer_config
::
mf_max_bound
;
w
[
i
]
=
optimizer_config
.
mf_max_bound
;
add_g2sum
+=
scaled_grad
*
scaled_grad
;
add_g2sum
+=
scaled_grad
*
scaled_grad
;
}
}
g2sum
+=
add_g2sum
/
n
;
g2sum
+=
add_g2sum
/
n
;
}
}
__device__
void
update_value
(
ValType
&
val
,
const
GradType
&
grad
)
{
// NOLINT
__device__
void
update_value
(
const
OptimizerConfig
&
optimizer_config
,
ValType
&
val
,
// NOLINT
const
GradType
&
grad
)
{
val
.
slot
=
grad
.
slot
;
val
.
slot
=
grad
.
slot
;
val
.
show
+=
grad
.
show
;
val
.
show
+=
grad
.
show
;
val
.
clk
+=
grad
.
clk
;
val
.
clk
+=
grad
.
clk
;
val
.
delta_score
+=
optimizer_config
::
nonclk_coeff
*
(
grad
.
show
-
grad
.
clk
)
+
val
.
delta_score
+=
optimizer_config
.
nonclk_coeff
*
(
grad
.
show
-
grad
.
clk
)
+
optimizer_config
::
clk_coeff
*
grad
.
clk
;
optimizer_config
.
clk_coeff
*
grad
.
clk
;
update_lr
(
val
.
lr
,
val
.
lr_g2sum
,
grad
.
lr_g
,
grad
.
show
);
update_lr
(
optimizer_config
,
val
.
lr
,
val
.
lr_g2sum
,
grad
.
lr_g
,
grad
.
show
);
if
(
val
.
mf_size
==
0
)
{
if
(
val
.
mf_size
==
0
)
{
if
(
optimizer_config
::
mf_create_thresholds
<=
if
(
optimizer_config
.
mf_create_thresholds
<=
optimizer_config
::
nonclk_coeff
*
(
val
.
show
-
val
.
clk
)
+
optimizer_config
.
nonclk_coeff
*
(
val
.
show
-
val
.
clk
)
+
optimizer_config
::
clk_coeff
*
val
.
clk
)
{
optimizer_config
.
clk_coeff
*
val
.
clk
)
{
val
.
mf_size
=
MF_DIM
+
1
;
val
.
mf_size
=
MF_DIM
+
1
;
val
.
mf
[
0
]
=
0
;
val
.
mf
[
0
]
=
0
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -94,30 +100,31 @@ class Optimizer {
...
@@ -94,30 +100,31 @@ class Optimizer {
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
val
.
mf
[
i
+
1
]
=
val
.
mf
[
i
+
1
]
=
(
curand_uniform
(
&
state
))
*
optimizer_config
::
mf_initial_range
;
(
curand_uniform
(
&
state
))
*
optimizer_config
.
mf_initial_range
;
}
}
}
}
}
else
{
}
else
{
update_mf
(
MF_DIM
,
&
val
.
mf
[
1
],
val
.
mf
[
0
],
grad
.
mf_g
,
grad
.
show
);
update_mf
(
optimizer_config
,
MF_DIM
,
&
val
.
mf
[
1
],
val
.
mf
[
0
],
grad
.
mf_g
,
grad
.
show
);
}
}
}
}
__device__
void
dy_mf_update_value
(
ValType
*
ptr
,
const
GradType
&
grad
)
{
__device__
void
dy_mf_update_value
(
const
OptimizerConfig
&
optimizer_config
,
ValType
*
ptr
,
const
GradType
&
grad
)
{
ptr
->
slot
=
grad
.
slot
;
ptr
->
slot
=
grad
.
slot
;
ptr
->
show
+=
grad
.
show
;
ptr
->
show
+=
grad
.
show
;
ptr
->
clk
+=
grad
.
clk
;
ptr
->
clk
+=
grad
.
clk
;
ptr
->
delta_score
+=
ptr
->
delta_score
+=
optimizer_config
.
nonclk_coeff
*
(
grad
.
show
-
grad
.
clk
)
+
optimizer_config
::
nonclk_coeff
*
(
grad
.
show
-
grad
.
clk
)
+
optimizer_config
.
clk_coeff
*
grad
.
clk
;
optimizer_config
::
clk_coeff
*
grad
.
clk
;
update_lr
(
ptr
->
lr
,
ptr
->
lr_g2sum
,
grad
.
lr_g
,
grad
.
show
);
update_lr
(
optimizer_config
,
ptr
->
lr
,
ptr
->
lr_g2sum
,
grad
.
lr_g
,
grad
.
show
);
// use MF_DIM temporarily
// use MF_DIM temporarily
// ptr->mf_dim = grad.mf_dim;
// ptr->mf_dim = grad.mf_dim;
if
(
ptr
->
mf_size
==
0
)
{
if
(
ptr
->
mf_size
==
0
)
{
if
(
optimizer_config
::
mf_create_thresholds
<=
if
(
optimizer_config
.
mf_create_thresholds
<=
optimizer_config
::
nonclk_coeff
*
(
ptr
->
show
-
ptr
->
clk
)
+
optimizer_config
.
nonclk_coeff
*
(
ptr
->
show
-
ptr
->
clk
)
+
optimizer_config
::
clk_coeff
*
ptr
->
clk
)
{
optimizer_config
.
clk_coeff
*
ptr
->
clk
)
{
// ptr->mf_size = ptr->mf_dim + 1;
// ptr->mf_size = ptr->mf_dim + 1;
ptr
->
mf_size
=
MF_DIM
+
1
;
ptr
->
mf_size
=
MF_DIM
+
1
;
...
@@ -127,11 +134,11 @@ class Optimizer {
...
@@ -127,11 +134,11 @@ class Optimizer {
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
ptr
->
mf
[
i
+
1
]
=
ptr
->
mf
[
i
+
1
]
=
(
curand_uniform
(
&
state
))
*
optimizer_config
::
mf_initial_range
;
(
curand_uniform
(
&
state
))
*
optimizer_config
.
mf_initial_range
;
}
}
}
}
}
else
{
}
else
{
update_mf
(
MF_DIM
,
&
(
ptr
->
mf
[
1
]),
ptr
->
mf
[
0
],
grad
.
mf_g
,
update_mf
(
optimizer_config
,
MF_DIM
,
&
(
ptr
->
mf
[
1
]),
ptr
->
mf
[
0
],
grad
.
mf_g
,
grad
.
show
);
// for local test
grad
.
show
);
// for local test
}
}
}
}
...
...
paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
浏览文件 @
0ce42fb0
...
@@ -14,50 +14,69 @@ limitations under the License. */
...
@@ -14,50 +14,69 @@ limitations under the License. */
#pragma once
#pragma once
#if defined(PADDLE_WITH_CUDA)
namespace
paddle
{
namespace
framework
{
namespace
optimizer_config
{
class
OptimizerConfig
{
public:
float
nonclk_coeff
=
0.1
;
float
clk_coeff
=
1
;
__constant__
float
nonclk_coeff
=
0.1
;
float
min_bound
=
-
10
;
__constant__
float
clk_coeff
=
1
;
float
max_bound
=
10
;
float
learning_rate
=
0.05
;
float
initial_g2sum
=
3.0
;
float
initial_range
=
0
;
__constant__
float
min_bound
=
-
10
;
float
mf_create_thresholds
=
10
;
__constant__
float
max_bound
=
10
;
float
mf_learning_rate
=
0.05
;
__constant__
float
learning_rate
=
0.05
;
float
mf_initial_g2sum
=
3.0
;
__constant__
float
initial_g2sum
=
3.0
;
float
mf_initial_range
=
1e-4
;
__constant__
float
initial_range
=
0
;
float
mf_min_bound
=
-
10
;
float
mf_max_bound
=
10
;
__constant__
float
mf_create_thresholds
=
10
;
void
set_sparse_sgd
(
float
nonclk_coeff
,
float
clk_coeff
,
float
min_bound
,
__constant__
float
mf_learning_rate
=
0.05
;
float
max_bound
,
float
learning_rate
,
float
initial_g2sum
,
__constant__
float
mf_initial_g2sum
=
3.0
;
float
initial_range
)
{
__constant__
float
mf_initial_range
=
1e-4
;
this
->
nonclk_coeff
=
nonclk_coeff
;
__constant__
float
mf_min_bound
=
-
10
;
this
->
clk_coeff
=
clk_coeff
;
__constant__
float
mf_max_bound
=
10
;
this
->
min_bound
=
min_bound
;
}
// namespace optimizer_config
this
->
max_bound
=
max_bound
;
this
->
learning_rate
=
learning_rate
;
this
->
initial_g2sum
=
initial_g2sum
;
this
->
initial_range
=
initial_range
;
}
#elif defined(PADDLE_WITH_XPU_KP)
void
set_sparse_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
namespace
paddle
{
this
->
nonclk_coeff
=
optimizer_config
.
nonclk_coeff
;
namespace
framework
{
this
->
clk_coeff
=
optimizer_config
.
clk_coeff
;
this
->
min_bound
=
optimizer_config
.
min_bound
;
this
->
max_bound
=
optimizer_config
.
max_bound
;
this
->
learning_rate
=
optimizer_config
.
learning_rate
;
this
->
initial_g2sum
=
optimizer_config
.
initial_g2sum
;
this
->
initial_range
=
optimizer_config
.
initial_range
;
}
class
OptimizerConfig
{
void
set_embedx_sgd
(
float
mf_create_thresholds
,
float
mf_learning_rate
,
public:
float
mf_initial_g2sum
,
float
mf_initial_range
,
float
nonclk_coeff
;
float
mf_min_bound
,
float
mf_max_bound
)
{
float
clk_coeff
;
this
->
mf_create_thresholds
=
mf_create_thresholds
;
this
->
mf_learning_rate
=
mf_learning_rate
;
float
min_bound
;
this
->
mf_initial_g2sum
=
mf_initial_g2sum
;
float
max_bound
;
this
->
mf_initial_range
=
mf_initial_range
;
float
learning_rate
;
this
->
mf_min_bound
=
mf_min_bound
;
float
initial_g2sum
;
this
->
mf_max_bound
=
mf_max_bound
;
float
initial_range
;
}
float
mf_create_thresholds
;
void
set_embedx_sgd
(
const
OptimizerConfig
&
optimizer_config
)
{
float
mf_learning_rate
;
this
->
mf_create_thresholds
=
optimizer_config
.
mf_create_thresholds
;
float
mf_initial_g2sum
;
this
->
mf_learning_rate
=
optimizer_config
.
mf_learning_rate
;
float
mf_initial_range
;
this
->
mf_initial_g2sum
=
optimizer_config
.
mf_initial_g2sum
;
float
mf_min_bound
;
this
->
mf_initial_range
=
optimizer_config
.
mf_initial_range
;
float
mf_max_bound
;
this
->
mf_min_bound
=
optimizer_config
.
mf_min_bound
;
this
->
mf_max_bound
=
optimizer_config
.
mf_max_bound
;
}
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#endif
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
浏览文件 @
0ce42fb0
...
@@ -181,35 +181,21 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
...
@@ -181,35 +181,21 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float
min_bound
,
float
max_bound
,
float
min_bound
,
float
max_bound
,
float
learning_rate
,
float
initial_g2sum
,
float
learning_rate
,
float
initial_g2sum
,
float
initial_range
)
{
float
initial_range
)
{
cudaMemcpyToSymbol
(
optimizer_config
::
nonclk_coeff
,
&
nonclk_coeff
,
OptimizerConfig
optimizer_config
;
sizeof
(
float
));
optimizer_config
.
set_sparse_sgd
(
nonclk_coeff
,
clk_coeff
,
min_bound
,
max_bound
,
cudaMemcpyToSymbol
(
optimizer_config
::
clk_coeff
,
&
clk_coeff
,
sizeof
(
float
));
learning_rate
,
initial_g2sum
,
initial_range
);
cudaMemcpyToSymbol
(
optimizer_config
::
min_bound
,
&
min_bound
,
sizeof
(
float
));
HeterPs_
->
set_sparse_sgd
(
optimizer_config
);
cudaMemcpyToSymbol
(
optimizer_config
::
max_bound
,
&
max_bound
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
learning_rate
,
&
learning_rate
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
initial_g2sum
,
&
initial_g2sum
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
initial_range
,
&
initial_range
,
sizeof
(
float
));
}
}
void
PSGPUWrapper
::
SetEmbedxSGD
(
float
mf_create_thresholds
,
void
PSGPUWrapper
::
SetEmbedxSGD
(
float
mf_create_thresholds
,
float
mf_learning_rate
,
float
mf_initial_g2sum
,
float
mf_learning_rate
,
float
mf_initial_g2sum
,
float
mf_initial_range
,
float
mf_min_bound
,
float
mf_initial_range
,
float
mf_min_bound
,
float
mf_max_bound
)
{
float
mf_max_bound
)
{
cudaMemcpyToSymbol
(
optimizer_config
::
mf_create_thresholds
,
OptimizerConfig
optimizer_config
;
&
mf_create_thresholds
,
sizeof
(
float
));
optimizer_config
.
set_embedx_sgd
(
mf_create_thresholds
,
mf_learning_rate
,
cudaMemcpyToSymbol
(
optimizer_config
::
mf_learning_rate
,
&
mf_learning_rate
,
mf_initial_g2sum
,
mf_initial_range
,
sizeof
(
float
));
mf_min_bound
,
mf_max_bound
);
cudaMemcpyToSymbol
(
optimizer_config
::
mf_initial_g2sum
,
&
mf_initial_g2sum
,
HeterPs_
->
set_embedx_sgd
(
optimizer_config
);
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_initial_range
,
&
mf_initial_range
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_min_bound
,
&
mf_min_bound
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_max_bound
,
&
mf_max_bound
,
sizeof
(
float
));
}
}
}
// end namespace framework
}
// end namespace framework
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.kps
浏览文件 @
0ce42fb0
...
@@ -256,13 +256,8 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
...
@@ -256,13 +256,8 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
float learning_rate, float initial_g2sum,
float learning_rate, float initial_g2sum,
float initial_range) {
float initial_range) {
OptimizerConfig optimizer_config;
OptimizerConfig optimizer_config;
optimizer_config.nonclk_coeff = nonclk_coeff;
optimizer_config.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, max_bound,
optimizer_config.clk_coeff = clk_coeff;
learning_rate, initial_g2sum, initial_range);
optimizer_config.min_bound = min_bound;
optimizer_config.max_bound = max_bound;
optimizer_config.learning_rate = learning_rate;
optimizer_config.initial_g2sum = initial_g2sum;
optimizer_config.initial_range = initial_range;
HeterPs_->set_sparse_sgd(optimizer_config);
HeterPs_->set_sparse_sgd(optimizer_config);
}
}
...
@@ -271,12 +266,9 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
...
@@ -271,12 +266,9 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
float mf_initial_range, float mf_min_bound,
float mf_initial_range, float mf_min_bound,
float mf_max_bound) {
float mf_max_bound) {
OptimizerConfig optimizer_config;
OptimizerConfig optimizer_config;
optimizer_config.mf_create_thresholds = mf_create_thresholds;
optimizer_config.set_embedx_sgd(mf_create_thresholds, mf_learning_rate,
optimizer_config.mf_learning_rate = mf_learning_rate;
mf_initial_g2sum, mf_initial_range,
optimizer_config.mf_initial_g2sum = mf_initial_g2sum;
mf_min_bound, mf_max_bound);
optimizer_config.mf_initial_range = mf_initial_range;
optimizer_config.mf_min_bound = mf_min_bound;
optimizer_config.mf_max_bound = mf_max_bound;
HeterPs_->set_embedx_sgd(optimizer_config);
HeterPs_->set_embedx_sgd(optimizer_config);
}
}
...
...
paddle/fluid/framework/ps_gpu_trainer.cc
浏览文件 @
0ce42fb0
...
@@ -95,8 +95,46 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -95,8 +95,46 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
return
;
return
;
}
}
void
add_sparse_optimizer
(
std
::
unordered_map
<
std
::
string
,
float
>&
config
,
// NOLINT
const
::
paddle
::
SparseCommonSGDRuleParameter
&
sgd_param
,
const
std
::
string
&
prefix
=
""
)
{
auto
optimizer_name
=
sgd_param
.
name
();
if
(
optimizer_name
==
"naive"
)
{
config
[
prefix
+
"learning_rate"
]
=
sgd_param
.
naive
().
learning_rate
();
config
[
prefix
+
"initial_range"
]
=
sgd_param
.
naive
().
initial_range
();
if
(
sgd_param
.
naive
().
weight_bounds_size
()
==
2
)
{
config
[
prefix
+
"min_bound"
]
=
sgd_param
.
naive
().
weight_bounds
()[
0
];
config
[
prefix
+
"max_bound"
]
=
sgd_param
.
naive
().
weight_bounds
()[
1
];
}
}
else
if
(
optimizer_name
==
"adagrad"
)
{
config
[
prefix
+
"learning_rate"
]
=
sgd_param
.
adagrad
().
learning_rate
();
config
[
prefix
+
"initial_range"
]
=
sgd_param
.
adagrad
().
initial_range
();
config
[
prefix
+
"initial_g2sum"
]
=
sgd_param
.
adagrad
().
initial_g2sum
();
if
(
sgd_param
.
adagrad
().
weight_bounds_size
()
==
2
)
{
config
[
prefix
+
"min_bound"
]
=
sgd_param
.
adagrad
().
weight_bounds
()[
0
];
config
[
prefix
+
"max_bound"
]
=
sgd_param
.
adagrad
().
weight_bounds
()[
1
];
}
}
else
if
(
optimizer_name
==
"std_adagrad"
)
{
config
[
prefix
+
"learning_rate"
]
=
sgd_param
.
adagrad
().
learning_rate
();
config
[
prefix
+
"initial_range"
]
=
sgd_param
.
adagrad
().
initial_range
();
config
[
prefix
+
"initial_g2sum"
]
=
sgd_param
.
adagrad
().
initial_g2sum
();
if
(
sgd_param
.
adagrad
().
weight_bounds_size
()
==
2
)
{
config
[
prefix
+
"min_bound"
]
=
sgd_param
.
adagrad
().
weight_bounds
()[
0
];
config
[
prefix
+
"max_bound"
]
=
sgd_param
.
adagrad
().
weight_bounds
()[
1
];
}
}
else
if
(
optimizer_name
==
"adam"
)
{
config
[
prefix
+
"learning_rate"
]
=
sgd_param
.
adam
().
learning_rate
();
config
[
prefix
+
"initial_range"
]
=
sgd_param
.
adam
().
initial_range
();
if
(
sgd_param
.
adam
().
weight_bounds_size
()
==
2
)
{
config
[
prefix
+
"min_bound"
]
=
sgd_param
.
adam
().
weight_bounds
()[
0
];
config
[
prefix
+
"max_bound"
]
=
sgd_param
.
adam
().
weight_bounds
()[
1
];
}
}
}
void
PSGPUTrainer
::
InitializeGPUServer
(
const
TrainerDesc
&
trainer_desc
)
{
void
PSGPUTrainer
::
InitializeGPUServer
(
const
TrainerDesc
&
trainer_desc
)
{
//
add for hbmps optimizer config
//
optimizer config for hbmps
auto
fleet_desc_str
=
trainer_desc
.
fleet_desc
();
auto
fleet_desc_str
=
trainer_desc
.
fleet_desc
();
google
::
protobuf
::
TextFormat
::
ParseFromString
(
fleet_desc_str
,
&
_ps_param
);
google
::
protobuf
::
TextFormat
::
ParseFromString
(
fleet_desc_str
,
&
_ps_param
);
auto
sparse_table
=
auto
sparse_table
=
...
@@ -105,7 +143,7 @@ void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
...
@@ -105,7 +143,7 @@ void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
auto
sparse_table_accessor_parameter
=
auto
sparse_table_accessor_parameter
=
sparse_table_accessor
.
downpour_accessor_param
();
sparse_table_accessor
.
downpour_accessor_param
();
auto
accessor_class
=
sparse_table_accessor
.
accessor_class
();
auto
accessor_class
=
sparse_table_accessor
.
accessor_class
();
//
gpups' sparse table optimizer config
//
NOTE(zhangminxu): gpups' sparse table optimizer config,
// now only support single sparse table
// now only support single sparse table
// auto sparse_table = param_.sparse_table(0);
// auto sparse_table = param_.sparse_table(0);
std
::
unordered_map
<
std
::
string
,
float
>
config
;
std
::
unordered_map
<
std
::
string
,
float
>
config
;
...
@@ -126,7 +164,14 @@ void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
...
@@ -126,7 +164,14 @@ void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
config
[
"max_bound"
]
=
config
[
"max_bound"
]
=
sparse_table_accessor
.
sparse_sgd_param
().
weight_bounds
()[
1
];
sparse_table_accessor
.
sparse_sgd_param
().
weight_bounds
()[
1
];
}
}
// NOTE(zhangminxu): for DownpourCtrAccessor & DownpourCtrDoubleAccessor,
// optimizer config for embed_w & embedx_w is the same
config
[
"mf_create_thresholds"
]
=
sparse_table_accessor
.
embedx_threshold
();
config
[
"mf_create_thresholds"
]
=
sparse_table_accessor
.
embedx_threshold
();
config
[
"mf_learning_rate"
]
=
config
[
"learning_rate"
];
config
[
"mf_initial_g2sum"
]
=
config
[
"initial_g2sum"
];
config
[
"mf_initial_range"
]
=
config
[
"initial_range"
];
config
[
"mf_min_bound"
]
=
config
[
"min_bound"
];
config
[
"mf_max_bound"
]
=
config
[
"max_bound"
];
}
else
if
(
accessor_class
==
"DownpourSparseValueAccessor"
)
{
}
else
if
(
accessor_class
==
"DownpourSparseValueAccessor"
)
{
auto
optimizer_name
=
sparse_table_accessor
.
sparse_commonsgd_param
().
name
();
auto
optimizer_name
=
sparse_table_accessor
.
sparse_commonsgd_param
().
name
();
if
(
optimizer_name
==
"naive"
)
{
if
(
optimizer_name
==
"naive"
)
{
...
@@ -186,71 +231,12 @@ void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
...
@@ -186,71 +231,12 @@ void PSGPUTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) {
accessor_class
==
"DownpourDoubleUnitAccessor"
)
{
accessor_class
==
"DownpourDoubleUnitAccessor"
)
{
config
[
"nonclk_coeff"
]
=
sparse_table_accessor_parameter
.
nonclk_coeff
();
config
[
"nonclk_coeff"
]
=
sparse_table_accessor_parameter
.
nonclk_coeff
();
config
[
"clk_coeff"
]
=
sparse_table_accessor_parameter
.
click_coeff
();
config
[
"clk_coeff"
]
=
sparse_table_accessor_parameter
.
click_coeff
();
auto
optimizer_name
=
sparse_table_accessor
.
embedx_sgd_param
().
name
();
if
(
optimizer_name
==
"naive"
)
{
config
[
"mf_learning_rate"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
naive
().
learning_rate
();
config
[
"mf_initial_range"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
naive
().
initial_range
();
if
(
sparse_table_accessor
.
embedx_sgd_param
()
.
naive
()
.
weight_bounds_size
()
==
2
)
{
config
[
"mf_min_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
naive
().
weight_bounds
()[
0
];
config
[
"mf_max_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
naive
().
weight_bounds
()[
1
];
}
}
else
if
(
optimizer_name
==
"adagrad"
)
{
config
[
"mf_learning_rate"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adagrad
().
learning_rate
();
config
[
"mf_initial_range"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adagrad
().
initial_range
();
config
[
"mf_initial_g2sum"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adagrad
().
initial_g2sum
();
if
(
sparse_table_accessor
.
embedx_sgd_param
()
.
adagrad
()
.
weight_bounds_size
()
==
2
)
{
config
[
"mf_min_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
()
.
adagrad
()
.
weight_bounds
()[
0
];
config
[
"mf_max_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
()
.
adagrad
()
.
weight_bounds
()[
1
];
}
}
else
if
(
optimizer_name
==
"std_adagrad"
)
{
config
[
"mf_learning_rate"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adagrad
().
learning_rate
();
config
[
"mf_initial_range"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adagrad
().
initial_range
();
config
[
"mf_initial_g2sum"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adagrad
().
initial_g2sum
();
if
(
sparse_table_accessor
.
embedx_sgd_param
()
.
adagrad
()
.
weight_bounds_size
()
==
2
)
{
config
[
"mf_min_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
()
.
adagrad
()
.
weight_bounds
()[
0
];
config
[
"mf_max_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
()
.
adagrad
()
.
weight_bounds
()[
1
];
}
}
else
if
(
optimizer_name
==
"adam"
)
{
config
[
"mf_learning_rate"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adam
().
learning_rate
();
config
[
"mf_initial_range"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adam
().
initial_range
();
if
(
sparse_table_accessor
.
embedx_sgd_param
()
.
adam
()
.
weight_bounds_size
()
==
2
)
{
config
[
"mf_min_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adam
().
weight_bounds
()[
0
];
config
[
"mf_max_bound"
]
=
sparse_table_accessor
.
embedx_sgd_param
().
adam
().
weight_bounds
()[
1
];
}
}
config
[
"mf_create_thresholds"
]
=
sparse_table_accessor
.
embedx_threshold
();
config
[
"mf_create_thresholds"
]
=
sparse_table_accessor
.
embedx_threshold
();
// optimizer config for embed_w and embedx
add_sparse_optimizer
(
config
,
sparse_table_accessor
.
embed_sgd_param
());
add_sparse_optimizer
(
config
,
sparse_table_accessor
.
embedx_sgd_param
(),
"mf_"
);
}
}
auto
ps_gpu_wrapper
=
paddle
::
framework
::
PSGPUWrapper
::
GetInstance
();
auto
ps_gpu_wrapper
=
paddle
::
framework
::
PSGPUWrapper
::
GetInstance
();
ps_gpu_wrapper
->
InitializeGPUServer
(
config
);
ps_gpu_wrapper
->
InitializeGPUServer
(
config
);
}
}
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
0ce42fb0
...
@@ -37,7 +37,7 @@ limitations under the License. */
...
@@ -37,7 +37,7 @@ limitations under the License. */
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/backends/dynload/port.h"
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
#include
<pslib.h>
#include
"proto/ps.pb.h"
#endif
#endif
namespace
paddle
{
namespace
paddle
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录