Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
82e41ce3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82e41ce3
编写于
2月 01, 2018
作者:
武
武毅
提交者:
GitHub
2月 01, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7947 from typhoonzero/rename_rpc_ops
Rename rpc ops
上级
1c2b071a
4d12a813
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
294 addition
and
177 deletion
+294
-177
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+4
-2
paddle/operators/listen_and_serv_op.cc
paddle/operators/listen_and_serv_op.cc
+207
-0
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+30
-157
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+8
-4
paddle/operators/send_recv_op_test.cc
paddle/operators/send_recv_op_test.cc
+10
-9
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+2
-2
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+1
-1
python/paddle/v2/fluid/layers/io.py
python/paddle/v2/fluid/layers/io.py
+30
-2
python/paddle/v2/fluid/tests/test_recv_op.py
python/paddle/v2/fluid/tests/test_recv_op.py
+2
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
82e41ce3
...
@@ -122,9 +122,11 @@ if(WITH_DISTRIBUTE)
...
@@ -122,9 +122,11 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
send_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
send_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
op_library
(
recv_op DEPS
${
DISTRIBUTE_DEPS
}
)
op_library
(
recv_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
recv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
recv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor
)
op_library
(
listen_and_serv_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
test_send_recv SRCS send_recv_op_test.cc DEPS send_op listen_and_serv_op sum_op executor
)
else
()
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op recv_op
)
set
(
DEPS_OPS
${
DEPS_OPS
}
send_op recv_op
listen_and_serv_op
)
endif
()
endif
()
op_library
(
cond_op DEPS framework_proto tensor net_op
)
op_library
(
cond_op DEPS framework_proto tensor net_op
)
...
...
paddle/operators/listen_and_serv_op.cc
0 → 100644
浏览文件 @
82e41ce3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
namespace
paddle
{
namespace
operators
{
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
sendrecv
::
VarType
var_type
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
var_type
==
sendrecv
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
();
}
else
{
PADDLE_THROW
(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]"
,
var_type
);
}
}
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
public:
ListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
if
(
!
rpc_service_
)
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
}
}
void
Stop
()
override
{
detail
::
MessageWithName
term_msg
;
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
rpc_service_
->
Push
(
term_msg
);
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
}
std
::
string
GetGradVarNameForTrainer
(
const
std
::
string
&
varname
)
const
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
}
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
program
=
block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
size_t
recv_var_cnt
=
0
;
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
if
(
grad_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
continue
;
}
else
{
// receive a variable
recv_var_cnt
++
;
auto
it
=
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
LOG
(
ERROR
)
<<
"grad has no paired param:"
<<
grad_var_name
;
}
VLOG
(
3
)
<<
"received grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
if
(
fan_in
>
1
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
grad_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
}
VLOG
(
3
)
<<
"recv "
<<
recv_var_cnt
<<
" parmeters for one barrier."
;
// TODO(Yancey1989): merge SelectedRows variables here
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
}
try
{
executor
.
Run
(
*
program
,
&
recv_scope
,
block
->
ID
(),
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
recv_var_cnt
);
grads_counter_
.
clear
();
}
// while(true)
}
protected:
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
};
class
ListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
ListenAndServOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddComment
(
R"DOC(
ListenAndServ operator
This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC"
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
"(string, default 127.0.0.1:6164)"
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
listen_and_serv
,
ops
::
ListenAndServOp
,
ops
::
ListenAndServOpMaker
);
\ No newline at end of file
paddle/operators/recv_op.cc
浏览文件 @
82e41ce3
...
@@ -12,187 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,187 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/framework/data_type.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/operators/detail/grpc_server.h"
#include <future>
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/grpc_client.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
sendrecv
::
VarType
var_type
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
var_type
==
sendrecv
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
();
}
else
{
PADDLE_THROW
(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]"
,
var_type
);
}
}
class
RecvOp
:
public
framework
::
OperatorBase
{
class
RecvOp
:
public
framework
::
OperatorBase
{
public:
public:
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
if
(
!
rpc_service_
)
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
void
Run
(
const
framework
::
Scope
&
scope
,
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
const
platform
::
Place
&
place
)
const
override
{
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
auto
outs
=
Outputs
(
"Out"
);
}
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
PADDLE_ENFORCE
(
client_
.
Wait
());
void
Stop
()
override
{
detail
::
MessageWithName
term_msg
;
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
rpc_service_
->
Push
(
term_msg
);
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
}
std
::
string
GetGradVarNameForTrainer
(
const
std
::
string
&
varname
)
const
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
}
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
program
=
block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
size_t
recv_var_cnt
=
0
;
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
if
(
grad_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
continue
;
}
else
{
// receive a variable
recv_var_cnt
++
;
auto
it
=
std
::
find
(
grad_list
.
begin
(),
grad_list
.
end
(),
grad_var_name
);
std
::
string
param_var_name
;
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
LOG
(
ERROR
)
<<
"grad has no paired param:"
<<
grad_var_name
;
}
}
VLOG
(
3
)
<<
"received grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
if
(
fan_in
>
1
)
{
private:
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
mutable
detail
::
RPCClient
client_
;
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
grad_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
}
VLOG
(
3
)
<<
"recv "
<<
recv_var_cnt
<<
" parmeters for one barrier."
;
// TODO(Yancey1989): merge SelectedRows variables here
if
(
exit_flag
)
{
break
;
}
try
{
executor
.
Run
(
*
program
,
&
recv_scope
,
block
->
ID
(),
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
recv_var_cnt
);
grads_counter_
.
clear
();
}
// while(true)
}
protected:
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
grads_counter_
;
};
};
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
RecvOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RecvOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"(Tensor) Variables to get from server."
).
AsDuplicable
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Recv operator
Recv operator
This operator
will recieve tensor from send_op
This operator
can get variables from server side.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string, default 127.0.0.1:6164)"
"(string vector, default 127.0.0.1:6164)"
"IP address to listen on."
)
"Server endpoints in the order of input "
.
SetDefault
(
"127.0.0.1:6164"
)
"variables for mapping"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"Serialized ProgramDesc string for recv to run."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which parameters to optimize."
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
}
}
};
};
...
...
paddle/operators/send_op.cc
浏览文件 @
82e41ce3
...
@@ -62,11 +62,13 @@ class SendOp : public framework::OperatorBase {
...
@@ -62,11 +62,13 @@ class SendOp : public framework::OperatorBase {
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
outs
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
PADDLE_ENFORCE
(
client_
.
Wait
());
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
};
};
...
@@ -85,6 +87,8 @@ Send operator
...
@@ -85,6 +87,8 @@ Send operator
This operator will send tensor to recv_op at the parameter server.
This operator will send tensor to recv_op at the parameter server.
)DOC"
);
)DOC"
);
// TODO(typhoonzero): remove this attr generate de-duplicated vector from
// epmap when initializing.
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
"Server endpoints to send variables to."
)
...
...
paddle/operators/send_recv_op_test.cc
浏览文件 @
82e41ce3
...
@@ -25,7 +25,7 @@ limitations under the License. */
...
@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/string/printf.h"
#include "paddle/string/printf.h"
USE_NO_KERNEL_OP
(
send
);
USE_NO_KERNEL_OP
(
send
);
USE_NO_KERNEL_OP
(
rec
v
);
USE_NO_KERNEL_OP
(
listen_and_ser
v
);
USE_OP
(
sum
);
USE_OP
(
sum
);
namespace
f
=
paddle
::
framework
;
namespace
f
=
paddle
::
framework
;
...
@@ -33,7 +33,7 @@ namespace p = paddle::platform;
...
@@ -33,7 +33,7 @@ namespace p = paddle::platform;
namespace
m
=
paddle
::
operators
::
math
;
namespace
m
=
paddle
::
operators
::
math
;
// global for simplicity.
// global for simplicity.
std
::
unique_ptr
<
f
::
OperatorBase
>
rec
v_op
;
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_ser
v_op
;
void
InitTensorsInScope
(
f
::
Scope
&
scope
,
p
::
CPUPlace
&
place
)
{
void
InitTensorsInScope
(
f
::
Scope
&
scope
,
p
::
CPUPlace
&
place
)
{
p
::
CPUDeviceContext
ctx
(
place
);
p
::
CPUDeviceContext
ctx
(
place
);
...
@@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) {
...
@@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) {
InitTensorsInScope
(
scope
,
place
);
InitTensorsInScope
(
scope
,
place
);
}
}
// sub program run in
rec
v_op, for simple test we use sum
// sub program run in
listen_and_ser
v_op, for simple test we use sum
f
::
ProgramDesc
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
// X for server side tensors, RX for received tensers, must be of same shape.
// X for server side tensors, RX for received tensers, must be of same shape.
...
@@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) {
...
@@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) {
attrs
.
insert
({
"ParamList"
,
std
::
vector
<
std
::
string
>
({
"Out"
})});
attrs
.
insert
({
"ParamList"
,
std
::
vector
<
std
::
string
>
({
"Out"
})});
attrs
.
insert
({
"GradList"
,
std
::
vector
<
std
::
string
>
({
"x1"
})});
attrs
.
insert
({
"GradList"
,
std
::
vector
<
std
::
string
>
({
"x1"
})});
attrs
.
insert
({
"OptimizeBlock"
,
block
});
attrs
.
insert
({
"OptimizeBlock"
,
block
});
recv_op
=
f
::
OpRegistry
::
CreateOp
(
"recv"
,
{{
"RX"
,
{
"x1"
}}},
{},
attrs
);
listen_and_serv_op
=
recv_op
->
Run
(
scope
,
place
);
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{},
{},
attrs
);
listen_and_serv_op
->
Run
(
scope
,
place
);
}
}
TEST
(
SendRecvOp
,
CPUDense
)
{
TEST
(
SendRecvOp
,
CPUDense
)
{
...
@@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) {
...
@@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) {
for
(
int64_t
i
=
0
;
i
<
target
->
numel
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
target
->
numel
();
++
i
)
{
EXPECT_EQ
(
expected
[
i
]
*
2
,
actual
[
i
]);
EXPECT_EQ
(
expected
[
i
]
*
2
,
actual
[
i
]);
}
}
rec
v_op
->
Stop
();
listen_and_ser
v_op
->
Stop
();
server_thread
.
join
();
server_thread
.
join
();
rec
v_op
.
reset
(
nullptr
);
listen_and_ser
v_op
.
reset
(
nullptr
);
}
}
TEST
(
SendRecvOp
,
CPUSparse
)
{
TEST
(
SendRecvOp
,
CPUSparse
)
{
...
@@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) {
...
@@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) {
EXPECT_EQ
(
expect_value
->
mutable_data
<
float
>
(
place
)[
i
],
EXPECT_EQ
(
expect_value
->
mutable_data
<
float
>
(
place
)[
i
],
actual
->
mutable_data
<
float
>
(
place
)[
i
]);
actual
->
mutable_data
<
float
>
(
place
)[
i
]);
}
}
rec
v_op
->
Stop
();
listen_and_ser
v_op
->
Stop
();
server_thread
.
join
();
server_thread
.
join
();
rec
v_op
.
reset
();
listen_and_ser
v_op
.
reset
();
}
}
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
82e41ce3
...
@@ -478,9 +478,9 @@ class DistributeTranspiler:
...
@@ -478,9 +478,9 @@ class DistributeTranspiler:
else
:
else
:
self
.
_append_pserver_non_opt_ops
(
optimize_sub_program
,
self
.
_append_pserver_non_opt_ops
(
optimize_sub_program
,
pserver_program
,
opt_op
)
pserver_program
,
opt_op
)
# Append the
rec
v op
# Append the
listen_and_ser
v op
pserver_program
.
global_block
().
append_op
(
pserver_program
.
global_block
().
append_op
(
type
=
"
rec
v"
,
type
=
"
listen_and_ser
v"
,
inputs
=
{},
inputs
=
{},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
82e41ce3
...
@@ -489,7 +489,7 @@ class Operator(object):
...
@@ -489,7 +489,7 @@ class Operator(object):
no_kernel_op_set
=
{
no_kernel_op_set
=
{
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'parallel_do'
'recv'
,
'
listen_and_serv'
,
'
parallel_do'
}
}
if
type
not
in
no_kernel_op_set
:
if
type
not
in
no_kernel_op_set
:
self
.
desc
.
infer_var_type
(
self
.
block
.
desc
)
self
.
desc
.
infer_var_type
(
self
.
block
.
desc
)
...
...
python/paddle/v2/fluid/layers/io.py
浏览文件 @
82e41ce3
...
@@ -108,7 +108,7 @@ class ListenAndServ(object):
...
@@ -108,7 +108,7 @@ class ListenAndServ(object):
"""
"""
def
__init__
(
self
,
endpoint
,
fan_in
=
1
,
optimizer_mode
=
True
):
def
__init__
(
self
,
endpoint
,
fan_in
=
1
,
optimizer_mode
=
True
):
self
.
helper
=
LayerHelper
(
"
rec
v"
)
self
.
helper
=
LayerHelper
(
"
listen_and_ser
v"
)
self
.
inputs
=
[]
self
.
inputs
=
[]
self
.
outputs
=
[]
self
.
outputs
=
[]
self
.
endpoint
=
endpoint
self
.
endpoint
=
endpoint
...
@@ -158,7 +158,7 @@ class ListenAndServ(object):
...
@@ -158,7 +158,7 @@ class ListenAndServ(object):
param_names
=
[
p
.
name
for
p
in
params
]
param_names
=
[
p
.
name
for
p
in
params
]
grad_names
=
[
g
.
name
for
g
in
grads
]
grad_names
=
[
g
.
name
for
g
in
grads
]
parent_block
.
append_op
(
parent_block
.
append_op
(
type
=
'
rec
v'
,
type
=
'
listen_and_ser
v'
,
inputs
=
{},
inputs
=
{},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
...
@@ -196,3 +196,31 @@ def Send(endpoints, send_vars, get_vars):
...
@@ -196,3 +196,31 @@ def Send(endpoints, send_vars, get_vars):
outputs
=
{
"Out"
:
get_vars
},
outputs
=
{
"Out"
:
get_vars
},
attrs
=
{
"endpoints"
:
endpoints
,
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
"epmap"
:
epmap
})
def
Recv
(
endpoints
,
get_vars
):
"""
Recv layer
Args:
endpoints: comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
Send variables to the server side, and get vars from server
side when server have finished running server side program.
"""
assert
(
type
(
send_vars
)
==
list
)
assert
(
type
(
get_vars
)
==
list
)
epmap
=
endpoints
.
split
(
","
)
endpoints
=
list
(
set
(
epmap
))
helper
=
LayerHelper
(
"Recv"
,
**
locals
())
helper
.
append_op
(
type
=
"recv"
,
inputs
=
{
"X"
:
get_vars
},
outputs
=
{
"Out"
:
get_vars
},
attrs
=
{
"endpoints"
:
endpoints
,
"epmap"
:
epmap
})
python/paddle/v2/fluid/tests/test_recv_op.py
浏览文件 @
82e41ce3
...
@@ -19,6 +19,7 @@ import paddle.v2.fluid.layers as layers
...
@@ -19,6 +19,7 @@ import paddle.v2.fluid.layers as layers
import
numpy
import
numpy
from
multiprocessing
import
Process
from
multiprocessing
import
Process
import
os
,
sys
import
os
,
sys
import
time
class
TestRecvOp
(
unittest
.
TestCase
):
class
TestRecvOp
(
unittest
.
TestCase
):
...
@@ -28,6 +29,7 @@ class TestRecvOp(unittest.TestCase):
...
@@ -28,6 +29,7 @@ class TestRecvOp(unittest.TestCase):
p
=
Process
(
target
=
self
.
init_serv
,
args
=
(
place
,
))
p
=
Process
(
target
=
self
.
init_serv
,
args
=
(
place
,
))
p
.
daemon
=
True
p
.
daemon
=
True
p
.
start
()
p
.
start
()
time
.
sleep
(
1
)
self
.
init_client
(
place
)
self
.
init_client
(
place
)
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
os
.
system
(
"kill -9 %d"
%
p
.
pid
)
os
.
system
(
"kill -9 %d"
%
p
.
pid
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录