Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0763ae9a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0763ae9a
编写于
4月 20, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove unused file
上级
dc3d2dc8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
0 addition
and
691 deletion
+0
-691
paddle/fluid/operators/async_listen_and_serv_op.cc
paddle/fluid/operators/async_listen_and_serv_op.cc
+0
-214
paddle/fluid/operators/async_listen_and_serv_op.h
paddle/fluid/operators/async_listen_and_serv_op.h
+0
-55
paddle/fluid/operators/detail/async_grpc_server.cc
paddle/fluid/operators/detail/async_grpc_server.cc
+0
-311
paddle/fluid/operators/detail/async_grpc_server.h
paddle/fluid/operators/detail/async_grpc_server.h
+0
-111
未找到文件。
paddle/fluid/operators/async_listen_and_serv_op.cc
已删除
100644 → 0
浏览文件 @
dc3d2dc8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <ostream>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/async_listen_and_serv_op.h"
#include "paddle/utils/StringUtil.h"
namespace
paddle
{
namespace
operators
{
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
void
RunServer
(
std
::
shared_ptr
<
detail
::
SyncGRPCServer
>
service
)
{
service
->
RunAsyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
AsyncExecuteBlock
(
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
,
framework
::
Scope
*
scope
)
{
framework
::
Async
([
&
executor
,
&
prepared
,
&
scope
]()
{
try
{
executor
->
RunPreparedContext
(
prepared
,
scope
,
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
}
AsyncListenAndServOp
::
AsyncListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
int
AsyncListenAndServOp
::
GetSelectedPort
()
const
{
return
rpc_service_
->
GetSelectedPort
();
}
void
AsyncListenAndServOp
::
Stop
()
{
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
server_thread_
->
join
();
}
void
AsyncListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
if
(
!
rpc_service_
)
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
SyncGRPCServer
(
endpoint
));
}
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
auto
grad_map_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_id"
);
for
(
auto
&
grad_and_id
:
grad_map_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_and_id
,
' '
,
&
pieces
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_id
[
pieces
[
0
]]
=
block_id
;
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
program
=
optimize_block
->
Program
();
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
framework
::
Executor
executor
(
dev_place
);
std
::
vector
<
int
>
block_list
;
for
(
size_t
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
if
(
blkid
!=
static_cast
<
size_t
>
(
prefetch_block
->
ID
()))
{
block_list
.
push_back
(
blkid
);
}
}
PADDLE_ENFORCE_EQ
(
grad_map_str
.
size
(),
block_list
.
size
(),
"grad num should be equal to optimize block num"
);
auto
optimize_prepared
=
executor
.
Prepare
(
*
program
,
block_list
);
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared
;
for
(
size_t
i
=
0
;
i
<
block_list
.
size
();
++
i
)
{
grad_to_prepared
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
// set proper fields for table lookup and update
rpc_service_
->
SetExecutor
(
&
executor
);
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
rpc_service_
->
SetPrefetchPreparedCtx
(
prefetch_prepared
.
get
());
prefetch_prepared
.
release
();
rpc_service_
->
SetProgram
(
program
);
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
sleep
(
5
);
// Write to a file of server selected port for python use.
std
::
ofstream
port_file
;
port_file
.
open
(
"/tmp/paddle.selected_port"
);
port_file
<<
rpc_service_
->
GetSelectedPort
();
port_file
.
close
();
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
AsyncExecuteBlock
(
&
executor
,
grad_to_prepared
[
recv_var_name
].
get
(),
&
recv_scope
);
// TODO(qiao): explain why
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
}
// while(true)
}
class
AsyncListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
AsyncListenAndServOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor) Variables that server recv."
).
AsDuplicable
();
AddComment
(
R"DOC(
ListenAndServ operator
This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC"
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
"(string, default 127.0.0.1:6164)"
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])"
,
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
,
"prefetch block to run on server side."
);
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
.
SetDefault
(
1
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
async_listen_and_serv
,
ops
::
AsyncListenAndServOp
,
ops
::
AsyncListenAndServOpMaker
);
paddle/fluid/operators/async_listen_and_serv_op.h
已删除
100644 → 0
浏览文件 @
dc3d2dc8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/async_grpc_server.h"
namespace
paddle
{
namespace
operators
{
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
SyncGRPCServer
>
service
);
class
AsyncListenAndServOp
:
public
framework
::
OperatorBase
{
public:
AsyncListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
int
GetSelectedPort
()
const
;
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
protected:
mutable
std
::
shared_ptr
<
detail
::
SyncGRPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/async_grpc_server.cc
已删除
100644 → 0
浏览文件 @
dc3d2dc8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/async_grpc_server.h"
#include <limits>
#include <string>
using
::
grpc
::
ServerAsyncResponseWriter
;
namespace
paddle
{
namespace
operators
{
namespace
detail
{
enum
CallStatus
{
PROCESS
=
0
,
FINISH
};
// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class
RequestBase
{
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
PADDLE_ENFORCE
(
cq_
);
}
virtual
~
RequestBase
()
{}
virtual
void
Process
()
{
assert
(
false
);
}
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
return
""
;
}
protected:
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
};
class
RequestSend
final
:
public
RequestBase
{
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
true
,
scope
,
dev_ctx_
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
}
virtual
~
RequestSend
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
virtual
void
Process
()
{
queue_
->
Push
(
std
::
make_pair
(
request_
->
Varname
(),
request_
));
sendrecv
::
VoidMessage
reply
;
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
}
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
ReceivedQueue
*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
class
RequestGet
final
:
public
RequestBase
{
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
)
{
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
virtual
~
RequestGet
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
varname
();
}
virtual
void
Process
()
{
// proc request.
std
::
string
var_name
=
request_
.
varname
();
auto
*
var
=
scope_
->
FindVar
(
var_name
);
::
grpc
::
ByteBuffer
reply
;
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
}
protected:
sendrecv
::
VariableMessage
request_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
};
class
RequestPrefetch
final
:
public
RequestBase
{
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
executor_
(
executor
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
)
{
request_
.
reset
(
new
VariableResponse
(
true
,
scope
,
dev_ctx_
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
}
virtual
~
RequestPrefetch
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
virtual
void
Process
()
{
// prefetch process...
::
grpc
::
ByteBuffer
reply
;
std
::
string
var_name
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"prefetch var "
<<
var_name
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
&
scope_
->
NewScope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
scope_
,
false
,
false
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
}
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
};
void
SyncGRPCServer
::
RunAsyncUpdate
()
{
::
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
::
grpc
::
InsecureServerCredentials
(),
&
selected_port_
);
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
RegisterService
(
&
service_
);
cq_send_
=
builder
.
AddCompletionQueue
();
cq_get_
=
builder
.
AddCompletionQueue
();
cq_prefetch_
=
builder
.
AddCompletionQueue
();
server_
=
builder
.
BuildAndStart
();
LOG
(
INFO
)
<<
"Server listening on "
<<
address_
<<
" selected port: "
<<
selected_port_
;
std
::
function
<
void
()
>
send_register
=
std
::
bind
(
&
SyncGRPCServer
::
TryToRegisterNewSendOne
,
this
);
std
::
function
<
void
()
>
get_register
=
std
::
bind
(
&
SyncGRPCServer
::
TryToRegisterNewGetOne
,
this
);
std
::
function
<
void
()
>
prefetch_register
=
std
::
bind
(
&
SyncGRPCServer
::
TryToRegisterNewPrefetchOne
,
this
);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
SyncGRPCServer
::
HandleRequest
,
this
,
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
t_get_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
SyncGRPCServer
::
HandleRequest
,
this
,
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
t_prefetch_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
SyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
// wait server
server_
->
Wait
();
t_send_
->
join
();
t_get_
->
join
();
t_prefetch_
->
join
();
}
void
SyncGRPCServer
::
ShutdownQueue
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
cq_send_
->
Shutdown
();
cq_get_
->
Shutdown
();
cq_prefetch_
->
Shutdown
();
}
// This URL explains why shutdown is complicate:
void
SyncGRPCServer
::
ShutDown
()
{
is_shut_down_
=
true
;
ShutdownQueue
();
server_
->
Shutdown
();
}
void
SyncGRPCServer
::
TryToRegisterNewSendOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
scope_
,
&
var_recv_queue_
,
dev_ctx_
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
void
SyncGRPCServer
::
TryToRegisterNewGetOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
void
SyncGRPCServer
::
TryToRegisterNewPrefetchOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
return
;
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
void
SyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
TryToRegisterNewOne
();
void
*
tag
=
NULL
;
bool
ok
=
false
;
while
(
true
)
{
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" while in"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" CompletionQueue shutdown!"
;
break
;
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" while after Next"
;
PADDLE_ENFORCE
(
tag
);
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
);
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name["
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
();
delete
base
;
continue
;
}
switch
(
base
->
Status
())
{
case
PROCESS
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
TryToRegisterNewOne
();
base
->
Process
();
break
;
}
case
FINISH
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
delete
base
;
break
;
}
default:
{
assert
(
false
);
}
}
}
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/async_grpc_server.h
已删除
100644 → 0
浏览文件 @
dc3d2dc8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <thread> // NOLINT
#include <utility>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
VariableResponse
>>
ReceivedMessage
;
typedef
SimpleBlockQueue
<
ReceivedMessage
>
ReceivedQueue
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
SyncGRPCServer
final
{
public:
explicit
SyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{}
void
RunAsyncUpdate
();
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
framework
::
ExecutorPrepareContext
*
prepared
)
{
prefetch_ctx_
=
prepared
;
}
int
GetSelectedPort
()
{
return
selected_port_
;
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
std
::
string
&
msg_name
)
{
this
->
var_recv_queue_
.
Push
(
std
::
make_pair
(
msg_name
,
nullptr
));
}
void
ShutDown
();
protected:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewGetOne
();
void
TryToRegisterNewPrefetchOne
();
void
ShutdownQueue
();
private:
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_prefetch_
;
GrpcService
::
AsyncService
service_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// client send variable to this queue.
ReceivedQueue
var_recv_queue_
;
std
::
unique_ptr
<
std
::
thread
>
t_send_
;
std
::
unique_ptr
<
std
::
thread
>
t_get_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录