Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
effe2c11
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
effe2c11
编写于
3月 15, 2023
作者:
P
pangengzheng
提交者:
GitHub
3月 15, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speedup datafeed (#51624)
上级
6f86c96b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
284 addition
and
94 deletion
+284
-94
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+192
-56
paddle/fluid/framework/data_feed.cu
paddle/fluid/framework/data_feed.cu
+4
-11
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+79
-24
paddle/fluid/framework/fleet/heter_ps/feature_value.h
paddle/fluid/framework/fleet/heter_ps/feature_value.h
+9
-3
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
effe2c11
...
...
@@ -2001,6 +2001,20 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<Record*>& ins_vec) {
#endif
}
SlotRecordInMemoryDataFeed
::~
SlotRecordInMemoryDataFeed
()
{
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
stop_token_
.
store
(
true
);
for
(
auto
&
thread
:
pack_threads_
)
{
if
(
thread
.
joinable
())
{
thread
.
join
();
}
}
for
(
auto
*
pack
:
pack_vec_
)
{
pack
->
set_use_flag
(
false
);
}
#endif
}
template
class
InMemoryDataFeed
<
SlotRecord
>;
void
SlotRecordInMemoryDataFeed
::
Init
(
const
DataFeedDesc
&
data_feed_desc
)
{
finish_init_
=
false
;
...
...
@@ -2513,9 +2527,7 @@ void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec,
}
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
paddle
::
platform
::
SetDeviceId
(
place_
.
GetDeviceId
());
pack_
->
pack_instance
(
ins_vec
,
num
);
BuildSlotBatchGPU
(
pack_
->
ins_num
());
// do nothing
#else
for
(
int
j
=
0
;
j
<
use_slot_size_
;
++
j
)
{
auto
&
feed
=
feed_vec_
[
j
];
...
...
@@ -2658,7 +2670,7 @@ void SlotRecordInMemoryDataFeed::ExpandSlotRecord(SlotRecord* rec) {
}
bool
SlotRecordInMemoryDataFeed
::
Start
()
{
VLOG
(
4
)
<<
"entering SlotRecordInMemoryDataFeed::Start"
;
VLOG
(
3
)
<<
"entering SlotRecordInMemoryDataFeed::Start"
;
#ifdef _LINUX
this
->
CheckSetFileList
();
if
(
input_channel_
->
Size
()
!=
0
)
{
...
...
@@ -2674,7 +2686,40 @@ bool SlotRecordInMemoryDataFeed::Start() {
this
->
finish_start_
=
true
;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
CHECK
(
paddle
::
platform
::
is_gpu_place
(
this
->
place_
));
pack_
=
BatchGpuPackMgr
().
get
(
this
->
GetPlace
(),
used_slots_info_
);
for
(
int
i
=
0
;
i
<
pack_thread_num_
+
1
;
i
++
)
{
auto
pack
=
BatchGpuPackMgr
().
get
(
this
->
GetPlace
(),
used_slots_info_
);
pack_vec_
.
push_back
(
pack
);
free_pack_queue_
.
Push
(
pack
);
}
pack_offset_index_
.
store
(
0
);
pack_is_end_
.
store
(
false
);
thread_count_
.
store
(
pack_thread_num_
);
pack_threads_
.
reserve
(
pack_thread_num_
);
for
(
int
i
=
0
;
i
<
pack_thread_num_
;
i
++
)
{
pack_threads_
.
emplace_back
(
std
::
thread
([
this
]()
->
void
{
while
(
!
stop_token_
.
load
())
{
uint64_t
offset_index
=
pack_offset_index_
.
fetch_add
(
1
);
if
(
offset_index
>=
batch_offsets_
.
size
())
{
int
thread_num
=
thread_count_
.
fetch_sub
(
1
);
if
(
thread_num
==
1
)
{
pack_is_end_
.
store
(
true
);
}
return
;
}
auto
*
pack
=
free_pack_queue_
.
Pop
();
auto
&
batch
=
batch_offsets_
[
offset_index
];
auto
offset
=
batch
.
first
;
auto
batch_size
=
batch
.
second
;
paddle
::
platform
::
SetDeviceId
(
place_
.
GetDeviceId
());
pack
->
pack_instance
(
&
records_
[
offset
],
batch_size
);
this
->
BuildSlotBatchGPU
(
batch_size
,
pack
);
using_pack_queue_
.
Push
(
pack
);
}
}));
}
#endif
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
gpu_graph_data_generator_
.
SetFeedVec
(
feed_vec_
);
...
...
@@ -2686,6 +2731,27 @@ int SlotRecordInMemoryDataFeed::Next() {
#ifdef _LINUX
this
->
CheckStart
();
if
(
!
gpu_graph_mode_
)
{
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
while
(
true
)
{
if
(
last_pack_
!=
nullptr
)
{
free_pack_queue_
.
Push
(
last_pack_
);
last_pack_
=
nullptr
;
}
if
(
using_pack_queue_
.
Size
()
!=
0
)
{
auto
*
pack
=
using_pack_queue_
.
Pop
();
PackToScope
(
pack
);
last_pack_
=
pack
;
return
pack
->
ins_num
();
}
bool
is_end
=
pack_is_end_
.
load
();
if
(
is_end
)
{
if
(
using_pack_queue_
.
Size
()
==
0
)
{
return
0
;
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
microseconds
(
200
));
}
#else
VLOG
(
3
)
<<
"enable heter next: "
<<
offset_index_
<<
" batch_offsets: "
<<
batch_offsets_
.
size
();
if
(
offset_index_
>=
batch_offsets_
.
size
())
{
...
...
@@ -2703,9 +2769,7 @@ int SlotRecordInMemoryDataFeed::Next() {
VLOG
(
3
)
<<
"finish reading for heterps, batch size zero, thread_id="
<<
thread_id_
;
}
VLOG
(
3
)
<<
"enable heter next: "
<<
offset_index_
<<
" batch_offsets: "
<<
batch_offsets_
.
size
()
<<
" baych_size: "
<<
this
->
batch_size_
;
#endif
}
else
{
VLOG
(
3
)
<<
"datafeed in gpu graph mode"
;
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
...
...
@@ -2736,47 +2800,59 @@ void SlotRecordInMemoryDataFeed::DumpWalkPath(std::string dump_path,
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void
SlotRecordInMemoryDataFeed
::
BuildSlotBatchGPU
(
const
int
ins_num
)
{
void
SlotRecordInMemoryDataFeed
::
BuildSlotBatchGPU
(
const
int
ins_num
,
MiniBatchGpuPack
*
pack
)
{
int
offset_cols_size
=
(
ins_num
+
1
);
size_t
slot_total_num
=
(
use_slot_size_
*
offset_cols_size
);
pack
_
->
resize_gpu_slot_offsets
(
slot_total_num
*
sizeof
(
size_t
));
pack
->
resize_gpu_slot_offsets
(
slot_total_num
*
sizeof
(
size_t
));
auto
&
value
=
pack
_
->
value
();
auto
&
value
=
pack
->
value
();
const
UsedSlotGpuType
*
used_slot_gpu_types
=
static_cast
<
const
UsedSlotGpuType
*>
(
pack
_
->
get_gpu_slots
());
static_cast
<
const
UsedSlotGpuType
*>
(
pack
->
get_gpu_slots
());
FillSlotValueOffset
(
ins_num
,
use_slot_size_
,
reinterpret_cast
<
size_t
*>
(
pack
_
->
gpu_slot_offsets
()),
reinterpret_cast
<
size_t
*>
(
pack
->
gpu_slot_offsets
()),
value
.
d_uint64_offset
.
data
(),
uint64_use_slot_size_
,
value
.
d_float_offset
.
data
(),
float_use_slot_size_
,
used_slot_gpu_types
);
size_t
*
d_slot_offsets
=
reinterpret_cast
<
size_t
*>
(
pack_
->
gpu_slot_offsets
());
used_slot_gpu_types
,
pack
->
get_stream
());
size_t
*
d_slot_offsets
=
reinterpret_cast
<
size_t
*>
(
pack
->
gpu_slot_offsets
());
HostBuffer
<
size_t
>&
offsets
=
pack
_
->
offsets
();
HostBuffer
<
size_t
>&
offsets
=
pack
->
offsets
();
offsets
.
resize
(
slot_total_num
);
HostBuffer
<
void
*>&
h_tensor_ptrs
=
pack
_
->
h_tensor_ptrs
();
HostBuffer
<
void
*>&
h_tensor_ptrs
=
pack
->
h_tensor_ptrs
();
h_tensor_ptrs
.
resize
(
use_slot_size_
);
// alloc gpu memory
pack
_
->
resize_tensor
();
pack
->
resize_tensor
();
phi
::
DenseTensor
&
float_tensor
=
pack
_
->
float_tensor
();
phi
::
DenseTensor
&
uint64_tensor
=
pack
_
->
uint64_tensor
();
phi
::
DenseTensor
&
float_tensor
=
pack
->
float_tensor
();
phi
::
DenseTensor
&
uint64_tensor
=
pack
->
uint64_tensor
();
int64_t
float_offset
=
0
;
int64_t
uint64_offset
=
0
;
size_t
float_zero_slot_index
=
0
;
size_t
uint64_zero_slot_index
=
0
;
// copy index
CUDA_CHECK
(
cudaMemcpy
(
offsets
.
data
(),
d_slot_offsets
,
slot_total_num
*
sizeof
(
size_t
),
cudaMemcpyDeviceToHost
));
auto
*
dev_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
this
->
place_
));
for
(
int
j
=
0
;
j
<
use_slot_size_
;
++
j
)
{
auto
&
feed
=
feed_vec_
[
j
];
if
(
feed
==
nullptr
)
{
h_tensor_ptrs
[
j
]
=
nullptr
;
continue
;
if
(
scpoe_feed_vec_
.
size
()
>
0
)
{
if
(
scpoe_feed_vec_
.
begin
()
->
second
[
j
]
==
nullptr
)
{
h_tensor_ptrs
[
j
]
=
nullptr
;
continue
;
}
}
else
{
if
(
feed_vec_
[
j
]
==
nullptr
)
{
h_tensor_ptrs
[
j
]
=
nullptr
;
continue
;
}
}
size_t
*
off_start_ptr
=
&
offsets
[
j
*
offset_cols_size
];
...
...
@@ -2786,6 +2862,85 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
<<
"slot idx:"
<<
j
<<
", total instance:"
<<
total_instance
;
auto
&
info
=
used_slots_info_
[
j
];
// fill slot value with default value 0
if
(
info
.
type
[
0
]
==
'f'
)
{
// float
if
(
total_instance
>
0
)
{
h_tensor_ptrs
[
j
]
=
float_tensor
.
data
<
float
>
()
+
float_offset
;
float_offset
+=
total_instance
;
}
else
{
phi
::
DenseTensor
&
f_tensor
=
pack
->
float_tensor_vec
()[
float_zero_slot_index
];
f_tensor
.
Resize
({
total_instance
,
1
});
dev_ctx
->
Alloc
<
float
>
(
&
f_tensor
);
h_tensor_ptrs
[
j
]
=
f_tensor
.
data
<
float
>
();
float_zero_slot_index
++
;
}
}
else
if
(
info
.
type
[
0
]
==
'u'
)
{
// uint64
if
(
total_instance
>
0
)
{
h_tensor_ptrs
[
j
]
=
uint64_tensor
.
data
<
int64_t
>
()
+
uint64_offset
;
uint64_offset
+=
total_instance
;
}
else
{
phi
::
DenseTensor
&
i_tensor
=
pack
->
uint64_tensor_vec
()[
uint64_zero_slot_index
];
i_tensor
.
Resize
({
total_instance
,
1
});
dev_ctx
->
Alloc
<
int64_t
>
(
&
i_tensor
);
h_tensor_ptrs
[
j
]
=
i_tensor
.
data
<
int64_t
>
();
uint64_zero_slot_index
++
;
}
}
}
void
**
dest_gpu_p
=
reinterpret_cast
<
void
**>
(
pack
->
slot_buf_ptr
());
CUDA_CHECK
(
cudaMemcpyAsync
(
dest_gpu_p
,
h_tensor_ptrs
.
data
(),
use_slot_size_
*
sizeof
(
void
*
),
cudaMemcpyHostToDevice
,
pack
->
get_stream
()));
CopyForTensor
(
ins_num
,
use_slot_size_
,
dest_gpu_p
,
(
const
size_t
*
)
pack
->
gpu_slot_offsets
(),
(
const
uint64_t
*
)
value
.
d_uint64_keys
.
data
(),
(
const
int
*
)
value
.
d_uint64_offset
.
data
(),
(
const
int
*
)
value
.
d_uint64_lens
.
data
(),
uint64_use_slot_size_
,
(
const
float
*
)
value
.
d_float_keys
.
data
(),
(
const
int
*
)
value
.
d_float_offset
.
data
(),
(
const
int
*
)
value
.
d_float_lens
.
data
(),
float_use_slot_size_
,
used_slot_gpu_types
,
pack
->
get_stream
());
}
void
SlotRecordInMemoryDataFeed
::
PackToScope
(
MiniBatchGpuPack
*
pack
,
const
Scope
*
scope
)
{
int64_t
float_offset
=
0
;
int64_t
uint64_offset
=
0
;
size_t
float_zero_slot_index
=
0
;
size_t
uint64_zero_slot_index
=
0
;
int
offset_cols_size
=
(
pack
->
ins_num
()
+
1
);
HostBuffer
<
size_t
>&
offsets
=
pack
->
offsets
();
phi
::
DenseTensor
&
float_tensor
=
pack
->
float_tensor
();
phi
::
DenseTensor
&
uint64_tensor
=
pack
->
uint64_tensor
();
auto
*
feed_vec
=
&
feed_vec_
;
if
(
scope
)
{
CHECK
(
scpoe_feed_vec_
.
count
(
scope
)
>
0
)
<<
"scope not found."
;
feed_vec
=
&
scpoe_feed_vec_
[
scope
];
}
CHECK
(
feed_vec
!=
nullptr
)
<<
"feed_vec nullptr."
;
for
(
int
j
=
0
;
j
<
use_slot_size_
;
++
j
)
{
auto
&
feed
=
(
*
feed_vec
)[
j
];
if
(
feed
==
nullptr
)
{
continue
;
}
size_t
*
off_start_ptr
=
&
offsets
[
j
*
offset_cols_size
];
int
total_instance
=
static_cast
<
int
>
(
off_start_ptr
[
offset_cols_size
-
1
]);
auto
&
info
=
used_slots_info_
[
j
];
// fill slot value with default value 0
if
(
info
.
type
[
0
]
==
'f'
)
{
// float
if
(
total_instance
>
0
)
{
...
...
@@ -2794,10 +2949,9 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
static_cast
<
int64_t
>
(
float_offset
+
total_instance
)));
feed
->
Resize
({
total_instance
,
1
});
float_offset
+=
total_instance
;
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
float
>
(
this
->
place_
);
}
else
{
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
float
>
({
total_instance
,
1
},
this
->
place_
);
feed
->
ShareDataWith
(
pack
->
float_tensor_vec
()[
float_zero_slot_index
++
]);
feed
->
Resize
({
total_instance
,
1
}
);
}
}
else
if
(
info
.
type
[
0
]
==
'u'
)
{
// uint64
if
(
total_instance
>
0
)
{
...
...
@@ -2806,10 +2960,10 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
static_cast
<
int64_t
>
(
uint64_offset
+
total_instance
)));
feed
->
Resize
({
total_instance
,
1
});
uint64_offset
+=
total_instance
;
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
int64_t
>
(
this
->
place_
);
}
else
{
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
int64_t
>
({
total_instance
,
1
},
this
->
place_
);
feed
->
ShareDataWith
(
pack
->
uint64_tensor_vec
()[
uint64_zero_slot_index
++
]);
feed
->
Resize
({
total_instance
,
1
});
}
}
...
...
@@ -2829,33 +2983,14 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {
offset_cols_size
*
sizeof
(
size_t
));
}
}
void
**
dest_gpu_p
=
reinterpret_cast
<
void
**>
(
pack_
->
slot_buf_ptr
());
CUDA_CHECK
(
cudaMemcpy
(
dest_gpu_p
,
h_tensor_ptrs
.
data
(),
use_slot_size_
*
sizeof
(
void
*
),
cudaMemcpyHostToDevice
));
CopyForTensor
(
ins_num
,
use_slot_size_
,
dest_gpu_p
,
(
const
size_t
*
)
pack_
->
gpu_slot_offsets
(),
(
const
uint64_t
*
)
value
.
d_uint64_keys
.
data
(),
(
const
int
*
)
value
.
d_uint64_offset
.
data
(),
(
const
int
*
)
value
.
d_uint64_lens
.
data
(),
uint64_use_slot_size_
,
(
const
float
*
)
value
.
d_float_keys
.
data
(),
(
const
int
*
)
value
.
d_float_offset
.
data
(),
(
const
int
*
)
value
.
d_float_lens
.
data
(),
float_use_slot_size_
,
used_slot_gpu_types
);
}
MiniBatchGpuPack
::
MiniBatchGpuPack
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
UsedSlotInfo
>&
infos
)
{
const
std
::
vector
<
UsedSlotInfo
>&
infos
,
phi
::
StreamId
stream_id
)
{
place_
=
place
;
stream_
=
dynamic_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
stream_holder_
.
reset
(
new
phi
::
CUDAStream
(
place
));
stream_
=
stream_holder_
->
raw_stream
();
ins_num_
=
0
;
pv_num_
=
0
;
...
...
@@ -2881,15 +3016,16 @@ MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
VLOG
(
3
)
<<
"begin get batch pack device id: "
<<
device_id
;
// sync
CUDA_CHECK
(
cudaStreamSynchronize
(
stream_
));
float_tensor_vec_
.
resize
(
used_slot_size_
);
uint64_tensor_vec_
.
resize
(
used_slot_size_
);
}
MiniBatchGpuPack
::~
MiniBatchGpuPack
()
{}
void
MiniBatchGpuPack
::
reset
(
const
paddle
::
platform
::
Place
&
place
)
{
place_
=
place
;
stream_
=
dynamic_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
stream_holder_
.
reset
(
new
phi
::
CUDAStream
(
place
));
stream_
=
stream_holder_
->
raw_stream
();
ins_num_
=
0
;
pv_num_
=
0
;
}
...
...
paddle/fluid/framework/data_feed.cu
浏览文件 @
effe2c11
...
...
@@ -320,11 +320,8 @@ void SlotRecordInMemoryDataFeed::FillSlotValueOffset(
const
int
uint64_slot_size
,
const
int
*
float_offsets
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
)
{
auto
stream
=
dynamic_cast
<
phi
::
GPUContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
this
->
place_
))
->
stream
();
const
UsedSlotGpuType
*
used_slots
,
cudaStream_t
stream
)
{
FillSlotValueOffsetKernel
<<<
GET_BLOCKS
(
used_slot_num
),
CUDA_NUM_THREADS
,
0
,
...
...
@@ -399,12 +396,8 @@ void SlotRecordInMemoryDataFeed::CopyForTensor(
const
int
*
float_offsets
,
const
int
*
float_ins_lens
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
)
{
auto
stream
=
dynamic_cast
<
phi
::
GPUContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
this
->
place_
))
->
stream
();
const
UsedSlotGpuType
*
used_slots
,
cudaStream_t
stream
)
{
CopyForTensorKernel
<<<
GET_BLOCKS
(
used_slot_num
*
ins_num
),
CUDA_NUM_THREADS
,
0
,
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
effe2c11
...
...
@@ -46,6 +46,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/phi/core/cuda_stream.h"
#endif
DECLARE_int32
(
record_pool_max_size
);
...
...
@@ -535,8 +536,11 @@ struct BatchGPUValue {
class
MiniBatchGpuPack
{
public:
MiniBatchGpuPack
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
UsedSlotInfo
>&
infos
);
const
std
::
vector
<
UsedSlotInfo
>&
infos
,
phi
::
StreamId
stream_id
);
~
MiniBatchGpuPack
();
bool
is_use
()
{
return
is_using_
;
}
void
set_use_flag
(
bool
is_use
)
{
is_using_
=
is_use
;
}
void
reset
(
const
paddle
::
platform
::
Place
&
place
);
void
pack_instance
(
const
SlotRecord
*
ins_vec
,
int
num
);
int
ins_num
()
{
return
ins_num_
;
}
...
...
@@ -566,6 +570,12 @@ class MiniBatchGpuPack {
}
phi
::
DenseTensor
&
float_tensor
(
void
)
{
return
float_tensor_
;
}
phi
::
DenseTensor
&
uint64_tensor
(
void
)
{
return
uint64_tensor_
;
}
std
::
vector
<
phi
::
DenseTensor
>&
float_tensor_vec
(
void
)
{
return
float_tensor_vec_
;
}
std
::
vector
<
phi
::
DenseTensor
>&
uint64_tensor_vec
(
void
)
{
return
uint64_tensor_vec_
;
}
HostBuffer
<
size_t
>&
offsets
(
void
)
{
return
offsets_
;
}
HostBuffer
<
void
*>&
h_tensor_ptrs
(
void
)
{
return
h_tensor_ptrs_
;
}
...
...
@@ -590,6 +600,8 @@ class MiniBatchGpuPack {
return
batch_ins_
[
idx
]
->
ins_id_
;
}
cudaStream_t
get_stream
()
{
return
stream_
;
}
private:
void
transfer_to_gpu
(
void
);
void
pack_all_data
(
const
SlotRecord
*
ins_vec
,
int
num
);
...
...
@@ -612,7 +624,9 @@ class MiniBatchGpuPack {
}
private:
bool
is_using_
=
false
;
paddle
::
platform
::
Place
place_
;
std
::
unique_ptr
<
phi
::
CUDAStream
>
stream_holder_
;
cudaStream_t
stream_
;
BatchGPUValue
value_
;
BatchCPUValue
buf_
;
...
...
@@ -631,8 +645,10 @@ class MiniBatchGpuPack {
// uint64 tensor
phi
::
DenseTensor
uint64_tensor_
;
std
::
vector
<
phi
::
DenseTensor
>
uint64_tensor_vec_
;
// float tensor
phi
::
DenseTensor
float_tensor_
;
std
::
vector
<
phi
::
DenseTensor
>
float_tensor_vec_
;
// batch
HostBuffer
<
size_t
>
offsets_
;
HostBuffer
<
void
*>
h_tensor_ptrs_
;
...
...
@@ -645,33 +661,52 @@ class MiniBatchGpuPackMgr {
public:
MiniBatchGpuPackMgr
()
{
pack_list_
.
resize
(
MAX_DEIVCE_NUM
);
for
(
int
i
=
0
;
i
<
MAX_DEIVCE_NUM
;
++
i
)
{
pack_list_
[
i
]
=
nullptr
;
pack_list_
[
i
]
.
clear
()
;
}
}
~
MiniBatchGpuPackMgr
()
{
for
(
int
i
=
0
;
i
<
MAX_DEIVCE_NUM
;
++
i
)
{
if
(
pack_list_
[
i
]
==
nullptr
)
{
continue
;
for
(
size_t
j
=
0
;
j
<
pack_list_
[
i
].
size
();
j
++
)
{
if
(
pack_list_
[
i
][
j
]
==
nullptr
)
{
continue
;
}
delete
pack_list_
[
i
][
j
];
pack_list_
[
i
][
j
]
=
nullptr
;
}
delete
pack_list_
[
i
];
pack_list_
[
i
]
=
nullptr
;
}
}
// one device one thread
// thread unsafe
MiniBatchGpuPack
*
get
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
UsedSlotInfo
>&
infos
)
{
int
device_id
=
place
.
GetDeviceId
();
if
(
pack_list_
[
device_id
]
==
nullptr
)
{
pack_list_
[
device_id
]
=
new
MiniBatchGpuPack
(
place
,
infos
);
}
else
{
pack_list_
[
device_id
]
->
reset
(
place
);
for
(
size_t
i
=
0
;
i
<
pack_list_
[
device_id
].
size
();
i
++
)
{
if
(
!
pack_list_
[
device_id
][
i
]
->
is_use
())
{
pack_list_
[
device_id
][
i
]
->
set_use_flag
(
true
);
pack_list_
[
device_id
][
i
]
->
reset
(
place
);
return
pack_list_
[
device_id
][
i
];
}
}
return
pack_list_
[
device_id
];
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
alloc_stream_map_
.
count
(
device_id
))
{
alloc_stream_map_
.
emplace
(
device_id
,
new
phi
::
CUDAStream
(
place
));
}
}
phi
::
StreamId
alloc_stream_id
=
reinterpret_cast
<
phi
::
StreamId
>
(
alloc_stream_map_
[
device_id
]
->
raw_stream
());
auto
*
pack
=
new
MiniBatchGpuPack
(
place
,
infos
,
alloc_stream_id
);
pack
->
set_use_flag
(
true
);
pack_list_
[
device_id
].
push_back
(
pack
);
return
pack
;
}
private:
MiniBatchGpuPack
*
pack_list_
[
MAX_DEIVCE_NUM
];
std
::
vector
<
std
::
vector
<
MiniBatchGpuPack
*>>
pack_list_
;
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
phi
::
CUDAStream
>>
alloc_stream_map_
;
std
::
mutex
mutex_
;
};
// global mgr
inline
MiniBatchGpuPackMgr
&
BatchGpuPackMgr
()
{
...
...
@@ -1212,6 +1247,13 @@ class DataFeed {
}
virtual
const
paddle
::
platform
::
Place
&
GetPlace
()
const
{
return
place_
;
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
virtual
void
PackToScope
(
MiniBatchGpuPack
*
pack
,
const
Scope
*
scope
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This function(PackToScope) is not implemented."
));
}
#endif
virtual
void
DumpWalkPath
(
std
::
string
dump_path
,
size_t
dump_rate
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This function(DumpWalkPath) is not implemented."
));
...
...
@@ -1766,13 +1808,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
class
SlotRecordInMemoryDataFeed
:
public
InMemoryDataFeed
<
SlotRecord
>
{
public:
SlotRecordInMemoryDataFeed
()
{}
virtual
~
SlotRecordInMemoryDataFeed
()
{
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if
(
pack_
!=
nullptr
)
{
pack_
=
nullptr
;
}
#endif
}
virtual
~
SlotRecordInMemoryDataFeed
();
virtual
void
Init
(
const
DataFeedDesc
&
data_feed_desc
);
virtual
void
LoadIntoMemory
();
void
ExpandSlotRecord
(
SlotRecord
*
ins
);
...
...
@@ -1797,7 +1833,11 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
virtual
void
PutToFeedVec
(
const
SlotRecord
*
ins_vec
,
int
num
);
virtual
void
AssignFeedVar
(
const
Scope
&
scope
);
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void
BuildSlotBatchGPU
(
const
int
ins_num
);
void
BuildSlotBatchGPU
(
const
int
ins_num
,
MiniBatchGpuPack
*
pack
);
virtual
void
PackToScope
(
MiniBatchGpuPack
*
pack
,
const
Scope
*
scope
=
nullptr
);
void
FillSlotValueOffset
(
const
int
ins_num
,
const
int
used_slot_num
,
size_t
*
slot_value_offsets
,
...
...
@@ -1805,7 +1845,8 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
const
int
uint64_slot_size
,
const
int
*
float_offsets
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
);
const
UsedSlotGpuType
*
used_slots
,
cudaStream_t
stream
);
void
CopyForTensor
(
const
int
ins_num
,
const
int
used_slot_num
,
void
**
dest
,
...
...
@@ -1818,7 +1859,8 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
const
int
*
float_offsets
,
const
int
*
float_ins_lens
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
);
const
UsedSlotGpuType
*
used_slots
,
cudaStream_t
stream
);
#endif
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
...
...
@@ -1838,7 +1880,20 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
std
::
vector
<
int
>
float_total_dims_without_inductives_
;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
MiniBatchGpuPack
*
pack_
=
nullptr
;
int
pack_thread_num_
{
5
};
std
::
vector
<
std
::
thread
>
pack_threads_
;
std
::
vector
<
MiniBatchGpuPack
*>
pack_vec_
;
BlockingQueue
<
MiniBatchGpuPack
*>
free_pack_queue_
;
BlockingQueue
<
MiniBatchGpuPack
*>
using_pack_queue_
;
std
::
atomic
<
bool
>
pack_is_end_
{
false
};
std
::
atomic
<
uint64_t
>
pack_offset_index_
{
0
};
MiniBatchGpuPack
*
last_pack_
{
nullptr
};
std
::
atomic
<
bool
>
stop_token_
{
false
};
std
::
atomic
<
int
>
thread_count_
{
0
};
std
::
mutex
pack_mutex_
;
// async infershape
std
::
map
<
const
Scope
*
,
std
::
vector
<
phi
::
DenseTensor
*>>
scpoe_feed_vec_
;
#endif
};
...
...
paddle/fluid/framework/fleet/heter_ps/feature_value.h
浏览文件 @
effe2c11
...
...
@@ -99,11 +99,13 @@ class CommonFeatureValueAccessor {
// 根据mf_dim计算的总长度
__host__
__device__
int
Dim
(
int
mf_dim
)
{
int
tmp_embedx_sgd_dim
=
1
;
int
tmp_embedx_sgd_dim
=
1
;
// shared adagrad
if
(
optimizer_type_
==
3
)
{
// adam
tmp_embedx_sgd_dim
=
mf_dim
*
2
+
2
;
}
else
if
(
optimizer_type_
==
4
)
{
// shared_adam
tmp_embedx_sgd_dim
=
4
;
}
else
if
(
optimizer_type_
==
2
)
{
tmp_embedx_sgd_dim
=
mf_dim
;
}
return
9
+
embed_sgd_dim
+
tmp_embedx_sgd_dim
+
mf_dim
;
}
...
...
@@ -115,11 +117,13 @@ class CommonFeatureValueAccessor {
// 根据mf_dim 计算的 mf_size byte数
__host__
__device__
size_t
MFSize
(
int
mf_dim
)
{
int
tmp_embedx_sgd_dim
=
1
;
int
tmp_embedx_sgd_dim
=
1
;
// shared adagrad
if
(
optimizer_type_
==
3
)
{
// adam
tmp_embedx_sgd_dim
=
mf_dim
*
2
+
2
;
}
else
if
(
optimizer_type_
==
4
)
{
// shared_adam
tmp_embedx_sgd_dim
=
4
;
}
else
if
(
optimizer_type_
=
2
)
{
// std adagrad
tmp_embedx_sgd_dim
=
mf_dim
;
}
return
(
tmp_embedx_sgd_dim
+
mf_dim
)
*
sizeof
(
float
);
}
...
...
@@ -127,12 +131,14 @@ class CommonFeatureValueAccessor {
__host__
__device__
int
EmbedxG2SumOffsetIndex
()
{
return
0
;
}
__host__
__device__
int
EmbedxWOffsetIndex
(
float
*
val
)
{
// has mf
int
tmp_embedx_sgd_dim
=
1
;
int
tmp_embedx_sgd_dim
=
1
;
// shared adagrad
if
(
static_cast
<
int
>
(
MfSize
(
val
))
>
0
)
{
if
(
optimizer_type_
==
3
)
{
// adam
tmp_embedx_sgd_dim
=
MfDim
(
val
)
*
2
+
2
;
}
else
if
(
optimizer_type_
==
4
)
{
// shared_adam
tmp_embedx_sgd_dim
=
4
;
}
else
if
(
optimizer_type_
==
2
)
{
// std adagrad
tmp_embedx_sgd_dim
=
static_cast
<
int
>
(
MfDim
(
val
));
}
return
EmbedxG2SumIndex
()
+
tmp_embedx_sgd_dim
;
}
else
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录