Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
be6a315f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be6a315f
编写于
6月 12, 2020
作者:
T
tangwei12
提交者:
GitHub
6月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix/sync barrier (#25016)
* fix sync barrier with barrier monitor, test=develop
上级
8db66fc3
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
666 addition
and
305 deletion
+666
-305
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
+21
-7
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+3
-1
paddle/fluid/operators/distributed/barrier_monitor.cc
paddle/fluid/operators/distributed/barrier_monitor.cc
+166
-0
paddle/fluid/operators/distributed/barrier_monitor.h
paddle/fluid/operators/distributed/barrier_monitor.h
+186
-0
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+27
-66
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+152
-157
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+13
-7
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+2
-2
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
+50
-32
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+21
-26
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
+24
-6
python/paddle/fluid/tests/unittests/test_communicator_half_async.py
...dle/fluid/tests/unittests/test_communicator_half_async.py
+1
-1
未找到文件。
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
浏览文件 @
be6a315f
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -24,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
...
...
@@ -95,22 +93,39 @@ class CGenNCCLIdOp : public framework::OperatorBase {
new
RPCSERVER_T
(
endpoint
,
1
));
rpc_service
->
RegisterRPC
(
distributed
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
rpc_service
.
get
());
distributed
::
RequestNotifyHandler
notify_h
(
distributed
::
DistributedMode
::
kSync
,
-
1
);
rpc_service
->
RegisterRPC
(
distributed
::
kRequestSend
,
&
rpc_h
);
rpc_service
->
RegisterRPC
(
distributed
::
kRequestNotify
,
&
notify_h
);
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_h
.
SetRPCServer
(
rpc_service
.
get
());
rpc_h
.
SetScope
(
scope
);
rpc_h
.
SetDevCtx
(
&
dev_ctx
);
rpc_h
.
SetProgram
(
&
empty_program
);
rpc_h
.
SetExecutor
(
&
executor
);
notify_h
.
SetRPCServer
(
rpc_service
.
get
());
notify_h
.
SetScope
(
scope
);
notify_h
.
SetDevCtx
(
&
dev_ctx
);
notify_h
.
SetProgram
(
&
empty_program
);
notify_h
.
SetExecutor
(
&
executor
);
distributed
::
BarrierMonitor
::
Init
(
1
);
auto
*
barrier
=
distributed
::
BarrierMonitor
::
GetInstance
();
barrier
->
Reset
(
1
,
distributed
::
BarrierType
::
kSendBarrier
);
std
::
thread
server_thread
(
std
::
bind
(
&
distributed
::
RPCServer
::
StartServer
,
rpc_service
.
get
()));
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
barrier
->
WaitServerWeakup
();
barrier
->
ServerWeakup
();
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
barrier
->
Stop
();
rpc_service
->
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
server_thread
.
join
();
...
...
@@ -123,7 +138,6 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"Raw variable contains a NCCL UniqueId instaces."
);
AddComment
(
R"DOC(
CGenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC"
);
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
be6a315f
...
...
@@ -15,6 +15,8 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r
cc_library
(
heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool
)
cc_test
(
heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor
)
cc_library
(
barrier_monitor SRCS barrier_monitor.cc DEPS enforce simple_threadpool trainer_desc_proto
)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
...
...
@@ -26,7 +28,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${
GRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder heart_beat_monitor
)
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder heart_beat_monitor
barrier_monitor
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
RPC_DEPS sendrecvop_rpc
${
GRPC_DEPS
}
)
...
...
paddle/fluid/operators/distributed/barrier_monitor.cc
0 → 100644
浏览文件 @
be6a315f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
bool
BarrierMonitor
::
IncreaseBarrier
(
const
int
worker_id
,
const
std
::
string
&
barrier
)
{
release_
=
false
;
if
(
barrier
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
4
)
<<
"BarrierMonitor send queue recv trainer: "
<<
worker_id
;
send_barrier_queue
->
Push
(
worker_id
);
}
else
if
(
barrier
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
4
)
<<
"BarrierMonitor recv queue recv trainer: "
<<
worker_id
;
recv_barrier_queue
->
Push
(
worker_id
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"unknown Message status %s, only "
"BATCH_BARRIER_MESSAGE/FETCH_BARRIER_MESSAGE"
,
barrier
));
}
return
Wait
();
}
void
BarrierMonitor
::
DecreaseWorker
()
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
workers_
--
;
VLOG
(
1
)
<<
"decrement worker num to "
<<
workers_
;
}
void
BarrierMonitor
::
Reset
(
int
workers
,
BarrierType
type
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
server_mutex_
);
workers_
=
workers
;
barrier_type
=
type
;
send_barrier_queue
->
Clear
();
recv_barrier_queue
->
Clear
();
VLOG
(
2
)
<<
"reset monitor workers: "
<<
workers_
<<
" type: "
<<
barrier_type
;
}
void
BarrierMonitor
::
Monitor
()
{
while
(
!
IsReady
()
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
VLOG
(
3
)
<<
"sync at first time, wait all trainer ready"
;
}
while
(
running_
)
{
int
timer
=
0
;
if
(
IsReady
())
{
Swap
(
true
);
}
else
{
VLOG
(
4
)
<<
"running timer: "
<<
timer
<<
" barrier: "
<<
barrier_type
<<
" sendQ:"
<<
send_barrier_queue
->
Size
()
<<
" recvQ: "
<<
recv_barrier_queue
->
Size
();
timer
++
;
if
(
max_wait_ms
==
-
1
||
timer
<
max_wait_ms
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1
));
}
else
{
VLOG
(
1
)
<<
"time out of "
<<
max_wait_ms
<<
", need barreir: "
<<
barrier_type
<<
" retry"
;
Swap
(
false
);
}
}
}
}
bool
BarrierMonitor
::
IsReady
()
{
if
(
barrier_type
==
BarrierType
::
kSendBarrier
)
{
return
static_cast
<
int
>
(
send_barrier_queue
->
Size
())
==
workers_
;
}
else
{
return
static_cast
<
int
>
(
recv_barrier_queue
->
Size
())
==
workers_
;
}
}
void
BarrierMonitor
::
Swap
(
bool
is_valid
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
valid_
=
is_valid
;
release_
=
true
;
if
(
barrier_type
==
BarrierType
::
kSendBarrier
)
{
barrier_type
=
BarrierType
::
kRecvBarrier
;
send_barrier_queue
->
Clear
();
VLOG
(
4
)
<<
"barrier monitor server clean up queue and barrier"
;
ServerWeakup
();
VLOG
(
4
)
<<
"barrier monitor server weak up sync to do"
;
WaitServerWeakup
();
VLOG
(
4
)
<<
"barrier monitor server weak up sync done"
;
}
else
{
barrier_type
=
BarrierType
::
kSendBarrier
;
recv_barrier_queue
->
Clear
();
VLOG
(
4
)
<<
"barrier monitor server switch to send barrier"
;
}
worker_cv_
.
notify_all
();
}
void
BarrierMonitor
::
Stop
()
{
valid_
=
true
;
release_
=
true
;
running_
=
false
;
barrier_type
=
BarrierType
::
kRecvBarrier
;
send_barrier_queue
->
Clear
();
recv_barrier_queue
->
Clear
();
worker_cv_
.
notify_all
();
server_cv_
.
notify_all
();
if
(
monitor_thread_
)
monitor_thread_
->
join
();
monitor_thread_
=
nullptr
;
}
bool
BarrierMonitor
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mutex_
);
worker_cv_
.
wait
(
lk
,
[
this
]
{
return
(
release_
);
});
return
valid_
;
}
void
BarrierMonitor
::
WaitServerWeakup
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
server_mutex_
);
server_cv_
.
wait
(
lk
);
}
void
BarrierMonitor
::
ServerWeakup
()
{
server_cv_
.
notify_all
();
}
std
::
once_flag
BarrierMonitor
::
init_flag_
;
std
::
unique_ptr
<
BarrierMonitor
>
BarrierMonitor
::
monitor_
(
nullptr
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/barrier_monitor.h
0 → 100644
浏览文件 @
be6a315f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <chrono> // NOLINT
#include <deque>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
enum
BarrierType
{
kSendBarrier
,
kRecvBarrier
};
constexpr
int64_t
kMaxWaitMS
=
120000
;
template
<
typename
T
>
class
BlockingQueueForBarrier
{
public:
explicit
BlockingQueueForBarrier
(
size_t
capacity
)
:
capacity_
(
capacity
)
{
PADDLE_ENFORCE_GT
(
capacity_
,
0
,
platform
::
errors
::
InvalidArgument
(
"The capacity must be greater than 0."
));
}
bool
Push
(
const
T
&
elem
)
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
worker_cv_
.
wait
(
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
;
});
queue_
.
push_back
(
elem
);
}
worker_cv_
.
notify_one
();
return
true
;
}
bool
Push
(
T
&&
elem
)
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
worker_cv_
.
wait
(
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
;
});
queue_
.
emplace_back
(
std
::
move
(
elem
));
}
worker_cv_
.
notify_one
();
return
true
;
}
T
Pop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
worker_cv_
.
wait
(
lock
,
[
=
]
{
return
!
queue_
.
empty
();
});
T
rc
(
std
::
move
(
queue_
.
front
()));
queue_
.
pop_front
();
worker_cv_
.
notify_one
();
return
rc
;
}
size_t
Cap
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
capacity_
;
}
size_t
Size
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
queue_
.
size
();
}
void
Clear
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
deque
<
T
>
().
swap
(
queue_
);
}
private:
const
size_t
capacity_
;
std
::
deque
<
T
>
queue_
;
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
worker_cv_
;
};
class
BarrierMonitor
{
public:
explicit
BarrierMonitor
(
int
workers
)
:
BarrierMonitor
(
workers
,
BarrierType
::
kRecvBarrier
,
kMaxWaitMS
)
{}
explicit
BarrierMonitor
(
int
workers
,
BarrierType
type
,
int64_t
max_wait_times
)
:
workers_
(
workers
),
barrier_type
(
type
),
max_wait_ms
(
max_wait_times
)
{
PADDLE_ENFORCE_GT
(
workers
,
0
,
platform
::
errors
::
InvalidArgument
(
"trainers must have one or more"
));
send_barrier_queue
=
std
::
make_shared
<
BlockingQueueForBarrier
<
int
>>
(
workers
);
recv_barrier_queue
=
std
::
make_shared
<
BlockingQueueForBarrier
<
int
>>
(
workers
);
running_
=
true
;
monitor_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
BarrierMonitor
::
Monitor
,
this
)));
}
static
BarrierMonitor
*
Init
(
int
workers
)
{
InitImpl
(
workers
);
return
GetInstance
();
}
static
BarrierMonitor
*
GetInstance
()
{
return
monitor_
.
get
();
}
bool
IncreaseBarrier
(
const
int
worker_id
,
const
std
::
string
&
barrier
);
void
DecreaseWorker
();
int
GetWorkerNum
()
{
return
workers_
;
}
void
Monitor
();
void
Swap
(
bool
is_valid
);
void
Stop
();
bool
IsReady
();
bool
Wait
();
void
WaitServerWeakup
();
void
ServerWeakup
();
void
WorkerWeakup
();
void
Reset
(
int
workers
,
BarrierType
type
);
private:
// Init is called by GetInstance.
static
void
InitImpl
(
int
workers
)
{
monitor_
.
reset
(
new
BarrierMonitor
(
workers
));
}
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
BarrierMonitor
>
monitor_
;
int
workers_
;
bool
running_
=
false
;
bool
valid_
=
false
;
bool
release_
=
false
;
std
::
condition_variable
worker_cv_
;
std
::
condition_variable
server_cv_
;
std
::
mutex
server_mutex_
;
std
::
mutex
mutex_
;
BarrierType
barrier_type
;
int64_t
max_wait_ms
;
std
::
unique_ptr
<
std
::
thread
>
monitor_thread_
{
nullptr
};
std
::
shared_ptr
<
BlockingQueueForBarrier
<
int
>>
send_barrier_queue
;
std
::
shared_ptr
<
BlockingQueueForBarrier
<
int
>>
recv_barrier_queue
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
be6a315f
...
...
@@ -306,52 +306,19 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr
GRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
kBatchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
BATCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
BATCH_BARRIER_MESSAGE
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
platform
::
CPUDeviceContext
ctx
;
auto
*
scope
=
new
framework
::
Scope
();
auto
h
=
AsyncDistributeNotify
(
ep
,
ctx
,
*
scope
,
BATCH_BARRIER_MESSAGE
);
delete
scope
;
return
h
;
}
VarHandlePtr
GRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
const
std
::
string
method
=
kFetchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
platform
::
CPUDeviceContext
ctx
;
auto
*
scope
=
new
framework
::
Scope
();
auto
h
=
AsyncDistributeNotify
(
ep
,
ctx
,
*
scope
,
FETCH_BARRIER_MESSAGE
);
delete
scope
;
return
h
;
}
...
...
@@ -384,27 +351,10 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
VarHandlePtr
GRPCClient
::
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
kSendCompleteRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
COMPLETE_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_varname
(
COMPLETE_MESSAGE
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
platform
::
CPUDeviceContext
ctx
;
auto
*
scope
=
new
framework
::
Scope
();
auto
h
=
AsyncDistributeNotify
(
ep
,
ctx
,
*
scope
,
COMPLETE_MESSAGE
);
delete
scope
;
return
h
;
}
...
...
@@ -454,10 +404,21 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
([
var_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
)
;
::
grpc
::
ByteBuffer
buf
;
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
var_name_val
,
var
,
*
p_ctx
,
&
req
,
""
,
trainer_id_
);
if
(
var_name_val
==
BATCH_BARRIER_MESSAGE
||
var_name_val
==
FETCH_BARRIER_MESSAGE
||
var_name_val
==
COMPLETE_MESSAGE
)
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
var_name_val
);
req
.
set_trainer_id
(
trainer_id_
);
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
}
else
{
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
SerializeToByteBuffer
(
var_name_val
,
var
,
*
p_ctx
,
&
buf
,
""
,
trainer_id_
);
}
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
...
...
@@ -467,7 +428,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
platform
::
RecordRPCEvent
record_event
(
method
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/DistributeNotify"
,
req
,
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/DistributeNotify"
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
be6a315f
...
...
@@ -28,6 +28,7 @@
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace
paddle
{
...
...
@@ -38,161 +39,130 @@ namespace distributed {
// to directory specified.
constexpr
char
LOOKUP_TABLE_PATH
[]
=
"kLookupTablePath"
;
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv BATCH_BARRIER_MESSAGE"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv complete message"
;
if
(
invar
==
nullptr
)
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"sync: Can not find server side var: %s"
,
varname
));
return
false
;
}
if
(
HeartBeatMonitor
::
GetInstance
()
!=
nullptr
)
{
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
""
,
COMPLETED
)
;
}
if
(
distributed_mode_
==
DistributedMode
::
kSync
)
{
return
true
;
}
rpc_server_
->
Complete
();
}
else
{
// Async
if
(
distributed_mode_
!=
DistributedMode
::
kSync
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
PADDLE_THROW
(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
varname
,
RUNNING
);
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
varname
,
RUNNING
);
std
::
string
run_varname
=
varname
;
std
::
string
run_varname
=
varname
;
string
::
Piece
part_piece
(
"@PIECE"
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
string
::
Piece
part_piece
(
"@PIECE"
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
if
(
string
::
Contains
(
var_name_piece
,
part_piece
))
{
auto
varname_splits
=
paddle
::
string
::
Split
(
varname
,
'@'
);
run_varname
=
varname_splits
[
0
];
scope
->
Rename
(
varname
,
run_varname
);
}
if
(
string
::
Contains
(
var_name_piece
,
part_piece
))
{
auto
varname_splits
=
paddle
::
string
::
Split
(
varname
,
'@'
);
PADDLE_ENFORCE_EQ
(
varname_splits
.
size
(),
3
);
run_varname
=
varname_splits
[
0
];
scope
->
Rename
(
varname
,
run_varname
);
}
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
auto
&
grad_slr
=
scope
->
FindVar
(
run_varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_varname
,
grad_slr
.
rows
());
}
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
auto
&
grad_slr
=
scope
->
FindVar
(
run_varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_varname
,
grad_slr
.
rows
());
}
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
run_varname
].
get
(),
scope
);
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
run_varname
].
get
(),
scope
);
return
true
;
}
else
{
// sync
rpc_server_
->
WaitCond
(
kRequestSend
);
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
PADDLE_ENFORCE_NOT_NULL
(
invar
,
platform
::
errors
::
NotFound
(
"sync: Can not find server side var %s."
,
varname
));
}
}
return
true
;
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
3
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" table_name: "
<<
table_name
;
if
(
distributed_mode_
==
DistributedMode
::
kSync
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
else
{
rpc_server_
->
WaitCond
(
kRequestGet
);
*
outvar
=
scope_
->
FindVar
(
varname
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
}
else
{
if
(
varname
!=
FETCH_BARRIER_MESSAGE
&&
varname
!=
COMPLETE_MESSAGE
)
{
if
(
enable_dc_asgd_
)
{
// NOTE: the format is determined by distribute_transpiler.py
std
::
string
param_bak_name
=
string
::
Sprintf
(
"%s.trainer_%d_bak"
,
varname
,
trainer_id
);
VLOG
(
3
)
<<
"getting "
<<
param_bak_name
<<
" trainer_id "
<<
trainer_id
;
auto
var
=
scope_
->
FindVar
(
varname
);
auto
t_orig
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
param_bak
=
scope_
->
Var
(
param_bak_name
);
auto
t
=
param_bak
->
GetMutable
<
framework
::
LoDTensor
>
();
t
->
mutable_data
(
dev_ctx_
->
GetPlace
(),
t_orig
.
type
());
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
VLOG
(
1
)
<<
"Table name empty? "
<<
table_name
.
empty
();
if
(
distributed_mode_
==
DistributedMode
::
kGeo
)
{
VLOG
(
1
)
<<
"AsyncSparseParamUpdateRecorder "
<<
varname
<<
" exist "
<<
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
);
}
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
varname
,
trainer_id
,
&
updated_rows
);
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
updated_rows
)
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"updated_rows size: "
<<
updated_rows
.
size
()
<<
" "
<<
sstream
.
str
();
}
auto
&
origin_tensor
=
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
auto
&
dims
=
origin_tensor
.
dims
();
*
outvar
=
scope
->
Var
();
auto
*
out_slr
=
(
*
outvar
)
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
set_rows
(
updated_rows
);
out_slr
->
set_height
(
dims
[
0
]);
auto
out_dims
=
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
updated_rows
.
size
()),
dims
[
1
]});
auto
*
data
=
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
(
out_dims
,
origin_tensor
.
place
());
auto
width
=
dims
[
1
];
for
(
size_t
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
updated_rows
[
i
],
dims
[
0
]);
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
sizeof
(
float
)
*
width
);
if
(
enable_dc_asgd_
)
{
// NOTE: the format is determined by distribute_transpiler.py
std
::
string
param_bak_name
=
string
::
Sprintf
(
"%s.trainer_%d_bak"
,
varname
,
trainer_id
);
VLOG
(
3
)
<<
"getting "
<<
param_bak_name
<<
" trainer_id "
<<
trainer_id
;
auto
var
=
scope_
->
FindVar
(
varname
);
auto
t_orig
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
param_bak
=
scope_
->
Var
(
param_bak_name
);
auto
t
=
param_bak
->
GetMutable
<
framework
::
LoDTensor
>
();
t
->
mutable_data
(
dev_ctx_
->
GetPlace
(),
t_orig
.
type
());
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
varname
,
trainer_id
,
&
updated_rows
);
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
updated_rows
)
{
sstream
<<
row_id
<<
", "
;
}
}
else
{
*
outvar
=
scope_
->
FindVar
(
varname
);
sstream
<<
"]"
;
VLOG
(
3
)
<<
"updated_rows size: "
<<
updated_rows
.
size
()
<<
" "
<<
sstream
.
str
();
}
auto
&
origin_tensor
=
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
auto
&
dims
=
origin_tensor
.
dims
();
*
outvar
=
scope
->
Var
();
auto
*
out_slr
=
(
*
outvar
)
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
set_rows
(
updated_rows
);
out_slr
->
set_height
(
dims
[
0
]);
auto
out_dims
=
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
updated_rows
.
size
()),
dims
[
1
]});
auto
*
data
=
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
(
out_dims
,
origin_tensor
.
place
());
auto
width
=
dims
[
1
];
for
(
size_t
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
updated_rows
[
i
],
dims
[
0
],
platform
::
errors
::
OutOfRange
(
"expected >= 0 and < %ld, but got %ld."
,
dims
[
0
],
updated_rows
[
i
]));
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
sizeof
(
float
)
*
width
);
}
}
else
{
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
return
true
;
}
bool
RequestGetNoBarrierHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
bool
RequestGetNoBarrierHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetNoBarrierHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
...
...
@@ -207,18 +177,19 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
*
outvar
=
scope_
->
FindVar
(
var_name_piece
.
ToString
());
return
true
;
}
else
{
PADDLE_THROW
(
"GetNoBarrier must contain %s"
,
WITHOUT_BARRIER_MESSAGE
);
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"GetNoBarrier must contain %s"
,
WITHOUT_BARRIER_MESSAGE
));
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
if
(
table_name
.
empty
())
{
...
...
@@ -236,19 +207,20 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return
true
;
}
bool
RequestCheckpointHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
bool
RequestCheckpointHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
PADDLE_ENFORCE
(
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
PADDLE_ENFORCE_NE
(
checkpoint_notify_id
,
-
1
,
platform
::
errors
::
Unavailable
(
"when checkpoint_notify_id = -1, there should be no RPC invoke."
));
// TODO(tangwei12): find out why scope will be error.
auto
*
lt_var
=
scope_
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
auto
*
lt_var
=
scope_
->
FindVar
(
LOOKUP_TABLE_PATH
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update var kLookupTablePath to: "
...
...
@@ -257,33 +229,56 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
return
true
;
}
bool
RequestNotifyHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
bool
RequestNotifyHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestNotifyHandler: "
<<
varname
;
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
<<
", trainer_id: "
<<
trainer_id
;
string
::
Piece
decay_piece
(
LEARNING_RATE_DECAY_COUNTER
);
string
::
Piece
batch_piece
(
BATCH_BARRIER_MESSAGE
);
string
::
Piece
fetch_piece
(
FETCH_BARRIER_MESSAGE
);
string
::
Piece
complete_piece
(
COMPLETE_MESSAGE
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
if
(
string
::
Contains
(
var_name_piece
,
decay_piece
))
{
if
(
string
::
Contains
(
var_name_piece
,
batch_piece
))
{
return
BarrierMonitor
::
GetInstance
()
->
IncreaseBarrier
(
trainer_id
,
BATCH_BARRIER_MESSAGE
);
}
else
if
(
string
::
Contains
(
var_name_piece
,
fetch_piece
))
{
return
BarrierMonitor
::
GetInstance
()
->
IncreaseBarrier
(
trainer_id
,
FETCH_BARRIER_MESSAGE
);
}
else
if
(
string
::
Contains
(
var_name_piece
,
complete_piece
))
{
if
(
HeartBeatMonitor
::
GetInstance
()
!=
nullptr
)
{
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
""
,
COMPLETED
);
}
rpc_server_
->
Complete
();
BarrierMonitor
::
GetInstance
()
->
DecreaseWorker
();
return
true
;
}
else
if
(
string
::
Contains
(
var_name_piece
,
decay_piece
))
{
VLOG
(
3
)
<<
"LearningRate Decay Counter Update"
;
PADDLE_ENFORCE_NE
(
lr_decay_block_id
,
-
1
,
"when lr_decay_block_id = -1, there should be no RPC invoke."
);
auto
*
origin_var
=
scope_
->
FindVar
(
varname
);
platform
::
errors
::
InvalidArgument
(
"when lr_decay_block_id = -1, there should be no RPC invoke."
));
auto
*
origin_var
=
scope_
->
FindVar
(
varname
);
auto
origin_var_tensor
=
origin_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
send_var
=
scope
->
FindVar
(
varname
);
auto
*
send_var
=
scope
->
FindVar
(
varname
);
auto
send_var_tensor
=
send_var
->
Get
<
framework
::
LoDTensor
>
();
int64_t
*
origin_value
=
int64_t
*
origin_value
=
origin_var_tensor
.
mutable_data
<
int64_t
>
(
origin_var_tensor
.
place
());
int64_t
*
send_value
=
int64_t
*
send_value
=
send_var_tensor
.
mutable_data
<
int64_t
>
(
send_var_tensor
.
place
());
origin_value
[
0
]
+=
send_value
[
0
];
executor_
->
RunPreparedContext
(
lr_decay_prepared_ctx_
.
get
(),
scope_
);
return
true
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unkown varname %s with RequestNotifyHandler"
,
varname
));
}
return
true
;
}
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
be6a315f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -24,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
...
...
@@ -119,6 +117,7 @@ void StartServer(const std::string& rpc_name) {
g_rpc_service
->
RegisterRPC
(
rpc_name
,
g_req_handler
.
get
());
distributed
::
HeartBeatMonitor
::
Init
(
2
,
true
,
"w@grad"
);
distributed
::
BarrierMonitor
::
Init
(
2
);
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
...
...
@@ -164,6 +163,9 @@ TEST(PREFETCH, CPU) {
}
}
auto
*
barrier
=
distributed
::
BarrierMonitor
::
GetInstance
();
barrier
->
Stop
();
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
...
...
@@ -174,20 +176,24 @@ TEST(PREFETCH, CPU) {
TEST
(
COMPLETE
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
distributed
::
DistributedMode
::
kSync
));
g_req_handler
.
reset
(
new
distributed
::
RequestNotifyHandler
(
distributed
::
DistributedMode
::
kSync
,
-
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
PADDLE_ENFORCE
(
client
!=
nullptr
);
std
::
thread
server_thread
(
StartServer
,
distributed
::
kRequest
Send
);
std
::
thread
server_thread
(
StartServer
,
distributed
::
kRequest
Notify
);
g_rpc_service
->
WaitServerReady
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
client
->
AsyncSendComplete
(
ep
);
client
->
Wait
();
EXPECT_EQ
(
g_rpc_service
->
GetClientNum
(),
1
);
auto
*
barrier
=
distributed
::
BarrierMonitor
::
GetInstance
();
EXPECT_EQ
(
barrier
->
GetWorkerNum
(),
1
);
barrier
->
Stop
();
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
...
...
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
be6a315f
...
...
@@ -2,9 +2,9 @@ include(operators)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv
barrier_monitor
communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node
)
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv
barrier_monitor
communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
浏览文件 @
be6a315f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -21,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
...
...
@@ -30,16 +28,16 @@ namespace operators {
class
GenNCCLIdOp
:
public
framework
::
OperatorBase
{
public:
GenNCCLIdOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
GenNCCLIdOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
// put nccl id in CPUPlace
auto
&
dev_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
auto
&
dev_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
std
::
vector
<
std
::
string
>
trainers
=
...
...
@@ -55,7 +53,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std
::
string
endpoint
=
trainers
[
trainer_id
];
framework
::
Scope
&
local_scope
=
scope
.
NewScope
();
framework
::
Scope
&
local_scope
=
scope
.
NewScope
();
int
nccl_comm_num
=
Attr
<
int
>
(
"nccl_comm_num"
);
int
use_hierarchical_allreduce
=
Attr
<
bool
>
(
"use_hierarchical_allreduce"
);
...
...
@@ -171,10 +169,10 @@ class GenNCCLIdOp : public framework::OperatorBase {
}
private:
void
GenerateAndSend
(
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
&
dev_ctx
,
const
std
::
string
&
nccl_id_name
,
const
std
::
vector
<
std
::
string
>
&
endpoint_list
)
const
{
void
GenerateAndSend
(
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
&
dev_ctx
,
const
std
::
string
&
nccl_id_name
,
const
std
::
vector
<
std
::
string
>
&
endpoint_list
)
const
{
auto
var
=
scope
->
FindVar
(
nccl_id_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Variable with name %s is not found"
,
...
...
@@ -182,76 +180,96 @@ class GenNCCLIdOp : public framework::OperatorBase {
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGetUniqueId
(
id
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl_id_var:"
<<
nccl_id_name
<<
" to "
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
*
scope
,
nccl_id_name
);
}
client
->
Wait
();
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
client
->
AsyncSendBatchBarrier
(
ep
);
}
client
->
Wait
();
VLOG
(
3
)
<<
"sending completed..."
;
}
void
GetIdByServer
(
const
std
::
string
&
endpoint
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
&
dev_ctx
,
int
nccl_comm_num
,
void
GetIdByServer
(
const
std
::
string
&
endpoint
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
&
dev_ctx
,
int
nccl_comm_num
,
bool
use_hierarchical_allreduce
,
int
trainer_id
,
int
inter_trainer_id
,
int
exter_trainer_id
)
const
{
// std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed
::
RequestSendHandler
rpc_h
(
distributed
::
DistributedMode
::
kSync
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
new
RPCSERVER_T
(
endpoint
,
1
));
distributed
::
RequestSendHandler
rpc_h
(
distributed
::
DistributedMode
::
kSync
);
distributed
::
RequestNotifyHandler
notify_h
(
distributed
::
DistributedMode
::
kSync
,
-
1
);
rpc_service
->
RegisterRPC
(
distributed
::
kRequestSend
,
&
rpc_h
);
rpc_
h
.
SetRPCServer
(
rpc_service
.
get
()
);
rpc_
service
->
RegisterRPC
(
distributed
::
kRequestNotify
,
&
notify_h
);
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_h
.
SetRPCServer
(
rpc_service
.
get
());
rpc_h
.
SetScope
(
scope
);
rpc_h
.
SetDevCtx
(
&
dev_ctx
);
rpc_h
.
SetProgram
(
&
empty_program
);
rpc_h
.
SetExecutor
(
&
executor
);
notify_h
.
SetRPCServer
(
rpc_service
.
get
());
notify_h
.
SetScope
(
scope
);
notify_h
.
SetDevCtx
(
&
dev_ctx
);
notify_h
.
SetProgram
(
&
empty_program
);
notify_h
.
SetExecutor
(
&
executor
);
distributed
::
BarrierMonitor
::
Init
(
1
);
auto
*
barrier
=
distributed
::
BarrierMonitor
::
GetInstance
();
barrier
->
Reset
(
1
,
distributed
::
BarrierType
::
kSendBarrier
);
std
::
thread
server_thread
(
std
::
bind
(
&
distributed
::
RPCServer
::
StartServer
,
rpc_service
.
get
()));
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
barrier
->
WaitServerWeakup
();
barrier
->
Reset
(
1
,
distributed
::
BarrierType
::
kSendBarrier
);
barrier
->
ServerWeakup
();
VLOG
(
3
)
<<
"trainer_id:"
<<
trainer_id
<<
" start getting nccl id from trainer 0, nccl_comm_no:"
<<
i
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service
->
ResetBarrierCounter
();
}
if
(
use_hierarchical_allreduce
)
{
if
(
inter_trainer_id
>
0
)
{
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
barrier
->
WaitServerWeakup
();
barrier
->
Reset
(
1
,
distributed
::
BarrierType
::
kSendBarrier
);
barrier
->
ServerWeakup
();
VLOG
(
3
)
<<
"trainer_id:"
<<
trainer_id
<<
", inter_trainer_id:"
<<
inter_trainer_id
<<
" start getting nccl id from inter_trainer:"
<<
i
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service
->
ResetBarrierCounter
();
}
}
if
(
exter_trainer_id
>
0
)
{
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
barrier
->
WaitServerWeakup
();
barrier
->
Reset
(
1
,
distributed
::
BarrierType
::
kSendBarrier
);
barrier
->
ServerWeakup
();
VLOG
(
3
)
<<
"trainer_id:"
<<
trainer_id
<<
", exter_trainer_id:"
<<
exter_trainer_id
<<
" start getting nccl id from exter_trainer 0, nccl_comm_no:"
<<
i
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service
->
ResetBarrierCounter
();
}
}
}
...
...
@@ -260,6 +278,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
<<
", inter_trainer_id:"
<<
inter_trainer_id
<<
", exter_trainer_id:"
<<
exter_trainer_id
<<
" got nccl id and stop server..."
;
barrier
->
Stop
();
rpc_service
->
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
server_thread
.
join
();
...
...
@@ -272,7 +291,6 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"NCCLID"
,
"Raw variable contains a NCCL UniqueId instaces."
);
AddComment
(
R"DOC(
GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC"
);
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
be6a315f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -25,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...
...
@@ -38,10 +36,13 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace
paddle
{
namespace
operators
{
volatile
sig_atomic_t
gSignalStatus
;
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
)
{
service
->
StartServer
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
...
...
@@ -126,6 +127,7 @@ void ListenAndServOp::RunSyncLoop(
for
(
size_t
i
=
1
;
i
<
program
->
Size
();
++
i
)
{
optimize_blocks_list
.
push_back
(
i
);
}
auto
optimize_prepared
=
executor
->
Prepare
(
*
program
,
optimize_blocks_list
);
// Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran.
...
...
@@ -135,21 +137,15 @@ void ListenAndServOp::RunSyncLoop(
// Trainers will get all parameters from pserver in the
// startup program, so we will wait RequestGet first
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
ResetBarrierCounter
();
auto
*
barrier
=
distributed
::
BarrierMonitor
::
GetInstance
();
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG
(
3
)
<<
"wait all clients to send gradient"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"wait all clients to send send_barrier"
;
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestSend
);
barrier
->
WaitServerWeakup
();
if
(
rpc_service_
->
IsExit
()
)
{
if
(
gSignalStatus
!=
0
)
{
LOG
(
WARNING
)
<<
"get exit!rpc_processor break!"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
break
;
}
...
...
@@ -180,12 +176,8 @@ void ListenAndServOp::RunSyncLoop(
VLOG
(
3
)
<<
"ResetReceivedVars"
;
ResetReceivedVars
(
recv_scope
,
dev_ctx
,
rpc_service_
->
NeedResetAllVars
());
VLOG
(
3
)
<<
"wait all clients to get parameters back"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
VLOG
(
3
)
<<
"wait all clients to send fetch_barrier"
;
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
VLOG
(
3
)
<<
"ResetBarrierCounter"
;
rpc_service_
->
ResetBarrierCounter
();
barrier
->
ServerWeakup
();
VLOG
(
3
)
<<
"kRecvBarrier to push params to trainers"
;
}
// while(true)
}
...
...
@@ -281,7 +273,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
while
(
true
)
{
if
(
rpc_service_
->
IsExit
()
)
{
if
(
gSignalStatus
!=
0
)
{
VLOG
(
4
)
<<
"get exit!rpc_processor break!"
;
break
;
}
...
...
@@ -391,7 +383,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
request_notify_handler_
.
get
(),
rpc_send_thread_num
);
request_notify_handler_
.
get
(),
fan_in
*
2
);
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
...
@@ -440,6 +432,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
prefetch_var_name_to_prepared_ctx
;
for
(
size_t
i
=
0
;
i
<
prefetch_block_id_list
.
size
();
++
i
)
{
auto
block_id
=
prefetch_block_id_list
[
i
];
auto
prefetch_var_name
=
block_id_to_prefetch_var_name
[
block_id
];
...
...
@@ -448,8 +441,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std
::
unordered_map
<
std
::
string
,
std
::
string
>
sparse_grad_name_to_param_name
;
auto
sparse_grad_name_to_param_name_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSparseGradToParam
);
for
(
const
auto
&
sparse_grad_name_and_param_name
:
sparse_grad_name_to_param_name_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
...
...
@@ -477,17 +472,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
distributed
::
BarrierMonitor
::
Init
(
fan_in
);
if
(
distributed_mode
==
distributed
::
DistributedMode
::
kSync
)
{
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
rpc_service_
->
WaitServerReady
();
CacheVarsType
(
inputs
,
recv_scope
);
// Write to a file of server selected port for python use.
SavePort
();
CacheVarsType
(
inputs
,
recv_scope
);
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
...
...
@@ -574,9 +570,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
remove
(
file_path
.
c_str
());
exit
(
0
);
distributed
::
BarrierMonitor
::
GetInstance
()
->
Stop
();
gSignalStatus
=
signal_num
;
}
}
// namespace operators
...
...
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
浏览文件 @
be6a315f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -20,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...
...
@@ -42,6 +40,7 @@ namespace string = paddle::string;
std
::
unique_ptr
<
distributed
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
distributed
::
RequestHandler
>
g_req_handler
;
std
::
unique_ptr
<
distributed
::
RequestNotifyHandler
>
g_notify_handler
;
void
StartServer
()
{
f
::
Scope
scope
;
...
...
@@ -52,21 +51,35 @@ void StartServer() {
f
::
ProgramDesc
empty_program
;
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetDevCtx
(
&
dev_ctx
);
g_req_handler
->
SetProgram
(
&
empty_program
);
g_req_handler
->
SetExecutor
(
&
executor
);
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
g_notify_handler
.
SetRPCServer
(
rpc_service
.
get
());
g_notify_handler
.
SetScope
(
scope
);
g_notify_handler
.
SetDevCtx
(
&
dev_ctx
);
g_notify_handler
.
SetProgram
(
&
empty_program
);
g_notify_handler
.
SetExecutor
(
&
executor
);
g_rpc_service
->
RegisterRPC
(
distributed
::
kRequestSend
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
g_rpc_service
->
RegisterRPC
(
distributed
::
RequestNotifyHandler
,
g_notify_handler
.
get
());
distributed
::
BarrierMonitor
::
Init
(
1
);
auto
*
barrier
=
distributed
::
BarrierMonitor
::
GetInstance
();
barrier
->
Reset
(
1
,
distributed
::
BarrierType
::
kSendBarrier
);
std
::
thread
server_thread
(
std
::
bind
(
&
distributed
::
RPCServer
::
StartServer
,
g_rpc_service
.
get
()));
g_rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
g_rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
barrier
->
WaitServerWeakup
(
);
barrier
->
ServerWeakup
(
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
barrier
->
Stop
();
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
}
...
...
@@ -74,6 +87,10 @@ void StartServer() {
TEST
(
SendNcclId
,
RPCServer
)
{
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
distributed
::
DistributedMode
::
kSync
));
g_notify_handler
.
reset
(
new
distributed
::
RequestNotifyHandler
(
distributed
::
DistributedMode
::
kSync
,
-
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
...
...
@@ -104,4 +121,5 @@ TEST(SendNcclId, RPCServer) {
server_thread
.
join
();
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
g_notify_handler
.
reset
(
nullptr
);
}
python/paddle/fluid/tests/unittests/test_communicator_half_async.py
浏览文件 @
be6a315f
...
...
@@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
current_id
=
0
,
role
=
role_maker
.
Role
.
WORKER
if
training_role
==
"TRAINER"
else
role_maker
.
Role
.
SERVER
,
worker_num
=
2
,
worker_num
=
1
,
server_endpoints
=
[
"127.0.0.1:6002"
])
if
training_role
==
"TRAINER"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录