Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1b8bc3c5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1b8bc3c5
编写于
1月 29, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename rpc ops
上级
686024c0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
58 addition
and
162 deletion
+58
-162
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+38
-150
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+10
-3
paddle/operators/send_recv_op_test.cc
paddle/operators/send_recv_op_test.cc
+10
-9
未找到文件。
paddle/operators/recv_op.cc
浏览文件 @
1b8bc3c5
...
...
@@ -12,179 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/framework/executor.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#include <future>
#include "paddle/operators/detail/grpc_client.h"
namespace
paddle
{
namespace
operators
{
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
sendrecv
::
VarType
var_type
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
var_type
==
sendrecv
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
();
}
else
{
PADDLE_THROW
(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]"
,
var_type
);
}
}
class
RecvOp
:
public
framework
::
OperatorBase
{
class
SendOp
:
public
framework
::
OperatorBase
{
public:
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
if
(
!
rpc_service_
)
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
SendOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
void
Stop
()
override
{
detail
::
MessageWithName
term_msg
;
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
rpc_service_
->
Push
(
term_msg
);
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
}
std
::
string
GetGradVarNameForTrainer
(
const
std
::
string
&
varname
)
const
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
}
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
size_t
param_count
=
param_list
.
size
();
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
program
=
block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
size_t
barrier_size
=
param_count
*
fan_in
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
for
(
size_t
i
=
0
;
i
<
barrier_size
;
++
i
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
auto
it
=
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
LOG
(
ERROR
)
<<
"grad has no paired param:"
<<
grad_var_name
;
}
VLOG
(
3
)
<<
"received grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
if
(
fan_in
>
1
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
grad_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
if
(
exit_flag
)
{
break
;
}
try
{
executor
.
Run
(
*
program
,
&
recv_scope
,
block
->
ID
(),
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
barrier_size
);
grads_counter_
.
clear
();
}
// while(true)
PADDLE_ENFORCE
(
client_
.
Wait
());
}
protected:
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
private:
mutable
detail
::
RPCClient
client_
;
};
class
Recv
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Send
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
RecvOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SendOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"RX"
,
"(Tensor) Input tensor to be optimized"
).
AsDuplicable
();
AddInput
(
"X"
,
"(Tensor) Input tensor to be sent"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor to be received from server"
)
.
AsDuplicable
();
AddComment
(
R"DOC(
Recv
operator
Send
operator
This operator will
recieve tensor from send_op
This operator will
send tensor to recv_op at the parameter server.
)DOC"
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
"(string, default 127.0.0.1:6164)"
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"Serialized ProgramDesc string for recv to run."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
}
};
...
...
@@ -193,4 +81,4 @@ This operator will recieve tensor from send_op
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
recv
,
ops
::
RecvOp
,
ops
::
Recv
OpMaker
);
REGISTER_OPERATOR
(
send
,
ops
::
SendOp
,
ops
::
Send
OpMaker
);
paddle/operators/send_op.cc
浏览文件 @
1b8bc3c5
...
...
@@ -37,6 +37,7 @@ class SendOp : public framework::OperatorBase {
auto
ins
=
Inputs
(
"X"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
bool
do_get
=
Attr
<
bool
>
(
"DoGet"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
...
...
@@ -46,9 +47,11 @@ class SendOp : public framework::OperatorBase {
}
PADDLE_ENFORCE
(
client_
.
Wait
());
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
if
(
do_get
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
PADDLE_ENFORCE
(
client_
.
Wait
());
...
...
@@ -79,6 +82,10 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"DoGet"
,
"(bool, default true)"
"Whether do GetVariable call after send"
)
.
SetDefault
(
true
);
}
};
...
...
paddle/operators/send_recv_op_test.cc
浏览文件 @
1b8bc3c5
...
...
@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/string/printf.h"
USE_NO_KERNEL_OP
(
send
);
USE_NO_KERNEL_OP
(
rec
v
);
USE_NO_KERNEL_OP
(
listen_and_ser
v
);
USE_OP
(
sum
);
namespace
f
=
paddle
::
framework
;
...
...
@@ -33,7 +33,7 @@ namespace p = paddle::platform;
namespace
m
=
paddle
::
operators
::
math
;
// global for simplicity.
std
::
unique_ptr
<
f
::
OperatorBase
>
rec
v_op
;
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_ser
v_op
;
void
InitTensorsInScope
(
f
::
Scope
&
scope
,
p
::
CPUPlace
&
place
)
{
p
::
CPUDeviceContext
ctx
(
place
);
...
...
@@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) {
InitTensorsInScope
(
scope
,
place
);
}
// sub program run in
rec
v_op, for simple test we use sum
// sub program run in
listen_and_ser
v_op, for simple test we use sum
f
::
ProgramDesc
program
;
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// X for server side tensors, RX for received tensers, must be of same shape.
...
...
@@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) {
attrs
.
insert
({
"ParamList"
,
std
::
vector
<
std
::
string
>
({
"Out"
})});
attrs
.
insert
({
"GradList"
,
std
::
vector
<
std
::
string
>
({
"x1"
})});
attrs
.
insert
({
"OptimizeBlock"
,
block
});
recv_op
=
f
::
OpRegistry
::
CreateOp
(
"recv"
,
{{
"RX"
,
{
"x1"
}}},
{},
attrs
);
recv_op
->
Run
(
scope
,
place
);
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"RX"
,
{
"x1"
}}},
{},
attrs
);
listen_and_serv_op
->
Run
(
scope
,
place
);
}
TEST
(
SendRecvOp
,
CPUDense
)
{
...
...
@@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) {
for
(
int64_t
i
=
0
;
i
<
target
->
numel
();
++
i
)
{
EXPECT_EQ
(
expected
[
i
]
*
2
,
actual
[
i
]);
}
rec
v_op
->
Stop
();
listen_and_ser
v_op
->
Stop
();
server_thread
.
join
();
rec
v_op
.
reset
(
nullptr
);
listen_and_ser
v_op
.
reset
(
nullptr
);
}
TEST
(
SendRecvOp
,
CPUSparse
)
{
...
...
@@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) {
EXPECT_EQ
(
expect_value
->
mutable_data
<
float
>
(
place
)[
i
],
actual
->
mutable_data
<
float
>
(
place
)[
i
]);
}
rec
v_op
->
Stop
();
listen_and_ser
v_op
->
Stop
();
server_thread
.
join
();
rec
v_op
.
reset
();
listen_and_ser
v_op
.
reset
();
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录