Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
da3087ad
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
da3087ad
编写于
1月 11, 2018
作者:
G
gongweibao
提交者:
GitHub
1月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Async GRPC sendrecv (#7133)
Async GRPC sendrecv
上级
020630b7
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
775 addition
and
127 deletion
+775
-127
paddle/operators/detail/CMakeLists.txt
paddle/operators/detail/CMakeLists.txt
+1
-1
paddle/operators/detail/grpc_client.cc
paddle/operators/detail/grpc_client.cc
+147
-0
paddle/operators/detail/grpc_client.h
paddle/operators/detail/grpc_client.h
+147
-0
paddle/operators/detail/grpc_server.cc
paddle/operators/detail/grpc_server.cc
+237
-0
paddle/operators/detail/grpc_server.h
paddle/operators/detail/grpc_server.h
+91
-0
paddle/operators/detail/recv_impl.cc
paddle/operators/detail/recv_impl.cc
+0
-65
paddle/operators/detail/send_recv.proto
paddle/operators/detail/send_recv.proto
+0
-2
paddle/operators/detail/sendrecvop_utils.cc
paddle/operators/detail/sendrecvop_utils.cc
+68
-0
paddle/operators/detail/sendrecvop_utils.h
paddle/operators/detail/sendrecvop_utils.h
+42
-0
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+20
-23
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+21
-35
paddle/operators/send_recv_op_test.cc
paddle/operators/send_recv_op_test.cc
+1
-1
未找到文件。
paddle/operators/detail/CMakeLists.txt
浏览文件 @
da3087ad
grpc_library
(
sendrecvop_grpc SRCS
recv_impl.cc send_impl
.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
grpc_library
(
sendrecvop_grpc SRCS
sendrecvop_utils.cc grpc_client.cc grpc_server
.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
paddle/operators/detail/grpc_client.cc
0 → 100644
浏览文件 @
da3087ad
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
auto
*
var
=
scope
.
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
ctx
,
&
req
);
// varhandle
VarHandle
var_h
;
var_h
.
ep
=
ep
;
var_h
.
scope
=
&
scope
;
var_h
.
name
=
var_name
;
var_h
.
ctx
=
&
ctx
;
// stub context
auto
ch
=
GetChannel
(
ep
);
SendProcessor
*
s
=
new
SendProcessor
(
ch
);
s
->
Prepare
(
var_h
,
time_out
);
s
->
response_call_back_
=
NULL
;
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
req_count_
++
;
return
true
;
}
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
sendrecv
::
VariableMessage
&
ret_msg
)
{
auto
*
outvar
=
var_h
.
scope
->
FindVar
(
var_h
.
name
);
std
::
istringstream
iss
(
ret_msg
.
serialized
());
DeserializeFromMessage
(
ret_msg
,
*
var_h
.
ctx
,
outvar
);
}
bool
RPCClient
::
AsyncGetVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name
);
auto
*
var
=
scope
.
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
ctx
,
&
req
);
// varhandle
VarHandle
var_h
;
var_h
.
ep
=
ep
;
var_h
.
scope
=
&
scope
;
var_h
.
name
=
var_name
;
var_h
.
ctx
=
&
ctx
;
// stub context
auto
ch
=
GetChannel
(
ep
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
s
->
Prepare
(
var_h
,
time_out
);
s
->
response_call_back_
=
ProcGetResponse
;
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
req_count_
++
;
return
true
;
}
bool
RPCClient
::
wait
()
{
bool
ok
=
true
;
while
(
true
)
{
if
(
req_count_
<=
0
)
{
break
;
}
if
(
!
Proceed
())
{
LOG
(
ERROR
)
<<
"Get meets CompletionQueue error"
;
return
false
;
}
}
return
ok
;
}
bool
RPCClient
::
Proceed
()
{
void
*
tag
=
NULL
;
bool
ok
=
false
;
// request counts.
if
(
!
cq_
.
Next
(
&
tag
,
&
ok
))
{
return
false
;
}
req_count_
--
;
GPR_ASSERT
(
ok
);
PADDLE_ENFORCE
(
tag
);
// TODO(gongwb): add more retries.
ClientBase
*
c
=
static_cast
<
ClientBase
*>
(
tag
);
if
(
!
c
->
status_
.
ok
())
{
delete
c
;
return
true
;
}
c
->
Process
();
delete
c
;
return
true
;
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
}
auto
ch
=
std
::
shared_ptr
<
grpc
::
Channel
>
(
grpc
::
CreateChannel
(
ep
,
grpc
::
InsecureChannelCredentials
()));
channels_
[
ep
]
=
ch
;
return
ch
;
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/operators/detail/grpc_client.h
0 → 100644
浏览文件 @
da3087ad
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <time.h>
#include <chrono>
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "paddle/framework/data_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
struct
VarHandle
{
std
::
string
ep
;
const
platform
::
DeviceContext
*
ctx
;
const
framework
::
Scope
*
scope
;
std
::
string
name
;
std
::
string
String
()
const
{
std
::
ostringstream
s
;
s
<<
"name:["
<<
name
<<
"] ep:["
<<
ep
<<
"]"
;
return
s
.
str
();
}
};
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
sendrecv
::
VariableMessage
&
msg
);
class
ClientBase
{
public:
explicit
ClientBase
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
context_
=
NULL
;
}
virtual
~
ClientBase
()
{}
virtual
void
Prepare
(
const
VarHandle
&
var_info
,
int64_t
time_out
)
{
context_
.
reset
(
new
grpc
::
ClientContext
());
var_h_
=
var_info
;
std
::
chrono
::
system_clock
::
time_point
deadline
=
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
milliseconds
(
time_out
);
context_
->
set_deadline
(
deadline
);
}
virtual
void
Process
()
=
0
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
grpc
::
ClientContext
>
context_
;
grpc
::
Status
status_
;
VarHandle
var_h_
;
};
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
sendrecv
::
VoidMessage
&
)
>
RequestSendCallBack
;
class
SendProcessor
:
public
ClientBase
{
public:
explicit
SendProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
ClientBase
(
ch
)
{}
virtual
~
SendProcessor
()
{}
virtual
void
Process
()
{
if
(
response_call_back_
)
{
response_call_back_
(
var_h_
,
reply_
);
}
}
sendrecv
::
VoidMessage
reply_
;
RequestSendCallBack
response_call_back_
=
NULL
;
};
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
sendrecv
::
VariableMessage
&
)
>
RequestGetCallBack
;
class
GetProcessor
:
public
ClientBase
{
public:
explicit
GetProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
ClientBase
(
ch
)
{}
virtual
~
GetProcessor
()
{}
virtual
void
Process
()
{
if
(
response_call_back_
)
{
response_call_back_
(
var_h_
,
reply_
);
}
}
sendrecv
::
VariableMessage
reply_
;
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
};
class
RPCClient
{
public:
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
AsyncGetVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
wait
();
private:
bool
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
private:
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
int64_t
req_count_
=
0
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/operators/detail/grpc_server.cc
0 → 100644
浏览文件 @
da3087ad
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
using
grpc
::
ServerAsyncResponseWriter
;
namespace
paddle
{
namespace
operators
{
namespace
detail
{
enum
CallStatus
{
PROCESS
=
0
,
FINISH
};
// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class
RequestBase
{
public:
explicit
RequestBase
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
)
{}
virtual
~
RequestBase
()
{}
virtual
void
Process
()
{
assert
(
false
);
}
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
protected:
grpc
::
ServerContext
ctx_
;
sendrecv
::
SendRecvService
::
AsyncService
*
service_
;
grpc
::
ServerCompletionQueue
*
cq_
;
CallStatus
status_
;
};
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestSend
final
:
public
RequestBase
{
public:
explicit
RequestSend
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
SimpleBlockQueue
<
MessageWithName
>*
queue
)
:
RequestBase
(
service
,
cq
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
service_
->
RequestSendVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
virtual
~
RequestSend
()
{}
virtual
void
Process
()
{
MessageWithName
msg_with_name
=
std
::
make_pair
(
request_
.
varname
(),
std
::
move
(
request_
));
queue_
->
Push
(
std
::
move
(
msg_with_name
));
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
}
protected:
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VoidMessage
reply_
;
SimpleBlockQueue
<
MessageWithName
>*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
class
RequestGet
final
:
public
RequestBase
{
public:
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
)
:
RequestBase
(
service
,
cq
),
responder_
(
&
ctx_
),
scope_
(
scope
)
{
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
virtual
~
RequestGet
()
{}
virtual
void
Process
()
{
// proc request.
std
::
string
var_name
=
request_
.
varname
();
auto
*
var
=
scope_
->
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
platform
::
CPUDeviceContext
(),
&
reply_
);
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
}
protected:
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VariableMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
framework
::
Scope
*
scope_
;
};
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
grpc
::
InsecureServerCredentials
());
builder
.
RegisterService
(
&
service_
);
cq_send_
=
builder
.
AddCompletionQueue
();
cq_get_
=
builder
.
AddCompletionQueue
();
server_
=
builder
.
BuildAndStart
();
LOG
(
INFO
)
<<
"Server listening on "
<<
address_
<<
std
::
endl
;
std
::
function
<
void
()
>
send_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewSendOne
,
this
);
std
::
function
<
void
()
>
get_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewGetOne
,
this
);
t_send_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
false
,
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
t_get_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
true
,
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
// wait server
server_
->
Wait
();
t_send_
->
join
();
t_get_
->
join
();
}
void
AsyncGRPCServer
::
ShutdownQueue
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
cq_send_
->
Shutdown
();
cq_get_
->
Shutdown
();
is_shut_down_
=
true
;
}
// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void
AsyncGRPCServer
::
ShutDown
()
{
server_
->
Shutdown
();
ShutdownQueue
();
}
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
&
var_recv_queue_
);
VLOG
(
4
)
<<
"create RequestSend status:"
<<
send
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
);
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
SetFinishOrDelete
(
RequestBase
*&
last
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
delete
last
;
last
=
NULL
;
return
;
}
last
->
SetStatus
(
FINISH
);
return
;
}
void
AsyncGRPCServer
::
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
TryToRegisterNewOne
();
void
*
tag
=
NULL
;
bool
ok
=
false
;
while
(
true
)
{
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" get CompletionQueue shutdown!"
;
break
;
}
if
(
wait
&&
!
done_
)
{
Wait
();
}
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
if
(
!
ok
)
{
VLOG
(
4
)
<<
cq_name
<<
" recv no regular event"
;
TryToRegisterNewOne
();
delete
base
;
continue
;
}
switch
(
base
->
Status
())
{
case
PROCESS
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
TryToRegisterNewOne
();
base
->
Process
();
SetFinishOrDelete
(
base
);
break
;
}
case
FINISH
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
delete
base
;
break
;
}
default:
{
assert
(
false
);
}
}
}
}
void
AsyncGRPCServer
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
}
void
AsyncGRPCServer
::
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
}
void
AsyncGRPCServer
::
Done
()
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
}
condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/operators/detail/grpc_server.h
0 → 100644
浏览文件 @
da3087ad
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/var_type.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <thread>
#include "paddle/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
AsyncGRPCServer
final
:
public
sendrecv
::
SendRecvService
::
Service
{
public:
explicit
AsyncGRPCServer
(
std
::
string
address
)
{
address_
=
address
;
}
void
RunSyncUpdate
();
void
Reset
();
void
Done
();
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
const
MessageWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
MessageWithName
&
msg
)
{
this
->
var_recv_queue_
.
Push
(
msg
);
}
void
ShutDown
();
protected:
void
Wait
();
void
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewGetOne
();
void
SetFinishOrDelete
(
RequestBase
*&
last
);
void
ShutdownQueue
();
private:
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<
grpc
::
ServerCompletionQueue
>
cq_get_
;
sendrecv
::
SendRecvService
::
AsyncService
service_
;
std
::
unique_ptr
<
grpc
::
Server
>
server_
;
std
::
string
address_
;
framework
::
Scope
*
scope_
;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
// condition of the sub program
std
::
mutex
mutex_
;
volatile
mutable
bool
done_
;
std
::
condition_variable
condition_
;
std
::
unique_ptr
<
std
::
thread
>
t_send_
;
std
::
unique_ptr
<
std
::
thread
>
t_get_
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/operators/detail/recv_impl.cc
已删除
100644 → 0
浏览文件 @
020630b7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "send_recv_impl.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
Status
SendRecvServerImpl
::
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VoidMessage
*
out_var
)
{
MessageWithName
msg_with_name
=
std
::
make_pair
(
in_var
->
varname
(),
std
::
move
(
*
in_var
));
var_recv_queue_
.
Push
(
std
::
move
(
msg_with_name
));
return
Status
::
OK
;
}
Status
SendRecvServerImpl
::
GetVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VariableMessage
*
out_var
)
{
std
::
string
get_var_name
=
in_var
->
varname
();
auto
*
var
=
scope_
->
FindVar
(
get_var_name
);
SerializeToMessage
(
get_var_name
,
var
,
platform
::
CPUDeviceContext
(),
out_var
);
return
Status
::
OK
;
}
Status
SendRecvServerImpl
::
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
}
return
Status
::
OK
;
}
void
SendRecvServerImpl
::
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
}
void
SendRecvServerImpl
::
Done
()
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
}
condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/operators/detail/send_recv.proto
浏览文件 @
da3087ad
...
@@ -21,8 +21,6 @@ service SendRecvService {
...
@@ -21,8 +21,6 @@ service SendRecvService {
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
// Argument VariableMessage for GetVariable should only contain varname.
// Argument VariableMessage for GetVariable should only contain varname.
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
// wait for one execution of the program
rpc
Wait
(
VoidMessage
)
returns
(
VoidMessage
)
{}
}
}
// VariableMessage is serialized paddle variable message.
// VariableMessage is serialized paddle variable message.
...
...
paddle/operators/detail/send
_recv_impl.h
→
paddle/operators/detail/send
recvop_utils.cc
浏览文件 @
da3087ad
...
@@ -12,87 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,87 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/var_type.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
using
grpc
::
Channel
;
using
grpc
::
Server
;
using
grpc
::
ServerContext
;
using
grpc
::
ServerReader
;
using
grpc
::
ServerBuilder
;
using
grpc
::
ClientContext
;
using
grpc
::
ClientReader
;
using
grpc
::
ClientReaderWriter
;
using
grpc
::
ClientWriter
;
using
grpc
::
Status
;
using
sendrecv
::
SendRecvService
;
using
sendrecv
::
VariableMessage
;
using
sendrecv
::
VoidMessage
;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
void
SerializeToMessage
(
const
std
::
string
&
name
,
const
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
class
SendRecvServerImpl
final
:
public
SendRecvService
::
Service
{
sendrecv
::
VariableMessage
*
msg
)
{
public:
explicit
SendRecvServerImpl
()
{}
Status
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VoidMessage
*
out_var
)
override
;
Status
GetVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VariableMessage
*
out_var
)
override
;
Status
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
override
;
void
Reset
();
void
Done
();
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
};
const
MessageWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
MessageWithName
&
msg
)
{
this
->
var_recv_queue_
.
Push
(
msg
);
}
private:
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
framework
::
Scope
*
scope_
;
// condition of the sub program
std
::
mutex
mutex_
;
bool
done_
;
std
::
condition_variable
condition_
;
};
// RPCClient is a class to send tensors to pserver sub-network
// using different hashing methods.
class
RPCClient
{
public:
RPCClient
(
std
::
shared_ptr
<
Channel
>
channel
)
:
stub_
(
SendRecvService
::
NewStub
(
channel
))
{}
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
);
bool
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
outname
);
void
Wait
();
private:
std
::
unique_ptr
<
SendRecvService
::
Stub
>
stub_
;
};
inline
void
SerializeToMessage
(
const
std
::
string
&
name
,
const
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VariableMessage
*
msg
)
{
msg
->
set_varname
(
name
);
msg
->
set_varname
(
name
);
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
...
@@ -114,10 +42,9 @@ inline void SerializeToMessage(const std::string &name,
...
@@ -114,10 +42,9 @@ inline void SerializeToMessage(const std::string &name,
msg
->
set_serialized
(
oss
.
str
());
msg
->
set_serialized
(
oss
.
str
());
}
}
inline
void
DeserializeFromMessage
(
const
VariableMessage
&
msg
,
void
DeserializeFromMessage
(
const
sendrecv
::
VariableMessage
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Variable
*
var
)
{
framework
::
Variable
*
var
)
{
using
namespace
paddle
::
framework
::
proto
;
std
::
istringstream
iss
(
msg
.
serialized
());
std
::
istringstream
iss
(
msg
.
serialized
());
switch
(
msg
.
type
())
{
switch
(
msg
.
type
())
{
case
sendrecv
::
VarType
::
LOD_TENSOR
:
case
sendrecv
::
VarType
::
LOD_TENSOR
:
...
...
paddle/operators/detail/send
_impl.cc
→
paddle/operators/detail/send
recvop_utils.h
浏览文件 @
da3087ad
...
@@ -12,56 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,56 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "send_recv_impl.h"
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/framework/data_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/var_type.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
bool
RPCClient
::
SendVariable
(
const
framework
::
Scope
&
scope
,
void
SerializeToMessage
(
const
std
::
string
&
name
,
const
framework
::
Variable
*
var
,
const
std
::
string
&
inname
)
{
const
platform
::
DeviceContext
&
ctx
,
ClientContext
context
;
sendrecv
::
VariableMessage
*
msg
);
VariableMessage
msg
;
VoidMessage
out_msg
;
// FIXME(typhoonzero): pass device context to here.
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
*
var
=
scope
.
FindVar
(
inname
);
PADDLE_ENFORCE
(
var
);
SerializeToMessage
(
inname
,
var
,
ctx
,
&
msg
);
Status
status
=
stub_
->
SendVariable
(
&
context
,
msg
,
&
out_msg
);
if
(
!
status
.
ok
())
{
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
return
false
;
}
return
true
;
}
bool
RPCClient
::
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
outname
)
{
ClientContext
context
;
VariableMessage
call_msg
,
ret_msg
;
call_msg
.
set_varname
(
outname
);
auto
ctx
=
platform
::
CPUDeviceContext
();
Status
status
=
stub_
->
GetVariable
(
&
context
,
call_msg
,
&
ret_msg
);
auto
*
outvar
=
scope
.
FindVar
(
outname
);
if
(
!
status
.
ok
())
{
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
return
false
;
}
std
::
istringstream
iss
(
ret_msg
.
serialized
());
DeserializeFromMessage
(
ret_msg
,
ctx
,
outvar
);
return
true
;
}
void
RPCClient
::
Wait
()
{
ClientContext
context
;
VoidMessage
call_msg
,
ret_msg
;
stub_
->
Wait
(
&
context
,
call_msg
,
&
ret_msg
);
}
void
DeserializeFromMessage
(
const
sendrecv
::
VariableMessage
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Variable
*
var
);
}
// namespace detail
}
// namespace detail
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/recv_op.cc
浏览文件 @
da3087ad
...
@@ -24,7 +24,8 @@ limitations under the License. */
...
@@ -24,7 +24,8 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/operators/detail/send_recv_impl.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/simple_block_queue.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
...
@@ -32,6 +33,11 @@ limitations under the License. */
...
@@ -32,6 +33,11 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
sendrecv
::
VarType
var_type
)
{
sendrecv
::
VarType
var_type
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
...
@@ -46,18 +52,6 @@ static void CreateTensorFromMessageType(framework::Variable *var,
...
@@ -46,18 +52,6 @@ static void CreateTensorFromMessageType(framework::Variable *var,
}
}
}
}
void
RunServer
(
Server
**
rpc_server
,
std
::
shared_ptr
<
detail
::
SendRecvServerImpl
>
service
,
const
std
::
string
&
server_address
)
{
ServerBuilder
builder
;
builder
.
AddListeningPort
(
server_address
,
grpc
::
InsecureServerCredentials
());
builder
.
RegisterService
(
service
.
get
());
std
::
unique_ptr
<
Server
>
server
(
builder
.
BuildAndStart
());
*
rpc_server
=
server
.
get
();
LOG
(
INFO
)
<<
"Server listening on "
<<
server_address
;
server
->
Wait
();
}
class
RecvOp
:
public
framework
::
OperatorBase
{
class
RecvOp
:
public
framework
::
OperatorBase
{
public:
public:
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
...
@@ -65,10 +59,9 @@ class RecvOp : public framework::OperatorBase {
...
@@ -65,10 +59,9 @@ class RecvOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
if
(
!
rpc_service_
)
{
if
(
!
rpc_service_
)
{
rpc_service_
.
reset
(
new
detail
::
SendRecvServerImpl
());
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
server_thread_
.
reset
(
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
new
std
::
thread
(
RunServer
,
&
rpc_server_
,
rpc_service_
,
endpoint
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
}
}
}
}
...
@@ -76,7 +69,7 @@ class RecvOp : public framework::OperatorBase {
...
@@ -76,7 +69,7 @@ class RecvOp : public framework::OperatorBase {
detail
::
MessageWithName
term_msg
;
detail
::
MessageWithName
term_msg
;
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
rpc_service_
->
Push
(
term_msg
);
rpc_service_
->
Push
(
term_msg
);
rpc_serv
er_
->
Shutd
own
();
rpc_serv
ice_
->
ShutD
own
();
server_thread_
->
join
();
server_thread_
->
join
();
}
}
...
@@ -99,10 +92,12 @@ class RecvOp : public framework::OperatorBase {
...
@@ -99,10 +92,12 @@ class RecvOp : public framework::OperatorBase {
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
size_t
param_count
=
param_list
.
size
();
size_t
param_count
=
param_list
.
size
();
rpc_service_
->
Reset
();
rpc_service_
->
Reset
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
while
(
!
exit_flag
)
{
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
// the gradient arrives, just add suffix 0~n then average the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
...
@@ -110,6 +105,7 @@ class RecvOp : public framework::OperatorBase {
...
@@ -110,6 +105,7 @@ class RecvOp : public framework::OperatorBase {
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
VLOG
(
4
)
<<
"received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit"
;
exit_flag
=
true
;
exit_flag
=
true
;
break
;
break
;
}
}
...
@@ -118,10 +114,12 @@ class RecvOp : public framework::OperatorBase {
...
@@ -118,10 +114,12 @@ class RecvOp : public framework::OperatorBase {
if
(
it
!=
grad_list
.
end
())
{
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
}
else
{
LOG
(
ERROR
)
<<
"grad have no paired param found!"
;
LOG
(
ERROR
)
<<
"grad have no paired param found!
\"
"
<<
grad_var_name
<<
"
\"
"
;
}
}
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
<<
" updating param: "
<<
param_var_name
;
auto
*
merged_grad
=
recv_scope
.
FindVar
(
grad_var_name
);
auto
*
merged_grad
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
merged_grad
==
nullptr
)
{
if
(
merged_grad
==
nullptr
)
{
auto
*
ptr
=
recv_scope
.
Var
(
grad_var_name
);
auto
*
ptr
=
recv_scope
.
Var
(
grad_var_name
);
...
@@ -141,9 +139,11 @@ class RecvOp : public framework::OperatorBase {
...
@@ -141,9 +139,11 @@ class RecvOp : public framework::OperatorBase {
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
}
if
(
exit_flag
)
{
if
(
exit_flag
)
{
break
;
break
;
}
}
rpc_service_
->
Reset
();
rpc_service_
->
Reset
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
...
@@ -158,17 +158,14 @@ class RecvOp : public framework::OperatorBase {
...
@@ -158,17 +158,14 @@ class RecvOp : public framework::OperatorBase {
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
}
rpc_service_
->
Done
();
rpc_service_
->
Done
();
grads_counter_
.
clear
();
grads_counter_
.
clear
();
}
// while(true)
}
// while(true)
}
}
protected:
protected:
// grpc server instance to track status and gracefully shutdown.
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
// borrow an pointer from server thread.
Server
*
rpc_server_
{
nullptr
};
// grpc send/recv service implement to register.
std
::
shared_ptr
<
detail
::
SendRecvServerImpl
>
rpc_service_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
};
};
...
...
paddle/operators/send_op.cc
浏览文件 @
da3087ad
...
@@ -19,59 +19,45 @@ limitations under the License. */
...
@@ -19,59 +19,45 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include
"paddle/operators/detail/send_recv_impl.h"
#include
<future>
#include "paddle/operators/detail/
simple_block_queue
.h"
#include "paddle/operators/detail/
grpc_client
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
// TODO(typhoonzero): this is a simple implementation which only send
// one tensor
class
SendOp
:
public
framework
::
OperatorBase
{
class
SendOp
:
public
framework
::
OperatorBase
{
public:
public:
SendOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
SendOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
// init client when the operator is created at runtime.
std
::
vector
<
std
::
string
>
endpoints
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
for
(
auto
ep
:
endpoints
)
{
client_map_
[
ep
].
reset
(
new
detail
::
RPCClient
(
grpc
::
CreateChannel
(
ep
,
grpc
::
InsecureChannelCredentials
())));
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
ins
=
Inputs
(
"X"
);
auto
ins
=
Inputs
(
"X"
);
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
// TODO(typhoonzero): use async calls to send multiple variable asyncly.
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
// FIXME(gongwb): DeviceContext?
bool
ret
=
client_map_
[
epmap
[
i
]]
->
SendVariable
(
scope
,
ins
[
i
]);
auto
ctx
=
platform
::
CPUDeviceContext
();
if
(
!
ret
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
LOG
(
ERROR
)
<<
"send variable error: "
<<
ins
[
i
];
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
}
// TODO(typhoonzero): support async optimization
client_map_
[
epmap
[
0
]]
->
Wait
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
++
i
)
{
bool
ret
=
client_map_
[
epmap
[
i
]]
->
GetVariable
(
scope
,
outs
[
i
]);
if
(
!
ret
)
{
LOG
(
ERROR
)
<<
"GetVariable error: "
<<
outs
[
i
];
}
}
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
client_
.
wait
();
}
}
protected:
private:
mutable
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
detail
::
RPCClient
>>
mutable
detail
::
RPCClient
client_
;
client_map_
;
};
};
class
SendOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
SendOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
SendOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SendOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor) Input tensor to be send"
).
AsDuplicable
();
AddInput
(
"X"
,
"(Tensor) Input tensor to be send"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor to get from server"
)
AddOutput
(
"Out"
,
"(Tensor) Output tensor to get from server"
)
...
...
paddle/operators/send_recv_op_test.cc
浏览文件 @
da3087ad
...
@@ -140,7 +140,7 @@ void StartServerNet(bool is_sparse) {
...
@@ -140,7 +140,7 @@ void StartServerNet(bool is_sparse) {
TEST
(
SendRecvOp
,
CPUDense
)
{
TEST
(
SendRecvOp
,
CPUDense
)
{
std
::
thread
server_thread
(
StartServerNet
,
false
);
std
::
thread
server_thread
(
StartServerNet
,
false
);
sleep
(
3
);
// wait server to start
sleep
(
10
);
// wait server to start
// local net
// local net
f
::
Scope
scope
;
f
::
Scope
scope
;
p
::
CPUPlace
place
;
p
::
CPUPlace
place
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录