Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1e549563
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e549563
编写于
12月 18, 2017
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multi trainers
上级
e6079390
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
56 addition
and
30 deletion
+56
-30
paddle/operators/detail/recv_impl.cc
paddle/operators/detail/recv_impl.cc
+25
-6
paddle/operators/detail/send_impl.cc
paddle/operators/detail/send_impl.cc
+7
-6
paddle/operators/detail/send_recv.proto
paddle/operators/detail/send_recv.proto
+3
-1
paddle/operators/detail/send_recv_impl.h
paddle/operators/detail/send_recv_impl.h
+12
-10
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+9
-7
未找到文件。
paddle/operators/detail/recv_impl.cc
浏览文件 @
1e549563
...
@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
...
@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
}
}
Status
SendRecvServerImpl
::
GetVariable
(
ServerContext
*
context
,
Status
SendRecvServerImpl
::
GetVariable
(
ServerContext
*
context
,
const
V
oid
Message
*
in_var
,
const
V
ariable
Message
*
in_var
,
VariableMessage
*
out_var
)
{
VariableMessage
*
out_var
)
{
// Block util the sub graph is done.
std
::
string
get_var_name
=
in_var
->
varname
();
auto
out_tensor_with_name
=
var_return_queue_
.
Pop
();
auto
*
var
=
scope_
->
FindVar
(
get_var_name
);
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
framework
::
SerializeToStream
(
oss
,
out_tensor_with_name
.
second
,
framework
::
SerializeToStream
(
oss
,
tensor
,
platform
::
CPUDeviceContext
());
platform
::
CPUDeviceContext
());
std
::
string
*
varname
=
out_var
->
mutable_varname
();
std
::
string
*
varname
=
out_var
->
mutable_varname
();
*
varname
=
out_tensor_with_name
.
first
;
*
varname
=
get_var_name
;
std
::
string
*
serialized
=
out_var
->
mutable_serialized
();
std
::
string
*
serialized
=
out_var
->
mutable_serialized
();
*
serialized
=
oss
.
str
();
*
serialized
=
oss
.
str
();
return
Status
::
OK
;
return
Status
::
OK
;
}
}
Status
SendRecvServerImpl
::
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
return
Status
::
OK
;
}
void
SendRecvServerImpl
::
Start
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
}
void
SendRecvServerImpl
::
Done
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace detail
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/detail/send_impl.cc
浏览文件 @
1e549563
...
@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
...
@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
return
true
;
return
true
;
}
}
bool
RPCClient
::
GetVariable
(
const
framework
::
Scope
&
scope
)
{
bool
RPCClient
::
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
outname
)
{
ClientContext
context
;
ClientContext
context
;
VariableMessage
msg
;
VariableMessage
call_msg
,
ret_
msg
;
VoidMessage
void_msg
;
call_msg
.
set_varname
(
outname
)
;
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
ctx
=
platform
::
CPUDeviceContext
();
Status
status
=
stub_
->
GetVariable
(
&
context
,
void_msg
,
&
msg
);
Status
status
=
stub_
->
GetVariable
(
&
context
,
call_msg
,
&
ret_
msg
);
if
(
!
status
.
ok
())
{
if
(
!
status
.
ok
())
{
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
return
false
;
return
false
;
}
}
std
::
istringstream
iss
(
msg
.
serialized
());
std
::
istringstream
iss
(
ret_
msg
.
serialized
());
auto
outname
=
msg
.
varname
();
framework
::
LoDTensor
ret_tensor
;
framework
::
LoDTensor
ret_tensor
;
framework
::
DeserializeFromStream
(
iss
,
&
ret_tensor
);
framework
::
DeserializeFromStream
(
iss
,
&
ret_tensor
);
auto
*
outvar
=
scope
.
FindVar
(
outname
);
auto
*
outvar
=
scope
.
FindVar
(
outname
);
...
...
paddle/operators/detail/send_recv.proto
浏览文件 @
1e549563
...
@@ -22,7 +22,9 @@ service SendRecvService {
...
@@ -22,7 +22,9 @@ service SendRecvService {
// TODO(typhoonzero): add streaming API
// TODO(typhoonzero): add streaming API
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
// Argument VariableMessage for GetVariable should only contain varname.
// Argument VariableMessage for GetVariable should only contain varname.
rpc
GetVariable
(
VoidMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
// wait for one execution of the program
rpc
Wait
(
VoidMessage
)
returns
(
VoidMessage
)
{}
}
}
// VariableMessage is serialized paddle variable message.
// VariableMessage is serialized paddle variable message.
...
...
paddle/operators/detail/send_recv_impl.h
浏览文件 @
1e549563
...
@@ -20,10 +20,6 @@
...
@@ -20,10 +20,6 @@
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/simple_block_queue.h"
// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
...
@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service {
...
@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service {
Status
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
Status
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VoidMessage
*
out_var
)
override
;
VoidMessage
*
out_var
)
override
;
Status
GetVariable
(
ServerContext
*
context
,
const
V
oid
Message
*
in_var
,
Status
GetVariable
(
ServerContext
*
context
,
const
V
ariable
Message
*
in_var
,
VariableMessage
*
out_var
)
override
;
VariableMessage
*
out_var
)
override
;
Status
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
override
;
void
Start
();
void
Done
();
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
};
const
TensorWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
const
TensorWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
TensorWithName
&
var
)
{
this
->
var_return_queue_
.
Push
(
var
);
}
private:
private:
// received variable from RPC, operators fetch variable from this queue.
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
TensorWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
TensorWithName
>
var_recv_queue_
;
// calculated variable should push to this queue.
framework
::
Scope
*
scope_
;
SimpleBlockQueue
<
TensorWithName
>
var_return_queue_
;
// condition of the sub program
std
::
mutex
mutex_
;
bool
done_
;
std
::
condition_variable
condition_
;
};
};
// RPCClient is a class to send tensors to pserver sub-network
// RPCClient is a class to send tensors to pserver sub-network
...
@@ -78,7 +80,7 @@ class RPCClient {
...
@@ -78,7 +80,7 @@ class RPCClient {
:
stub_
(
SendRecvService
::
NewStub
(
channel
))
{}
:
stub_
(
SendRecvService
::
NewStub
(
channel
))
{}
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
);
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
);
bool
GetVariable
(
const
framework
::
Scope
&
scope
);
bool
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
outname
);
private:
private:
std
::
unique_ptr
<
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
SendRecvService
::
Stub
>
stub_
;
...
...
paddle/operators/recv_op.cc
浏览文件 @
1e549563
...
@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase {
...
@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
// FIXME(typhoonzero): no new scopes for every run.
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
rpc_service_
.
SetScope
(
&
recv_scope
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
size_t
param_count
=
param_list
.
size
();
size_t
param_count
=
param_list
.
size
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while
(
true
)
{
while
(
true
)
{
rpc_service_
.
Start
();
// Get from multiple trainers, we don't care about order in which
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
// the gradient arrives, just add suffix 0~n then average the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
...
@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase {
...
@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase {
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
}
for
(
size_t
i
=
0
;
i
<
param_count
;
++
i
)
{
//
for (size_t i = 0; i < param_count; ++i) {
auto
*
out_var
=
recv_scope
.
FindVar
(
param_list
[
i
]);
//
auto *out_var = recv_scope.FindVar(param_list[i]);
detail
::
TensorWithName
out
;
//
detail::TensorWithName out;
out
.
first
=
param_list
[
i
];
//
out.first = param_list[i];
out
.
second
=
out_var
->
Get
<
framework
::
LoDTensor
>
();
//
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_
->
Push
(
out
);
//
rpc_service_->Push(out);
}
//
}
}
// while(true)
}
// while(true)
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录