Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b5a41046
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b5a41046
编写于
10月 07, 2019
作者:
T
tangwei12
提交者:
GitHub
10月 07, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Trainer heartbeat for async mode (#19600)
Heartbeat for distributed async training.
上级
76ba55e8
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
327 addition
and
4 deletion
+327
-4
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+4
-1
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+1
-0
paddle/fluid/operators/distributed/heart_beat_monitor.cc
paddle/fluid/operators/distributed/heart_beat_monitor.cc
+97
-0
paddle/fluid/operators/distributed/heart_beat_monitor.h
paddle/fluid/operators/distributed/heart_beat_monitor.h
+136
-0
paddle/fluid/operators/distributed/heart_beat_monitor_test.cc
...le/fluid/operators/distributed/heart_beat_monitor_test.cc
+57
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+6
-1
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+4
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+19
-2
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+1
-0
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
b5a41046
...
...
@@ -12,6 +12,9 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
cc_library
(
async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool
)
cc_test
(
async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder
)
cc_library
(
heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool
)
cc_test
(
heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor
)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
...
...
@@ -23,7 +26,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${
GRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder
)
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder
heart_beat_monitor
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
RPC_DEPS sendrecvop_rpc
${
GRPC_DEPS
}
)
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
b5a41046
...
...
@@ -392,6 +392,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_varname
(
COMPLETE_MESSAGE
);
platform
::
RecordRPCEvent
record_event
(
method
);
...
...
paddle/fluid/operators/distributed/heart_beat_monitor.cc
0 → 100644
浏览文件 @
b5a41046
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <chrono> // NOLINT
#include <ctime>
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
DEFINE_int32
(
worker_update_interval_secs
,
900
,
" the longest time interval between the worker update variables"
);
inline
int
GetCurrentUS
()
{
// current date/time based on current system
time_t
t
=
std
::
time
(
0
);
int
now
=
static_cast
<
int
>
(
t
);
return
now
;
}
void
HeartBeatMonitor
::
Update
(
const
int
worker_id
,
std
::
string
be_monitored_var
,
WorkerStatus
status
)
{
if
(
status
==
UNINITED
)
{
LOG
(
WARNING
)
<<
"HeartBeatMonitor receive UNINITED status can not be used "
"in Update, something error"
;
}
if
(
!
is_chief_
)
{
return
;
}
if
((
be_monitored_var
==
be_monitored_var_
&&
status
==
RUNNING
)
||
status
==
COMPLETED
)
{
auto
timestamp
=
GetCurrentUS
();
UnderMonitoredWorker
&
worker
=
worker_status_map_
.
at
(
worker_id
);
if
(
worker
.
status
!=
COMPLETED
)
{
worker
.
status
=
status
;
}
worker
.
timestamp
=
timestamp
;
return
;
}
}
void
HeartBeatMonitor
::
LostWorkerMonitor
()
{
VLOG
(
1
)
<<
"worker heartbeat monitor start at No.0 parameter server"
;
while
(
running_
)
{
for
(
int
id
=
0
;
id
<
workers_
;
++
id
)
{
auto
&
worker
=
worker_status_map_
.
at
(
id
);
if
(
worker
.
status
==
UNINITED
)
{
VLOG
(
4
)
<<
"worker "
<<
worker
.
id
<<
" is under UNINITED"
;
continue
;
}
if
(
worker
.
status
==
COMPLETED
)
{
VLOG
(
4
)
<<
"worker "
<<
worker
.
id
<<
" is under COMPLETED"
;
continue
;
}
auto
timestamp
=
GetCurrentUS
();
VLOG
(
4
)
<<
"worker "
<<
worker
.
id
<<
" status is "
<<
worker
.
status
<<
" timestamp is "
<<
worker
.
timestamp
<<
" the interval is "
<<
timestamp
-
worker
.
timestamp
;
if
(
timestamp
-
worker
.
timestamp
>=
FLAGS_worker_update_interval_secs
)
{
PADDLE_THROW
(
"the latest update of worker %d is %d secs ago, we doubt the "
"the worker is not alive and this may have a bad effect on the "
"fitting result, please check"
,
worker
.
id
,
FLAGS_worker_update_interval_secs
);
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
30
*
1000
));
}
VLOG
(
1
)
<<
"worker heartbeat monitor stopped, thread exit"
;
}
std
::
once_flag
HeartBeatMonitor
::
init_flag_
;
std
::
unique_ptr
<
HeartBeatMonitor
>
HeartBeatMonitor
::
monitor_
(
nullptr
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/heart_beat_monitor.h
0 → 100644
浏览文件 @
b5a41046
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
enum
WorkerStatus
{
UNINITED
=
0
,
RUNNING
,
COMPLETED
};
struct
UnderMonitoredWorker
{
int
id
;
WorkerStatus
status
;
int
timestamp
;
UnderMonitoredWorker
()
{}
explicit
UnderMonitoredWorker
(
int
worker_id
)
{
this
->
id
=
worker_id
;
this
->
status
=
UNINITED
;
this
->
timestamp
=
0
;
}
};
class
HeartBeatMonitor
{
public:
explicit
HeartBeatMonitor
(
int
workers
,
bool
is_chief
,
std
::
string
be_monitored_var
)
:
workers_
(
workers
),
is_chief_
(
is_chief
),
be_monitored_var_
(
be_monitored_var
),
running_
(
true
)
{
PADDLE_ENFORCE_GT
(
workers
,
0
,
"trainers must have one or more"
);
for
(
auto
worker_id
=
0
;
worker_id
<
workers
;
worker_id
++
)
{
UnderMonitoredWorker
worker
(
worker_id
);
worker_status_map_
[
worker_id
]
=
std
::
move
(
worker
);
}
// we define the No.0 pserver is the first parameter server
// only No.0 will check the heartbeat of all trainers
if
(
is_chief
)
{
monitor_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
HeartBeatMonitor
::
LostWorkerMonitor
,
this
)));
}
}
~
HeartBeatMonitor
()
{
running_
=
false
;
if
(
monitor_thread_
)
monitor_thread_
->
join
();
}
static
void
Init
(
int
workers
,
bool
is_chief
,
std
::
string
be_monitored_var
)
{
std
::
call_once
(
init_flag_
,
&
HeartBeatMonitor
::
InitImpl
,
workers
,
is_chief
,
be_monitored_var
);
}
static
HeartBeatMonitor
*
GetInstance
()
{
if
(
monitor_
==
nullptr
)
{
PADDLE_THROW
(
"HeartBeatMonitor is not inited, call "
"HeartBeatMonitor::Init first"
);
}
return
monitor_
.
get
();
}
void
Stop
()
{
running_
=
false
;
if
(
!
monitor_
)
{
VLOG
(
0
)
<<
"HeartBeatMonitor is not inited, do nothing"
;
}
else
{
if
(
monitor_thread_
)
{
monitor_thread_
->
join
();
monitor_thread_
.
reset
(
nullptr
);
}
}
}
void
Update
(
const
int
worker_id
,
std
::
string
be_monitored_var
,
WorkerStatus
status
);
void
LostWorkerMonitor
();
private:
// Init is called by GetInstance.
static
void
InitImpl
(
int
workers
,
bool
is_chief
,
std
::
string
be_monitored_var
)
{
if
(
monitor_
==
nullptr
)
{
monitor_
.
reset
(
new
HeartBeatMonitor
(
workers
,
is_chief
,
be_monitored_var
));
}
}
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
HeartBeatMonitor
>
monitor_
;
int
workers_
;
bool
is_chief_
;
std
::
string
be_monitored_var_
;
std
::
unordered_map
<
int
,
UnderMonitoredWorker
>
worker_status_map_
;
std
::
unique_ptr
<
std
::
thread
>
monitor_thread_
{
nullptr
};
std
::
mutex
mutex_
;
bool
running_
=
false
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/heart_beat_monitor_test.cc
0 → 100644
浏览文件 @
b5a41046
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include <algorithm>
#include <thread> // NOLINT
#include "gtest/gtest.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
run
(
HeartBeatMonitor
*
monitor
)
{
monitor
->
LostWorkerMonitor
();
}
TEST
(
HeartBeatMonitor
,
All
)
{
int
trainers
=
10
;
int
pserver_id
=
0
;
std
::
string
var
=
"nce_w@GRAD.block0"
;
std
::
string
var2
=
"nce_w@GRAD.block2"
;
HeartBeatMonitor
::
Init
(
trainers
,
pserver_id
==
0
,
var
);
auto
*
monitor
=
HeartBeatMonitor
::
GetInstance
();
std
::
vector
<
int
>
ids
{
1
,
3
,
5
,
7
};
for
(
auto
&
id
:
ids
)
{
monitor
->
Update
(
id
,
var
,
RUNNING
);
}
monitor
->
Update
(
9
,
var2
,
RUNNING
);
monitor
->
Update
(
2
,
var
,
COMPLETED
);
std
::
thread
t
(
run
,
monitor
);
t
.
detach
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
45
*
1000
));
monitor
->
Stop
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
b5a41046
...
...
@@ -22,12 +22,14 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
...
...
@@ -51,6 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv complete message"
;
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
""
,
COMPLETED
);
rpc_server_
->
Complete
();
}
else
{
// Async
...
...
@@ -61,6 +64,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
HeartBeatMonitor
::
GetInstance
()
->
Update
(
trainer_id
,
varname
,
RUNNING
);
std
::
string
run_varname
=
varname
;
...
...
@@ -82,6 +86,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
}
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
run_varname
].
get
(),
scope
);
return
true
;
}
else
{
// sync
rpc_server_
->
WaitCond
(
kRequestSend
);
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
b5a41046
...
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
...
...
@@ -116,6 +117,9 @@ void StartServer(const std::string& rpc_name) {
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
rpc_name
,
g_req_handler
.
get
());
distributed
::
HeartBeatMonitor
::
Init
(
2
,
true
,
"w@grad"
);
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
b5a41046
...
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...
...
@@ -338,14 +339,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
pserver_id
=
Attr
<
int
>
(
"pserver_id"
);
auto
inputs
=
Inputs
(
"X"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
VLOG
(
4
)
<<
"
sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
VLOG
(
4
)
<<
"
pserver_id: "
<<
pserver_id
<<
", sync_mode:"
<<
sync_mode
<<
",
fan_in:"
<<
fan_in
<<
",
end_point:"
<<
endpoint
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
;
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
...
...
@@ -466,6 +468,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
}
else
{
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
fan_in
,
sparse_grad_name_to_param_name
);
VLOG
(
2
)
<<
"RunAsyncLoop"
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
if
(
grad_to_block_id_str
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"there are no gradients on this parameter server"
;
}
else
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_to_block_id_str
[
0
],
':'
,
&
pieces
);
distributed
::
HeartBeatMonitor
::
Init
(
fan_in
,
pserver_id
==
0
,
pieces
[
0
]);
}
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
...
@@ -482,6 +496,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
int
>
(
"pserver_id"
,
"(int, default -1), the parameter server index id"
)
.
SetDefault
(
-
1
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
,
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
...
...
python/paddle/fluid/__init__.py
浏览文件 @
b5a41046
...
...
@@ -189,6 +189,8 @@ def __bootstrap__():
read_env_flags
.
append
(
'rpc_prefetch_thread_num'
)
read_env_flags
.
append
(
'rpc_disable_reuse_port'
)
read_env_flags
.
append
(
'worker_update_interval_secs'
)
# env for communicator
read_env_flags
.
append
(
'communicator_independent_recv_thread'
)
read_env_flags
.
append
(
'communicator_send_queue_size'
)
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b5a41046
...
...
@@ -1193,6 +1193,7 @@ class DistributeTranspiler(object):
attrs
=
{
"optimize_blocks"
:
optimize_blocks
,
"endpoint"
:
endpoint
,
"pserver_id"
:
self
.
pserver_endpoints
.
index
(
endpoint
),
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录