Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2028a8ef
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2028a8ef
编写于
6月 06, 2018
作者:
G
gongweibao
提交者:
GitHub
6月 06, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add rpc_client interface. (#11154)
上级
ca2d6d3c
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
191 addition
and
89 deletion
+191
-89
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+1
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+25
-39
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+27
-29
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+4
-2
paddle/fluid/operators/detail/rpc_client.cc
paddle/fluid/operators/detail/rpc_client.cc
+26
-0
paddle/fluid/operators/detail/rpc_client.h
paddle/fluid/operators/detail/rpc_client.h
+82
-0
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+3
-1
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+4
-3
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+3
-3
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+3
-2
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-1
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+4
-3
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+3
-2
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+4
-3
未找到文件。
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
2028a8ef
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
request_handler_impl.cc rpc_
client.cc rpc_
server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
2028a8ef
...
@@ -25,29 +25,15 @@ namespace paddle {
...
@@ -25,29 +25,15 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
void
GRPCClient
::
InitImpl
()
{
InitEventLoop
();
}
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
void
GRPCClient
::
InitEventLoop
()
{
RPCClient
*
RPCClient
::
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
);
return
rpc_client_
.
get
();
}
void
RPCClient
::
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
}
rpc_client_
->
InitEventLoop
();
}
void
RPCClient
::
InitEventLoop
()
{
// start the client process thread
// start the client process thread
// TODO(wuyi): can make this in a threadpool
// TODO(wuyi): can make this in a threadpool
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
RPCClient
::
Proceed
,
this
)));
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
G
RPCClient
::
Proceed
,
this
)));
}
}
RPCClient
::~
RPCClient
()
{
GRPCClient
::~
G
RPCClient
()
{
Wait
();
Wait
();
cq_
.
Shutdown
();
cq_
.
Shutdown
();
{
{
...
@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
...
@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
client_thread_
->
join
();
client_thread_
->
join
();
}
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
GRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
...
@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
...
@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result
->
Swap
(
&
tmp
);
result
->
Swap
(
&
tmp
);
}
}
bool
RPCClient
::
AsyncGetVariable
(
const
std
::
string
&
ep
,
bool
GRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
...
@@ -155,7 +139,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -155,7 +139,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return
true
;
return
true
;
}
}
bool
RPCClient
::
AsyncPrefetchVariable
(
const
std
::
string
&
ep
,
bool
GRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
in_var_name
,
...
@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
return
true
;
return
true
;
}
}
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
GRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
...
@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
req_count_
++
;
req_count_
++
;
}
}
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
GRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
...
@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_
++
;
req_count_
++
;
}
}
void
RPCClient
::
Wait
()
{
void
G
RPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
}
}
void
RPCClient
::
Proceed
()
{
void
G
RPCClient
::
Proceed
()
{
void
*
tag
=
nullptr
;
void
*
tag
=
nullptr
;
bool
ok
=
false
;
bool
ok
=
false
;
...
@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
...
@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
}
}
}
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
std
::
shared_ptr
<
grpc
::
Channel
>
G
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
// TODO(Yancey1989): make grpc client completely thread-safe
// TODO(Yancey1989): make grpc client completely thread-safe
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
auto
it
=
channels_
.
find
(
ep
);
auto
it
=
channels_
.
find
(
ep
);
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
2028a8ef
...
@@ -38,6 +38,7 @@ limitations under the License. */
...
@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...
@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
...
@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
};
class
RPCClient
{
class
GRPCClient
:
public
RPCClient
{
public:
public:
RPCClient
()
{}
G
RPCClient
()
{}
~
RPCClient
();
virtual
~
G
RPCClient
();
static
RPCClient
*
GetInstance
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
AsyncGetVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
AsyncPrefetchVar
iable
(
const
std
::
string
&
ep
,
bool
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
600
*
1000
)
;
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
void
AsyncSendBatchBarrier
(
int64_t
time_out
=
600
*
1000
);
const
std
::
string
&
ep
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
void
AsyncSendFetchBarrier
(
int64_t
time_out
=
600
*
1000
);
const
std
::
string
&
ep
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
Wait
();
void
Wait
()
override
;
protected:
void
InitImpl
()
override
;
private:
// InitEventLoop should only be called by Init()
// InitEventLoop should only be called by Init()
void
InitEventLoop
();
void
InitEventLoop
();
private:
void
Proceed
();
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
static
void
Init
();
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
...
@@ -218,9 +218,7 @@ class RPCClient {
...
@@ -218,9 +218,7 @@ class RPCClient {
// mutex for GetChannel thread safety
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
std
::
mutex
chan_mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
DISABLE_COPY_AND_ASSIGN
(
GRPCClient
);
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
};
};
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
2028a8ef
...
@@ -19,6 +19,7 @@ limitations under the License. */
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
...
@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
std
::
thread
server_thread
(
StartServer
);
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
...
@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
...
@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
client
->
AsyncPrefetchVar
iable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
AsyncPrefetchVar
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
...
...
paddle/fluid/operators/detail/rpc_client.cc
0 → 100644
浏览文件 @
2028a8ef
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
std
::
once_flag
RPCClient
::
init_flag_
;
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_client.h
0 → 100644
浏览文件 @
2028a8ef
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RPCClient
{
public:
virtual
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
bool
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
rpc_time_out
)
=
0
;
virtual
void
Wait
()
=
0
;
static
constexpr
int64_t
rpc_time_out
=
120
*
1000
;
template
<
typename
T
>
static
RPCClient
*
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
<
T
>
);
return
rpc_client_
.
get
();
}
// Init is called by GetInstance.
template
<
typename
T
>
static
void
Init
()
{
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
T
());
rpc_client_
->
InitImpl
();
}
}
protected:
virtual
void
InitImpl
()
{}
private:
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
2028a8ef
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
...
@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
rpc_client
->
Wait
();
rpc_client
->
Wait
();
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
2028a8ef
...
@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
detail
::
RPCClient
client
;
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
*
scope
,
NCCL_ID_VARNAME
);
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
*
scope
,
NCCL_ID_VARNAME
);
}
}
client
.
Wait
();
client
->
Wait
();
VLOG
(
3
)
<<
"sending completed..."
;
VLOG
(
3
)
<<
"sending completed..."
;
}
}
...
...
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
2028a8ef
...
@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase {
...
@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
]
<<
" to get "
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
]
<<
" to get "
<<
outs
[
i
]
<<
" back"
;
<<
outs
[
i
]
<<
" back"
;
rpc_client
->
AsyncPrefetchVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
],
rpc_client
->
AsyncPrefetchVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
],
outs
[
i
]);
outs
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
2028a8ef
...
@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase {
...
@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVar
iable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
if
(
sync_mode
)
{
if
(
sync_mode
)
{
rpc_client
->
Wait
();
rpc_client
->
Wait
();
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
2028a8ef
...
@@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
2028a8ef
...
@@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase {
...
@@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
rpc_client
->
AsyncSendVar
iable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
...
@@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase {
...
@@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase {
if
(
outs
.
size
()
>
0
)
{
if
(
outs
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVar
iable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
rpc_client
->
Wait
();
rpc_client
->
Wait
();
// tell pservers that current trainer have called fetch
// tell pservers that current trainer have called fetch
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
2028a8ef
...
@@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase {
...
@@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase {
// For profiling
// For profiling
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
]
<<
" to "
<<
epmap
[
i
];
// TODO(Yancey1989): we need to use an IO threadpool which has
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
// a larger number of threads than the computing threadpool.
rpc_client
->
AsyncSendVar
iable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
rpc_client
->
AsyncSendVar
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
else
{
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
2028a8ef
...
@@ -87,9 +87,10 @@ TEST(SendNcclId, GrpcServer) {
...
@@ -87,9 +87,10 @@ TEST(SendNcclId, GrpcServer) {
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
();
detail
::
RPCClient
*
client
=
LOG
(
INFO
)
<<
"connect to server "
<<
ep
;
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
client
->
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
->
Wait
();
client
->
Wait
();
client
->
AsyncSendBatchBarrier
(
ep
);
client
->
AsyncSendBatchBarrier
(
ep
);
client
->
Wait
();
client
->
Wait
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录