Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cb66c53c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cb66c53c
编写于
2月 01, 2021
作者:
T
Thunderbrook
提交者:
GitHub
2月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dump to cpu (#30750)
* dump to cpu * format * format * format
上级
d3fac0ea
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
89 addition
and
11 deletion
+89
-11
paddle/fluid/framework/fleet/heter_ps/feature_value.h
paddle/fluid/framework/fleet/heter_ps/feature_value.h
+1
-0
paddle/fluid/framework/fleet/heter_ps/hashtable.h
paddle/fluid/framework/fleet/heter_ps/hashtable.h
+4
-1
paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
+35
-0
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
+7
-2
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+28
-0
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
+1
-1
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
+2
-2
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
+1
-1
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
+3
-3
paddle/fluid/framework/fleet/heter_ps/test_comm.cu
paddle/fluid/framework/fleet/heter_ps/test_comm.cu
+1
-1
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
+1
-0
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
+3
-0
paddle/fluid/pybind/ps_gpu_wrapper_py.cc
paddle/fluid/pybind/ps_gpu_wrapper_py.cc
+2
-0
未找到文件。
paddle/fluid/framework/fleet/heter_ps/feature_value.h
浏览文件 @
cb66c53c
...
@@ -33,6 +33,7 @@ struct FeatureValue {
...
@@ -33,6 +33,7 @@ struct FeatureValue {
float
lr_g2sum
;
float
lr_g2sum
;
int
mf_size
;
int
mf_size
;
float
mf
[
MF_DIM
+
1
];
float
mf
[
MF_DIM
+
1
];
uint64_t
cpu_ptr
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
FeatureValue
&
val
)
{
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
FeatureValue
&
val
)
{
out
<<
"show: "
<<
val
.
show
<<
" clk: "
<<
val
.
clk
<<
" slot: "
<<
val
.
slot
out
<<
"show: "
<<
val
.
show
<<
" clk: "
<<
val
.
clk
<<
" slot: "
<<
val
.
slot
...
...
paddle/fluid/framework/fleet/heter_ps/hashtable.h
浏览文件 @
cb66c53c
...
@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
...
@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <glog/logging.h>
#include <limits>
#include <limits>
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "common_value.h" // NOLINT
#include "thrust/pair.h"
#include "thrust/pair.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
...
@@ -47,6 +49,7 @@ class HashTable {
...
@@ -47,6 +49,7 @@ class HashTable {
void
get
(
const
KeyType
*
d_keys
,
ValType
*
d_vals
,
size_t
len
,
void
get
(
const
KeyType
*
d_keys
,
ValType
*
d_vals
,
size_t
len
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
show
();
void
show
();
void
dump_to_cpu
(
int
devid
,
cudaStream_t
stream
);
template
<
typename
GradType
,
typename
Sgd
>
template
<
typename
GradType
,
typename
Sgd
>
void
update
(
const
KeyType
*
d_keys
,
const
GradType
*
d_grads
,
size_t
len
,
void
update
(
const
KeyType
*
d_keys
,
const
GradType
*
d_grads
,
size_t
len
,
...
@@ -60,5 +63,5 @@ class HashTable {
...
@@ -60,5 +63,5 @@ class HashTable {
};
};
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
#include "hashtable
.tpp
"
#include "hashtable
_inl.h
"
#endif
#endif
paddle/fluid/framework/fleet/heter_ps/hashtable
.tpp
→
paddle/fluid/framework/fleet/heter_ps/hashtable
_inl.h
浏览文件 @
cb66c53c
...
@@ -108,6 +108,41 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
...
@@ -108,6 +108,41 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals
,
len
);
d_vals
,
len
);
}
}
template
<
typename
KeyType
,
typename
ValType
>
void
HashTable
<
KeyType
,
ValType
>::
dump_to_cpu
(
int
devid
,
cudaStream_t
stream
)
{
container_
->
prefetch
(
cudaCpuDeviceId
,
stream
);
size_t
num
=
container_
->
size
();
KeyType
unuse_key
=
std
::
numeric_limits
<
KeyType
>::
max
();
thrust
::
pair
<
KeyType
,
ValType
>*
kv
=
container_
->
data
();
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
if
(
kv
[
i
].
first
==
unuse_key
)
{
continue
;
}
ValType
&
gpu_val
=
kv
[
i
].
second
;
auto
*
downpour_value
=
(
paddle
::
ps
::
DownpourFixedFeatureValue
*
)(
gpu_val
.
cpu_ptr
);
int
downpour_value_size
=
downpour_value
->
size
();
if
(
gpu_val
.
mf_size
>
0
&&
downpour_value_size
==
7
)
{
downpour_value
->
resize
(
gpu_val
.
mf_size
+
downpour_value_size
);
}
float
*
cpu_val
=
downpour_value
->
data
();
cpu_val
[
0
]
=
0
;
cpu_val
[
1
]
=
gpu_val
.
delta_score
;
cpu_val
[
2
]
=
gpu_val
.
show
;
cpu_val
[
3
]
=
gpu_val
.
clk
;
cpu_val
[
4
]
=
gpu_val
.
lr
;
cpu_val
[
5
]
=
gpu_val
.
lr_g2sum
;
cpu_val
[
6
]
=
gpu_val
.
slot
;
if
(
gpu_val
.
mf_size
>
0
)
{
for
(
int
x
=
0
;
x
<
gpu_val
.
mf_size
;
x
++
)
{
cpu_val
[
x
+
7
]
=
gpu_val
.
mf
[
x
];
}
}
}
container_
->
prefetch
(
devid
,
stream
);
}
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
KeyType
,
typename
ValType
>
template
<
typename
GradType
,
typename
Sgd
>
template
<
typename
GradType
,
typename
Sgd
>
void
HashTable
<
KeyType
,
ValType
>::
update
(
const
KeyType
*
d_keys
,
void
HashTable
<
KeyType
,
ValType
>::
update
(
const
KeyType
*
d_keys
,
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
浏览文件 @
cb66c53c
...
@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
...
@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <thread>
#include <vector>
#include <vector>
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#include "hashtable.h"
#include "hashtable.h"
#include "heter_resource.h"
#include "heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
.h
"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -72,6 +73,10 @@ class HeterComm {
...
@@ -72,6 +73,10 @@ class HeterComm {
return
((
send_id
/
4
!=
receive_id
/
4
)
&&
(
send_id
+
4
)
%
8
!=
receive_id
);
return
((
send_id
/
4
!=
receive_id
/
4
)
&&
(
send_id
+
4
)
%
8
!=
receive_id
);
}
}
// void dump_to_cpu(int index);
void
end_pass
();
int
get_transfer_devid
(
int
send_id
)
{
return
(
send_id
+
4
)
%
8
;
}
int
get_transfer_devid
(
int
send_id
)
{
return
(
send_id
+
4
)
%
8
;
}
struct
Node
{
struct
Node
{
...
@@ -110,5 +115,5 @@ class HeterComm {
...
@@ -110,5 +115,5 @@ class HeterComm {
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm
.tpp
"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm
_inl.h
"
#endif
#endif
paddle/fluid/framework/fleet/heter_ps/heter_comm
.tpp
→
paddle/fluid/framework/fleet/heter_ps/heter_comm
_inl.h
浏览文件 @
cb66c53c
...
@@ -595,6 +595,34 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
...
@@ -595,6 +595,34 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
}
}
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
end_pass
()
{
int
total_gpu
=
resource_
->
total_gpu
();
std
::
vector
<
std
::
thread
>
threads
;
auto
dump_to_cpu_func
=
[
this
](
int
index
)
{
auto
stream
=
resource_
->
local_stream
(
index
,
0
);
int
dev_id
=
resource_
->
dev_id
(
index
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
tables_
[
index
]
->
dump_to_cpu
(
dev_id
,
stream
);
};
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
threads
.
push_back
(
std
::
thread
(
dump_to_cpu_func
,
i
));
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
}
// template <typename KeyType, typename ValType, typename GradType>
// void HeterComm<KeyType, ValType, GradType>::dump_to_cpu(int index) {
// auto stream = resource_->local_stream(index, 0);
// int dev_id = resource_->dev_id(index);
// platform::CUDADeviceGuard guard(dev_id);
// tables_[index]->dump_to_cpu(dev_id, stream);
//}
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
#endif
#endif
paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
浏览文件 @
cb66c53c
...
@@ -48,7 +48,7 @@ int HeterPs::get_index_by_devid(int devid) {
...
@@ -48,7 +48,7 @@ int HeterPs::get_index_by_devid(int devid) {
return
comm_
->
get_index_by_devid
(
devid
);
return
comm_
->
get_index_by_devid
(
devid
);
}
}
void
HeterPs
::
dump
()
{
}
void
HeterPs
::
end_pass
()
{
comm_
->
end_pass
();
}
void
HeterPs
::
show_one_table
(
int
gpu_num
)
{
comm_
->
show_one_table
(
gpu_num
);
}
void
HeterPs
::
show_one_table
(
int
gpu_num
)
{
comm_
->
show_one_table
(
gpu_num
);
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps.h
浏览文件 @
cb66c53c
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
.h
"
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
...
@@ -35,7 +35,7 @@ class HeterPs : public HeterPsBase {
...
@@ -35,7 +35,7 @@ class HeterPs : public HeterPsBase {
size_t
len
)
override
;
size_t
len
)
override
;
virtual
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
FeatureValue
*
h_vals
,
virtual
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
FeatureValue
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
int
stream_num
)
override
;
size_t
len
,
size_t
chunk_size
,
int
stream_num
)
override
;
virtual
void
dump
()
override
;
virtual
void
end_pass
()
override
;
virtual
int
get_index_by_devid
(
int
devid
)
override
;
virtual
int
get_index_by_devid
(
int
devid
)
override
;
virtual
void
show_one_table
(
int
gpu_num
)
override
;
virtual
void
show_one_table
(
int
gpu_num
)
override
;
virtual
void
push_sparse
(
int
num
,
FeatureKey
*
d_keys
,
virtual
void
push_sparse
(
int
num
,
FeatureKey
*
d_keys
,
...
...
paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
浏览文件 @
cb66c53c
...
@@ -35,7 +35,7 @@ class HeterPsBase {
...
@@ -35,7 +35,7 @@ class HeterPsBase {
virtual
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
FeatureValue
*
h_vals
,
virtual
void
build_ps
(
int
num
,
FeatureKey
*
h_keys
,
FeatureValue
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
int
stream_num
)
=
0
;
size_t
len
,
size_t
chunk_size
,
int
stream_num
)
=
0
;
virtual
int
get_index_by_devid
(
int
devid
)
=
0
;
virtual
int
get_index_by_devid
(
int
devid
)
=
0
;
virtual
void
dump
()
=
0
;
virtual
void
end_pass
()
=
0
;
virtual
void
show_one_table
(
int
gpu_num
)
=
0
;
virtual
void
show_one_table
(
int
gpu_num
)
=
0
;
virtual
void
push_sparse
(
int
num
,
FeatureKey
*
d_keys
,
virtual
void
push_sparse
(
int
num
,
FeatureKey
*
d_keys
,
FeaturePushValue
*
d_grads
,
size_t
len
)
=
0
;
FeaturePushValue
*
d_grads
,
size_t
len
)
=
0
;
...
...
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
→
paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
.h
浏览文件 @
cb66c53c
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <curand_kernel.h>
#include <curand_kernel.h>
#include <vector>
#include "optimizer_conf.h"
#include "optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
...
@@ -111,8 +111,8 @@ class Optimizer {
...
@@ -111,8 +111,8 @@ class Optimizer {
curandState
state
;
curandState
state
;
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
curand_init
(
clock64
(),
tid_x
,
0
,
&
state
);
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
val
.
mf
[
i
+
1
]
=
(
curand_uniform
(
&
state
))
*
val
.
mf
[
i
+
1
]
=
optimizer_config
::
mf_initial_range
;
(
curand_uniform
(
&
state
))
*
optimizer_config
::
mf_initial_range
;
}
}
}
}
}
else
{
}
else
{
...
...
paddle/fluid/framework/fleet/heter_ps/test_comm.cu
浏览文件 @
cb66c53c
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh
.h
"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
浏览文件 @
cb66c53c
...
@@ -183,6 +183,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
...
@@ -183,6 +183,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
val
.
slot
=
ptr_val
[
6
];
val
.
slot
=
ptr_val
[
6
];
val
.
lr
=
ptr_val
[
4
];
val
.
lr
=
ptr_val
[
4
];
val
.
lr_g2sum
=
ptr_val
[
5
];
val
.
lr_g2sum
=
ptr_val
[
5
];
val
.
cpu_ptr
=
(
uint64_t
)(
task_ptrs
[
dev
][
j
]);
if
(
dim
>
7
)
{
if
(
dim
>
7
)
{
val
.
mf_size
=
MF_DIM
+
1
;
val
.
mf_size
=
MF_DIM
+
1
;
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
浏览文件 @
cb66c53c
...
@@ -162,6 +162,9 @@ class PSGPUWrapper {
...
@@ -162,6 +162,9 @@ class PSGPUWrapper {
slot_vector_
=
slot_vector
;
slot_vector_
=
slot_vector
;
}
}
void
EndPass
()
{
HeterPs_
->
end_pass
();
}
void
ShowOneTable
(
int
index
)
{
HeterPs_
->
show_one_table
(
index
);
}
private:
private:
static
std
::
shared_ptr
<
PSGPUWrapper
>
s_instance_
;
static
std
::
shared_ptr
<
PSGPUWrapper
>
s_instance_
;
Dataset
*
dataset_
;
Dataset
*
dataset_
;
...
...
paddle/fluid/pybind/ps_gpu_wrapper_py.cc
浏览文件 @
cb66c53c
...
@@ -45,6 +45,8 @@ void BindPSGPUWrapper(py::module* m) {
...
@@ -45,6 +45,8 @@ void BindPSGPUWrapper(py::module* m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"init_gpu_ps"
,
&
framework
::
PSGPUWrapper
::
InitializeGPU
,
.
def
(
"init_gpu_ps"
,
&
framework
::
PSGPUWrapper
::
InitializeGPU
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"end_pass"
,
&
framework
::
PSGPUWrapper
::
EndPass
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"build_gpu_ps"
,
&
framework
::
PSGPUWrapper
::
BuildGPUPS
,
.
def
(
"build_gpu_ps"
,
&
framework
::
PSGPUWrapper
::
BuildGPUPS
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
// end PSGPUWrapper
}
// end PSGPUWrapper
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录