Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f1fb64b1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f1fb64b1
编写于
12月 07, 2018
作者:
G
gongweibao
提交者:
GitHub
12月 07, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add reduce sparse tensor feature. (#14757)
上级
c83d5b7a
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
1013 addition
and
28 deletion
+1013
-28
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+14
-2
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+15
-4
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-0
paddle/fluid/framework/details/reduce_and_gather.h
paddle/fluid/framework/details/reduce_and_gather.h
+1
-1
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+142
-2
paddle/fluid/framework/details/reduce_op_handle.h
paddle/fluid/framework/details/reduce_op_handle.h
+39
-0
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+12
-2
paddle/fluid/operators/distributed/collective_client.cc
paddle/fluid/operators/distributed/collective_client.cc
+59
-0
paddle/fluid/operators/distributed/collective_client.h
paddle/fluid/operators/distributed/collective_client.h
+93
-0
paddle/fluid/operators/distributed/collective_server.cc
paddle/fluid/operators/distributed/collective_server.cc
+74
-0
paddle/fluid/operators/distributed/collective_server.h
paddle/fluid/operators/distributed/collective_server.h
+110
-0
paddle/fluid/operators/distributed/collective_server_test.cc
paddle/fluid/operators/distributed/collective_server_test.cc
+115
-0
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+53
-6
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+17
-7
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+100
-2
paddle/fluid/operators/distributed/grpc_service.h
paddle/fluid/operators/distributed/grpc_service.h
+7
-1
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+2
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+11
-1
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+90
-0
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+31
-0
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+3
-0
paddle/fluid/operators/math/softmax_impl.h
paddle/fluid/operators/math/softmax_impl.h
+1
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+12
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-0
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+8
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+1
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
f1fb64b1
...
...
@@ -15,14 +15,26 @@ cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_ro
if
(
WITH_GPU
)
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor
)
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda
)
if
(
WITH_DISTRIBUTE
)
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc
)
else
()
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor
)
endif
()
nv_library
(
broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda
)
nv_library
(
fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle
)
else
()
cc_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor
)
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc
)
else
()
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor
)
endif
()
cc_library
(
broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle
)
endif
()
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
f1fb64b1
...
...
@@ -58,6 +58,17 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
PADDLE_ENFORCE
(
strategy_
.
trainer_id_
>=
0
,
"trainer_id_ >= 0"
);
if
(
strategy_
.
trainer_id_
>
0
)
{
PADDLE_ENFORCE
((
unsigned
)(
strategy_
.
trainer_id_
)
<
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
}
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
// Convert graph to run on multi-devices.
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_pass"
);
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
...
...
@@ -135,16 +146,16 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
}
else
if
(
pass
->
Type
()
==
"sequential_execution_pass"
)
{
VLOG
(
1
)
<<
"set enable_sequential_execution:"
<<
enable_sequential_execution_
;
LOG
(
INFO
)
<<
"set enable_sequential_execution:"
<<
enable_sequential_execution_
;
pass
->
Erase
(
kAllOpDescs
);
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
else
if
(
pass
->
Type
()
==
"all_reduce_deps_pass"
)
{
VLOG
(
1
)
<<
"SeqOnlyAllReduceOps:"
<<
SeqOnlyAllReduceOps
(
*
this
)
<<
", num_trainers:"
<<
num_trainers_
;
LOG
(
INFO
)
<<
"SeqOnlyAllReduceOps:"
<<
SeqOnlyAllReduceOps
(
*
this
)
<<
", num_trainers:"
<<
num_trainers_
;
pass
->
Erase
(
kAllOpDescs
);
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
f1fb64b1
...
...
@@ -74,6 +74,8 @@ struct BuildStrategy {
bool
fuse_broadcast_op_
{
false
};
int
num_trainers_
{
1
};
int
trainer_id_
{
0
};
std
::
vector
<
std
::
string
>
trainers_endpoints_
;
bool
remove_unnecessary_lock_
{
false
};
// NOTE:
...
...
paddle/fluid/framework/details/reduce_and_gather.h
浏览文件 @
f1fb64b1
...
...
@@ -53,7 +53,7 @@ struct ReduceLoDTensor {
}
};
inline
void
GatherSelectedRows
(
inline
void
Gather
Local
SelectedRows
(
const
std
::
vector
<
const
SelectedRows
*>
&
src_selecte_rows_
,
const
std
::
vector
<
platform
::
Place
>
&
in_places
,
const
std
::
map
<
platform
::
Place
,
platform
::
DeviceContext
*>
&
dev_ctxes
,
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
f1fb64b1
...
...
@@ -16,6 +16,12 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#endif
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool
(
...
...
@@ -26,6 +32,112 @@ namespace paddle {
namespace
framework
{
namespace
details
{
std
::
once_flag
CollectiveContext
::
init_flag_
;
std
::
unique_ptr
<
CollectiveContext
>
CollectiveContext
::
context_
;
static
inline
std
::
string
GetRemoteVarName
(
const
std
::
string
&
var_name
,
int
trainer_id
)
{
return
string
::
Sprintf
(
"%s_merged_tmp@trainer_%d"
,
var_name
,
trainer_id
);
}
void
ReduceOpHandle
::
Wait
(
const
std
::
map
<
platform
::
Place
,
platform
::
DeviceContext
*>
&
dev_ctxes
)
{
// TODO(gongwb): use event wait?
for
(
auto
&
dev_ctx
:
dev_ctxes
)
{
dev_ctx
.
second
->
Wait
();
}
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template
<
typename
DevCtx
,
typename
DataType
>
void
ReduceOpHandle
::
GatherSelectedRows
(
const
std
::
vector
<
const
SelectedRows
*>
&
src_selected_rows
,
const
std
::
vector
<
platform
::
Place
>
&
in_places
,
const
std
::
map
<
platform
::
Place
,
platform
::
DeviceContext
*>
&
dev_ctxes
,
VarHandle
*
out_var_handle
,
const
platform
::
Place
&
out_place
,
SelectedRows
*
dst_selected_rows
)
{
const
CollectiveContext
&
collective_context
=
*
CollectiveContext
::
GetInstance
();
// 1. gather local selected rows, merge them
std
::
string
gathered_var_name
=
out_var_handle
->
name_
+
"_gathered_tmp"
;
auto
scope
=
local_scopes_
.
at
(
out_var_handle
->
scope_idx_
);
auto
gathered_var_mid
=
scope
->
Var
(
gathered_var_name
);
auto
gathered_select_rows
=
gathered_var_mid
->
GetMutable
<
framework
::
SelectedRows
>
();
GatherLocalSelectedRows
(
src_selected_rows
,
in_places
,
dev_ctxes
,
out_place
,
gathered_select_rows
);
// FIXME(gongwb): remove this Wait.
Wait
(
dev_ctxes
);
// merge them
auto
merged_dev_ctx
=
dynamic_cast
<
DevCtx
*>
(
dev_ctxes
.
at
(
out_place
));
std
::
string
merged_var_name
=
GetRemoteVarName
(
out_var_handle
->
name_
,
collective_context
.
trainer_id_
);
auto
merged_select_rows
=
scope
->
Var
(
merged_var_name
)
->
GetMutable
<
SelectedRows
>
();
operators
::
math
::
scatter
::
MergeAdd
<
DevCtx
,
DataType
>
merge_func
;
merge_func
(
*
merged_dev_ctx
,
*
gathered_select_rows
,
merged_select_rows
);
// 2. start collective server if it doesn't exist
operators
::
distributed
::
CollectiveServer
*
server
=
operators
::
distributed
::
CollectiveServer
::
GetInstance
(
collective_context
.
endpoints_
[
collective_context
.
trainer_id_
],
collective_context
.
endpoints_
.
size
()
-
1
);
auto
rpc_server
=
server
->
GetRPCServer
();
rpc_server
->
RegisterVar
(
merged_var_name
,
operators
::
distributed
::
kRequestGetMonomerVariable
,
scope
,
merged_dev_ctx
);
// 3. gather them from all remote nodes.
std
::
vector
<
const
SelectedRows
*>
remote
;
operators
::
distributed
::
CollectiveClient
*
client
=
operators
::
distributed
::
CollectiveClient
::
GetInstance
();
std
::
vector
<
operators
::
distributed
::
RemoteVar
>
vars
;
for
(
unsigned
int
i
=
0
;
i
<
collective_context
.
endpoints_
.
size
();
i
++
)
{
if
(
i
==
(
unsigned
)
collective_context
.
trainer_id_
)
continue
;
operators
::
distributed
::
RemoteVar
var
;
var
.
trainer_id_
=
i
;
var
.
var_name_
=
GetRemoteVarName
(
out_var_handle
->
name_
,
i
);
var
.
ep_
=
collective_context
.
endpoints_
[
i
];
vars
.
push_back
(
var
);
VLOG
(
4
)
<<
"gather from:"
<<
var
.
String
();
}
// erase gathered vars
merged_dev_ctx
->
Wait
();
scope
->
EraseVars
(
std
::
vector
<
std
::
string
>
{
gathered_var_name
});
PADDLE_ENFORCE
(
client
->
Gather
(
vars
,
&
remote
,
*
merged_dev_ctx
,
scope
));
PADDLE_ENFORCE
(
remote
.
size
()
==
vars
.
size
());
// 4. merged local selected rows.
std
::
vector
<
const
SelectedRows
*>
all
;
all
.
resize
(
collective_context
.
endpoints_
.
size
());
for
(
auto
v
:
vars
)
{
all
[
v
.
trainer_id_
]
=
scope
->
FindVar
(
v
.
var_name_
)
->
GetMutable
<
SelectedRows
>
();
}
all
[
collective_context
.
trainer_id_
]
=
merged_select_rows
;
merge_func
(
*
merged_dev_ctx
,
all
,
dst_selected_rows
);
rpc_server
->
WaitVarBarrier
(
merged_var_name
);
rpc_server
->
ClearVar
(
merged_var_name
);
// 5. clear mid vars
std
::
vector
<
std
::
string
>
tmp_vars
{
merged_var_name
};
for
(
auto
r
:
vars
)
{
tmp_vars
.
push_back
(
r
.
var_name_
);
}
scope
->
EraseVars
(
tmp_vars
);
}
#endif
void
ReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
cbegin
()
->
second
);
...
...
@@ -90,8 +202,36 @@ void ReduceOpHandle::RunImpl() {
this
->
RunAndRecordEvent
([
&
]
{
std
::
vector
<
const
SelectedRows
*>
in_selected_rows
=
GetInputValues
<
SelectedRows
>
(
in_var_handles
,
var_scopes
);
GatherSelectedRows
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
const
CollectiveContext
&
collective_context
=
*
CollectiveContext
::
GetInstance
();
VLOG
(
10
)
<<
"GatherSelectedRows CollectiveContext:"
<<
collective_context
.
String
();
// TODO(gongwb): add cpu support
if
(
collective_context
.
endpoints_
.
size
()
<=
1
||
is_cpu_place
(
in_places
[
0
])
||
is_cpu_place
(
t_out_p
))
{
GatherLocalSelectedRows
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
return
;
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if
(
framework
::
IsType
<
const
float
>
(
in_selected_rows
[
0
]
->
value
().
type
()))
{
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
float
>
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
if
(
framework
::
IsType
<
const
double
>
(
in_selected_rows
[
0
]
->
value
().
type
()))
{
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
double
>
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
{
PADDLE_ENFORCE
(
false
,
"only support double or float when gahter SelectedRows"
);
}
#endif
});
}
else
{
std
::
vector
<
const
LoDTensor
*>
lod_tensors
=
...
...
paddle/fluid/framework/details/reduce_op_handle.h
浏览文件 @
f1fb64b1
...
...
@@ -30,6 +30,32 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
struct
CollectiveContext
{
std
::
vector
<
std
::
string
>
endpoints_
;
int
trainer_id_
{
0
};
std
::
string
String
()
const
{
std
::
stringstream
ss
;
ss
<<
"endpoints_:"
;
for
(
auto
e
:
endpoints_
)
{
ss
<<
e
<<
","
;
}
ss
<<
"trainer_id_:"
<<
trainer_id_
;
return
ss
.
str
();
}
static
CollectiveContext
*
GetInstance
()
{
std
::
call_once
(
init_flag_
,
[
&
]()
{
context_
.
reset
(
new
CollectiveContext
());
});
return
context_
.
get
();
}
private:
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
CollectiveContext
>
context_
;
};
struct
ReduceOpHandle
:
public
OpHandleBase
{
std
::
vector
<
Scope
*>
local_scopes_
;
...
...
@@ -64,6 +90,19 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template
<
typename
DevCtx
,
typename
DataType
>
void
GatherSelectedRows
(
const
std
::
vector
<
const
SelectedRows
*>
&
src_selecte_rows_
,
const
std
::
vector
<
platform
::
Place
>
&
in_places
,
const
std
::
map
<
platform
::
Place
,
platform
::
DeviceContext
*>
&
dev_ctxes
,
VarHandle
*
out_var_handle
,
const
platform
::
Place
&
out_place
,
SelectedRows
*
dst_selecte_rows
);
#endif
void
Wait
(
const
std
::
map
<
platform
::
Place
,
platform
::
DeviceContext
*>
&
dev_ctxes
);
template
<
typename
T
>
std
::
vector
<
const
T
*>
GetInputValues
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
,
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
f1fb64b1
...
...
@@ -13,16 +13,26 @@ set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
DEPS lod_tensor selected_rows
_functor
memory
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler
)
if
(
WITH_GPU
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL
)
endif
()
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory
)
else
()
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
...
...
paddle/fluid/operators/distributed/collective_client.cc
0 → 100644
浏览文件 @
f1fb64b1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <condition_variable> // NOLINT
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/distributed/collective_client.h"
DECLARE_int32
(
rpc_deadline
);
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
std
::
once_flag
CollectiveClient
::
init_flag_
;
std
::
unique_ptr
<
CollectiveClient
>
CollectiveClient
::
client_
(
nullptr
);
bool
CollectiveClient
::
Gather
(
const
std
::
vector
<
RemoteVar
>&
remote_vars
,
std
::
vector
<
const
framework
::
SelectedRows
*>*
dst
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
*
scope
,
int64_t
time_out
)
{
for
(
auto
r
:
remote_vars
)
{
VLOG
(
50
)
<<
"begin gather from ep:"
<<
r
.
String
();
scope
->
Var
(
r
.
var_name_
)
->
GetMutable
<
framework
::
SelectedRows
>
();
VarHandlePtr
ptr
=
rpc_client_
->
AsyncGetMonomerVariable
(
r
.
ep_
,
ctx
,
*
scope
,
r
.
var_name_
,
time_out
);
}
rpc_client_
->
Wait
();
for
(
auto
r
:
remote_vars
)
{
auto
select_rows
=
scope
->
FindVar
(
r
.
var_name_
)
->
GetMutable
<
framework
::
SelectedRows
>
();
dst
->
push_back
(
select_rows
);
VLOG
(
4
)
<<
"gather from ep:"
<<
r
.
String
()
<<
", select_rows:"
<<
GetSelectedRowsInfo
(
*
select_rows
);
rpc_client_
->
AsyncGetMonomerBarrier
(
r
.
ep_
,
r
.
var_name_
);
}
rpc_client_
->
Wait
();
return
true
;
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/collective_client.h
0 → 100644
浏览文件 @
f1fb64b1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32
(
rpc_deadline
);
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
inline
std
::
string
GetSelectedRowsInfo
(
const
framework
::
SelectedRows
&
slr
)
{
std
::
stringstream
ss
;
ss
<<
", height:"
<<
slr
.
height
()
<<
", rows:["
;
for
(
unsigned
int
i
=
0
;
i
<
slr
.
rows
().
size
();
i
++
)
{
if
(
i
!=
slr
.
rows
().
size
()
-
1
)
{
ss
<<
slr
.
rows
()[
i
]
<<
","
;
}
else
{
ss
<<
slr
.
rows
()[
i
];
}
}
ss
<<
"], dims:"
<<
slr
.
value
().
dims
();
return
ss
.
str
();
}
struct
RemoteVar
{
std
::
string
ep_
;
std
::
string
var_name_
;
int
trainer_id_
{
0
};
std
::
string
String
()
{
std
::
stringstream
ss
;
ss
<<
"ep:"
<<
ep_
<<
", var_name:"
<<
var_name_
<<
", trainer_id:"
<<
trainer_id_
;
return
ss
.
str
();
}
};
class
CollectiveClient
{
public:
CollectiveClient
()
{
rpc_client_
.
reset
(
new
RPCCLIENT_T
());
rpc_client_
->
InitImpl
();
}
virtual
~
CollectiveClient
()
{}
// note this function will retain the rank order.
bool
Gather
(
const
std
::
vector
<
RemoteVar
>&
remote_vars
,
std
::
vector
<
const
framework
::
SelectedRows
*>*
dst
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
*
scope
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
static
CollectiveClient
*
GetInstance
()
{
std
::
call_once
(
init_flag_
,
[
&
]()
{
if
(
client_
.
get
()
==
nullptr
)
{
client_
.
reset
(
new
CollectiveClient
());
}
});
return
client_
.
get
();
}
private:
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
CollectiveClient
>
client_
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/collective_server.cc
0 → 100644
浏览文件 @
f1fb64b1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> // for removing the port file
#include <csignal>
#include <cstdlib>
#include <fstream>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/distributed/collective_server.h"
DEFINE_int32
(
collective_get_thread_num
,
5
,
"number of threads for rpc get"
);
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
std
::
once_flag
CollectiveServer
::
init_flag_
;
std
::
shared_ptr
<
CollectiveServer
>
CollectiveServer
::
collective_server_
(
nullptr
);
CollectiveServer
::
CollectiveServer
(
const
std
::
string
&
end_point
,
int
fan_in
)
{
VLOG
(
1
)
<<
"Create colllective server:"
<<
end_point
<<
", fan_in:"
<<
fan_in
;
rpc_server_
.
reset
(
new
RPCSERVER_T
(
end_point
,
fan_in
));
}
void
CollectiveServer
::
Stop
()
{
rpc_server_
->
ShutDown
();
server_thread_
->
join
();
loop_thread_
->
join
();
}
void
CollectiveServer
::
StartServer
()
{
get_monomer_handler_
.
reset
(
new
GetMonomerHandler
());
get_monomer_handler_
->
SetRPCServer
(
rpc_server_
.
get
());
get_barrier_handler_
.
reset
(
new
GetMonomerBarrierHandler
());
get_barrier_handler_
->
SetRPCServer
(
rpc_server_
.
get
());
rpc_server_
->
RegisterRPC
(
distributed
::
kRequestGetMonomerVariable
,
get_monomer_handler_
.
get
(),
FLAGS_collective_get_thread_num
);
rpc_server_
->
RegisterRPC
(
distributed
::
kRequestGetMonomerBarrier
,
get_barrier_handler_
.
get
(),
1
);
server_thread_
.
reset
(
new
std
::
thread
([
&
]()
{
rpc_server_
->
StartServer
();
}));
rpc_server_
->
WaitServerReady
();
loop_thread_
.
reset
(
new
std
::
thread
([
&
]()
{
while
(
true
)
{
if
(
rpc_server_
->
IsExit
())
{
LOG
(
WARNING
)
<<
"get exit!rpc_processor break!"
;
break
;
}
sleep
(
1
);
}
VLOG
(
1
)
<<
"CollectiveServer loop_thread end"
;
}));
}
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/collective_server.h
0 → 100644
浏览文件 @
f1fb64b1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
CollectiveServer
;
class
GetMonomerHandler
final
:
public
RequestHandler
{
public:
GetMonomerHandler
()
:
RequestHandler
(
true
)
{}
virtual
~
GetMonomerHandler
()
{}
bool
Handle
(
const
std
::
string
&
var_name
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
override
{
VLOG
(
50
)
<<
"GetMonomerHandler recv "
<<
var_name
;
*
outvar
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE
(
outvar
!=
nullptr
,
"%s not found"
,
var_name
);
return
true
;
}
};
class
GetMonomerBarrierHandler
final
:
public
RequestHandler
{
public:
GetMonomerBarrierHandler
()
:
RequestHandler
(
true
)
{}
virtual
~
GetMonomerBarrierHandler
()
{}
bool
Handle
(
const
std
::
string
&
var_name
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
override
{
VLOG
(
50
)
<<
"GetMonomerHandler recv "
<<
var_name
;
rpc_server_
->
IncreaseVarBarrier
(
var_name
);
return
true
;
}
};
class
CollectiveServer
final
{
public:
explicit
CollectiveServer
(
const
std
::
string
&
end_point
,
int
fan_in
);
virtual
~
CollectiveServer
()
{}
void
StartServer
();
static
CollectiveServer
*
GetInstance
(
const
std
::
string
&
end_point
,
int
fan_in
)
{
std
::
call_once
(
init_flag_
,
[
&
]()
{
if
(
collective_server_
.
get
()
==
nullptr
)
{
collective_server_
.
reset
(
new
CollectiveServer
(
end_point
,
fan_in
));
collective_server_
->
StartServer
();
}
});
return
collective_server_
.
get
();
}
std
::
shared_ptr
<
RPCServer
>
GetRPCServer
()
{
return
rpc_server_
;
}
void
Stop
();
private:
std
::
unique_ptr
<
GetMonomerHandler
>
get_monomer_handler_
;
std
::
unique_ptr
<
GetMonomerBarrierHandler
>
get_barrier_handler_
;
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_server_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
std
::
shared_ptr
<
std
::
thread
>
loop_thread_
;
bool
ready_
{
false
};
static
std
::
once_flag
init_flag_
;
static
std
::
shared_ptr
<
CollectiveServer
>
collective_server_
;
};
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/collective_server_test.cc
0 → 100644
浏览文件 @
f1fb64b1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
operators
::
distributed
;
std
::
unique_ptr
<
distributed
::
CollectiveServer
>
StartServer
(
const
std
::
string
&
ep
,
int
fan_in
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
)
{
distributed
::
CollectiveServer
*
server
=
distributed
::
CollectiveServer
::
GetInstance
(
ep
,
fan_in
);
auto
rpc_server
=
server
->
GetRPCServer
();
rpc_server
->
RegisterVar
(
"var1"
,
distributed
::
kRequestGetMonomerVariable
,
scope
,
dev_ctx
);
std
::
cout
<<
"StartServer return"
<<
std
::
endl
;
return
std
::
unique_ptr
<
distributed
::
CollectiveServer
>
(
server
);
}
std
::
unique_ptr
<
framework
::
Scope
>
GenerateVars
(
platform
::
Place
place
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
framework
::
Scope
*
scope
=
new
framework
::
Scope
();
framework
::
Variable
*
var
=
scope
->
Var
(
"var1"
);
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
set_height
(
1000
);
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
tensor
->
Resize
(
framework
::
make_ddim
({
3
,
5
}));
tensor
->
mutable_data
<
float
>
(
place
);
paddle
::
operators
::
math
::
set_constant
(
ctx
,
tensor
,
32.7
);
for
(
int
i
=
0
;
i
<
3
;
++
i
)
rows
->
push_back
(
i
);
std
::
cout
<<
"src:"
<<
distributed
::
GetSelectedRowsInfo
(
*
slr
);
return
std
::
unique_ptr
<
framework
::
Scope
>
(
scope
);
}
void
Gather
(
const
std
::
vector
<
distributed
::
RemoteVar
>&
vars
,
platform
::
DeviceContext
*
dev_ctx
)
{
distributed
::
CollectiveClient
*
client
=
distributed
::
CollectiveClient
::
GetInstance
();
framework
::
Scope
*
scope
=
new
framework
::
Scope
();
framework
::
Variable
*
var
=
scope
->
Var
(
"var1"
);
var
->
GetMutable
<
framework
::
SelectedRows
>
();
std
::
vector
<
const
framework
::
SelectedRows
*>
dst
;
client
->
Gather
(
vars
,
&
dst
,
*
dev_ctx
,
scope
);
std
::
cout
<<
"dst:"
<<
distributed
::
GetSelectedRowsInfo
(
*
dst
[
0
]);
}
TEST
(
PREFETCH
,
GPU
)
{
platform
::
CUDAPlace
place
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
std
::
string
ep
=
"127.0.0.1:7164"
;
auto
scope
=
GenerateVars
(
place
);
auto
*
v1
=
scope
->
FindVar
(
"var1"
);
std
::
cout
<<
"var1:"
<<
v1
<<
std
::
endl
;
auto
server
=
StartServer
(
ep
,
2
,
scope
.
get
(),
&
ctx
);
auto
rpc_server
=
server
->
GetRPCServer
();
distributed
::
RemoteVar
var
;
var
.
ep_
=
ep
;
var
.
var_name_
=
"var1"
;
var
.
trainer_id_
=
0
;
std
::
vector
<
distributed
::
RemoteVar
>
vars
{
var
};
Gather
(
vars
,
&
ctx
);
Gather
(
vars
,
&
ctx
);
std
::
cout
<<
"begin WaitVarBarrier"
<<
std
::
endl
;
rpc_server
->
WaitVarBarrier
(
"var1"
);
rpc_server
->
ClearRegisteredVars
();
server
->
Stop
();
scope
.
release
();
server
.
release
();
}
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
f1fb64b1
...
...
@@ -28,11 +28,11 @@ namespace paddle {
namespace
operators
{
namespace
distributed
{
void
GRPCClient
::
InitImpl
()
{
InitEventLoop
();
}
void
GRPCClient
::
InitEventLoop
()
{
void
GRPCClient
::
InitImpl
()
{
// start the client process thread
// TODO(wuyi): can make this in a threadpool
PADDLE_ENFORCE
(
client_thread_
==
nullptr
,
"please not re init proceed thread"
);
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
}
...
...
@@ -106,6 +106,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
VLOG
(
100
)
<<
"ProcGetResponse"
;
framework
::
Variable
*
outvar
=
nullptr
;
// get response's trainer_id is not used
int
trainer_id
;
...
...
@@ -126,6 +127,24 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"/sendrecv.SendRecvService/GetVariable"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"/sendrecv.SendRecvService/GetMonomerVariable"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
...
...
@@ -136,7 +155,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
([
var_name_val
,
s
,
method
,
p_ctx
,
h
,
this
]
{
framework
::
AsyncIO
([
var_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
...
...
@@ -151,8 +170,8 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/GetVariable"
,
buf
,
&
cq_
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
...
...
@@ -268,6 +287,34 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
return
h
;
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendMonomerFetchBarrierRPC"
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
VLOG
(
30
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name
);
platform
::
RecordRPCEvent
record_event
(
method
,
nullptr
);
auto
rpc
=
s
->
stub_
->
AsyncGetMonomerBarrier
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
return
h
;
}
VarHandlePtr
GRPCClient
::
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
f1fb64b1
...
...
@@ -189,6 +189,11 @@ class GRPCClient : public RPCClient {
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
...
...
@@ -200,8 +205,12 @@ class GRPCClient : public RPCClient {
VarHandlePtr
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
override
;
VarHandlePtr
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
...
...
@@ -214,21 +223,22 @@ class GRPCClient : public RPCClient {
void
SendComplete
()
override
;
protected:
void
InitImpl
()
override
;
private:
// InitEventLoop should only be called by Init()
void
InitEventLoop
();
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc
,
int64_t
time_out
);
private:
grpc
::
CompletionQueue
cq_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
unique_ptr
<
std
::
thread
>
client_thread_
;
std
::
unique_ptr
<
std
::
thread
>
client_thread_
{
nullptr
}
;
// mutex for Wait client sync
std
::
mutex
sync_mutex_
;
...
...
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
f1fb64b1
...
...
@@ -158,6 +158,98 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
};
class
RequestGetMonomerVariable
final
:
public
RequestBase
{
public:
explicit
RequestGetMonomerVariable
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
,
RPCServer
*
rpc_server
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
),
rpc_server_
(
rpc_server
)
{
auto
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kGetMonomerVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestGetMonomerVariable
()
{}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
rpc_server_
->
WaitVarCond
(
varname
);
MonomerHandle
h
=
rpc_server_
->
GetMonomer
(
varname
);
auto
scope
=
h
.
scope_
;
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
request_
.
trainer_id
());
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
h
.
dev_ctx_
,
&
reply_
);
}
Finish
(
reply_
,
&
responder_
);
}
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
RPCServer
*
rpc_server_
{
nullptr
};
};
class
RequestGetMonomerBarrier
final
:
public
RequestBase
{
public:
explicit
RequestGetMonomerBarrier
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
,
RPCServer
*
rpc_server
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
),
rpc_server_
(
rpc_server
)
{
auto
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kGetMonomerBarrier
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestGetMonomerBarrier
()
{}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
VLOG
(
4
)
<<
"RequestGetMonomerBarrier "
<<
varname
;
rpc_server_
->
WaitVarCond
(
varname
);
MonomerHandle
h
=
rpc_server_
->
GetMonomer
(
varname
);
framework
::
Scope
*
scope
=
nullptr
;
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
request_
.
trainer_id
());
Finish
(
reply_
,
&
responder_
);
}
protected:
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VoidMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
RPCServer
*
rpc_server_
{
nullptr
};
};
class
RequestPrefetch
final
:
public
RequestBase
{
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
...
...
@@ -249,7 +341,7 @@ class RequestCheckpointNotify final : public RequestBase {
};
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
4
)
<<
"AsyncGRPCServer is wait server ready"
;
VLOG
(
4
)
<<
"AsyncGRPCServer is wait
ing
server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
VLOG
(
4
)
<<
"AsyncGRPCServer WaitSeverReady"
;
...
...
@@ -368,6 +460,12 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b
=
new
RequestSend
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGet
)
{
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGetMonomerVariable
)
{
b
=
new
RequestGetMonomerVariable
(
&
service_
,
cq
.
get
(),
handler
,
req_id
,
this
);
}
else
if
(
rpc_name
==
kRequestGetMonomerBarrier
)
{
b
=
new
RequestGetMonomerBarrier
(
&
service_
,
cq
.
get
(),
handler
,
req_id
,
this
);
}
else
if
(
rpc_name
==
kRequestPrefetch
)
{
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestCheckpoint
)
{
...
...
@@ -378,7 +476,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
reqs
[
req_id
]
=
b
;
VLOG
(
4
)
<<
"
Create RequestSend
status:"
<<
b
->
Status
();
VLOG
(
4
)
<<
"
TryToRegisterNewOne
status:"
<<
b
->
Status
();
}
void
AsyncGRPCServer
::
HandleRequest
(
...
...
paddle/fluid/operators/distributed/grpc_service.h
浏览文件 @
f1fb64b1
...
...
@@ -81,10 +81,12 @@ enum class GrpcMethod {
kGetVariable
,
kPrefetchVariable
,
kCheckpointNotify
,
kGetMonomerVariable
,
kGetMonomerBarrier
,
};
static
const
int
kGrpcNumMethods
=
static_cast
<
int
>
(
GrpcMethod
::
k
CheckpointNotify
)
+
1
;
static_cast
<
int
>
(
GrpcMethod
::
k
GetMonomerBarrier
)
+
1
;
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
switch
(
id
)
{
...
...
@@ -92,6 +94,10 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return
"/sendrecv.SendRecvService/SendVariable"
;
case
GrpcMethod
::
kGetVariable
:
return
"/sendrecv.SendRecvService/GetVariable"
;
case
GrpcMethod
::
kGetMonomerVariable
:
return
"/sendrecv.SendRecvService/GetMonomerVariable"
;
case
GrpcMethod
::
kGetMonomerBarrier
:
return
"/sendrecv.SendRecvService/GetMonomerBarrier"
;
case
GrpcMethod
::
kPrefetchVariable
:
return
"/sendrecv.SendRecvService/PrefetchVariable"
;
case
GrpcMethod
::
kCheckpointNotify
:
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
f1fb64b1
...
...
@@ -37,6 +37,8 @@ namespace distributed {
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestGetMonomerVariable
[]
=
"RequestGetMonomerVariable"
;
constexpr
char
kRequestGetMonomerBarrier
[]
=
"RequestGetMonomerBarrier"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
f1fb64b1
...
...
@@ -45,6 +45,11 @@ class RPCClient {
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
...
...
@@ -57,6 +62,10 @@ class RPCClient {
virtual
VarHandlePtr
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
...
...
@@ -87,8 +96,9 @@ class RPCClient {
}
}
protected:
virtual
void
InitImpl
()
{}
protected:
// each trainer have exact one trainer id, it should be static
static
int
trainer_id_
;
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
f1fb64b1
...
...
@@ -132,6 +132,96 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
}
void
RPCServer
::
RegisterVar
(
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc_name
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
)
{
MonomerHandle
h
;
h
.
var_name_
=
var_name
;
h
.
rpc_name_
=
rpc_name
;
h
.
scope_
=
scope
;
h
.
dev_ctx_
=
dev_ctx
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
())
{
PADDLE_ENFORCE
(
false
,
"%s alreay in var_map"
,
var_name
);
}
var_map_
[
var_name
]
=
h
;
}
rpc_cond_
.
notify_all
();
VLOG
(
4
)
<<
"RegisterVar context:"
<<
h
.
String
();
}
void
RPCServer
::
IncreaseVarBarrier
(
const
std
::
string
&
var_name
)
{
int
b
=
0
;
MonomerHandle
h
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
var_map_
[
var_name
].
barrier_
;
h
=
var_map_
[
var_name
];
}
if
(
b
>=
client_num_
)
{
barrier_cond_
.
notify_all
();
}
VLOG
(
4
)
<<
"IncreaseVarBarrier context:"
<<
h
.
String
();
}
void
RPCServer
::
WaitVarBarrier
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"WaitBarrier var_name:"
<<
var_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
&
]()
{
return
((
var_map_
[
var_name
].
barrier_
>=
client_num_
&&
client_num_
!=
0
)
||
exit_flag_
.
load
());
});
VLOG
(
4
)
<<
"WaitBarrier context: "
<<
var_map_
[
var_name
].
String
();
}
void
RPCServer
::
SetVarCond
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"SetVarCond var_name:"
<<
var_name
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
())
{
rpc_cond_
.
notify_all
();
}
}
}
void
RPCServer
::
WaitVarCond
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"WaitVarCond var_name:"
<<
var_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
()
||
exit_flag_
.
load
());
});
VLOG
(
4
)
<<
"WaitVarCond var_name:"
<<
var_name
<<
" end"
;
}
MonomerHandle
RPCServer
::
GetMonomer
(
const
std
::
string
&
var_name
)
{
MonomerHandle
h
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
h
=
var_map_
[
var_name
];
}
return
h
;
}
void
RPCServer
::
ClearRegisteredVars
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
var_map_
.
clear
();
}
void
RPCServer
::
ClearVar
(
const
std
::
string
&
var_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
var_map_
.
erase
(
var_name
);
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
f1fb64b1
...
...
@@ -21,12 +21,30 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
struct
MonomerHandle
{
std
::
string
var_name_
;
std
::
string
rpc_name_
;
framework
::
Scope
*
scope_
{
nullptr
};
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
int64_t
barrier_
{
0
};
std
::
string
String
()
{
std
::
stringstream
ss
;
ss
<<
"var_name:"
<<
var_name_
<<
", rpc_name:"
<<
rpc_name_
<<
", scope:"
<<
scope_
<<
", dev_ctx:"
<<
dev_ctx_
<<
", barrier_:"
<<
barrier_
;
return
ss
.
str
();
}
};
class
RPCServer
{
public:
explicit
RPCServer
(
const
std
::
string
&
address
,
int
client_num
)
...
...
@@ -67,6 +85,16 @@ class RPCServer {
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
RegisterVar
(
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc_name
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
);
void
IncreaseVarBarrier
(
const
std
::
string
&
var_name
);
void
WaitVarBarrier
(
const
std
::
string
&
var_name
);
void
SetVarCond
(
const
std
::
string
&
var_name
);
void
WaitVarCond
(
const
std
::
string
&
var_name
);
void
ClearRegisteredVars
();
void
ClearVar
(
const
std
::
string
&
var_name
);
MonomerHandle
GetMonomer
(
const
std
::
string
&
var_name
);
void
Complete
();
void
ResetBarrierCounter
();
...
...
@@ -95,6 +123,9 @@ class RPCServer {
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
friend
class
RequestHandler
;
// TODO(gongwb): use more cond to notify or wait;
std
::
unordered_map
<
std
::
string
,
MonomerHandle
>
var_map_
;
};
};
// namespace distributed
...
...
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
f1fb64b1
...
...
@@ -28,6 +28,9 @@ service SendRecvService {
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
GetMonomerVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
}
//
VariableMessage
is
serialized
paddle
variable
message
.
...
...
paddle/fluid/operators/math/softmax_impl.h
浏览文件 @
f1fb64b1
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
f1fb64b1
...
...
@@ -925,6 +925,18 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
int
num_trainers
)
{
self
.
num_trainers_
=
num_trainers
;
})
.
def_property
(
"trainers_endpoints"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
trainers_endpoints_
;
},
[](
BuildStrategy
&
self
,
const
std
::
vector
<
std
::
string
>
&
trainers_endpoints
)
{
self
.
trainers_endpoints_
=
trainers_endpoints
;
})
.
def_property
(
"trainer_id"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
trainer_id_
;
},
[](
BuildStrategy
&
self
,
int
trainer_id
)
{
self
.
trainer_id_
=
trainer_id
;
})
.
def_property
(
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
...
...
python/paddle/fluid/framework.py
浏览文件 @
f1fb64b1
...
...
@@ -1483,6 +1483,7 @@ class Program(object):
self
.
_is_chief
=
False
self
.
_slice_vars_and_attrs
=
[]
self
.
_endpoints
=
[]
self
.
_trainers_endpoints
=
[]
self
.
_distributed_lookup_table
=
None
@
property
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
f1fb64b1
...
...
@@ -135,9 +135,17 @@ class ParallelExecutor(object):
build_strategy
=
BuildStrategy
()
build_strategy
.
num_trainers
=
num_trainers
build_strategy
.
trainer_id
=
trainer_id
main
=
main_program
main
=
main
if
main
else
framework
.
default_main_program
()
trainers_endpoints
=
main
.
_trainers_endpoints
if
num_trainers
>
1
and
trainers_endpoints
:
assert
num_trainers
==
len
(
trainers_endpoints
),
"num_trainers == len(end_points)"
build_strategy
.
trainers_endpoints
=
trainers_endpoints
if
scope
==
None
:
scope
=
executor
.
global_scope
()
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
f1fb64b1
...
...
@@ -305,6 +305,7 @@ class DistributeTranspiler(object):
if
self
.
config
.
mode
==
"nccl2"
:
assert
(
isinstance
(
trainers
,
str
))
self
.
origin_program
.
_trainers_endpoints
=
trainers
.
split
(
","
)
self
.
_transpile_nccl2
(
trainer_id
,
trainers
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录