Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6e0da01c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
6e0da01c
编写于
1月 14, 2021
作者:
Y
yaoxuefeng
提交者:
GitHub
1月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Heter ps new (#30198)
上级
49e79cad
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
305 addition
and
230 deletion
+305
-230
paddle/fluid/framework/fleet/heter_context.h
paddle/fluid/framework/fleet/heter_context.h
+36
-0
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
+5
-1
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+148
-12
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
+37
-0
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
+73
-2
paddle/fluid/framework/ps_gpu_trainer.cc
paddle/fluid/framework/ps_gpu_trainer.cc
+0
-208
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+1
-7
paddle/fluid/pybind/ps_gpu_wrapper_py.cc
paddle/fluid/pybind/ps_gpu_wrapper_py.cc
+5
-0
未找到文件。
paddle/fluid/framework/fleet/heter_context.h
浏览文件 @
6e0da01c
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB)
#include <algorithm>
#include <map>
#include <unordered_map>
#include <vector>
...
...
@@ -33,6 +34,8 @@ class HeterContext {
std
::
vector
<
std
::
vector
<
FeatureKey
>>
feature_keys_
;
std
::
vector
<
std
::
vector
<
paddle
::
ps
::
DownpourFixedFeatureValue
*>>
value_ptr_
;
std
::
vector
<
std
::
vector
<
FeatureValue
>>
feature_values_
;
std
::
vector
<
std
::
mutex
*>
mutex_lock_
;
uint32_t
shard_num_
=
37
;
uint64_t
size
()
{
uint64_t
total_size
=
0
;
for
(
auto
&
keys
:
feature_keys_
)
{
...
...
@@ -40,6 +43,39 @@ class HeterContext {
}
return
total_size
;
}
void
SetShardNum
(
uint32_t
shard_num
)
{
shard_num_
=
shard_num
;
}
uint32_t
ShardNum
()
{
return
shard_num_
;
}
void
init
()
{
feature_keys_
.
resize
(
shard_num_
);
}
void
batch_add_keys
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
thread_keys
)
{
assert
(
thread_keys
.
size
()
==
feature_keys_
.
size
());
for
(
uint32_t
i
=
0
;
i
<
shard_num_
;
i
++
)
{
int
idx
=
0
;
// mutex_lock_[i]->lock();
idx
=
feature_keys_
[
i
].
size
();
feature_keys_
[
i
].
resize
(
feature_keys_
[
i
].
size
()
+
thread_keys
[
i
].
size
());
for
(
uint64_t
j
=
0
;
j
<
thread_keys
[
i
].
size
();
j
++
)
{
feature_keys_
[
i
][
idx
+
j
]
=
thread_keys
[
i
][
j
];
}
// mutex_lock_[i]->unlock();
}
}
void
UniqueKeys
()
{
std
::
vector
<
std
::
thread
>
threads
;
auto
unique_func
=
[
this
](
int
i
)
{
auto
&
cur_keys
=
feature_keys_
[
i
];
std
::
sort
(
cur_keys
.
begin
(),
cur_keys
.
end
());
std
::
vector
<
FeatureKey
>::
iterator
it
;
it
=
std
::
unique
(
cur_keys
.
begin
(),
cur_keys
.
end
());
cur_keys
.
resize
(
std
::
distance
(
cur_keys
.
begin
(),
it
));
};
for
(
uint32_t
i
=
0
;
i
<
shard_num_
;
i
++
)
{
threads
.
push_back
(
std
::
thread
(
unique_func
,
i
));
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
}
};
}
// end namespace framework
...
...
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
浏览文件 @
6e0da01c
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include <curand_kernel.h>
#include "optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
...
...
@@ -106,8 +107,11 @@ class Optimizer {
optimizer_config
::
clk_coeff
*
val
.
clk
)
{
val
.
mf_size
=
MF_DIM
+
1
;
val
.
mf
[
0
]
=
0
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandState
state
;
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
val
.
mf
[
i
+
1
]
=
(
cu
da_normal_random
((
int
)
grad
.
show
)
*
2
-
1
)
*
val
.
mf
[
i
+
1
]
=
(
cu
rand_uniform
(
&
state
)
)
*
optimizer_config
::
mf_initial_range
;
}
}
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
6e0da01c
...
...
@@ -27,13 +27,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB)
/*
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
*/
#include <deque>
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/platform/timer.h"
...
...
@@ -43,10 +40,142 @@ namespace framework {
std
::
shared_ptr
<
PSGPUWrapper
>
PSGPUWrapper
::
s_instance_
=
NULL
;
bool
PSGPUWrapper
::
is_initialized_
=
false
;
void
PSGPUWrapper
::
BuildGPUPS
(
uint64_t
table_id
,
int
feature_dim
,
std
::
shared_ptr
<
HeterContext
>
gpu_task
)
{
void
PSGPUWrapper
::
BuildTask
(
uint64_t
table_id
,
int
feature_dim
)
{
VLOG
(
3
)
<<
"PSGPUWrapper::BuildGPUPSTask begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
MultiSlotDataset
*
dataset
=
dynamic_cast
<
MultiSlotDataset
*>
(
dataset_
);
std
::
shared_ptr
<
HeterContext
>
gpu_task
=
gpu_task_pool_
.
Get
();
auto
input_channel
=
dataset
->
GetInputChannel
();
auto
&
local_keys
=
gpu_task
->
feature_keys_
;
auto
&
local_values
=
gpu_task
->
feature_values_
;
auto
&
local_ptr
=
gpu_task
->
value_ptr_
;
std
::
vector
<
std
::
thread
>
threads
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
// data should be in input channel
thread_keys_
.
resize
(
thread_keys_thread_num_
);
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
thread_keys_
[
i
].
resize
(
thread_keys_shard_num_
);
for
(
int
j
=
0
;
j
<
thread_keys_shard_num_
;
j
++
)
{
thread_keys_
[
i
][
j
].
reserve
(
2
*
max_fea_num_per_pass_
/
thread_keys_shard_num_
/
thread_keys_thread_num_
);
}
}
const
std
::
deque
<
Record
>&
vec_data
=
input_channel
->
GetData
();
size_t
total_len
=
vec_data
.
size
();
size_t
len_per_thread
=
total_len
/
thread_keys_thread_num_
;
int
remain
=
total_len
%
thread_keys_thread_num_
;
size_t
begin
=
0
;
auto
gen_func
=
[
this
](
const
std
::
deque
<
Record
>&
total_data
,
int
begin_index
,
int
end_index
,
int
i
)
{
for
(
auto
iter
=
total_data
.
begin
()
+
begin_index
;
iter
!=
total_data
.
begin
()
+
end_index
;
iter
++
)
{
const
auto
&
ins
=
*
iter
;
const
auto
&
feasign_v
=
ins
.
uint64_feasigns_
;
for
(
const
auto
feasign
:
feasign_v
)
{
uint64_t
cur_key
=
feasign
.
sign
().
uint64_feasign_
;
int
shard_id
=
cur_key
%
thread_keys_shard_num_
;
this
->
thread_keys_
[
i
][
shard_id
].
push_back
(
cur_key
);
}
}
};
for
(
int
i
=
0
;
i
<
thread_keys_thread_num_
;
i
++
)
{
threads
.
push_back
(
std
::
thread
(
gen_func
,
std
::
ref
(
vec_data
),
begin
,
begin
+
len_per_thread
+
(
i
<
remain
?
1
:
0
),
i
));
begin
+=
len_per_thread
+
(
i
<
remain
?
1
:
0
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs build task cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
timeline
.
Start
();
// merge thread_keys to shard_keys
gpu_task
->
init
();
for
(
size_t
i
=
0
;
i
<
thread_keys_
.
size
();
i
++
)
{
gpu_task
->
batch_add_keys
(
thread_keys_
[
i
]);
for
(
int
j
=
0
;
j
<
thread_keys_thread_num_
;
j
++
)
{
thread_keys_
[
i
][
j
].
clear
();
}
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs task unique11111 cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
VLOG
(
0
)
<<
"FK1"
;
timeline
.
Start
();
gpu_task
->
UniqueKeys
();
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs task unique cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
local_values
[
i
].
resize
(
local_keys
[
i
].
size
());
local_ptr
[
i
].
resize
(
local_keys
[
i
].
size
());
}
auto
ptl_func
=
[
this
,
&
local_keys
,
&
local_values
,
&
local_ptr
,
&
table_id
,
&
fleet_ptr
](
int
i
)
{
size_t
key_size
=
local_keys
[
i
].
size
();
auto
tt
=
fleet_ptr
->
pslib_ptr_
->
_worker_ptr
->
pull_sparse_ptr
(
reinterpret_cast
<
char
**>
(
local_ptr
[
i
].
data
()),
table_id
,
local_keys
[
i
].
data
(),
key_size
);
tt
.
wait
();
auto
status
=
tt
.
get
();
// auto status = 0;
if
(
status
!=
0
)
{
LOG
(
ERROR
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
300
);
exit
(
-
1
);
}
else
{
VLOG
(
3
)
<<
"FleetWrapper Pull sparse to local done with table size: "
<<
local_keys
[
i
].
size
();
}
for
(
size_t
num
=
0
;
num
<
local_ptr
[
i
].
size
();
++
num
)
{
float
*
ptr_val
=
local_ptr
[
i
][
num
]
->
data
();
FeatureValue
&
val
=
local_values
[
i
][
num
];
size_t
dim
=
local_ptr
[
i
][
num
]
->
size
();
val
.
delta_score
=
ptr_val
[
1
];
val
.
show
=
ptr_val
[
2
];
val
.
clk
=
ptr_val
[
3
];
val
.
slot
=
ptr_val
[
6
];
val
.
lr
=
ptr_val
[
4
];
val
.
lr_g2sum
=
ptr_val
[
5
];
if
(
dim
>
7
)
{
val
.
mf_size
=
MF_DIM
+
1
;
for
(
int
x
=
0
;
x
<
val
.
mf_size
;
x
++
)
{
val
.
mf
[
x
]
=
ptr_val
[
x
+
7
];
}
}
else
{
val
.
mf_size
=
0
;
for
(
int
x
=
0
;
x
<
MF_DIM
+
1
;
x
++
)
{
val
.
mf
[
x
]
=
0
;
}
}
}
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
ptl_func
,
i
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs pull sparse cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
}
void
PSGPUWrapper
::
BuildGPUPS
(
uint64_t
table_id
,
int
feature_dim
)
{
BuildTask
(
table_id
,
feature_dim
);
platform
::
Timer
timeline
;
timeline
.
Start
();
std
::
shared_ptr
<
HeterContext
>
gpu_task
=
gpu_task_pool_
.
Get
();
int
shard_num
=
gpu_task
->
feature_keys_
.
size
();
if
(
shard_num
==
0
)
{
return
;
...
...
@@ -62,13 +191,20 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim,
HeterPs_
->
show_one_table
(
0
);
return
;
}
std
::
vector
<
std
::
thread
>
threads
(
shard_num
);
HeterPs_
=
HeterPsBase
::
get_instance
(
size_max
,
resource_
);
for
(
int
i
=
0
;
i
<
shard_num
;
++
i
)
{
auto
build_func
=
[
this
,
&
gpu_task
,
&
feature_keys_count
](
int
i
)
{
std
::
cout
<<
"building table: "
<<
i
<<
std
::
endl
;
HeterPs_
->
build_ps
(
i
,
gpu_task
->
feature_keys_
[
i
].
data
(),
this
->
HeterPs_
->
build_ps
(
i
,
gpu_task
->
feature_keys_
[
i
].
data
(),
gpu_task
->
feature_values_
[
i
].
data
(),
feature_keys_count
[
i
],
10000
,
2
);
HeterPs_
->
show_one_table
(
i
);
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
build_func
,
i
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs build table total costs: "
<<
timeline
.
ElapsedSec
()
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
浏览文件 @
6e0da01c
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"
...
...
@@ -177,6 +178,42 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
slot_lengths
.
size
(),
total_length
,
batch_size
,
d_slot_vector
);
cudaStreamSynchronize
(
stream
);
}
void
PSGPUWrapper
::
SetSparseSGD
(
float
nonclk_coeff
,
float
clk_coeff
,
float
min_bound
,
float
max_bound
,
float
learning_rate
,
float
initial_g2sum
,
float
initial_range
)
{
cudaMemcpyToSymbol
(
optimizer_config
::
nonclk_coeff
,
&
nonclk_coeff
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
clk_coeff
,
&
clk_coeff
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
min_bound
,
&
min_bound
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
max_bound
,
&
max_bound
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
learning_rate
,
&
learning_rate
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
initial_g2sum
,
&
initial_g2sum
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
initial_range
,
&
initial_range
,
sizeof
(
float
));
}
void
PSGPUWrapper
::
SetEmbedxSGD
(
float
mf_create_thresholds
,
float
mf_learning_rate
,
float
mf_initial_g2sum
,
float
mf_initial_range
,
float
mf_min_bound
,
float
mf_max_bound
)
{
cudaMemcpyToSymbol
(
optimizer_config
::
mf_create_thresholds
,
&
mf_create_thresholds
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_learning_rate
,
&
mf_learning_rate
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_initial_g2sum
,
&
mf_initial_g2sum
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_initial_range
,
&
mf_initial_range
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_min_bound
,
&
mf_min_bound
,
sizeof
(
float
));
cudaMemcpyToSymbol
(
optimizer_config
::
mf_max_bound
,
&
mf_max_bound
,
sizeof
(
float
));
}
}
// end namespace framework
}
// end namespace paddle
#endif
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
浏览文件 @
6e0da01c
...
...
@@ -23,8 +23,10 @@ limitations under the License. */
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
...
...
@@ -73,16 +75,77 @@ class PSGPUWrapper {
const
int
hidden_size
,
const
int64_t
total_length
,
const
int
batch_size
);
void
BuildGPUPS
(
const
uint64_t
table_id
,
int
feature_dim
,
std
::
shared_ptr
<
HeterContext
>
context
);
void
BuildGPUPS
(
const
uint64_t
table_id
,
int
feature_dim
);
void
BuildTask
(
uint64_t
table_id
,
int
feature_dim
);
void
InitializeGPU
(
const
std
::
vector
<
int
>&
dev_ids
)
{
if
(
s_instance_
!=
NULL
)
{
VLOG
(
3
)
<<
"PSGPUWrapper Begin InitializeGPU"
;
resource_
=
std
::
make_shared
<
HeterPsResource
>
(
dev_ids
);
resource_
->
enable_p2p
();
keys_tensor
.
resize
(
resource_
->
total_gpu
());
heter_devices_
=
dev_ids
;
}
}
void
SetSparseSGD
(
float
nonclk_coeff
,
float
clk_coeff
,
float
min_bound
,
float
max_bound
,
float
learning_rate
,
float
initial_g2sum
,
float
initial_range
);
void
SetEmbedxSGD
(
float
mf_create_thresholds
,
float
mf_learning_rate
,
float
mf_initial_g2sum
,
float
mf_initial_range
,
float
mf_min_bound
,
float
mf_max_bound
);
void
InitializeGPUServer
(
std
::
unordered_map
<
std
::
string
,
float
>
config
)
{
float
nonclk_coeff
=
(
config
.
find
(
"nonclk_coeff"
)
==
config
.
end
())
?
1.0
:
config
[
"nonclk_coeff"
];
float
clk_coeff
=
(
config
.
find
(
"clk_coeff"
)
==
config
.
end
())
?
1.0
:
config
[
"clk_coeff"
];
float
min_bound
=
(
config
.
find
(
"min_bound"
)
==
config
.
end
())
?
-
10000.0
:
config
[
"min_bound"
];
float
max_bound
=
(
config
.
find
(
"max_bound"
)
==
config
.
end
())
?
10000.0
:
config
[
"max_bound"
];
float
learning_rate
=
(
config
.
find
(
"learning_rate"
)
==
config
.
end
())
?
1.0
:
config
[
"learning_rate"
];
float
initial_g2sum
=
(
config
.
find
(
"initial_g2sum"
)
==
config
.
end
())
?
1.0
:
config
[
"initial_g2sum"
];
float
initial_range
=
(
config
.
find
(
"initial_range"
)
==
config
.
end
())
?
1.0
:
config
[
"initial_range"
];
// mf config settings
float
mf_create_thresholds
=
(
config
.
find
(
"mf_create_thresholds"
)
==
config
.
end
())
?
static_cast
<
float
>
(
1.0
)
:
config
[
"mf_create_thresholds"
];
float
mf_learning_rate
=
(
config
.
find
(
"mf_learning_rate"
)
==
config
.
end
())
?
1.0
:
config
[
"mf_learning_rate"
];
float
mf_initial_g2sum
=
(
config
.
find
(
"mf_initial_g2sum"
)
==
config
.
end
())
?
1.0
:
config
[
"mf_initial_g2sum"
];
float
mf_initial_range
=
(
config
.
find
(
"mf_initial_range"
)
==
config
.
end
())
?
1.0
:
config
[
"mf_initial_range"
];
float
mf_min_bound
=
(
config
.
find
(
"mf_min_bound"
)
==
config
.
end
())
?
1.0
:
config
[
"mf_min_bound"
];
float
mf_max_bound
=
(
config
.
find
(
"mf_max_bound"
)
==
config
.
end
())
?
1.0
:
config
[
"mf_max_bound"
];
for
(
size_t
i
=
0
;
i
<
heter_devices_
.
size
();
i
++
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaSetDevice
(
heter_devices_
[
i
]));
this
->
SetSparseSGD
(
nonclk_coeff
,
clk_coeff
,
min_bound
,
max_bound
,
learning_rate
,
initial_g2sum
,
initial_range
);
this
->
SetEmbedxSGD
(
mf_create_thresholds
,
mf_learning_rate
,
mf_initial_g2sum
,
mf_initial_range
,
mf_min_bound
,
mf_max_bound
);
}
}
void
SetDataset
(
Dataset
*
dataset
)
{
dataset_
=
dataset
;
}
// PSGPUWrapper singleton
static
std
::
shared_ptr
<
PSGPUWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
...
...
@@ -100,6 +163,7 @@ class PSGPUWrapper {
private:
static
std
::
shared_ptr
<
PSGPUWrapper
>
s_instance_
;
Dataset
*
dataset_
;
std
::
unordered_map
<
uint64_t
,
std
::
vector
<
std
::
unordered_map
<
uint64_t
,
std
::
vector
<
float
>>>>
local_tables_
;
...
...
@@ -108,6 +172,13 @@ class PSGPUWrapper {
std
::
shared_ptr
<
HeterPsResource
>
resource_
;
int32_t
sleep_seconds_before_fail_exit_
;
std
::
vector
<
int
>
slot_vector_
;
std
::
vector
<
int
>
heter_devices_
;
std
::
unordered_set
<
std
::
string
>
gpu_ps_config_keys_
;
HeterObjectPool
<
HeterContext
>
gpu_task_pool_
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
thread_keys_
;
int
thread_keys_thread_num_
=
37
;
int
thread_keys_shard_num_
=
37
;
uint64_t
max_fea_num_per_pass_
=
5000000000
;
protected:
static
bool
is_initialized_
;
...
...
paddle/fluid/framework/ps_gpu_trainer.cc
浏览文件 @
6e0da01c
...
...
@@ -131,219 +131,11 @@ void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) {
}
void
PSGPUTrainer
::
Run
()
{
BuildGPUPSTask
(
0
,
8
);
for
(
size_t
thidx
=
0
;
thidx
<
places_
.
size
();
++
thidx
)
{
threads_
.
push_back
(
std
::
thread
(
&
DeviceWorker
::
TrainFiles
,
workers_
[
thidx
].
get
()));
}
}
void
PSGPUTrainer
::
BuildGPUPSTask
(
int
table_id
,
int
feadim
)
{
VLOG
(
3
)
<<
"PSGPUTrainer::BuildGPUPSTask begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
MultiSlotDataset
*
dataset
=
dynamic_cast
<
MultiSlotDataset
*>
(
dataset_
);
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
shared_ptr
<
HeterContext
>
heter_context
=
std
::
make_shared
<
HeterContext
>
();
auto
&
multi_output_channel
=
dataset
->
GetCurOutputChannel
();
auto
&
input_channel
=
dataset
->
GetInputChannelRef
();
int
gen_shard_num
=
multi_output_channel
.
size
();
int
device_num
=
places_
.
size
();
auto
gpu_ps_wrapper
=
PSGPUWrapper
::
GetInstance
();
auto
&
local_keys
=
heter_context
->
feature_keys_
;
local_keys
.
resize
(
device_num
);
auto
&
local_values
=
heter_context
->
feature_values_
;
local_values
.
resize
(
device_num
);
auto
&
local_ptr
=
heter_context
->
value_ptr_
;
local_ptr
.
resize
(
device_num
);
for
(
auto
&
ks
:
local_keys
)
{
ks
.
reserve
(
100000
);
}
// read thread
std
::
vector
<
std
::
thread
>
threads
(
gen_shard_num
);
std
::
vector
<
std
::
shared_ptr
<
ThreadPool
>>
consume_task_pool
(
device_num
);
for
(
size_t
i
=
0
;
i
<
consume_task_pool
.
size
();
i
++
)
{
consume_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
auto
consume_func
=
[
&
local_keys
](
int
shard_id
,
int
feadim
,
std
::
vector
<
uint64_t
>&
keys
)
{
local_keys
[
shard_id
].
insert
(
local_keys
[
shard_id
].
end
(),
keys
.
begin
(),
keys
.
end
());
};
if
(
input_channel
->
Size
()
==
0
)
{
// output_channel_ should hold one pass instances now
uint64_t
output_channels_data_size
=
0
;
for
(
size_t
i
=
0
;
i
<
multi_output_channel
.
size
();
i
++
)
{
int
cur_channel_size
=
multi_output_channel
[
i
]
->
Size
();
output_channels_data_size
+=
cur_channel_size
;
}
CHECK
(
output_channels_data_size
>
0
);
for
(
auto
&
ks
:
local_keys
)
{
ks
.
reserve
(
output_channels_data_size
*
10
);
// magic number
}
auto
gen_func
=
[
&
dataset
,
&
device_num
,
&
feadim
,
&
consume_task_pool
,
&
multi_output_channel
,
&
consume_func
](
int
i
)
{
const
std
::
deque
<
Record
>&
vec_data
=
multi_output_channel
[
i
]
->
GetData
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
task_keys
(
device_num
);
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
for
(
size_t
j
=
0
;
j
<
vec_data
.
size
();
j
++
)
{
for
(
auto
&
feature
:
vec_data
[
j
].
uint64_feasigns_
)
{
int
shard
=
feature
.
sign
().
uint64_feasign_
%
device_num
;
task_keys
[
shard
].
push_back
(
feature
.
sign
().
uint64_feasign_
);
}
}
for
(
int
shard_id
=
0
;
shard_id
<
device_num
;
shard_id
++
)
{
task_futures
.
emplace_back
(
consume_task_pool
[
shard_id
]
->
enqueue
(
consume_func
,
shard_id
,
feadim
,
task_keys
[
shard_id
]));
}
for
(
auto
&
tf
:
task_futures
)
{
tf
.
wait
();
}
for
(
auto
&
tk
:
task_keys
)
{
tk
.
clear
();
std
::
vector
<
uint64_t
>
().
swap
(
tk
);
}
task_keys
.
clear
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
().
swap
(
task_keys
);
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
gen_func
,
i
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
}
else
{
int
input_channel_size
=
input_channel
->
Size
();
CHECK
(
input_channel_size
>
0
);
CHECK
(
gen_shard_num
>
0
);
for
(
auto
&
ks
:
local_keys
)
{
ks
.
reserve
(
input_channel_size
*
10
);
// magic number
}
const
std
::
deque
<
Record
>&
vec_data
=
input_channel
->
GetData
();
auto
gen_func
=
[
&
dataset
,
&
vec_data
,
&
device_num
,
&
gen_shard_num
,
&
input_channel_size
,
&
feadim
,
&
consume_task_pool
,
multi_output_channel
,
&
consume_func
](
int
i
)
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
task_keys
(
device_num
);
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
size_t
per_shard_num
=
input_channel_size
/
gen_shard_num
+
1
;
size_t
total_size
=
vec_data
.
size
();
size_t
start_index
=
i
*
per_shard_num
;
size_t
end_index
=
std
::
min
(
start_index
+
per_shard_num
-
1
,
total_size
-
1
);
for
(
size_t
j
=
start_index
;
j
<=
end_index
;
j
++
)
{
for
(
auto
&
feature
:
vec_data
[
j
].
uint64_feasigns_
)
{
int
shard
=
feature
.
sign
().
uint64_feasign_
%
device_num
;
task_keys
[
shard
].
push_back
(
feature
.
sign
().
uint64_feasign_
);
}
}
for
(
int
shard_id
=
0
;
shard_id
<
device_num
;
shard_id
++
)
{
task_futures
.
emplace_back
(
consume_task_pool
[
shard_id
]
->
enqueue
(
consume_func
,
shard_id
,
feadim
,
task_keys
[
shard_id
]));
}
for
(
auto
&
tf
:
task_futures
)
{
tf
.
wait
();
}
for
(
auto
&
tk
:
task_keys
)
{
tk
.
clear
();
std
::
vector
<
uint64_t
>
().
swap
(
tk
);
}
task_keys
.
clear
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
().
swap
(
task_keys
);
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
gen_func
,
i
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs build task cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
timeline
.
Start
();
auto
unique_func
=
[
&
local_keys
](
int
i
)
{
auto
&
cur_keys
=
local_keys
[
i
];
std
::
sort
(
cur_keys
.
begin
(),
cur_keys
.
end
());
cur_keys
.
erase
(
std
::
unique
(
cur_keys
.
begin
(),
cur_keys
.
end
()),
cur_keys
.
end
());
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
unique_func
,
i
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs task unique cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
timeline
.
Start
();
for
(
size_t
i
=
0
;
i
<
consume_task_pool
.
size
();
i
++
)
{
consume_task_pool
[
i
].
reset
();
}
consume_task_pool
.
clear
();
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
local_values
[
i
].
resize
(
local_keys
[
i
].
size
());
local_ptr
[
i
].
resize
(
local_keys
[
i
].
size
());
}
auto
ptl_func
=
[
this
,
&
local_keys
,
&
local_values
,
&
local_ptr
,
&
table_id
,
&
fleet_ptr
](
int
i
)
{
size_t
key_size
=
local_keys
[
i
].
size
();
auto
tt
=
fleet_ptr
->
pslib_ptr_
->
_worker_ptr
->
pull_sparse_ptr
(
(
char
**
)(
local_ptr
[
i
].
data
()),
table_id
,
local_keys
[
i
].
data
(),
key_size
);
tt
.
wait
();
auto
status
=
tt
.
get
();
// auto status = 0;
if
(
status
!=
0
)
{
LOG
(
ERROR
)
<<
"fleet pull sparse failed, status["
<<
status
<<
"]"
;
sleep
(
300
);
exit
(
-
1
);
}
else
{
VLOG
(
3
)
<<
"FleetWrapper Pull sparse to local done with table size: "
<<
local_keys
[
i
].
size
();
}
for
(
size_t
num
=
0
;
num
<
local_ptr
[
i
].
size
();
++
num
)
{
float
*
ptr_val
=
local_ptr
[
i
][
num
]
->
data
();
FeatureValue
&
val
=
local_values
[
i
][
num
];
size_t
dim
=
local_ptr
[
i
][
num
]
->
size
();
val
.
delta_score
=
ptr_val
[
1
];
val
.
show
=
ptr_val
[
2
];
val
.
clk
=
ptr_val
[
3
];
val
.
slot
=
ptr_val
[
6
];
val
.
lr
=
ptr_val
[
4
];
val
.
lr_g2sum
=
ptr_val
[
5
];
if
(
dim
>
7
)
{
val
.
mf_size
=
MF_DIM
+
1
;
for
(
int
x
=
0
;
x
<
val
.
mf_size
;
x
++
)
{
val
.
mf
[
x
]
=
ptr_val
[
x
+
7
];
}
}
else
{
val
.
mf_size
=
0
;
for
(
int
x
=
0
;
x
<
MF_DIM
+
1
;
x
++
)
{
val
.
mf
[
x
]
=
0
;
}
}
}
};
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
i
++
)
{
threads
[
i
]
=
std
::
thread
(
ptl_func
,
i
);
}
for
(
std
::
thread
&
t
:
threads
)
{
t
.
join
();
}
timeline
.
Pause
();
VLOG
(
0
)
<<
"GpuPs pull sparse cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
gpu_ps_wrapper
->
BuildGPUPS
(
table_id
,
feadim
,
heter_context
);
}
Scope
*
PSGPUTrainer
::
GetWorkerScope
(
int
thread_id
)
{
return
nullptr
;
}
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
6e0da01c
...
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
@@ -296,13 +297,6 @@ class PSGPUTrainer : public TrainerBase {
}
virtual
std
::
string
GetDumpPath
(
int
tid
)
{
return
""
;
}
virtual
void
InitDumpEnv
()
{}
void
BuildGPUPSTask
(
int
table_id
,
int
feadim
);
/*
template <typename T>
void HeterMemCpy(LoDTensor* tensor, LoDTensor* root_tensor,
const paddle::platform::Place& thread_place,
cudaStream_t stream);
*/
template
<
typename
T
>
void
MergeToRootScope
(
LoDTensor
*
root_tensor
,
LoDTensor
*
thread_tensor
);
...
...
paddle/fluid/pybind/ps_gpu_wrapper_py.cc
浏览文件 @
6e0da01c
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
...
...
@@ -37,6 +38,10 @@ void BindPSGPUWrapper(py::module* m) {
*
m
,
"PSGPU"
)
.
def
(
py
::
init
([]()
{
return
framework
::
PSGPUWrapper
::
GetInstance
();
}))
.
def
(
"set_slot_vector"
,
&
framework
::
PSGPUWrapper
::
SetSlotVector
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"init_GPU_server"
,
&
framework
::
PSGPUWrapper
::
InitializeGPUServer
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"build_gpu_ps"
,
&
framework
::
PSGPUWrapper
::
BuildGPUPS
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
// end PSGPUWrapper
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录