Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b5a41046
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b5a41046
编写于
10月 07, 2019
作者:
T
tangwei12
提交者:
GitHub
10月 07, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Trainer heartbeat for async mode (#19600)
Heartbeat for distributed async training.
上级
76ba55e8
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
327 addition
and
4 deletion
+327
-4
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+4
-1
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+1
-0
paddle/fluid/operators/distributed/heart_beat_monitor.cc
paddle/fluid/operators/distributed/heart_beat_monitor.cc
+97
-0
paddle/fluid/operators/distributed/heart_beat_monitor.h
paddle/fluid/operators/distributed/heart_beat_monitor.h
+136
-0
paddle/fluid/operators/distributed/heart_beat_monitor_test.cc
...le/fluid/operators/distributed/heart_beat_monitor_test.cc
+57
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+6
-1
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+4
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+19
-2
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+1
-0
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
b5a41046
...
...
@@ -12,6 +12,9 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
cc_library
(
async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool
)
cc_test
(
async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder
)
cc_library
(
heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool
)
cc_test
(
heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor
)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
...
...
@@ -23,7 +26,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${
GRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder
)
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder
heart_beat_monitor
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
RPC_DEPS sendrecvop_rpc
${
GRPC_DEPS
}
)
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
b5a41046
...
...
@@ -392,6 +392,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_varname
(
COMPLETE_MESSAGE
);
platform
::
RecordRPCEvent
record_event
(
method
);
...
...
paddle/fluid/operators/distributed/heart_beat_monitor.cc
0 → 100644
浏览文件 @
b5a41046
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <chrono> // NOLINT
#include <ctime>
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
DEFINE_int32
(
worker_update_interval_secs
,
900
,
" the longest time interval between the worker update variables"
);
inline
int
GetCurrentUS
()
{
// current date/time based on current system
time_t
t
=
std
::
time
(
0
);
int
now
=
static_cast
<
int
>
(
t
);
return
now
;
}
void
HeartBeatMonitor
::
Update
(
const
int
worker_id
,
std
::
string
be_monitored_var
,
WorkerStatus
status
)
{
if
(
status
==
UNINITED
)
{
LOG
(
WARNING
)
<<
"HeartBeatMonitor receive UNINITED status can not be used "
"in Update, something error"
;
}
if
(
!
is_chief_
)
{
return
;
}
if
((
be_monitored_var
==
be_monitored_var_
&&
status
==
RUNNING
)
||
status
==
COMPLETED
)
{
auto
timestamp
=
GetCurrentUS
();
UnderMonitoredWorker
&
worker
=
worker_status_map_
.
at
(
worker_id
);
if
(
worker
.
status
!=
COMPLETED
)
{
worker
.
status
=
status
;
}
worker
.
timestamp
=
timestamp
;
return
;
}
}
void
HeartBeatMonitor
::
LostWorkerMonitor
()
{
VLOG
(
1
)
<<
"worker heartbeat monitor start at No.0 parameter server"
;
while
(
running_
)
{
for
(
int
id
=
0
;
id
<
workers_
;
++
id
)
{
auto
&
worker
=
worker_status_map_
.
at
(
id
);
if
(
worker
.
status
==
UNINITED
)
{
VLOG
(
4
)
<<
"worker "
<<
worker
.
id
<<
" is under UNINITED"
;
continue
;
}
if
(
worker
.
status
==
COMPLETED
)
{
VLOG
(
4
)
<<
"worker "
<<
worker
.
id
<<
" is under COMPLETED"
;
continue
;
}
auto
timestamp
=
GetCurrentUS
();
VLOG
(
4
)
<<
"worker "
<<
worker
.
id
<<
" status is "
<<
worker
.
status
<<
" timestamp is "
<<
worker
.
timestamp
<<
" the interval is "
<<
timestamp
-
worker
.
timestamp
;
if
(
timestamp
-
worker
.
timestamp
>=
FLAGS_worker_update_interval_secs
)
{
PADDLE_THROW
(
"the latest update of worker %d is %d secs ago, we doubt the "
"the worker is not alive and this may have a bad effect on the "
"fitting result, please check"
,
worker
.
id
,
FLAGS_worker_update_interval_secs
);
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
30
*
1000
));
}
VLOG
(
1
)
<<
"worker heartbeat monitor stopped, thread exit"
;
}
std
::
once_flag
HeartBeatMonitor
::
init_flag_
;
std
::
unique_ptr
<
HeartBeatMonitor
>
HeartBeatMonitor
::
monitor_
(
nullptr
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/heart_beat_monitor.h
0 → 100644
浏览文件 @
b5a41046
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
enum
WorkerStatus
{
UNINITED
=
0
,
RUNNING
,
COMPLETED
};
struct
UnderMonitoredWorker
{
int
id
;
WorkerStatus
status
;
int
timestamp
;
UnderMonitoredWorker
()
{}
explicit
UnderMonitoredWorker
(
int
worker_id
)
{
this
->
id
=
worker_id
;
this
->
status
=
UNINITED
;
this
->
timestamp
=
0
;
}
};
class
HeartBeatMonitor
{
public:
explicit
HeartBeatMonitor
(
int
workers
,
bool
is_chief
,
std
::
string
be_monitored_var
)
:
workers_
(
workers
),
is_chief_
(
is_chief
),
be_monitored_var_
(
be_monitored_var
),
running_
(
true
)
{
PADDLE_ENFORCE_GT
(
workers
,
0
,
"trainers must have one or more"
);
for
(
auto
worker_id
=
0
;
worker_id
<
workers
;
worker_id
++
)
{
UnderMonitoredWorker
worker
(
worker_id
);
worker_status_map_
[
worker_id
]
=
std
::
move
(
worker
);
}
// we define the No.0 pserver is the first parameter server
// only No.0 will check the heartbeat of all trainers
if
(
is_chief
)
{
monitor_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
HeartBeatMonitor
::
LostWorkerMonitor
,
this
)));
}
}
~
HeartBeatMonitor
()
{
running_
=
false
;
if
(
monitor_thread_
)
monitor_thread_
->
join
();
}
static
void
Init
(
int
workers
,
bool
is_chief
,
std
::
string
be_monitored_var
)
{
std
::
call_once
(
init_flag_
,
&
HeartBeatMonitor
::
InitImpl
,
workers
,
is_chief
,
be_monitored_var
);
}
static
HeartBeatMonitor
*
GetInstance
()
{
if
(
monitor_
==
nullptr
)
{
PADDLE_THROW
(
"HeartBeatMonitor is not inited, call "
"HeartBeatMonitor::Init first"
);
}
return
monitor_
.
get
();
}
void
Stop
()
{
running_
=
false
;
if
(
!
monitor_
)
{
VLOG
(
0
)
<<
"HeartBeatMonitor is not inited, do nothing"
;
}
else
{
if
(
monitor_thread_
)
{
monitor_thread_
->
join
();
monitor_thread_
.
reset
(
nullptr
);
}
}
}
void
Update
(
const
int
worker_id
,
std
::
string
be_monitored_var
,
WorkerStatus
status
);
void
LostWorkerMonitor
();
private:
// Init is called by GetInstance.
static
void
InitImpl
(
int
workers
,
bool
is_chief
,
std
::
string
be_monitored_var
)
{
if
(
monitor_
==
nullptr
)
{
monitor_
.
reset
(
new
HeartBeatMonitor
(
workers
,
is_chief
,
be_monitored_var
));
}
}
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
HeartBeatMonitor
>
monitor_
;
int
workers_
;
bool
is_chief_
;
std
::
string
be_monitored_var_
;
std
::
unordered_map
<
int
,
UnderMonitoredWorker
>
worker_status_map_
;
std
::
unique_ptr
<
std
::
thread
>
monitor_thread_
{
nullptr
};
std
::
mutex
mutex_
;
bool
running_
=
false
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/heart_beat_monitor_test.cc
0 → 100644
浏览文件 @
b5a41046
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <algorithm>
#include <thread> // NOLINT
#include "gtest/gtest.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
run
(
HeartBeatMonitor
*
monitor
)
{
monitor
->
LostWorkerMonitor
();
}
TEST
(
HeartBeatMonitor
,
All
)
{
int
trainers
=
10
;
int
pserver_id
=
0
;
std
::
string
var
=
"nce_w@GRAD.block0"
;
std
::
string
var2
=
"nce_w@GRAD.block2"
;
HeartBeatMonitor
::
Init
(
trainers
,
pserver_id
==
0
,
var
);
auto
*
monitor
=
HeartBeatMonitor
::
GetInstance
();
std
::
vector
<
int
>
ids
{
1
,
3
,
5
,
7
};
for
(
auto
&
id
:
ids
)
{
monitor
->
Update
(
id
,
var
,
RUNNING
);
}
monitor
->
Update
(
9
,
var2
,
RUNNING
);
monitor
->
Update
(
2
,
var
,
COMPLETED
);
std
::
thread
t
(
run
,
monitor
);
t
.
detach
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
45
*
1000
));
monitor
->
Stop
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
b5a41046
...
...
@@ -22,12 +22,14 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
...
...
@@ -51,6 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv complete message"
;
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
""
,
COMPLETED
);
rpc_server_
->
Complete
();
}
else
{
// Async
...
...
@@ -61,6 +64,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
varname
,
RUNNING
);
std
::
string
run_varname
=
varname
;
...
...
@@ -82,6 +86,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
}
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
run_varname
].
get
(),
scope
);
return
true
;
}
else
{
// sync
rpc_server_
->
WaitCond
(
kRequestSend
);
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
b5a41046
...
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
...
...
@@ -116,6 +117,9 @@ void StartServer(const std::string& rpc_name) {
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
rpc_name
,
g_req_handler
.
get
());
distributed
::
HeartBeatMonitor
::
Init
(
2
,
true
,
"w@grad"
);
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
b5a41046
...
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...
...
@@ -338,14 +339,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
pserver_id
=
Attr
<
int
>
(
"pserver_id"
);
auto
inputs
=
Inputs
(
"X"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
VLOG
(
4
)
<<
"
sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
VLOG
(
4
)
<<
"
pserver_id: "
<<
pserver_id
<<
", sync_mode:"
<<
sync_mode
<<
",
fan_in:"
<<
fan_in
<<
",
end_point:"
<<
endpoint
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
;
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
...
...
@@ -466,6 +468,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
}
else
{
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
fan_in
,
sparse_grad_name_to_param_name
);
VLOG
(
2
)
<<
"RunAsyncLoop"
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
if
(
grad_to_block_id_str
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"there are no gradients on this parameter server"
;
}
else
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_to_block_id_str
[
0
],
':'
,
&
pieces
);
distributed
::
HeartBeatMonitor
::
Init
(
fan_in
,
pserver_id
==
0
,
pieces
[
0
]);
}
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
...
@@ -482,6 +496,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
int
>
(
"pserver_id"
,
"(int, default -1), the parameter server index id"
)
.
SetDefault
(
-
1
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
,
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
...
...
python/paddle/fluid/__init__.py
浏览文件 @
b5a41046
...
...
@@ -189,6 +189,8 @@ def __bootstrap__():
read_env_flags
.
append
(
'rpc_prefetch_thread_num'
)
read_env_flags
.
append
(
'rpc_disable_reuse_port'
)
read_env_flags
.
append
(
'worker_update_interval_secs'
)
# env for communicator
read_env_flags
.
append
(
'communicator_independent_recv_thread'
)
read_env_flags
.
append
(
'communicator_send_queue_size'
)
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b5a41046
...
...
@@ -1193,6 +1193,7 @@ class DistributeTranspiler(object):
attrs
=
{
"optimize_blocks"
:
optimize_blocks
,
"endpoint"
:
endpoint
,
"pserver_id"
:
self
.
pserver_endpoints
.
index
(
endpoint
),
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录