Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1b8bc3c5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1b8bc3c5
编写于
1月 29, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename rpc ops
上级
686024c0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
58 addition
and
162 deletion
+58
-162
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+38
-150
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+10
-3
paddle/operators/send_recv_op_test.cc
paddle/operators/send_recv_op_test.cc
+10
-9
未找到文件。
paddle/operators/recv_op.cc
浏览文件 @
1b8bc3c5
...
...
@@ -12,179 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/framework/executor.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#include <future>
#include "paddle/operators/detail/grpc_client.h"
namespace
paddle
{
namespace
operators
{
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
sendrecv
::
VarType
var_type
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
var_type
==
sendrecv
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
();
}
else
{
PADDLE_THROW
(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]"
,
var_type
);
}
}
class
RecvOp
:
public
framework
::
OperatorBase
{
class
SendOp
:
public
framework
::
OperatorBase
{
public:
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
if
(
!
rpc_service_
)
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
SendOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
void
Stop
()
override
{
detail
::
MessageWithName
term_msg
;
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
rpc_service_
->
Push
(
term_msg
);
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
}
std
::
string
GetGradVarNameForTrainer
(
const
std
::
string
&
varname
)
const
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
}
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
size_t
param_count
=
param_list
.
size
();
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
program
=
block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
size_t
barrier_size
=
param_count
*
fan_in
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
for
(
size_t
i
=
0
;
i
<
barrier_size
;
++
i
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
auto
it
=
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
LOG
(
ERROR
)
<<
"grad has no paired param:"
<<
grad_var_name
;
}
VLOG
(
3
)
<<
"received grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
if
(
fan_in
>
1
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
grad_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
if
(
exit_flag
)
{
break
;
}
try
{
executor
.
Run
(
*
program
,
&
recv_scope
,
block
->
ID
(),
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
barrier_size
);
grads_counter_
.
clear
();
}
// while(true)
PADDLE_ENFORCE
(
client_
.
Wait
());
}
protected:
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
private:
mutable
detail
::
RPCClient
client_
;
};
class
Recv
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Send
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
RecvOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SendOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"RX"
,
"(Tensor) Input tensor to be optimized"
).
AsDuplicable
();
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
.
AsDuplicable
();
AddComment
(
R"DOC(
Recv
operator
Send
operator
This operator will
recieve tensor from send_op
This operator will
send tensor to recv_op at the parameter server.
)DOC"
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
"(string, default 127.0.0.1:6164)"
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"Serialized ProgramDesc string for recv to run."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
}
};
...
...
@@ -193,4 +81,4 @@ This operator will recieve tensor from send_op
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
recv
,
ops
::
RecvOp
,
ops
::
Recv
OpMaker
);
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
ops
::
Send
OpMaker
);
paddle/operators/send_op.cc
浏览文件 @
1b8bc3c5
...
...
@@ -37,6 +37,7 @@ class SendOp : public framework::OperatorBase {
auto
ins
=
Inputs
(
"X"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
bool
do_get
=
Attr
<
bool
>
(
"DoGet"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
...
...
@@ -46,9 +47,11 @@ class SendOp : public framework::OperatorBase {
}
PADDLE_ENFORCE
(
client_
.
Wait
());
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
if
(
do_get
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
PADDLE_ENFORCE
(
client_
.
Wait
());
...
...
@@ -79,6 +82,10 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"DoGet"
,
"(bool, default true)"
"Whether do GetVariable call after send"
)
.
SetDefault
(
true
);
}
};
...
...
paddle/operators/send_recv_op_test.cc
浏览文件 @
1b8bc3c5
...
...
@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/string/printf.h"
USE_NO_KERNEL_OP
(
send
);
USE_NO_KERNEL_OP
(
rec
v
);
USE_NO_KERNEL_OP
(
listen_and_ser
v
);
USE_OP
(
sum
);
namespace
f
=
paddle
::
framework
;
...
...
@@ -33,7 +33,7 @@ namespace p = paddle::platform;
namespace
m
=
paddle
::
operators
::
math
;
// global for simplicity.
std
::
unique_ptr
<
f
::
OperatorBase
>
rec
v_op
;
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_ser
v_op
;
void
InitTensorsInScope
(
f
::
Scope
&
scope
,
p
::
CPUPlace
&
place
)
{
p
::
CPUDeviceContext
ctx
(
place
);
...
...
@@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) {
InitTensorsInScope
(
scope
,
place
);
}
// sub program run in
rec
v_op, for simple test we use sum
// sub program run in
listen_and_ser
v_op, for simple test we use sum
f
::
ProgramDesc
program
;
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// X for server side tensors, RX for received tensers, must be of same shape.
...
...
@@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) {
attrs
.
insert
({
"ParamList"
,
std
::
vector
<
std
::
string
>
({
"Out"
})});
attrs
.
insert
({
"GradList"
,
std
::
vector
<
std
::
string
>
({
"x1"
})});
attrs
.
insert
({
"OptimizeBlock"
,
block
});
recv_op
=
f
::
OpRegistry
::
CreateOp
(
"recv"
,
{{
"RX"
,
{
"x1"
}}},
{},
attrs
);
recv_op
->
Run
(
scope
,
place
);
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"RX"
,
{
"x1"
}}},
{},
attrs
);
listen_and_serv_op
->
Run
(
scope
,
place
);
}
TEST
(
SendRecvOp
,
CPUDense
)
{
...
...
@@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) {
for
(
int64_t
i
=
0
;
i
<
target
->
numel
();
++
i
)
{
EXPECT_EQ
(
expected
[
i
]
*
2
,
actual
[
i
]);
}
rec
v_op
->
Stop
();
listen_and_ser
v_op
->
Stop
();
server_thread
.
join
();
rec
v_op
.
reset
(
nullptr
);
listen_and_ser
v_op
.
reset
(
nullptr
);
}
TEST
(
SendRecvOp
,
CPUSparse
)
{
...
...
@@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) {
EXPECT_EQ
(
expect_value
->
mutable_data
<
float
>
(
place
)[
i
],
actual
->
mutable_data
<
float
>
(
place
)[
i
]);
}
rec
v_op
->
Stop
();
listen_and_ser
v_op
->
Stop
();
server_thread
.
join
();
rec
v_op
.
reset
();
listen_and_ser
v_op
.
reset
();
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录