Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7fc2ce50
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7fc2ce50
编写于
1月 05, 2021
作者:
T
Thunderbrook
提交者:
GitHub
1月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add topo-aware in heter-ps (#30087) (#30117)
* add topo aware * resource.h * topo aware * format
上级
faeee3c3
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
300 addition
and
117 deletion
+300
-117
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+17
-1
paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h
...mework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h
+1
-1
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
+30
-0
paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp
paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp
+173
-67
paddle/fluid/framework/fleet/heter_ps/heter_resource.cc
paddle/fluid/framework/fleet/heter_ps/heter_resource.cc
+29
-13
paddle/fluid/framework/fleet/heter_ps/heter_resource.h
paddle/fluid/framework/fleet/heter_ps/heter_resource.h
+11
-7
paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
+13
-12
paddle/fluid/framework/ps_gpu_worker.cc
paddle/fluid/framework/ps_gpu_worker.cc
+8
-4
paddle/fluid/pybind/fleet_wrapper_py.cc
paddle/fluid/pybind/fleet_wrapper_py.cc
+4
-0
python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py
...e/fluid/incubate/fleet/parameter_server/pslib/__init__.py
+14
-12
未找到文件。
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
7fc2ce50
...
@@ -1225,6 +1225,13 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
...
@@ -1225,6 +1225,13 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
void
FleetWrapper
::
LoadWithWhitelist
(
const
uint64_t
table_id
,
void
FleetWrapper
::
LoadWithWhitelist
(
const
uint64_t
table_id
,
const
std
::
string
&
path
,
const
int
mode
)
{
const
std
::
string
&
path
,
const
int
mode
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
auto
ret
=
pslib_ptr_
->
_worker_ptr
->
load_with_whitelist
(
table_id
,
path
,
std
::
to_string
(
mode
));
ret
.
wait
();
if
(
ret
.
get
()
!=
0
)
{
LOG
(
ERROR
)
<<
"load model of table id: "
<<
table_id
<<
", from path: "
<<
path
<<
" failed"
;
}
#else
#else
VLOG
(
0
)
<<
"FleetWrapper::LoadWhitelist does nothing when no pslib"
;
VLOG
(
0
)
<<
"FleetWrapper::LoadWhitelist does nothing when no pslib"
;
#endif
#endif
...
@@ -1349,7 +1356,16 @@ int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path,
...
@@ -1349,7 +1356,16 @@ int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path,
const
int
mode
,
const
int
mode
,
const
std
::
string
&
whitelist_path
)
{
const
std
::
string
&
whitelist_path
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
return
0
;
auto
ret
=
pslib_ptr_
->
_worker_ptr
->
save_with_whitelist
(
table_id
,
path
,
std
::
to_string
(
mode
),
whitelist_path
);
ret
.
wait
();
int32_t
feasign_cnt
=
ret
.
get
();
if
(
feasign_cnt
==
-
1
)
{
LOG
(
ERROR
)
<<
"table save cache failed"
;
sleep
(
sleep_seconds_before_fail_exit_
);
exit
(
-
1
);
}
return
feasign_cnt
;
#else
#else
VLOG
(
0
)
<<
"FleetWrapper::SaveCache does nothing when no pslib"
;
VLOG
(
0
)
<<
"FleetWrapper::SaveCache does nothing when no pslib"
;
return
-
1
;
return
-
1
;
...
...
paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h
浏览文件 @
7fc2ce50
...
@@ -765,7 +765,7 @@ x.second );
...
@@ -765,7 +765,7 @@ x.second );
unsigned
long
long
get_num_collisions
()
const
{
return
m_collisions
;
}
unsigned
long
long
get_num_collisions
()
const
{
return
m_collisions
;
}
void
print
()
{
void
print
()
{
for
(
size_type
i
=
0
;
i
<
m_hashtbl_size
;
++
i
)
{
for
(
size_type
i
=
0
;
i
<
10
;
++
i
)
{
std
::
cout
<<
i
<<
": "
<<
m_hashtbl_values
[
i
].
first
<<
","
std
::
cout
<<
i
<<
": "
<<
m_hashtbl_values
[
i
].
first
<<
","
<<
m_hashtbl_values
[
i
].
second
<<
std
::
endl
;
<<
m_hashtbl_values
[
i
].
second
<<
std
::
endl
;
}
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm.h
浏览文件 @
7fc2ce50
...
@@ -68,6 +68,34 @@ class HeterComm {
...
@@ -68,6 +68,34 @@ class HeterComm {
Sgd
&
sgd
);
Sgd
&
sgd
);
int
log2i
(
int
x
);
int
log2i
(
int
x
);
bool
need_transfer
(
int
send_id
,
int
receive_id
)
{
return
((
send_id
/
4
!=
receive_id
/
4
)
&&
(
send_id
+
4
)
%
8
!=
receive_id
);
}
int
get_transfer_devid
(
int
send_id
)
{
return
(
send_id
+
4
)
%
8
;
}
struct
Node
{
cudaStream_t
in_stream
;
cudaStream_t
out_stream
;
char
*
key_storage
;
char
*
val_storage
;
int
sync
;
int
key_bytes_len
;
int
val_bytes_len
;
int
gpu_num
;
};
struct
Path
{
std
::
vector
<
Node
>
nodes_
;
};
void
init_path
();
void
create_storage
(
int
start_index
,
int
end_index
,
int
keylen
,
int
vallen
,
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>&
local_strorage
);
void
walk_to_src
(
int
start_index
,
int
end_index
,
char
*
src_val
);
void
walk_to_dest
(
int
start_index
,
int
end_index
,
char
*
src_key
,
char
*
src_val
);
private:
private:
using
Table
=
HashTable
<
KeyType
,
ValType
>
;
using
Table
=
HashTable
<
KeyType
,
ValType
>
;
...
@@ -76,6 +104,8 @@ class HeterComm {
...
@@ -76,6 +104,8 @@ class HeterComm {
std
::
vector
<
Table
*>
tables_
;
std
::
vector
<
Table
*>
tables_
;
std
::
shared_ptr
<
HeterPsResource
>
resource_
;
std
::
shared_ptr
<
HeterPsResource
>
resource_
;
CustomGradMerger
merger_
;
CustomGradMerger
merger_
;
int
topo_aware_
{
1
};
std
::
vector
<
std
::
vector
<
Path
>>
path_
;
};
};
}
// end namespace framework
}
// end namespace framework
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp
浏览文件 @
7fc2ce50
...
@@ -100,6 +100,131 @@ HeterComm<KeyType, ValType, GradType>::HeterComm(
...
@@ -100,6 +100,131 @@ HeterComm<KeyType, ValType, GradType>::HeterComm(
auto
table
=
new
Table
(
capacity
/
load_factor_
);
auto
table
=
new
Table
(
capacity
/
load_factor_
);
tables_
.
push_back
(
table
);
tables_
.
push_back
(
table
);
}
}
init_path
();
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
init_path
()
{
int
total_gpu
=
resource_
->
total_gpu
();
path_
.
resize
(
total_gpu
);
if
(
!
topo_aware_
)
{
VLOG
(
1
)
<<
"init path without topo aware"
;
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
path_
[
i
].
resize
(
total_gpu
);
for
(
int
j
=
0
;
j
<
total_gpu
;
++
j
)
{
auto
&
nodes
=
path_
[
i
][
j
].
nodes_
;
nodes
.
resize
(
1
);
nodes
[
0
].
in_stream
=
resource_
->
comm_stream
(
i
,
j
);
nodes
[
0
].
out_stream
=
resource_
->
comm_stream
(
j
,
i
);
nodes
[
0
].
key_storage
=
NULL
;
nodes
[
0
].
val_storage
=
NULL
;
nodes
[
0
].
sync
=
0
;
nodes
[
0
].
gpu_num
=
j
;
}
}
}
else
{
VLOG
(
1
)
<<
"init path with topo aware"
;
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
path_
[
i
].
resize
(
total_gpu
);
for
(
int
j
=
0
;
j
<
total_gpu
;
++
j
)
{
auto
&
nodes
=
path_
[
i
][
j
].
nodes_
;
int
from
=
resource_
->
dev_id
(
i
);
int
to
=
resource_
->
dev_id
(
j
);
int
transfer_id
=
i
;
if
(
need_transfer
(
from
,
to
))
{
transfer_id
=
resource_
->
get_index_by_devid
(
get_transfer_devid
(
from
));
nodes
.
push_back
(
Node
());
Node
&
node
=
nodes
.
back
();
node
.
in_stream
=
resource_
->
comm_stream
(
i
,
transfer_id
);
node
.
out_stream
=
resource_
->
comm_stream
(
transfer_id
,
i
);
node
.
key_storage
=
NULL
;
node
.
val_storage
=
NULL
;
node
.
sync
=
1
;
node
.
gpu_num
=
transfer_id
;
}
nodes
.
push_back
(
Node
());
Node
&
node
=
nodes
.
back
();
node
.
in_stream
=
resource_
->
comm_stream
(
i
,
transfer_id
);
node
.
out_stream
=
resource_
->
comm_stream
(
transfer_id
,
i
);
node
.
key_storage
=
NULL
;
node
.
val_storage
=
NULL
;
node
.
sync
=
0
;
node
.
gpu_num
=
j
;
}
}
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
create_storage
(
int
start_index
,
int
end_index
,
int
keylen
,
int
vallen
,
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>&
local_storage
)
{
auto
&
nodes
=
path_
[
start_index
][
end_index
].
nodes_
;
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
));
platform
::
CUDAPlace
remote_place
=
platform
::
CUDAPlace
(
resource_
->
dev_id
(
nodes
[
i
].
gpu_num
));
auto
key_mem
=
memory
::
AllocShared
(
remote_place
,
keylen
);
local_storage
.
push_back
(
key_mem
);
nodes
[
i
].
key_storage
=
reinterpret_cast
<
char
*>
(
key_mem
->
ptr
());
auto
val_mem
=
memory
::
AllocShared
(
remote_place
,
vallen
);
local_storage
.
push_back
(
val_mem
);
nodes
[
i
].
val_storage
=
reinterpret_cast
<
char
*>
(
val_mem
->
ptr
());
nodes
[
i
].
key_bytes_len
=
keylen
;
nodes
[
i
].
val_bytes_len
=
vallen
;
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
walk_to_dest
(
int
start_index
,
int
end_index
,
char
*
src_key
,
char
*
src_val
)
{
int
need_copy_val
=
0
;
if
(
src_val
)
{
need_copy_val
=
1
;
}
auto
&
nodes
=
path_
[
start_index
][
end_index
].
nodes_
;
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
cudaMemcpyAsync
(
nodes
[
i
].
key_storage
,
src_key
,
nodes
[
i
].
key_bytes_len
,
cudaMemcpyDefault
,
nodes
[
i
].
in_stream
);
if
(
need_copy_val
)
{
cudaMemcpyAsync
(
nodes
[
i
].
val_storage
,
src_val
,
nodes
[
i
].
val_bytes_len
,
cudaMemcpyDefault
,
nodes
[
i
].
in_stream
);
}
if
(
nodes
[
i
].
sync
)
{
cudaStreamSynchronize
(
nodes
[
i
].
in_stream
);
}
// cudaStreamSynchronize(nodes[i].in_stream);
src_key
=
nodes
[
i
].
key_storage
;
src_val
=
nodes
[
i
].
val_storage
;
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
walk_to_src
(
int
start_index
,
int
end_index
,
char
*
src_val
)
{
auto
&
nodes
=
path_
[
start_index
][
end_index
].
nodes_
;
int
len
=
nodes
.
size
();
char
*
start
=
NULL
;
for
(
int
i
=
len
-
1
;
i
>=
0
;
--
i
)
{
if
(
start
==
NULL
)
{
start
=
nodes
[
i
].
val_storage
;
continue
;
}
cudaMemcpyAsync
(
nodes
[
i
].
val_storage
,
start
,
nodes
[
i
].
val_bytes_len
,
cudaMemcpyDefault
,
nodes
[
i
].
out_stream
);
if
(
nodes
[
i
].
sync
)
{
cudaStreamSynchronize
(
nodes
[
i
].
out_stream
);
}
start
=
nodes
[
i
].
val_storage
;
}
cudaMemcpyAsync
(
src_val
,
nodes
[
0
].
val_storage
,
nodes
[
0
].
val_bytes_len
,
cudaMemcpyDefault
,
nodes
[
0
].
out_stream
);
// cudaStreamSynchronize(nodes[0].out_stream);
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
...
@@ -131,7 +256,8 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
...
@@ -131,7 +256,8 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
build_ps
(
int
num
,
KeyType
*
h_keys
,
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
build_ps
(
int
num
,
KeyType
*
h_keys
,
ValType
*
h_vals
,
size_t
len
,
ValType
*
h_vals
,
size_t
len
,
size_t
chunk_size
,
size_t
chunk_size
,
int
stream_num
)
{
int
stream_num
)
{
if
(
len
<=
0
)
{
if
(
len
<=
0
)
{
...
@@ -182,13 +308,15 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
...
@@ -182,13 +308,15 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
merge_grad
(
int
gpu_num
,
KeyType
*
d_keys
,
GradType
*
d_grads
,
GradType
*
d_grads
,
size_t
len
,
int
&
uniq_len
)
{
size_t
len
,
int
&
uniq_len
)
{
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
stream
(
gpu_num
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
size_t
temp_storage_bytes
;
size_t
temp_storage_bytes
;
...
@@ -240,7 +368,7 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
...
@@ -240,7 +368,7 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
stream
(
gpu_num
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
auto
d_idx_tmp
=
memory
::
AllocShared
(
place
,
len
*
sizeof
(
int
));
auto
d_idx_tmp
=
memory
::
AllocShared
(
place
,
len
*
sizeof
(
int
));
int
*
d_idx_tmp_ptr
=
reinterpret_cast
<
int
*>
(
d_idx_tmp
->
ptr
());
int
*
d_idx_tmp_ptr
=
reinterpret_cast
<
int
*>
(
d_idx_tmp
->
ptr
());
...
@@ -272,7 +400,8 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
...
@@ -272,7 +400,8 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
}
}
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
template
<
typename
KeyType
,
typename
ValType
,
typename
GradType
>
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
pull_sparse
(
int
num
,
KeyType
*
d_keys
,
void
HeterComm
<
KeyType
,
ValType
,
GradType
>::
pull_sparse
(
int
num
,
KeyType
*
d_keys
,
ValType
*
d_vals
,
ValType
*
d_vals
,
size_t
len
)
{
size_t
len
)
{
if
(
len
==
0
)
{
if
(
len
==
0
)
{
...
@@ -283,7 +412,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
...
@@ -283,7 +412,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
int
dev_id
=
resource_
->
dev_id
(
num
);
int
dev_id
=
resource_
->
dev_id
(
num
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
stream
(
num
);
auto
stream
=
resource_
->
local_stream
(
num
,
0
);
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
int
grid_size
=
(
len
-
1
)
/
block_size_
+
1
;
...
@@ -318,28 +447,15 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
...
@@ -318,28 +447,15 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
cudaMemcpy
(
h_right
,
d_right_ptr
,
total_gpu
*
sizeof
(
int
),
cudaMemcpy
(
h_right
,
d_right_ptr
,
total_gpu
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
cudaMemcpyDeviceToHost
);
std
::
vector
<
KeyType
*>
d_remote_shard_keys_ptr
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>
local_storage
;
std
::
vector
<
ValType
*>
d_remote_shard_vals_ptr
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>
d_remote_shard_keys
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>
d_remote_shard_vals
;
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
int
shard_len
=
h_right
[
i
]
-
h_left
[
i
]
+
1
;
int
shard_len
=
h_right
[
i
]
-
h_left
[
i
]
+
1
;
if
(
shard_len
==
0
)
{
if
(
shard_len
==
0
)
{
continue
;
continue
;
}
}
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
create_storage
(
num
,
i
,
shard_len
*
sizeof
(
KeyType
),
platform
::
CUDAPlace
remote_place
=
shard_len
*
sizeof
(
ValType
),
local_storage
);
platform
::
CUDAPlace
(
resource_
->
dev_id
(
i
));
d_remote_shard_keys
.
push_back
(
memory
::
AllocShared
(
remote_place
,
shard_len
*
sizeof
(
KeyType
)));
d_remote_shard_keys_ptr
.
push_back
(
reinterpret_cast
<
KeyType
*>
(
d_remote_shard_keys
[
i
]
->
ptr
()));
d_remote_shard_vals
.
push_back
(
memory
::
AllocShared
(
remote_place
,
shard_len
*
sizeof
(
ValType
)));
d_remote_shard_vals_ptr
.
push_back
(
reinterpret_cast
<
ValType
*>
(
d_remote_shard_vals
[
i
]
->
ptr
()));
}
}
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
...
@@ -347,21 +463,23 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
...
@@ -347,21 +463,23 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
continue
;
}
}
cudaMemcpyAsync
(
d_remote_shard_keys_ptr
[
i
],
d_shard_keys_ptr
+
h_left
[
i
]
,
walk_to_dest
(
num
,
i
,
reinterpret_cast
<
char
*>
(
d_shard_keys_ptr
+
h_left
[
i
])
,
shard_len
*
sizeof
(
KeyType
),
cudaMemcpyDefault
,
stream
);
NULL
);
}
}
cudaStreamSynchronize
(
stream
);
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
if
(
h_left
[
i
]
==
-
1
)
{
if
(
h_left
[
i
]
==
-
1
)
{
continue
;
continue
;
}
}
auto
&
node
=
path_
[
num
][
i
].
nodes_
.
back
();
cudaStreamSynchronize
(
node
.
in_stream
);
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
tables_
[
i
]
->
get
(
d_remote_shard_keys_ptr
[
i
],
d_remote_shard_vals_ptr
[
i
],
tables_
[
i
]
->
get
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
resource_
->
stream
(
i
));
reinterpret_cast
<
ValType
*>
(
node
.
val_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
resource_
->
remote_stream
(
i
));
}
}
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
cudaStreamSynchronize
(
resource_
->
stream
(
i
));
cudaStreamSynchronize
(
resource_
->
remote_
stream
(
i
));
}
}
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
...
@@ -370,13 +488,12 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
...
@@ -370,13 +488,12 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num, KeyType* d_keys
continue
;
continue
;
}
}
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
cudaMemcpyAsync
(
d_shard_vals_ptr
+
h_left
[
i
],
d_remote_shard_vals_ptr
[
i
],
walk_to_src
(
num
,
i
,
reinterpret_cast
<
char
*>
(
d_shard_vals_ptr
+
h_left
[
i
]));
shard_len
*
sizeof
(
ValType
),
cudaMemcpyDefault
,
resource_
->
stream
(
i
));
}
}
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
cudaStreamSynchronize
(
resource_
->
stream
(
i
));
auto
&
node
=
path_
[
num
][
i
].
nodes_
.
front
();
cudaStreamSynchronize
(
node
.
out_stream
);
}
}
fill_dvals
<<<
grid_size
,
block_size_
,
0
,
stream
>>>
(
d_shard_vals_ptr
,
d_vals
,
fill_dvals
<<<
grid_size
,
block_size_
,
0
,
stream
>>>
(
d_shard_vals_ptr
,
d_vals
,
...
@@ -398,7 +515,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
...
@@ -398,7 +515,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
int
dev_id
=
resource_
->
dev_id
(
gpu_num
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
platform
::
CUDADeviceGuard
guard
(
dev_id
);
auto
stream
=
resource_
->
stream
(
gpu_num
);
auto
stream
=
resource_
->
local_stream
(
gpu_num
,
0
);
int
h_left
[
total_gpu
];
int
h_left
[
total_gpu
];
int
h_right
[
total_gpu
];
int
h_right
[
total_gpu
];
...
@@ -439,28 +556,15 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
...
@@ -439,28 +556,15 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
cudaMemcpy
(
h_right
,
d_right_ptr
,
total_gpu
*
sizeof
(
int
),
cudaMemcpy
(
h_right
,
d_right_ptr
,
total_gpu
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
cudaMemcpyDeviceToHost
);
std
::
vector
<
KeyType
*>
d_remote_shard_keys_ptr
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>
local_storage
;
std
::
vector
<
GradType
*>
d_remote_shard_grads_ptr
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>
d_remote_shard_keys
;
std
::
vector
<
std
::
shared_ptr
<
memory
::
Allocation
>>
d_remote_shard_grads
;
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
int
shard_len
=
h_right
[
i
]
-
h_left
[
i
]
+
1
;
int
shard_len
=
h_right
[
i
]
-
h_left
[
i
]
+
1
;
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
continue
;
}
}
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
create_storage
(
gpu_num
,
i
,
shard_len
*
sizeof
(
KeyType
),
platform
::
CUDAPlace
remote_place
=
shard_len
*
sizeof
(
GradType
),
local_storage
);
platform
::
CUDAPlace
(
resource_
->
dev_id
(
i
));
d_remote_shard_keys
.
push_back
(
memory
::
AllocShared
(
remote_place
,
shard_len
*
sizeof
(
KeyType
)));
d_remote_shard_keys_ptr
.
push_back
(
reinterpret_cast
<
KeyType
*>
(
d_remote_shard_keys
[
i
]
->
ptr
()));
d_remote_shard_grads
.
push_back
(
memory
::
AllocShared
(
remote_place
,
shard_len
*
sizeof
(
GradType
)));
d_remote_shard_grads_ptr
.
push_back
(
reinterpret_cast
<
GradType
*>
(
d_remote_shard_grads
[
i
]
->
ptr
()));
}
}
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
...
@@ -468,24 +572,26 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
...
@@ -468,24 +572,26 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
continue
;
}
}
cudaMemcpyAsync
(
d_remote_shard_keys_ptr
[
i
],
d_shard_keys_ptr
+
h_left
[
i
],
walk_to_dest
(
gpu_num
,
i
,
shard_len
*
sizeof
(
KeyType
),
cudaMemcpyDefault
,
stream
);
reinterpret_cast
<
char
*>
(
d_shard_keys_ptr
+
h_left
[
i
]),
cudaMemcpyAsync
(
d_remote_shard_grads_ptr
[
i
],
d_shard_grads_ptr
+
h_left
[
i
],
reinterpret_cast
<
char
*>
(
d_shard_grads_ptr
+
h_left
[
i
]));
shard_len
*
sizeof
(
GradType
),
cudaMemcpyDefault
,
stream
);
}
}
cudaStreamSynchronize
(
stream
);
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
if
(
h_left
[
i
]
==
-
1
||
h_right
[
i
]
==
-
1
)
{
continue
;
continue
;
}
}
auto
&
node
=
path_
[
gpu_num
][
i
].
nodes_
.
back
();
cudaStreamSynchronize
(
node
.
in_stream
);
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
platform
::
CUDADeviceGuard
guard
(
resource_
->
dev_id
(
i
));
tables_
[
i
]
->
update
(
d_remote_shard_keys_ptr
[
i
],
d_remote_shard_grads_ptr
[
i
],
tables_
[
i
]
->
update
(
reinterpret_cast
<
KeyType
*>
(
node
.
key_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
sgd
,
resource_
->
stream
(
i
));
reinterpret_cast
<
GradType
*>
(
node
.
val_storage
),
h_right
[
i
]
-
h_left
[
i
]
+
1
,
sgd
,
resource_
->
remote_stream
(
i
));
}
}
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_gpu
;
++
i
)
{
cudaStreamSynchronize
(
resource_
->
stream
(
i
));
cudaStreamSynchronize
(
resource_
->
remote_
stream
(
i
));
}
}
}
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_resource.cc
浏览文件 @
7fc2ce50
...
@@ -19,23 +19,35 @@ limitations under the License. */
...
@@ -19,23 +19,35 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
GPUResource
::
GPUResource
(
int
dev_id
,
int
index
)
{
GPUResource
::
GPUResource
(
std
::
vector
<
int
>&
dev_ids
,
int
index
)
{
index_
=
index
;
index_
=
index
;
dev_id_
=
dev_id
;
dev_ids_
=
dev_ids
;
dev_id_
=
dev_ids_
[
index
];
platform
::
CUDADeviceGuard
guard
(
dev_id_
);
platform
::
CUDADeviceGuard
guard
(
dev_id_
);
local_streams_
.
resize
(
dev_ids_
.
size
());
comm_streams_
.
resize
(
dev_ids_
.
size
());
for
(
size_t
i
=
0
;
i
<
dev_ids_
.
size
();
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamCreateWithFlags
(
&
local_streams_
[
i
],
cudaStreamNonBlocking
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamCreateWithFlags
(
&
stream_
,
cudaStreamNonBlocking
));
cudaStreamCreateWithFlags
(
&
comm_streams_
[
i
],
cudaStreamNonBlocking
));
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamCreateWithFlags
(
&
copy
_stream_
,
cudaStreamNonBlocking
));
cudaStreamCreateWithFlags
(
&
remote
_stream_
,
cudaStreamNonBlocking
));
}
}
GPUResource
::~
GPUResource
()
{
GPUResource
::~
GPUResource
()
{
platform
::
CUDADeviceGuard
guard
(
dev_id_
);
platform
::
CUDADeviceGuard
guard
(
dev_id_
);
for
(
size_t
i
=
0
;
i
<
local_streams_
.
size
();
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
local_streams_
[
i
]));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
copy_stream_
));
}
for
(
size_t
i
=
0
;
i
<
comm_streams_
.
size
();
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
comm_streams_
[
i
]));
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamDestroy
(
remote_stream_
));
}
}
void
HeterPsResource
::
enable_p2p
()
{
void
HeterPsResource
::
enable_p2p
()
{
...
@@ -64,18 +76,22 @@ HeterPsResource::HeterPsResource(const std::vector<int>& dev_ids) {
...
@@ -64,18 +76,22 @@ HeterPsResource::HeterPsResource(const std::vector<int>& dev_ids) {
dev_ids_
=
dev_ids
;
dev_ids_
=
dev_ids
;
for
(
size_t
i
=
0
;
i
<
dev_ids_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dev_ids_
.
size
();
++
i
)
{
std
::
shared_ptr
<
GPUResource
>
resource
=
std
::
shared_ptr
<
GPUResource
>
resource
=
std
::
make_shared
<
GPUResource
>
(
dev_ids_
[
i
]
,
i
);
std
::
make_shared
<
GPUResource
>
(
dev_ids_
,
i
);
resources_
.
push_back
(
resource
);
resources_
.
push_back
(
resource
);
devid_2_index_
[
dev_ids_
[
i
]]
=
i
;
devid_2_index_
[
dev_ids_
[
i
]]
=
i
;
}
}
}
}
cudaStream_t
HeterPsResource
::
copy_stream
(
int
num
)
{
cudaStream_t
HeterPsResource
::
comm_stream
(
int
gpu_num
,
int
stream_num
)
{
return
resources_
[
num
]
->
copy_stream
();
return
resources_
[
gpu_num
]
->
comm_stream
(
stream_num
);
}
cudaStream_t
HeterPsResource
::
local_stream
(
int
gpu_num
,
int
stream_num
)
{
return
resources_
[
gpu_num
]
->
local_stream
(
stream_num
);
}
}
cudaStream_t
HeterPsResource
::
stream
(
int
num
)
{
cudaStream_t
HeterPsResource
::
remote_stream
(
int
gpu_
num
)
{
return
resources_
[
num
]
->
stream
();
return
resources_
[
gpu_num
]
->
remote_
stream
();
}
}
int
HeterPsResource
::
dev_id
(
int
num
)
{
return
dev_ids_
[
num
];
}
int
HeterPsResource
::
dev_id
(
int
num
)
{
return
dev_ids_
[
num
];
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_resource.h
浏览文件 @
7fc2ce50
...
@@ -27,20 +27,23 @@ namespace framework {
...
@@ -27,20 +27,23 @@ namespace framework {
class
GPUResource
{
class
GPUResource
{
public:
public:
GPUResource
(
int
device_id
,
int
index
);
GPUResource
(
std
::
vector
<
int
>&
device_id
,
int
index
);
virtual
~
GPUResource
();
virtual
~
GPUResource
();
GPUResource
(
const
GPUResource
&
)
=
delete
;
GPUResource
(
const
GPUResource
&
)
=
delete
;
GPUResource
&
operator
=
(
const
GPUResource
&
)
=
delete
;
GPUResource
&
operator
=
(
const
GPUResource
&
)
=
delete
;
int
dev_id
()
const
{
return
dev_id_
;
}
int
dev_id
()
const
{
return
dev_id_
;
}
int
index
()
const
{
return
index_
;
}
int
index
()
const
{
return
index_
;
}
cudaStream_t
stream
()
{
return
stream_
;
}
cudaStream_t
local_stream
(
int
num
)
{
return
local_streams_
[
num
];
}
cudaStream_t
copy_stream
()
{
return
copy_stream_
;
}
cudaStream_t
remote_stream
()
{
return
remote_stream_
;
}
cudaStream_t
comm_stream
(
int
num
)
{
return
comm_streams_
[
num
];
}
int
dev_id_
;
int
dev_id_
;
int
index_
;
int
index_
;
cudaStream_t
stream_
;
std
::
vector
<
int
>
dev_ids_
;
cudaStream_t
copy_stream_
;
cudaStream_t
remote_stream_
;
std
::
vector
<
cudaStream_t
>
local_streams_
;
std
::
vector
<
cudaStream_t
>
comm_streams_
;
};
};
class
HeterPsResource
{
class
HeterPsResource
{
...
@@ -52,9 +55,10 @@ class HeterPsResource {
...
@@ -52,9 +55,10 @@ class HeterPsResource {
void
enable_p2p
();
void
enable_p2p
();
int
total_gpu
();
int
total_gpu
();
int
get_index_by_devid
(
int
devid
);
int
get_index_by_devid
(
int
devid
);
cudaStream_t
stream
(
int
num
);
cudaStream_t
copy_stream
(
int
num
);
int
dev_id
(
int
num
);
int
dev_id
(
int
num
);
cudaStream_t
local_stream
(
int
gpu_num
,
int
stream_num
);
cudaStream_t
remote_stream
(
int
gpu_num
);
cudaStream_t
comm_stream
(
int
gpu_num
,
int
stream_num
);
std
::
vector
<
std
::
shared_ptr
<
GPUResource
>>
resources_
;
std
::
vector
<
std
::
shared_ptr
<
GPUResource
>>
resources_
;
std
::
vector
<
int
>
dev_ids_
;
std
::
vector
<
int
>
dev_ids_
;
...
...
paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
浏览文件 @
7fc2ce50
...
@@ -15,18 +15,19 @@ limitations under the License. */
...
@@ -15,18 +15,19 @@ limitations under the License. */
#pragma once
#pragma once
namespace
optimizer_config
{
namespace
optimizer_config
{
__constant__
float
mf_create_thresholds
=
1
;
__constant__
float
nonclk_coeff
=
1
;
__constant__
float
mf_create_thresholds
=
0
;
__constant__
float
nonclk_coeff
=
0.1
;
__constant__
float
clk_coeff
=
1
;
__constant__
float
clk_coeff
=
1
;
__constant__
float
min_bound
=
-
10
000
;
__constant__
float
min_bound
=
-
10
;
__constant__
float
max_bound
=
10
000
;
__constant__
float
max_bound
=
10
;
__constant__
float
learning_rate
=
1
;
__constant__
float
learning_rate
=
0.05
;
__constant__
float
initial_g2sum
=
1
;
__constant__
float
initial_g2sum
=
3.0
;
__constant__
float
initial_range
=
1
;
__constant__
float
initial_range
=
1
e-4
;
__constant__
float
mf_learning_rate
=
1
;
__constant__
float
mf_learning_rate
=
0.05
;
__constant__
float
mf_initial_g2sum
=
1
;
__constant__
float
mf_initial_g2sum
=
3.0
;
__constant__
float
mf_initial_range
=
1
;
__constant__
float
mf_initial_range
=
1
e-4
;
__constant__
float
mf_min_bound
=
1
;
__constant__
float
mf_min_bound
=
-
10
;
__constant__
float
mf_max_bound
=
1
;
__constant__
float
mf_max_bound
=
1
0
;
}
}
paddle/fluid/framework/ps_gpu_worker.cc
浏览文件 @
7fc2ce50
...
@@ -143,16 +143,17 @@ void PSGPUWorker::SetNeedDump(bool need_dump_field) {
...
@@ -143,16 +143,17 @@ void PSGPUWorker::SetNeedDump(bool need_dump_field) {
void
PSGPUWorker
::
DumpParam
()
{}
void
PSGPUWorker
::
DumpParam
()
{}
void
PSGPUWorker
::
TrainFiles
()
{
void
PSGPUWorker
::
TrainFiles
()
{
VLOG
(
3
)
<<
"train file A"
;
platform
::
SetNumThreads
(
1
);
platform
::
SetNumThreads
(
1
);
platform
::
Timer
timeline
;
timeline
.
Start
();
int
total_ins_num
=
0
;
VLOG
(
3
)
<<
"train file B"
;
// how to accumulate fetched values here
// how to accumulate fetched values here
device_reader_
->
Start
();
device_reader_
->
Start
();
VLOG
(
3
)
<<
"train file C"
;
int
cur_batch
;
int
cur_batch
;
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
VLOG
(
3
)
<<
"train file D"
;
total_ins_num
+=
cur_batch
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
bool
need_skip
=
false
;
bool
need_skip
=
false
;
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
...
@@ -169,6 +170,9 @@ void PSGPUWorker::TrainFiles() {
...
@@ -169,6 +170,9 @@ void PSGPUWorker::TrainFiles() {
PrintFetchVars
();
PrintFetchVars
();
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
}
}
timeline
.
Pause
();
VLOG
(
1
)
<<
"GpuPs worker "
<<
thread_id_
<<
" train cost "
<<
timeline
.
ElapsedSec
()
<<
" seconds, ins_num: "
<<
total_ins_num
;
return
;
return
;
}
}
...
...
paddle/fluid/pybind/fleet_wrapper_py.cc
浏览文件 @
7fc2ce50
...
@@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) {
...
@@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) {
.
def
(
"get_cache_threshold"
,
&
framework
::
FleetWrapper
::
GetCacheThreshold
)
.
def
(
"get_cache_threshold"
,
&
framework
::
FleetWrapper
::
GetCacheThreshold
)
.
def
(
"cache_shuffle"
,
&
framework
::
FleetWrapper
::
CacheShuffle
)
.
def
(
"cache_shuffle"
,
&
framework
::
FleetWrapper
::
CacheShuffle
)
.
def
(
"save_cache"
,
&
framework
::
FleetWrapper
::
SaveCache
)
.
def
(
"save_cache"
,
&
framework
::
FleetWrapper
::
SaveCache
)
.
def
(
"save_model_with_whitelist"
,
&
framework
::
FleetWrapper
::
SaveWithWhitelist
)
.
def
(
"load_model"
,
&
framework
::
FleetWrapper
::
LoadModel
)
.
def
(
"load_model"
,
&
framework
::
FleetWrapper
::
LoadModel
)
.
def
(
"load_table_with_whitelist"
,
&
framework
::
FleetWrapper
::
LoadWithWhitelist
)
.
def
(
"clear_model"
,
&
framework
::
FleetWrapper
::
ClearModel
)
.
def
(
"clear_model"
,
&
framework
::
FleetWrapper
::
ClearModel
)
.
def
(
"clear_one_table"
,
&
framework
::
FleetWrapper
::
ClearOneTable
)
.
def
(
"clear_one_table"
,
&
framework
::
FleetWrapper
::
ClearOneTable
)
.
def
(
"stop_server"
,
&
framework
::
FleetWrapper
::
StopServer
)
.
def
(
"stop_server"
,
&
framework
::
FleetWrapper
::
StopServer
)
...
...
python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py
浏览文件 @
7fc2ce50
...
@@ -101,6 +101,7 @@ class PSLib(Fleet):
...
@@ -101,6 +101,7 @@ class PSLib(Fleet):
# barrier_all for init_worker
# barrier_all for init_worker
self
.
_role_maker
.
_barrier_all
()
self
.
_role_maker
.
_barrier_all
()
# prepare for client to client communication
# prepare for client to client communication
if
not
self
.
_opt_info
[
"use_ps_gpu"
]:
if
self
.
_role_maker
.
is_worker
():
if
self
.
_role_maker
.
is_worker
():
info
=
self
.
_fleet_ptr
.
get_clients_info
()
info
=
self
.
_fleet_ptr
.
get_clients_info
()
all_info
=
self
.
_role_maker
.
_worker_gather
(
info
[
0
])
all_info
=
self
.
_role_maker
.
_worker_gather
(
info
[
0
])
...
@@ -137,6 +138,7 @@ class PSLib(Fleet):
...
@@ -137,6 +138,7 @@ class PSLib(Fleet):
"var "
+
var_name
+
" not found in scope, "
"var "
+
var_name
+
" not found in scope, "
+
"you should run startup program first"
)
+
"you should run startup program first"
)
var_name_list
.
append
(
var_name
)
var_name_list
.
append
(
var_name
)
if
not
self
.
_opt_info
[
"use_ps_gpu"
]:
self
.
_fleet_ptr
.
init_model
(
scope
,
self
.
_fleet_ptr
.
init_model
(
scope
,
int
(
table
.
table_id
),
int
(
table
.
table_id
),
var_name_list
)
var_name_list
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录