Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2028a8ef
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2028a8ef
编写于
6月 06, 2018
作者:
G
gongweibao
提交者:
GitHub
6月 06, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add rpc_client interface. (#11154)
上级
ca2d6d3c
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
191 addition
and
89 deletion
+191
-89
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+1
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+25
-39
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+27
-29
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+4
-2
paddle/fluid/operators/detail/rpc_client.cc
paddle/fluid/operators/detail/rpc_client.cc
+26
-0
paddle/fluid/operators/detail/rpc_client.h
paddle/fluid/operators/detail/rpc_client.h
+82
-0
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+3
-1
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+4
-3
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+3
-3
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+3
-2
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-1
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+4
-3
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+3
-2
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+4
-3
未找到文件。
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
2028a8ef
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
request_handler_impl.cc rpc_
client.cc rpc_
server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
2028a8ef
...
@@ -25,29 +25,15 @@ namespace paddle {
...
@@ -25,29 +25,15 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
void
GRPCClient
::
InitImpl
()
{
InitEventLoop
();
}
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
void
GRPCClient
::
InitEventLoop
()
{
RPCClient
*
RPCClient
::
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
);
return
rpc_client_
.
get
();
}
void
RPCClient
::
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
}
rpc_client_
->
InitEventLoop
();
}
void
RPCClient
::
InitEventLoop
()
{
// start the client process thread
// start the client process thread
// TODO(wuyi): can make this in a threadpool
// TODO(wuyi): can make this in a threadpool
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
RPCClient
::
Proceed
,
this
)));
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
G
RPCClient
::
Proceed
,
this
)));
}
}
RPCClient
::~
RPCClient
()
{
GRPCClient
::~
G
RPCClient
()
{
Wait
();
Wait
();
cq_
.
Shutdown
();
cq_
.
Shutdown
();
{
{
...
@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
...
@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
client_thread_
->
join
();
client_thread_
->
join
();
}
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
GRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
...
@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
...
@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result
->
Swap
(
&
tmp
);
result
->
Swap
(
&
tmp
);
}
}
bool
RPCClient
::
AsyncGetVariable
(
const
std
::
string
&
ep
,
bool
GRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
...
@@ -155,12 +139,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -155,12 +139,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return
true
;
return
true
;
}
}
bool
RPCClient
::
AsyncPrefetchVariable
(
const
std
::
string
&
ep
,
bool
GRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
in_var_name_val
=
in_var_name
;
...
@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
return
true
;
return
true
;
}
}
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
GRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
...
@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
req_count_
++
;
req_count_
++
;
}
}
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
GRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
...
@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_
++
;
req_count_
++
;
}
}
void
RPCClient
::
Wait
()
{
void
G
RPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
}
}
void
RPCClient
::
Proceed
()
{
void
G
RPCClient
::
Proceed
()
{
void
*
tag
=
nullptr
;
void
*
tag
=
nullptr
;
bool
ok
=
false
;
bool
ok
=
false
;
...
@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
...
@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
}
}
}
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
std
::
shared_ptr
<
grpc
::
Channel
>
G
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
// TODO(Yancey1989): make grpc client completely thread-safe
// TODO(Yancey1989): make grpc client completely thread-safe
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
auto
it
=
channels_
.
find
(
ep
);
auto
it
=
channels_
.
find
(
ep
);
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
2028a8ef
...
@@ -38,6 +38,7 @@ limitations under the License. */
...
@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...
@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
...
@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
};
class
RPCClient
{
class
GRPCClient
:
public
RPCClient
{
public:
public:
RPCClient
()
{}
G
RPCClient
()
{}
~
RPCClient
();
virtual
~
G
RPCClient
();
static
RPCClient
*
GetInstance
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
Async
GetVariable
(
const
std
::
string
&
ep
,
bool
Async
PrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
in_var_name
,
int64_t
time_out
=
600
*
1000
);
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
bool
AsyncPrefetchVariable
(
const
std
::
string
&
ep
,
void
AsyncSendBatchBarrier
(
const
platform
::
DeviceContext
&
ctx
,
const
std
::
string
&
ep
,
const
framework
::
Scope
&
scope
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
600
*
1000
);
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
void
AsyncSendFetchBarrier
(
int64_t
time_out
=
600
*
1000
);
const
std
::
string
&
ep
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
void
Wait
()
override
;
int64_t
time_out
=
600
*
1000
);
void
Wait
();
protected:
void
InitImpl
()
override
;
private:
// InitEventLoop should only be called by Init()
// InitEventLoop should only be called by Init()
void
InitEventLoop
();
void
InitEventLoop
();
private:
void
Proceed
();
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
static
void
Init
();
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
...
@@ -218,9 +218,7 @@ class RPCClient {
...
@@ -218,9 +218,7 @@ class RPCClient {
// mutex for GetChannel thread safety
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
std
::
mutex
chan_mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
DISABLE_COPY_AND_ASSIGN
(
GRPCClient
);
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
};
};
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
2028a8ef
...
@@ -19,6 +19,7 @@ limitations under the License. */
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
...
@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
std
::
thread
server_thread
(
StartServer
);
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
...
@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
...
@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
client
->
AsyncPrefetchVar
iable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
AsyncPrefetchVar
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
...
...
paddle/fluid/operators/detail/rpc_client.cc
0 → 100644
浏览文件 @
2028a8ef
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_client.h
0 → 100644
浏览文件 @
2028a8ef
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RPCClient
{
public:
virtual
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
bool
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
void
Wait
()
=
0
;
static
constexpr
int64_t
rpc_time_out
=
120
*
1000
;
template
<
typename
T
>
static
RPCClient
*
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
<
T
>
);
return
rpc_client_
.
get
();
}
// Init is called by GetInstance.
template
<
typename
T
>
static
void
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
T
());
rpc_client_
->
InitImpl
();
}
}
protected:
virtual
void
InitImpl
()
{}
private:
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
2028a8ef
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
...
@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
rpc_client
->
Wait
();
rpc_client
->
Wait
();
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
2028a8ef
...
@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
detail
::
RPCClient
client
;
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
*
scope
,
NCCL_ID_VARNAME
);
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
*
scope
,
NCCL_ID_VARNAME
);
}
}
client
.
Wait
();
client
->
Wait
();
VLOG
(
3
)
<<
"sending completed..."
;
VLOG
(
3
)
<<
"sending completed..."
;
}
}
...
...
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
2028a8ef
...
@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase {
...
@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
]
<<
" to get "
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
]
<<
" to get "
<<
outs
[
i
]
<<
" back"
;
<<
outs
[
i
]
<<
" back"
;
rpc_client
->
AsyncPrefetchVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
],
rpc_client
->
AsyncPrefetchVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
],
outs
[
i
]);
outs
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
2028a8ef
...
@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase {
...
@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVar
iable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
if
(
sync_mode
)
{
if
(
sync_mode
)
{
rpc_client
->
Wait
();
rpc_client
->
Wait
();
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
2028a8ef
...
@@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
2028a8ef
...
@@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase {
...
@@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
rpc_client
->
AsyncSendVar
iable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
...
@@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase {
...
@@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase {
if
(
outs
.
size
()
>
0
)
{
if
(
outs
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVar
iable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
rpc_client
->
Wait
();
rpc_client
->
Wait
();
// tell pservers that current trainer have called fetch
// tell pservers that current trainer have called fetch
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
2028a8ef
...
@@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase {
...
@@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
// TODO(Yancey1989): we need to use an IO threadpool which has
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
// a larger number of threads than the computing threadpool.
rpc_client
->
AsyncSendVar
iable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
2028a8ef
...
@@ -87,9 +87,10 @@ TEST(SendNcclId, GrpcServer) {
...
@@ -87,9 +87,10 @@ TEST(SendNcclId, GrpcServer) {
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
client
=
LOG
(
INFO
)
<<
"connect to server "
<<
ep
;
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
client
->
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
->
Wait
();
client
->
Wait
();
client
->
AsyncSendBatchBarrier
(
ep
);
client
->
AsyncSendBatchBarrier
(
ep
);
client
->
Wait
();
client
->
Wait
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录