Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bb0713b2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bb0713b2
编写于
12月 20, 2021
作者:
石
石晓伟
提交者:
GitHub
12月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
changes the call AllocShared to Alloc, test=develop (#38258)
上级
2635cc86
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
89 addition
and
87 deletion
+89
-87
paddle/fluid/framework/fleet/box_wrapper.cu
paddle/fluid/framework/fleet/box_wrapper.cu
+4
-5
paddle/fluid/framework/fleet/box_wrapper_impl.h
paddle/fluid/framework/fleet/box_wrapper_impl.h
+4
-4
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
+11
-10
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+58
-54
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+4
-5
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
+4
-5
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+2
-2
paddle/pten/kernels/hybird/transpose.cu
paddle/pten/kernels/hybird/transpose.cu
+2
-2
未找到文件。
paddle/fluid/framework/fleet/box_wrapper.cu
浏览文件 @
bb0713b2
...
...
@@ -140,7 +140,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
platform
::
DeviceContextPool
::
Instance
().
Get
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
)))
->
stream
();
auto
buf_value
=
memory
::
Alloc
Shared
(
place
,
values
.
size
()
*
sizeof
(
float
*
));
auto
buf_value
=
memory
::
Alloc
(
place
,
values
.
size
()
*
sizeof
(
float
*
));
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_value
->
ptr
());
#ifdef PADDLE_WITH_HIP
hipMemcpy
(
gpu_values
,
values
.
data
(),
values
.
size
()
*
sizeof
(
float
*
),
...
...
@@ -233,11 +233,10 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_grad_value
=
memory
::
AllocShared
(
place
,
grad_values
.
size
()
*
sizeof
(
float
*
));
auto
buf_length
=
memory
::
AllocShared
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
memory
::
Alloc
(
place
,
grad_values
.
size
()
*
sizeof
(
float
*
));
auto
buf_length
=
memory
::
Alloc
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
auto
buf_slot_vector
=
memory
::
Alloc
Shared
(
place
,
slot_lengths_lod
.
size
()
*
sizeof
(
int
));
memory
::
Alloc
(
place
,
slot_lengths_lod
.
size
()
*
sizeof
(
int
));
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_grad_value
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
...
...
paddle/fluid/framework/fleet/box_wrapper_impl.h
浏览文件 @
bb0713b2
...
...
@@ -32,7 +32,7 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place,
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
Alloc
Shared
(
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
sizeof
(
boxps
::
FeatureValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>
));
boxps
::
FeatureValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>*
total_values_gpu
=
...
...
@@ -55,9 +55,9 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place,
for
(
size_t
i
=
1
;
i
<
slot_lengths_lod
.
size
();
i
++
)
{
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_key
=
memory
::
Alloc
Shared
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_key
=
memory
::
Alloc
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_length
=
memory
::
Alloc
Shared
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
memory
::
Alloc
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
uint64_t
**
gpu_keys
=
reinterpret_cast
<
uint64_t
**>
(
buf_key
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
#ifdef PADDLE_WITH_HIP
...
...
@@ -118,7 +118,7 @@ void BoxWrapper::PushSparseGradCase(
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
Alloc
Shared
(
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
sizeof
(
boxps
::
FeaturePushValueGpu
<
EMBEDX_DIM
,
EXPAND_EMBED_DIM
>
));
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
浏览文件 @
bb0713b2
...
...
@@ -17,9 +17,10 @@ limitations under the License. */
#include <vector>
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#include "hashtable.h"
#include "heter_resource.h"
#include "hashtable.h"
// NOLINT
#include "heter_resource.h"
// NOLINT
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
...
...
@@ -58,7 +59,7 @@ class HeterComm {
void
split_input_to_shard
(
KeyType
*
d_keys
,
int
*
d_idx_ptr
,
size_t
len
,
int
*
left
,
int
*
right
,
int
gpu_num
);
void
merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
);
int
&
uniq_len
);
// NOLINT
void
pull_sparse
(
int
num
,
KeyType
*
d_keys
,
ValType
*
d_vals
,
size_t
len
);
void
build_ps
(
int
num
,
KeyType
*
h_keys
,
ValType
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
int
stream_num
);
...
...
@@ -68,15 +69,15 @@ class HeterComm {
template
<
typename
Sgd
>
void
push_sparse
(
int
num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
);
Sgd
&
sgd
);
// NOLINT
template
<
typename
Sgd
>
void
push_sparse_multi_node
(
int
num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
);
size_t
len
,
Sgd
&
sgd
);
// NOLINT
template
<
typename
Sgd
>
void
update_one_table
(
int
num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
);
Sgd
&
sgd
);
// NOLINT
int
gather_one_node_grad
(
int
num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
int
len
);
...
...
@@ -136,16 +137,16 @@ class HeterComm {
if
(
force
||
size
>
all_keys_mem
->
size
())
{
all_keys_mem
.
reset
();
all_grads_mem
.
reset
();
all_keys_mem
=
memory
::
Alloc
Shared
(
place_
,
size
*
sizeof
(
KeyType
));
all_grads_mem
=
memory
::
Alloc
Shared
(
place_
,
size
*
sizeof
(
GradType
));
all_keys_mem
=
memory
::
Alloc
(
place_
,
size
*
sizeof
(
KeyType
));
all_grads_mem
=
memory
::
Alloc
(
place_
,
size
*
sizeof
(
GradType
));
all_keys
=
reinterpret_cast
<
KeyType
*>
(
all_keys_mem
->
ptr
());
all_grads
=
reinterpret_cast
<
GradType
*>
(
all_grads_mem
->
ptr
());
}
if
(
force
||
size
>
local_keys_mem
->
size
())
{
local_keys_mem
.
reset
();
local_grads_mem
.
reset
();
local_keys_mem
=
memory
::
Alloc
Shared
(
place_
,
size
*
sizeof
(
KeyType
));
local_grads_mem
=
memory
::
Alloc
Shared
(
place_
,
size
*
sizeof
(
GradType
));
local_keys_mem
=
memory
::
Alloc
(
place_
,
size
*
sizeof
(
KeyType
));
local_grads_mem
=
memory
::
Alloc
(
place_
,
size
*
sizeof
(
GradType
));
local_keys
=
reinterpret_cast
<
KeyType
*>
(
local_keys_mem
->
ptr
());
local_grads
=
reinterpret_cast
<
GradType
*>
(
local_grads_mem
->
ptr
());
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
浏览文件 @
bb0713b2
...
...
@@ -28,7 +28,7 @@ __global__ void fill_idx(T* idx, size_t len) {
template
<
typename
T
>
void
show_tensor
(
T
*
input
,
size_t
len
,
gpuStream_t
stream
,
std
::
string
name
)
{
T
tmp
[
len
];
T
tmp
[
len
];
// NOLINT
cudaMemcpyAsync
(
&
tmp
,
input
,
sizeof
(
T
)
*
len
,
cudaMemcpyDeviceToHost
,
stream
);
cudaStreamSynchronize
(
stream
);
std
::
cout
<<
name
;
...
...
@@ -101,7 +101,7 @@ HeterComm<KeyType, ValType, GradType>::HeterComm(
for
(
int
i
=
0
;
i
<
resource_
->
total_gpu
();
++
i
)
{
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
allocators_
.
push_back
(
std
::
make_shared
<
cub
::
CachingDeviceAllocator
>
(
8
,
1
,
(
unsigned
int
)
-
1
,
(
size_t
)
-
1
,
false
,
false
));
8
,
1
,
(
unsigned
int
)
-
1
,
(
size_t
)
-
1
,
false
,
false
));
// NOLINT
auto
table
=
new
Table
(
capacity
/
load_factor_
);
tables_
.
push_back
(
table
);
if
(
multi_node_
)
{
...
...
@@ -174,10 +174,12 @@ void HeterComm<KeyType, ValType, GradType>::create_storage(int start_index,
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
));
allocator
->
DeviceAllocate
(
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
),
(
void
**
)
&
(
nodes
[
i
].
key_storage
),
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
),
(
void
**
)
&
(
nodes
[
i
].
key_storage
),
// NOLINT
keylen
,
resource_
->
remote_stream
(
nodes
[
i
].
gpu_num
,
start_index
));
allocator
->
DeviceAllocate
(
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
),
(
void
**
)
&
(
nodes
[
i
].
val_storage
),
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
),
(
void
**
)
&
(
nodes
[
i
].
val_storage
),
// NOLINT
vallen
,
resource_
->
remote_stream
(
nodes
[
i
].
gpu_num
,
start_index
));
nodes
[
i
].
key_bytes_len
=
keylen
;
...
...
@@ -342,16 +344,16 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>
>
d_key_bufs
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>
>
d_val_bufs
;
std
::
vector
<
memory
::
allocation
::
AllocationPtr
>
d_key_bufs
;
std
::
vector
<
memory
::
allocation
::
AllocationPtr
>
d_val_bufs
;
gpuStream_t
streams
[
stream_num
];
gpuStream_t
streams
[
stream_num
];
// NOLINT
for
(
int
i
=
0
;
i
<
stream_num
;
++
i
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamCreate
(
&
(
streams
[
i
])));
auto
d_k_buf
=
memory
::
Alloc
Shared
(
place
,
chunk_size
*
sizeof
(
KeyType
));
auto
d_v_buf
=
memory
::
Alloc
Shared
(
place
,
chunk_size
*
sizeof
(
ValType
));
d_key_bufs
.
push_back
(
d_k_buf
);
d_val_bufs
.
push_back
(
d_v_buf
);
auto
d_k_buf
=
memory
::
Alloc
(
place
,
chunk_size
*
sizeof
(
KeyType
));
auto
d_v_buf
=
memory
::
Alloc
(
place
,
chunk_size
*
sizeof
(
ValType
));
d_key_bufs
.
push_back
(
std
::
move
(
d_k_buf
)
);
d_val_bufs
.
push_back
(
std
::
move
(
d_v_buf
)
);
}
int
cur_len
=
0
;
...
...
@@ -383,11 +385,9 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
)
{
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
)
{
// NOLINT
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
...
...
@@ -395,10 +395,10 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
size_t
temp_storage_bytes
;
auto
d_merge_keys
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
KeyType
));
auto
d_merge_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_merge_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_merge_keys
->
ptr
());
auto
d_merge_grads
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
GradType
));
auto
d_merge_grads
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
GradType
));
GradType
*
d_merge_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_merge_grads
->
ptr
());
...
...
@@ -407,14 +407,14 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
d_merge_grads_ptr
,
len
,
0
,
8
*
sizeof
(
KeyType
),
stream
,
false
));
void
*
d_buff
=
NULL
;
auto
d_temp_storage
=
memory
::
Alloc
Shared
(
place
,
temp_storage_bytes
);
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceRadixSort
::
SortPairs
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_keys
,
d_merge_keys_ptr
,
d_grads
,
d_merge_grads_ptr
,
len
,
0
,
8
*
sizeof
(
KeyType
),
stream
,
false
));
temp_storage_bytes
=
0
;
auto
d_num_runs_out_mem
=
memory
::
Alloc
Shared
(
place
,
sizeof
(
int
));
auto
d_num_runs_out_mem
=
memory
::
Alloc
(
place
,
sizeof
(
int
));
int
*
d_num_runs_out
=
reinterpret_cast
<
int
*>
(
d_num_runs_out_mem
->
ptr
());
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceReduce
::
ReduceByKey
(
...
...
@@ -423,7 +423,7 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
if
(
d_temp_storage
->
size
()
<
temp_storage_bytes
)
{
d_temp_storage
=
NULL
;
d_temp_storage
=
memory
::
Alloc
Shared
(
place
,
temp_storage_bytes
);
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceReduce
::
ReduceByKey
(
...
...
@@ -445,13 +445,13 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
auto
d_idx_tmp
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
int
));
auto
d_idx_tmp
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_idx_tmp_ptr
=
reinterpret_cast
<
int
*>
(
d_idx_tmp
->
ptr
());
auto
d_shard_index
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
int
));
auto
d_shard_index
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_shard_index_ptr
=
reinterpret_cast
<
int
*>
(
d_shard_index
->
ptr
());
auto
d_shard_index_tmp
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
int
));
auto
d_shard_index_tmp
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_shard_index_tmp_ptr
=
reinterpret_cast
<
int
*>
(
d_shard_index_tmp
->
ptr
());
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
...
...
@@ -465,7 +465,7 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
NULL
,
temp_storage_bytes
,
d_shard_index_tmp_ptr
,
d_shard_index_ptr
,
d_idx_tmp_ptr
,
d_idx_ptr
,
len
,
0
,
num_bits
,
stream
));
auto
d_temp_storage
=
memory
::
Alloc
Shared
(
place
,
temp_storage_bytes
);
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceRadixSort
::
SortPairs
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
d_shard_index_tmp_ptr
,
d_shard_index_ptr
,
d_idx_tmp_ptr
,
d_idx_ptr
,
len
,
0
,
num_bits
,
stream
));
...
...
@@ -491,23 +491,23 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
int
h_left
[
total_gpu
];
int
h_right
[
total_gpu
];
int
h_left
[
total_gpu
];
// NOLINT
int
h_right
[
total_gpu
];
// NOLINT
auto
d_left
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_right
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_left
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_right
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
int
*
d_left_ptr
=
reinterpret_cast
<
int
*>
(
d_left
->
ptr
());
int
*
d_right_ptr
=
reinterpret_cast
<
int
*>
(
d_right
->
ptr
());
cudaMemsetAsync
(
d_left_ptr
,
-
1
,
total_gpu
*
sizeof
(
int
),
stream
);
cudaMemsetAsync
(
d_right_ptr
,
-
1
,
total_gpu
*
sizeof
(
int
),
stream
);
//
auto
d_idx
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
int
));
auto
d_idx
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_idx_ptr
=
reinterpret_cast
<
int
*>
(
d_idx
->
ptr
());
auto
d_shard_keys
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
KeyType
));
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_shard_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_shard_keys
->
ptr
());
auto
d_shard_vals
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
ValType
));
auto
d_shard_vals
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
ValType
));
ValType
*
d_shard_vals_ptr
=
reinterpret_cast
<
ValType
*>
(
d_shard_vals
->
ptr
());
split_input_to_shard
(
d_keys
,
d_idx_ptr
,
len
,
d_left_ptr
,
d_right_ptr
,
num
);
...
...
@@ -574,7 +574,8 @@ template <typename Sgd>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
push_sparse
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
)
{
size_t
len
,
Sgd
&
sgd
)
{
// NOLINT
if
(
len
==
0
)
{
return
;
}
...
...
@@ -585,23 +586,23 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
int
h_left
[
total_gpu
];
int
h_right
[
total_gpu
];
int
h_left
[
total_gpu
];
// NOLINT
int
h_right
[
total_gpu
];
// NOLINT
auto
d_left
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_right
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_left
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_right
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
int
*
d_left_ptr
=
reinterpret_cast
<
int
*>
(
d_left
->
ptr
());
int
*
d_right_ptr
=
reinterpret_cast
<
int
*>
(
d_right
->
ptr
());
cudaMemsetAsync
(
d_left_ptr
,
-
1
,
total_gpu
*
sizeof
(
int
),
stream
);
cudaMemsetAsync
(
d_right_ptr
,
-
1
,
total_gpu
*
sizeof
(
int
),
stream
);
//
auto
d_idx
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
int
));
auto
d_idx
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
int
));
int
*
d_idx_ptr
=
reinterpret_cast
<
int
*>
(
d_idx
->
ptr
());
auto
d_shard_keys
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
KeyType
));
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_shard_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_shard_keys
->
ptr
());
auto
d_shard_grads
=
memory
::
Alloc
Shared
(
place
,
len
*
sizeof
(
GradType
));
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
GradType
));
GradType
*
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
...
...
@@ -664,7 +665,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
Sgd
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
update_one_table
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
)
{
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
)
{
// NOLINT
if
(
len
==
0
)
{
return
;
}
...
...
@@ -681,7 +683,8 @@ void HeterComm<KeyType, ValType, GradType>::update_one_table(
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
Sgd
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
push_sparse_multi_node
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
)
{
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
size_t
len
,
Sgd
&
sgd
)
{
// NOLINT
if
(
len
==
0
)
{
return
;
}
...
...
@@ -711,8 +714,8 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
ncclComm_t
nccl_inner_comm
=
nccl_inner_comms_
[
gpu_num
];
// alloc for size
int
h_node_len
[
total_gpu
];
auto
d_node_len_mem
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
int
h_node_len
[
total_gpu
];
// NOLINT
auto
d_node_len_mem
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
int
*
d_node_len
=
reinterpret_cast
<
int
*>
(
d_node_len_mem
->
ptr
());
h_node_len
[
gpu_num
]
=
len
;
...
...
@@ -721,9 +724,10 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
// allgather grad len
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllGather
(
(
const
void
*
)(
d_node_len
+
gpu_num
),
(
void
*
)
d_node_len
,
1
,
ncclInt
,
nccl_inner_comm
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclAllGather
((
const
void
*
)(
d_node_len
+
gpu_num
),
(
void
*
)
d_node_len
,
1
,
ncclInt
,
// NOLINT
nccl_inner_comm
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
cudaMemcpy
(
h_node_len
,
d_node_len
,
sizeof
(
int
)
*
total_gpu
,
...
...
@@ -747,17 +751,17 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamSynchronize
(
stream
));
int
h_left
[
total_gpu
];
int
h_right
[
total_gpu
];
auto
d_left
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_right
=
memory
::
Alloc
Shared
(
place
,
total_gpu
*
sizeof
(
int
));
int
h_left
[
total_gpu
];
// NOLINT
int
h_right
[
total_gpu
];
// NOLINT
auto
d_left
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
auto
d_right
=
memory
::
Alloc
(
place
,
total_gpu
*
sizeof
(
int
));
int
*
d_left_ptr
=
reinterpret_cast
<
int
*>
(
d_left
->
ptr
());
int
*
d_right_ptr
=
reinterpret_cast
<
int
*>
(
d_right
->
ptr
());
int
merge_num
=
0
;
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
int
index
=
i
*
max_size
;
auto
d_idx
=
memory
::
Alloc
Shared
(
place
,
h_node_len
[
i
]
*
sizeof
(
int
));
auto
d_idx
=
memory
::
Alloc
(
place
,
h_node_len
[
i
]
*
sizeof
(
int
));
int
*
d_idx_ptr
=
reinterpret_cast
<
int
*>
(
d_idx
->
ptr
());
cudaMemset
(
d_left_ptr
,
-
1
,
total_gpu
*
sizeof
(
int
));
...
...
@@ -794,8 +798,8 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
int
max_size
=
0
;
ncclComm_t
nccl_inter_comm
=
nccl_inter_comms_
[
gpu_num
];
// alloc for size
int
h_node_len
[
node_size_
];
auto
d_node_len_mem
=
memory
::
Alloc
Shared
(
place
,
node_size_
*
sizeof
(
int
));
int
h_node_len
[
node_size_
];
// NOLINT
auto
d_node_len_mem
=
memory
::
Alloc
(
place
,
node_size_
*
sizeof
(
int
));
int
*
d_node_len
=
reinterpret_cast
<
int
*>
(
d_node_len_mem
->
ptr
());
h_node_len
[
0
]
=
len
;
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
bb0713b2
...
...
@@ -592,7 +592,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
Alloc
Shared
(
place
,
total_length
*
sizeof
(
FeatureValue
));
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
sizeof
(
FeatureValue
));
FeatureValue
*
total_values_gpu
=
reinterpret_cast
<
FeatureValue
*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
...
@@ -610,9 +610,9 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
for
(
size_t
i
=
1
;
i
<
slot_lengths_lod
.
size
();
i
++
)
{
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_key
=
memory
::
Alloc
Shared
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_key
=
memory
::
Alloc
(
place
,
keys
.
size
()
*
sizeof
(
uint64_t
*
));
auto
buf_length
=
memory
::
Alloc
Shared
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
memory
::
Alloc
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
uint64_t
**
gpu_keys
=
reinterpret_cast
<
uint64_t
**>
(
buf_key
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
cudaMemcpy
(
gpu_keys
,
keys
.
data
(),
keys
.
size
()
*
sizeof
(
uint64_t
*
),
...
...
@@ -660,8 +660,7 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
all_timer
.
Start
();
int64_t
total_length
=
std
::
accumulate
(
slot_lengths
.
begin
(),
slot_lengths
.
end
(),
0UL
);
auto
buf
=
memory
::
AllocShared
(
place
,
total_length
*
sizeof
(
FeaturePushValue
));
auto
buf
=
memory
::
Alloc
(
place
,
total_length
*
sizeof
(
FeaturePushValue
));
FeaturePushValue
*
total_grad_values_gpu
=
reinterpret_cast
<
FeaturePushValue
*>
(
buf
->
ptr
());
if
(
platform
::
is_cpu_place
(
place
))
{
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
浏览文件 @
bb0713b2
...
...
@@ -116,7 +116,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
platform
::
DeviceContextPool
::
Instance
().
Get
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
)))
->
stream
();
auto
buf_value
=
memory
::
Alloc
Shared
(
place
,
values
.
size
()
*
sizeof
(
float
*
));
auto
buf_value
=
memory
::
Alloc
(
place
,
values
.
size
()
*
sizeof
(
float
*
));
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_value
->
ptr
());
cudaMemcpy
(
gpu_values
,
values
.
data
(),
values
.
size
()
*
sizeof
(
float
*
),
cudaMemcpyHostToDevice
);
...
...
@@ -156,11 +156,10 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
slot_lengths_lod
[
i
]
+=
slot_lengths_lod
[
i
-
1
];
}
auto
buf_grad_value
=
memory
::
AllocShared
(
place
,
grad_values
.
size
()
*
sizeof
(
float
*
));
auto
buf_length
=
memory
::
AllocShared
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
memory
::
Alloc
(
place
,
grad_values
.
size
()
*
sizeof
(
float
*
));
auto
buf_length
=
memory
::
Alloc
(
place
,
slot_lengths
.
size
()
*
sizeof
(
int64_t
));
auto
buf_slot_vector
=
memory
::
Alloc
Shared
(
place
,
slot_lengths_lod
.
size
()
*
sizeof
(
int
));
memory
::
Alloc
(
place
,
slot_lengths_lod
.
size
()
*
sizeof
(
int
));
float
**
gpu_values
=
reinterpret_cast
<
float
**>
(
buf_grad_value
->
ptr
());
int64_t
*
gpu_len
=
reinterpret_cast
<
int64_t
*>
(
buf_length
->
ptr
());
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
bb0713b2
...
...
@@ -102,8 +102,8 @@ struct TransposeNormal<platform::CUDADeviceContext, T> {
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
());
platform
::
CPUPlace
cpu_place
=
platform
::
CPUPlace
();
size_t
size
=
3
*
rank
*
sizeof
(
int64_t
);
auto
cpu_buf_holder
=
memory
::
Alloc
Shared
(
cpu_place
,
size
);
auto
cuda_buf_holder
=
memory
::
Alloc
Shared
(
cuda_place
,
size
);
auto
cpu_buf_holder
=
memory
::
Alloc
(
cpu_place
,
size
);
auto
cuda_buf_holder
=
memory
::
Alloc
(
cuda_place
,
size
);
REINTERPRET
(
int64_t
,
cpu_buf
,
cpu_buf_holder
->
ptr
());
REINTERPRET
(
int64_t
,
cuda_buf
,
cuda_buf_holder
->
ptr
());
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
...
...
paddle/pten/kernels/hybird/transpose.cu
浏览文件 @
bb0713b2
...
...
@@ -69,8 +69,8 @@ struct TransposeNormal<CUDAContext, T> {
BOOST_GET_CONST
(
paddle
::
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
());
paddle
::
platform
::
CPUPlace
cpu_place
=
paddle
::
platform
::
CPUPlace
();
size_t
size
=
3
*
rank
*
sizeof
(
int64_t
);
auto
cpu_buf_holder
=
paddle
::
memory
::
Alloc
Shared
(
cpu_place
,
size
);
auto
cuda_buf_holder
=
paddle
::
memory
::
Alloc
Shared
(
cuda_place
,
size
);
auto
cpu_buf_holder
=
paddle
::
memory
::
Alloc
(
cpu_place
,
size
);
auto
cuda_buf_holder
=
paddle
::
memory
::
Alloc
(
cuda_place
,
size
);
REINTERPRET
(
int64_t
,
cpu_buf
,
cpu_buf_holder
->
ptr
());
REINTERPRET
(
int64_t
,
cuda_buf
,
cuda_buf_holder
->
ptr
());
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录