Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c202a613
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c202a613
编写于
4月 12, 2022
作者:
D
danleifeng
提交者:
GitHub
4月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【heterps】datafeed puttofeedvec performance (#40168)
* perform SlotRecordInMemoryDataFeed feedvec;test=develop
上级
7b627dd8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
793 addition
and
22 deletion
+793
-22
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+5
-5
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+328
-3
paddle/fluid/framework/data_feed.cu
paddle/fluid/framework/data_feed.cu
+149
-0
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+292
-1
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+9
-9
paddle/fluid/framework/ps_gpu_worker.cc
paddle/fluid/framework/ps_gpu_worker.cc
+10
-4
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
c202a613
...
...
@@ -295,7 +295,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
data_feed.cu
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer
...
...
@@ -316,7 +316,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc heter_pipeline_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
downpour_worker.cc downpour_lite_worker.cc downpour_worker_opt.cc
downpour_worker.cc downpour_lite_worker.cc downpour_worker_opt.cc
data_feed.cu
pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
index_sampler index_wrapper sampler index_dataset_proto
...
...
@@ -339,7 +339,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
data_feed.cu
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
...
...
@@ -359,7 +359,7 @@ elseif(WITH_PSLIB)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
data_feed.cu
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
...
...
@@ -369,7 +369,7 @@ else()
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
data_feed.cu
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
c202a613
...
...
@@ -2394,9 +2394,6 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line,
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
uint64_t
feasign
=
static_cast
<
uint64_t
>
(
strtoull
(
endptr
,
&
endptr
,
10
));
if
(
feasign
==
0
&&
!
used_slots_info_
[
info
.
used_idx
].
dense
)
{
continue
;
}
slot_fea
.
push_back
(
feasign
);
++
uint64_total_slot_num
;
}
...
...
@@ -2419,8 +2416,21 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line,
return
(
uint64_total_slot_num
>
0
);
}
void
SlotRecordInMemoryDataFeed
::
AssignFeedVar
(
const
Scope
&
scope
)
{
CheckInit
();
for
(
int
i
=
0
;
i
<
use_slot_size_
;
++
i
)
{
feed_vec_
[
i
]
=
scope
.
FindVar
(
used_slots_info_
[
i
].
slot
)
->
GetMutable
<
LoDTensor
>
();
}
}
void
SlotRecordInMemoryDataFeed
::
PutToFeedVec
(
const
SlotRecord
*
ins_vec
,
int
num
)
{
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
paddle
::
platform
::
SetDeviceId
(
place_
.
GetDeviceId
());
pack_
->
pack_instance
(
ins_vec
,
num
);
BuildSlotBatchGPU
(
pack_
->
ins_num
());
#else
for
(
int
j
=
0
;
j
<
use_slot_size_
;
++
j
)
{
auto
&
feed
=
feed_vec_
[
j
];
if
(
feed
==
nullptr
)
{
...
...
@@ -2497,6 +2507,7 @@ void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec,
feed_vec_
[
j
]
->
set_lod
(
data_lod
);
}
}
#endif
}
void
SlotRecordInMemoryDataFeed
::
ExpandSlotRecord
(
SlotRecord
*
rec
)
{
...
...
@@ -2573,6 +2584,10 @@ bool SlotRecordInMemoryDataFeed::Start() {
this
->
offset_index_
=
0
;
}
this
->
finish_start_
=
true
;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
CHECK
(
paddle
::
platform
::
is_gpu_place
(
this
->
place_
));
pack_
=
BatchGpuPackMgr
().
get
(
this
->
GetPlace
(),
used_slots_info_
);
#endif
return
true
;
}
...
...
@@ -2607,5 +2622,315 @@ int SlotRecordInMemoryDataFeed::Next() {
#endif
}
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void
SlotRecordInMemoryDataFeed
::
BuildSlotBatchGPU
(
const
int
ins_num
)
{
int
offset_cols_size
=
(
ins_num
+
1
);
size_t
slot_total_num
=
(
use_slot_size_
*
offset_cols_size
);
pack_
->
resize_gpu_slot_offsets
(
slot_total_num
*
sizeof
(
size_t
));
auto
&
value
=
pack_
->
value
();
const
UsedSlotGpuType
*
used_slot_gpu_types
=
static_cast
<
const
UsedSlotGpuType
*>
(
pack_
->
get_gpu_slots
());
FillSlotValueOffset
(
ins_num
,
use_slot_size_
,
reinterpret_cast
<
size_t
*>
(
pack_
->
gpu_slot_offsets
()),
value
.
d_uint64_offset
.
data
(),
uint64_use_slot_size_
,
value
.
d_float_offset
.
data
(),
float_use_slot_size_
,
used_slot_gpu_types
);
size_t
*
d_slot_offsets
=
reinterpret_cast
<
size_t
*>
(
pack_
->
gpu_slot_offsets
());
HostBuffer
<
size_t
>&
offsets
=
pack_
->
offsets
();
offsets
.
resize
(
slot_total_num
);
HostBuffer
<
void
*>&
h_tensor_ptrs
=
pack_
->
h_tensor_ptrs
();
h_tensor_ptrs
.
resize
(
use_slot_size_
);
// alloc gpu memory
pack_
->
resize_tensor
();
LoDTensor
&
float_tensor
=
pack_
->
float_tensor
();
LoDTensor
&
uint64_tensor
=
pack_
->
uint64_tensor
();
int64_t
float_offset
=
0
;
int64_t
uint64_offset
=
0
;
// copy index
CUDA_CHECK
(
cudaMemcpy
(
offsets
.
data
(),
d_slot_offsets
,
slot_total_num
*
sizeof
(
size_t
),
cudaMemcpyDeviceToHost
));
for
(
int
j
=
0
;
j
<
use_slot_size_
;
++
j
)
{
auto
&
feed
=
feed_vec_
[
j
];
if
(
feed
==
nullptr
)
{
h_tensor_ptrs
[
j
]
=
nullptr
;
continue
;
}
size_t
*
off_start_ptr
=
&
offsets
[
j
*
offset_cols_size
];
int
total_instance
=
static_cast
<
int
>
(
off_start_ptr
[
offset_cols_size
-
1
]);
CHECK
(
total_instance
>=
0
)
<<
"slot idx:"
<<
j
<<
", total instance:"
<<
total_instance
;
auto
&
info
=
used_slots_info_
[
j
];
// fill slot value with default value 0
if
(
info
.
type
[
0
]
==
'f'
)
{
// float
if
(
total_instance
>
0
)
{
feed
->
ShareDataWith
(
float_tensor
.
Slice
(
static_cast
<
int64_t
>
(
float_offset
),
static_cast
<
int64_t
>
(
float_offset
+
total_instance
)));
feed
->
Resize
({
total_instance
,
1
});
float_offset
+=
total_instance
;
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
float
>
(
this
->
place_
);
}
else
{
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
float
>
({
total_instance
,
1
},
this
->
place_
);
}
}
else
if
(
info
.
type
[
0
]
==
'u'
)
{
// uint64
if
(
total_instance
>
0
)
{
feed
->
ShareDataWith
(
uint64_tensor
.
Slice
(
static_cast
<
int64_t
>
(
uint64_offset
),
static_cast
<
int64_t
>
(
uint64_offset
+
total_instance
)));
feed
->
Resize
({
total_instance
,
1
});
uint64_offset
+=
total_instance
;
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
int64_t
>
(
this
->
place_
);
}
else
{
h_tensor_ptrs
[
j
]
=
feed
->
mutable_data
<
int64_t
>
({
total_instance
,
1
},
this
->
place_
);
}
}
if
(
info
.
dense
)
{
if
(
info
.
inductive_shape_index
!=
-
1
)
{
info
.
local_shape
[
info
.
inductive_shape_index
]
=
total_instance
/
info
.
total_dims_without_inductive
;
}
feed
->
Resize
(
phi
::
make_ddim
(
info
.
local_shape
));
}
else
{
LoD
&
lod
=
(
*
feed
->
mutable_lod
());
lod
.
resize
(
1
);
lod
[
0
].
resize
(
offset_cols_size
);
paddle
::
framework
::
MixVector
<
size_t
>
mixv_lod
(
&
lod
[
0
]);
memcpy
(
mixv_lod
.
MutableData
(
platform
::
CPUPlace
()),
off_start_ptr
,
offset_cols_size
*
sizeof
(
size_t
));
}
}
void
**
dest_gpu_p
=
reinterpret_cast
<
void
**>
(
pack_
->
slot_buf_ptr
());
CUDA_CHECK
(
cudaMemcpy
(
dest_gpu_p
,
h_tensor_ptrs
.
data
(),
use_slot_size_
*
sizeof
(
void
*
),
cudaMemcpyHostToDevice
));
CopyForTensor
(
ins_num
,
use_slot_size_
,
dest_gpu_p
,
(
const
size_t
*
)
pack_
->
gpu_slot_offsets
(),
(
const
uint64_t
*
)
value
.
d_uint64_keys
.
data
(),
(
const
int
*
)
value
.
d_uint64_offset
.
data
(),
(
const
int
*
)
value
.
d_uint64_lens
.
data
(),
uint64_use_slot_size_
,
(
const
float
*
)
value
.
d_float_keys
.
data
(),
(
const
int
*
)
value
.
d_float_offset
.
data
(),
(
const
int
*
)
value
.
d_float_lens
.
data
(),
float_use_slot_size_
,
used_slot_gpu_types
);
}
MiniBatchGpuPack
::
MiniBatchGpuPack
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
UsedSlotInfo
>&
infos
)
{
place_
=
place
;
stream_
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
ins_num_
=
0
;
pv_num_
=
0
;
used_float_num_
=
0
;
used_uint64_num_
=
0
;
used_slot_size_
=
static_cast
<
int
>
(
infos
.
size
());
for
(
int
i
=
0
;
i
<
used_slot_size_
;
++
i
)
{
auto
&
info
=
infos
[
i
];
if
(
info
.
type
[
0
]
==
'u'
)
{
gpu_used_slots_
.
push_back
({
1
,
info
.
slot_value_idx
});
++
used_uint64_num_
;
}
else
{
gpu_used_slots_
.
push_back
({
0
,
info
.
slot_value_idx
});
++
used_float_num_
;
}
}
copy_host2device
(
&
gpu_slots_
,
gpu_used_slots_
.
data
(),
gpu_used_slots_
.
size
());
slot_buf_ptr_
=
memory
::
AllocShared
(
place_
,
used_slot_size_
*
sizeof
(
void
*
));
int
device_id
=
place_
.
GetDeviceId
();
VLOG
(
3
)
<<
"begin get batch pack device id: "
<<
device_id
;
// sync
CUDA_CHECK
(
cudaStreamSynchronize
(
stream_
));
}
MiniBatchGpuPack
::~
MiniBatchGpuPack
()
{}
void
MiniBatchGpuPack
::
reset
(
const
paddle
::
platform
::
Place
&
place
)
{
place_
=
place
;
stream_
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
ins_num_
=
0
;
pv_num_
=
0
;
}
void
MiniBatchGpuPack
::
pack_all_data
(
const
SlotRecord
*
ins_vec
,
int
num
)
{
int
uint64_total_num
=
0
;
int
float_total_num
=
0
;
buf_
.
h_uint64_lens
.
resize
(
num
+
1
);
buf_
.
h_uint64_lens
[
0
]
=
0
;
buf_
.
h_float_lens
.
resize
(
num
+
1
);
buf_
.
h_float_lens
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
auto
r
=
ins_vec
[
i
];
uint64_total_num
+=
r
->
slot_uint64_feasigns_
.
slot_values
.
size
();
buf_
.
h_uint64_lens
[
i
+
1
]
=
uint64_total_num
;
float_total_num
+=
r
->
slot_float_feasigns_
.
slot_values
.
size
();
buf_
.
h_float_lens
[
i
+
1
]
=
float_total_num
;
}
int
uint64_cols
=
(
used_uint64_num_
+
1
);
buf_
.
h_uint64_offset
.
resize
(
uint64_cols
*
num
);
buf_
.
h_uint64_keys
.
resize
(
uint64_total_num
);
int
float_cols
=
(
used_float_num_
+
1
);
buf_
.
h_float_offset
.
resize
(
float_cols
*
num
);
buf_
.
h_float_keys
.
resize
(
float_total_num
);
size_t
fea_num
=
0
;
uint64_total_num
=
0
;
float_total_num
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
auto
r
=
ins_vec
[
i
];
auto
&
uint64_feasigns
=
r
->
slot_uint64_feasigns_
;
fea_num
=
uint64_feasigns
.
slot_values
.
size
();
if
(
fea_num
>
0
)
{
memcpy
(
&
buf_
.
h_uint64_keys
[
uint64_total_num
],
uint64_feasigns
.
slot_values
.
data
(),
fea_num
*
sizeof
(
uint64_t
));
}
uint64_total_num
+=
fea_num
;
// copy uint64 offset
memcpy
(
&
buf_
.
h_uint64_offset
[
i
*
uint64_cols
],
uint64_feasigns
.
slot_offsets
.
data
(),
sizeof
(
int
)
*
uint64_cols
);
auto
&
float_feasigns
=
r
->
slot_float_feasigns_
;
fea_num
=
float_feasigns
.
slot_values
.
size
();
memcpy
(
&
buf_
.
h_float_keys
[
float_total_num
],
float_feasigns
.
slot_values
.
data
(),
fea_num
*
sizeof
(
float
));
float_total_num
+=
fea_num
;
// copy float offset
memcpy
(
&
buf_
.
h_float_offset
[
i
*
float_cols
],
float_feasigns
.
slot_offsets
.
data
(),
sizeof
(
int
)
*
float_cols
);
}
CHECK
(
uint64_total_num
==
static_cast
<
int
>
(
buf_
.
h_uint64_lens
.
back
()))
<<
"uint64 value length error"
;
CHECK
(
float_total_num
==
static_cast
<
int
>
(
buf_
.
h_float_lens
.
back
()))
<<
"float value length error"
;
}
void
MiniBatchGpuPack
::
pack_uint64_data
(
const
SlotRecord
*
ins_vec
,
int
num
)
{
int
uint64_total_num
=
0
;
buf_
.
h_float_lens
.
clear
();
buf_
.
h_float_keys
.
clear
();
buf_
.
h_float_offset
.
clear
();
buf_
.
h_uint64_lens
.
resize
(
num
+
1
);
buf_
.
h_uint64_lens
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
auto
r
=
ins_vec
[
i
];
uint64_total_num
+=
r
->
slot_uint64_feasigns_
.
slot_values
.
size
();
buf_
.
h_uint64_lens
[
i
+
1
]
=
uint64_total_num
;
}
int
uint64_cols
=
(
used_uint64_num_
+
1
);
buf_
.
h_uint64_offset
.
resize
(
uint64_cols
*
num
);
buf_
.
h_uint64_keys
.
resize
(
uint64_total_num
);
size_t
fea_num
=
0
;
uint64_total_num
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
auto
r
=
ins_vec
[
i
];
auto
&
uint64_feasigns
=
r
->
slot_uint64_feasigns_
;
fea_num
=
uint64_feasigns
.
slot_values
.
size
();
if
(
fea_num
>
0
)
{
memcpy
(
&
buf_
.
h_uint64_keys
[
uint64_total_num
],
uint64_feasigns
.
slot_values
.
data
(),
fea_num
*
sizeof
(
uint64_t
));
}
uint64_total_num
+=
fea_num
;
// copy uint64 offset
memcpy
(
&
buf_
.
h_uint64_offset
[
i
*
uint64_cols
],
uint64_feasigns
.
slot_offsets
.
data
(),
sizeof
(
int
)
*
uint64_cols
);
}
CHECK
(
uint64_total_num
==
static_cast
<
int
>
(
buf_
.
h_uint64_lens
.
back
()))
<<
"uint64 value length error"
;
}
void
MiniBatchGpuPack
::
pack_float_data
(
const
SlotRecord
*
ins_vec
,
int
num
)
{
int
float_total_num
=
0
;
buf_
.
h_uint64_lens
.
clear
();
buf_
.
h_uint64_offset
.
clear
();
buf_
.
h_uint64_keys
.
clear
();
buf_
.
h_float_lens
.
resize
(
num
+
1
);
buf_
.
h_float_lens
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
auto
r
=
ins_vec
[
i
];
float_total_num
+=
r
->
slot_float_feasigns_
.
slot_values
.
size
();
buf_
.
h_float_lens
[
i
+
1
]
=
float_total_num
;
}
int
float_cols
=
(
used_float_num_
+
1
);
buf_
.
h_float_offset
.
resize
(
float_cols
*
num
);
buf_
.
h_float_keys
.
resize
(
float_total_num
);
size_t
fea_num
=
0
;
float_total_num
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
auto
r
=
ins_vec
[
i
];
auto
&
float_feasigns
=
r
->
slot_float_feasigns_
;
fea_num
=
float_feasigns
.
slot_values
.
size
();
memcpy
(
&
buf_
.
h_float_keys
[
float_total_num
],
float_feasigns
.
slot_values
.
data
(),
fea_num
*
sizeof
(
float
));
float_total_num
+=
fea_num
;
// copy float offset
memcpy
(
&
buf_
.
h_float_offset
[
i
*
float_cols
],
float_feasigns
.
slot_offsets
.
data
(),
sizeof
(
int
)
*
float_cols
);
}
CHECK
(
float_total_num
==
static_cast
<
int
>
(
buf_
.
h_float_lens
.
back
()))
<<
"float value length error"
;
}
void
MiniBatchGpuPack
::
pack_instance
(
const
SlotRecord
*
ins_vec
,
int
num
)
{
ins_num_
=
num
;
batch_ins_
=
ins_vec
;
CHECK
(
used_uint64_num_
>
0
||
used_float_num_
>
0
);
// uint64 and float
if
(
used_uint64_num_
>
0
&&
used_float_num_
>
0
)
{
pack_all_data
(
ins_vec
,
num
);
}
else
if
(
used_uint64_num_
>
0
)
{
// uint64
pack_uint64_data
(
ins_vec
,
num
);
}
else
{
// only float
pack_float_data
(
ins_vec
,
num
);
}
// to gpu
transfer_to_gpu
();
}
void
MiniBatchGpuPack
::
transfer_to_gpu
(
void
)
{
copy_host2device
(
&
value_
.
d_uint64_lens
,
buf_
.
h_uint64_lens
);
copy_host2device
(
&
value_
.
d_uint64_keys
,
buf_
.
h_uint64_keys
);
copy_host2device
(
&
value_
.
d_uint64_offset
,
buf_
.
h_uint64_offset
);
copy_host2device
(
&
value_
.
d_float_lens
,
buf_
.
h_float_lens
);
copy_host2device
(
&
value_
.
d_float_keys
,
buf_
.
h_float_keys
);
copy_host2device
(
&
value_
.
d_float_offset
,
buf_
.
h_float_offset
);
CUDA_CHECK
(
cudaStreamSynchronize
(
stream_
));
}
#endif
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/data_feed.cu
0 → 100644
浏览文件 @
c202a613
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
framework
{
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: use 512 threads per block
const
int
CUDA_NUM_THREADS
=
512
;
// CUDA: number of blocks for threads.
inline
int
GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
}
// fill slot values
__global__
void
FillSlotValueOffsetKernel
(
const
int
ins_num
,
const
int
used_slot_num
,
size_t
*
slot_value_offsets
,
const
int
*
uint64_offsets
,
const
int
uint64_slot_size
,
const
int
*
float_offsets
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
)
{
int
col_num
=
ins_num
+
1
;
int
uint64_cols
=
uint64_slot_size
+
1
;
int
float_cols
=
float_slot_size
+
1
;
CUDA_KERNEL_LOOP
(
slot_idx
,
used_slot_num
)
{
int
value_off
=
slot_idx
*
col_num
;
slot_value_offsets
[
value_off
]
=
0
;
auto
&
info
=
used_slots
[
slot_idx
];
if
(
info
.
is_uint64_value
)
{
for
(
int
k
=
0
;
k
<
ins_num
;
++
k
)
{
int
pos
=
k
*
uint64_cols
+
info
.
slot_value_idx
;
int
num
=
uint64_offsets
[
pos
+
1
]
-
uint64_offsets
[
pos
];
PADDLE_ENFORCE
(
num
>=
0
,
"The number of slot size must be ge 0."
);
slot_value_offsets
[
value_off
+
k
+
1
]
=
slot_value_offsets
[
value_off
+
k
]
+
num
;
}
}
else
{
for
(
int
k
=
0
;
k
<
ins_num
;
++
k
)
{
int
pos
=
k
*
float_cols
+
info
.
slot_value_idx
;
int
num
=
float_offsets
[
pos
+
1
]
-
float_offsets
[
pos
];
PADDLE_ENFORCE
(
num
>=
0
,
"The number of slot size must be ge 0."
);
slot_value_offsets
[
value_off
+
k
+
1
]
=
slot_value_offsets
[
value_off
+
k
]
+
num
;
}
}
}
}
void
SlotRecordInMemoryDataFeed
::
FillSlotValueOffset
(
const
int
ins_num
,
const
int
used_slot_num
,
size_t
*
slot_value_offsets
,
const
int
*
uint64_offsets
,
const
int
uint64_slot_size
,
const
int
*
float_offsets
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
)
{
auto
stream
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
this
->
place_
))
->
stream
();
FillSlotValueOffsetKernel
<<<
GET_BLOCKS
(
used_slot_num
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
ins_num
,
used_slot_num
,
slot_value_offsets
,
uint64_offsets
,
uint64_slot_size
,
float_offsets
,
float_slot_size
,
used_slots
);
cudaStreamSynchronize
(
stream
);
}
__global__
void
CopyForTensorKernel
(
const
int
used_slot_num
,
const
int
ins_num
,
void
**
dest
,
const
size_t
*
slot_value_offsets
,
const
uint64_t
*
uint64_feas
,
const
int
*
uint64_offsets
,
const
int
*
uint64_ins_lens
,
const
int
uint64_slot_size
,
const
float
*
float_feas
,
const
int
*
float_offsets
,
const
int
*
float_ins_lens
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
)
{
int
col_num
=
ins_num
+
1
;
int
uint64_cols
=
uint64_slot_size
+
1
;
int
float_cols
=
float_slot_size
+
1
;
CUDA_KERNEL_LOOP
(
i
,
ins_num
*
used_slot_num
)
{
int
slot_idx
=
i
/
ins_num
;
int
ins_idx
=
i
%
ins_num
;
uint32_t
value_offset
=
slot_value_offsets
[
slot_idx
*
col_num
+
ins_idx
];
auto
&
info
=
used_slots
[
slot_idx
];
if
(
info
.
is_uint64_value
)
{
uint64_t
*
up
=
reinterpret_cast
<
uint64_t
*>
(
dest
[
slot_idx
]);
int
index
=
info
.
slot_value_idx
+
uint64_cols
*
ins_idx
;
int
old_off
=
uint64_offsets
[
index
];
int
num
=
uint64_offsets
[
index
+
1
]
-
old_off
;
PADDLE_ENFORCE
(
num
>=
0
,
"The number of slot size must be ge 0."
);
int
uint64_value_offset
=
uint64_ins_lens
[
ins_idx
];
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
up
[
k
+
value_offset
]
=
uint64_feas
[
k
+
old_off
+
uint64_value_offset
];
}
}
else
{
float
*
fp
=
reinterpret_cast
<
float
*>
(
dest
[
slot_idx
]);
int
index
=
info
.
slot_value_idx
+
float_cols
*
ins_idx
;
int
old_off
=
float_offsets
[
index
];
int
num
=
float_offsets
[
index
+
1
]
-
old_off
;
PADDLE_ENFORCE
(
num
>=
0
,
"The number of slot size must be ge 0."
);
int
float_value_offset
=
float_ins_lens
[
ins_idx
];
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
fp
[
k
+
value_offset
]
=
float_feas
[
k
+
old_off
+
float_value_offset
];
}
}
}
}
void
SlotRecordInMemoryDataFeed
::
CopyForTensor
(
const
int
ins_num
,
const
int
used_slot_num
,
void
**
dest
,
const
size_t
*
slot_value_offsets
,
const
uint64_t
*
uint64_feas
,
const
int
*
uint64_offsets
,
const
int
*
uint64_ins_lens
,
const
int
uint64_slot_size
,
const
float
*
float_feas
,
const
int
*
float_offsets
,
const
int
*
float_ins_lens
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
)
{
auto
stream
=
dynamic_cast
<
platform
::
CUDADeviceContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
this
->
place_
))
->
stream
();
CopyForTensorKernel
<<<
GET_BLOCKS
(
used_slot_num
*
ins_num
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
used_slot_num
,
ins_num
,
dest
,
slot_value_offsets
,
uint64_feas
,
uint64_offsets
,
uint64_ins_lens
,
uint64_slot_size
,
float_feas
,
float_offsets
,
float_ins_lens
,
float_slot_size
,
used_slots
);
cudaStreamSynchronize
(
stream
);
}
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/data_feed.h
浏览文件 @
c202a613
...
...
@@ -41,6 +41,10 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif
DECLARE_int32
(
record_pool_max_size
);
DECLARE_int32
(
slotpool_thread_num
);
...
...
@@ -409,6 +413,266 @@ class CustomParser {
}
};
struct
UsedSlotGpuType
{
int
is_uint64_value
;
int
slot_value_idx
;
};
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
#define CUDA_CHECK(val) CHECK(val == gpuSuccess)
template
<
typename
T
>
struct
CudaBuffer
{
T
*
cu_buffer
;
uint64_t
buf_size
;
CudaBuffer
<
T
>
()
{
cu_buffer
=
NULL
;
buf_size
=
0
;
}
~
CudaBuffer
<
T
>
()
{
free
();
}
T
*
data
()
{
return
cu_buffer
;
}
uint64_t
size
()
{
return
buf_size
;
}
void
malloc
(
uint64_t
size
)
{
buf_size
=
size
;
CUDA_CHECK
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
cu_buffer
),
size
*
sizeof
(
T
)));
}
void
free
()
{
if
(
cu_buffer
!=
NULL
)
{
CUDA_CHECK
(
cudaFree
(
cu_buffer
));
cu_buffer
=
NULL
;
}
buf_size
=
0
;
}
void
resize
(
uint64_t
size
)
{
if
(
size
<=
buf_size
)
{
return
;
}
free
();
malloc
(
size
);
}
};
template
<
typename
T
>
struct
HostBuffer
{
T
*
host_buffer
;
size_t
buf_size
;
size_t
data_len
;
HostBuffer
<
T
>
()
{
host_buffer
=
NULL
;
buf_size
=
0
;
data_len
=
0
;
}
~
HostBuffer
<
T
>
()
{
free
();
}
T
*
data
()
{
return
host_buffer
;
}
const
T
*
data
()
const
{
return
host_buffer
;
}
size_t
size
()
const
{
return
data_len
;
}
void
clear
()
{
free
();
}
T
&
back
()
{
return
host_buffer
[
data_len
-
1
];
}
T
&
operator
[](
size_t
i
)
{
return
host_buffer
[
i
];
}
const
T
&
operator
[](
size_t
i
)
const
{
return
host_buffer
[
i
];
}
void
malloc
(
size_t
len
)
{
buf_size
=
len
;
CUDA_CHECK
(
cudaHostAlloc
(
reinterpret_cast
<
void
**>
(
&
host_buffer
),
buf_size
*
sizeof
(
T
),
cudaHostAllocDefault
));
CHECK
(
host_buffer
!=
NULL
);
}
void
free
()
{
if
(
host_buffer
!=
NULL
)
{
CUDA_CHECK
(
cudaFreeHost
(
host_buffer
));
host_buffer
=
NULL
;
}
buf_size
=
0
;
}
void
resize
(
size_t
size
)
{
if
(
size
<=
buf_size
)
{
data_len
=
size
;
return
;
}
data_len
=
size
;
free
();
malloc
(
size
);
}
};
struct
BatchCPUValue
{
HostBuffer
<
int
>
h_uint64_lens
;
HostBuffer
<
uint64_t
>
h_uint64_keys
;
HostBuffer
<
int
>
h_uint64_offset
;
HostBuffer
<
int
>
h_float_lens
;
HostBuffer
<
float
>
h_float_keys
;
HostBuffer
<
int
>
h_float_offset
;
HostBuffer
<
int
>
h_rank
;
HostBuffer
<
int
>
h_cmatch
;
HostBuffer
<
int
>
h_ad_offset
;
};
struct
BatchGPUValue
{
CudaBuffer
<
int
>
d_uint64_lens
;
CudaBuffer
<
uint64_t
>
d_uint64_keys
;
CudaBuffer
<
int
>
d_uint64_offset
;
CudaBuffer
<
int
>
d_float_lens
;
CudaBuffer
<
float
>
d_float_keys
;
CudaBuffer
<
int
>
d_float_offset
;
CudaBuffer
<
int
>
d_rank
;
CudaBuffer
<
int
>
d_cmatch
;
CudaBuffer
<
int
>
d_ad_offset
;
};
class
MiniBatchGpuPack
{
public:
MiniBatchGpuPack
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
UsedSlotInfo
>&
infos
);
~
MiniBatchGpuPack
();
void
reset
(
const
paddle
::
platform
::
Place
&
place
);
void
pack_instance
(
const
SlotRecord
*
ins_vec
,
int
num
);
int
ins_num
()
{
return
ins_num_
;
}
int
pv_num
()
{
return
pv_num_
;
}
BatchGPUValue
&
value
()
{
return
value_
;
}
BatchCPUValue
&
cpu_value
()
{
return
buf_
;
}
UsedSlotGpuType
*
get_gpu_slots
(
void
)
{
return
reinterpret_cast
<
UsedSlotGpuType
*>
(
gpu_slots_
.
data
());
}
SlotRecord
*
get_records
(
void
)
{
return
&
ins_vec_
[
0
];
}
// tensor gpu memory reused
void
resize_tensor
(
void
)
{
if
(
used_float_num_
>
0
)
{
int
float_total_len
=
buf_
.
h_float_lens
.
back
();
if
(
float_total_len
>
0
)
{
float_tensor_
.
mutable_data
<
float
>
({
float_total_len
,
1
},
this
->
place_
);
}
}
if
(
used_uint64_num_
>
0
)
{
int
uint64_total_len
=
buf_
.
h_uint64_lens
.
back
();
if
(
uint64_total_len
>
0
)
{
uint64_tensor_
.
mutable_data
<
int64_t
>
({
uint64_total_len
,
1
},
this
->
place_
);
}
}
}
LoDTensor
&
float_tensor
(
void
)
{
return
float_tensor_
;
}
LoDTensor
&
uint64_tensor
(
void
)
{
return
uint64_tensor_
;
}
HostBuffer
<
size_t
>&
offsets
(
void
)
{
return
offsets_
;
}
HostBuffer
<
void
*>&
h_tensor_ptrs
(
void
)
{
return
h_tensor_ptrs_
;
}
void
*
gpu_slot_offsets
(
void
)
{
return
gpu_slot_offsets_
->
ptr
();
}
void
*
slot_buf_ptr
(
void
)
{
return
slot_buf_ptr_
->
ptr
();
}
void
resize_gpu_slot_offsets
(
const
size_t
slot_total_bytes
)
{
if
(
gpu_slot_offsets_
==
nullptr
)
{
gpu_slot_offsets_
=
memory
::
AllocShared
(
place_
,
slot_total_bytes
);
}
else
if
(
gpu_slot_offsets_
->
size
()
<
slot_total_bytes
)
{
auto
buf
=
memory
::
AllocShared
(
place_
,
slot_total_bytes
);
gpu_slot_offsets_
.
swap
(
buf
);
buf
=
nullptr
;
}
}
const
std
::
string
&
get_lineid
(
int
idx
)
{
if
(
enable_pv_
)
{
return
ins_vec_
[
idx
]
->
ins_id_
;
}
return
batch_ins_
[
idx
]
->
ins_id_
;
}
private:
void
transfer_to_gpu
(
void
);
void
pack_all_data
(
const
SlotRecord
*
ins_vec
,
int
num
);
void
pack_uint64_data
(
const
SlotRecord
*
ins_vec
,
int
num
);
void
pack_float_data
(
const
SlotRecord
*
ins_vec
,
int
num
);
public:
template
<
typename
T
>
void
copy_host2device
(
CudaBuffer
<
T
>*
buf
,
const
T
*
val
,
size_t
size
)
{
if
(
size
==
0
)
{
return
;
}
buf
->
resize
(
size
);
CUDA_CHECK
(
cudaMemcpyAsync
(
buf
->
data
(),
val
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
,
stream_
));
}
template
<
typename
T
>
void
copy_host2device
(
CudaBuffer
<
T
>*
buf
,
const
HostBuffer
<
T
>&
val
)
{
copy_host2device
(
buf
,
val
.
data
(),
val
.
size
());
}
private:
paddle
::
platform
::
Place
place_
;
cudaStream_t
stream_
;
BatchGPUValue
value_
;
BatchCPUValue
buf_
;
int
ins_num_
=
0
;
int
pv_num_
=
0
;
bool
enable_pv_
=
false
;
int
used_float_num_
=
0
;
int
used_uint64_num_
=
0
;
int
used_slot_size_
=
0
;
CudaBuffer
<
UsedSlotGpuType
>
gpu_slots_
;
std
::
vector
<
UsedSlotGpuType
>
gpu_used_slots_
;
std
::
vector
<
SlotRecord
>
ins_vec_
;
const
SlotRecord
*
batch_ins_
=
nullptr
;
// uint64 tensor
LoDTensor
uint64_tensor_
;
// float tensor
LoDTensor
float_tensor_
;
// batch
HostBuffer
<
size_t
>
offsets_
;
HostBuffer
<
void
*>
h_tensor_ptrs_
;
std
::
shared_ptr
<
phi
::
Allocation
>
gpu_slot_offsets_
=
nullptr
;
std
::
shared_ptr
<
phi
::
Allocation
>
slot_buf_ptr_
=
nullptr
;
};
class
MiniBatchGpuPackMgr
{
static
const
int
MAX_DEIVCE_NUM
=
16
;
public:
MiniBatchGpuPackMgr
()
{
for
(
int
i
=
0
;
i
<
MAX_DEIVCE_NUM
;
++
i
)
{
pack_list_
[
i
]
=
nullptr
;
}
}
~
MiniBatchGpuPackMgr
()
{
for
(
int
i
=
0
;
i
<
MAX_DEIVCE_NUM
;
++
i
)
{
if
(
pack_list_
[
i
]
==
nullptr
)
{
continue
;
}
delete
pack_list_
[
i
];
pack_list_
[
i
]
=
nullptr
;
}
}
// one device one thread
MiniBatchGpuPack
*
get
(
const
paddle
::
platform
::
Place
&
place
,
const
std
::
vector
<
UsedSlotInfo
>&
infos
)
{
int
device_id
=
place
.
GetDeviceId
();
if
(
pack_list_
[
device_id
]
==
nullptr
)
{
pack_list_
[
device_id
]
=
new
MiniBatchGpuPack
(
place
,
infos
);
}
else
{
pack_list_
[
device_id
]
->
reset
(
place
);
}
return
pack_list_
[
device_id
];
}
private:
MiniBatchGpuPack
*
pack_list_
[
MAX_DEIVCE_NUM
];
};
// global mgr
inline
MiniBatchGpuPackMgr
&
BatchGpuPackMgr
()
{
static
MiniBatchGpuPackMgr
mgr
;
return
mgr
;
}
#endif
typedef
paddle
::
framework
::
CustomParser
*
(
*
CreateParserObjectFunc
)();
class
DLManager
{
...
...
@@ -1126,7 +1390,13 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
class
SlotRecordInMemoryDataFeed
:
public
InMemoryDataFeed
<
SlotRecord
>
{
public:
SlotRecordInMemoryDataFeed
()
{}
virtual
~
SlotRecordInMemoryDataFeed
()
{}
virtual
~
SlotRecordInMemoryDataFeed
()
{
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if
(
pack_
!=
nullptr
)
{
pack_
=
nullptr
;
}
#endif
}
virtual
void
Init
(
const
DataFeedDesc
&
data_feed_desc
);
virtual
void
LoadIntoMemory
();
void
ExpandSlotRecord
(
SlotRecord
*
ins
);
...
...
@@ -1149,6 +1419,23 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
}
bool
ParseOneInstance
(
const
std
::
string
&
line
,
SlotRecord
*
rec
);
virtual
void
PutToFeedVec
(
const
SlotRecord
*
ins_vec
,
int
num
);
virtual
void
AssignFeedVar
(
const
Scope
&
scope
);
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void
BuildSlotBatchGPU
(
const
int
ins_num
);
void
FillSlotValueOffset
(
const
int
ins_num
,
const
int
used_slot_num
,
size_t
*
slot_value_offsets
,
const
int
*
uint64_offsets
,
const
int
uint64_slot_size
,
const
int
*
float_offsets
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
);
void
CopyForTensor
(
const
int
ins_num
,
const
int
used_slot_num
,
void
**
dest
,
const
size_t
*
slot_value_offsets
,
const
uint64_t
*
uint64_feas
,
const
int
*
uint64_offsets
,
const
int
*
uint64_ins_lens
,
const
int
uint64_slot_size
,
const
float
*
float_feas
,
const
int
*
float_offsets
,
const
int
*
float_ins_lens
,
const
int
float_slot_size
,
const
UsedSlotGpuType
*
used_slots
);
#endif
float
sample_rate_
=
1.0
f
;
int
use_slot_size_
=
0
;
int
float_use_slot_size_
=
0
;
...
...
@@ -1157,6 +1444,10 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
std
::
vector
<
UsedSlotInfo
>
used_slots_info_
;
size_t
float_total_dims_size_
=
0
;
std
::
vector
<
int
>
float_total_dims_without_inductives_
;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
MiniBatchGpuPack
*
pack_
=
nullptr
;
#endif
};
class
PaddleBoxDataFeed
:
public
MultiSlotInMemoryDataFeed
{
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
c202a613
...
...
@@ -271,13 +271,13 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
timeline
.
Pause
();
VLOG
(
1
)
<<
"GpuPs task add keys cost "
<<
timeline
.
ElapsedSec
()
VLOG
(
0
)
<<
"GpuPs task add keys cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
timeline
.
Start
();
gpu_task
->
UniqueKeys
();
timeline
.
Pause
();
VLOG
(
1
)
<<
"GpuPs task unique cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
VLOG
(
0
)
<<
"GpuPs task unique cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds."
;
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
thread_keys_shard_num_
;
i
++
)
{
...
...
@@ -667,7 +667,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
if
(
!
multi_mf_dim_
)
{
for
(
int
i
=
0
;
i
<
device_num
;
i
++
)
{
feature_keys_count
[
i
]
=
gpu_task
->
device_keys_
[
i
].
size
();
VLOG
(
1
)
<<
i
<<
" card contains feasign nums: "
<<
feature_keys_count
[
i
];
VLOG
(
0
)
<<
i
<<
" card contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
}
else
{
...
...
@@ -675,7 +675,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
for
(
int
j
=
0
;
j
<
multi_mf_dim_
;
j
++
)
{
feature_keys_count
[
i
]
+=
gpu_task
->
device_dim_ptr_
[
i
][
j
].
size
();
}
VLOG
(
1
)
<<
i
<<
" card with dynamic mf contains feasign nums: "
VLOG
(
0
)
<<
i
<<
" card with dynamic mf contains feasign nums: "
<<
feature_keys_count
[
i
];
size_max
=
std
::
max
(
size_max
,
feature_keys_count
[
i
]);
}
...
...
@@ -685,7 +685,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
HeterPs_
=
nullptr
;
}
if
(
size_max
<=
0
)
{
VLOG
(
1
)
<<
"Skip build gpu ps cause feasign nums = "
<<
size_max
;
VLOG
(
0
)
<<
"Skip build gpu ps cause feasign nums = "
<<
size_max
;
return
;
}
std
::
vector
<
std
::
thread
>
threads
(
device_num
);
...
...
@@ -707,7 +707,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
t
.
join
();
}
timeline
.
Pause
();
VLOG
(
1
)
<<
"GpuPs build table total costs: "
<<
timeline
.
ElapsedSec
()
VLOG
(
0
)
<<
"GpuPs build table total costs: "
<<
timeline
.
ElapsedSec
()
<<
" s."
;
}
...
...
@@ -749,7 +749,7 @@ void PSGPUWrapper::pre_build_thread() {
// build cpu ps data process
PreBuildTask
(
gpu_task
);
timer
.
Pause
();
VLOG
(
1
)
<<
"thread PreBuildTask end, cost time: "
<<
timer
.
ElapsedSec
()
VLOG
(
0
)
<<
"thread PreBuildTask end, cost time: "
<<
timer
.
ElapsedSec
()
<<
"s"
;
buildcpu_ready_channel_
->
Put
(
gpu_task
);
}
...
...
@@ -768,13 +768,13 @@ void PSGPUWrapper::build_task() {
return
;
}
VLOG
(
1
)
<<
"BuildPull start."
;
VLOG
(
0
)
<<
"BuildPull start."
;
platform
::
Timer
timer
;
timer
.
Start
();
BuildPull
(
gpu_task
);
BuildGPUTask
(
gpu_task
);
timer
.
Pause
();
VLOG
(
1
)
<<
"BuildPull + BuildGPUTask end, cost time: "
<<
timer
.
ElapsedSec
()
VLOG
(
0
)
<<
"BuildPull + BuildGPUTask end, cost time: "
<<
timer
.
ElapsedSec
()
<<
"s"
;
current_task_
=
gpu_task
;
...
...
paddle/fluid/framework/ps_gpu_worker.cc
浏览文件 @
c202a613
...
...
@@ -119,6 +119,7 @@ void PSGPUWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
}
void
PSGPUWorker
::
TrainFiles
()
{
VLOG
(
0
)
<<
"Begin to train files"
;
platform
::
SetNumThreads
(
1
);
platform
::
Timer
timeline
;
timeline
.
Start
();
...
...
@@ -129,6 +130,8 @@ void PSGPUWorker::TrainFiles() {
device_reader_
->
Start
();
int
cur_batch
;
int
batch_cnt
=
0
;
platform
::
SetDeviceId
(
thread_id_
);
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
total_ins_num
+=
cur_batch
;
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -190,14 +193,14 @@ void PSGPUWorker::TrainFiles() {
writer_
.
Flush
();
}
timeline
.
Pause
();
VLOG
(
1
)
<<
"GpuPs worker "
<<
thread_id_
<<
" train cost "
VLOG
(
0
)
<<
"GpuPs worker "
<<
thread_id_
<<
" train cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds, ins_num: "
<<
total_ins_num
;
return
;
}
void
PSGPUWorker
::
TrainFilesWithProfiler
()
{
platform
::
SetNumThreads
(
1
);
VLOG
(
1
)
<<
"Begin to train files with profiler"
;
VLOG
(
0
)
<<
"Begin to train files with profiler"
;
device_reader_
->
Start
();
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
...
...
@@ -225,6 +228,7 @@ void PSGPUWorker::TrainFilesWithProfiler() {
int
total_ins_num
=
0
;
int
cur_batch
;
timeline
.
Start
();
platform
::
SetDeviceId
(
thread_id_
);
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
total_ins_num
+=
cur_batch
;
timeline
.
Pause
();
...
...
@@ -260,13 +264,15 @@ void PSGPUWorker::TrainFilesWithProfiler() {
total_time
+=
timeline
.
ElapsedSec
();
timeline
.
Start
();
}
VLOG
(
1
)
<<
"GpuPs worker "
<<
thread_id_
<<
" train cost "
<<
total_time
VLOG
(
0
)
<<
"GpuPs worker "
<<
thread_id_
<<
" train cost "
<<
total_time
<<
" seconds, ins_num: "
<<
total_ins_num
;
for
(
size_t
i
=
0
;
i
<
op_name
.
size
();
++
i
)
{
VLOG
(
1
)
<<
"card:"
<<
thread_id_
<<
", op: "
<<
op_name
[
i
]
VLOG
(
0
)
<<
"card:"
<<
thread_id_
<<
", op: "
<<
op_name
[
i
]
<<
", mean time: "
<<
op_total_time
[
i
]
/
total_ins_num
<<
"s, totol time:"
<<
op_total_time
[
i
]
<<
"sec"
;
}
VLOG
(
0
)
<<
"card: "
<<
thread_id_
<<
" read time: "
<<
read_time
<<
", percent: "
<<
read_time
/
total_time
*
100
;
return
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录