Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
1e549563
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e549563
编写于
12月 18, 2017
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multi trainers
上级
e6079390
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
56 addition
and
30 deletion
+56
-30
paddle/operators/detail/recv_impl.cc
paddle/operators/detail/recv_impl.cc
+25
-6
paddle/operators/detail/send_impl.cc
paddle/operators/detail/send_impl.cc
+7
-6
paddle/operators/detail/send_recv.proto
paddle/operators/detail/send_recv.proto
+3
-1
paddle/operators/detail/send_recv_impl.h
paddle/operators/detail/send_recv_impl.h
+12
-10
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+9
-7
未找到文件。
paddle/operators/detail/recv_impl.cc
浏览文件 @
1e549563
...
@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
...
@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
}
}
Status
SendRecvServerImpl
::
GetVariable
(
ServerContext
*
context
,
Status
SendRecvServerImpl
::
GetVariable
(
ServerContext
*
context
,
const
V
oid
Message
*
in_var
,
const
V
ariable
Message
*
in_var
,
VariableMessage
*
out_var
)
{
VariableMessage
*
out_var
)
{
// Block util the sub graph is done.
std
::
string
get_var_name
=
in_var
->
varname
();
auto
out_tensor_with_name
=
var_return_queue_
.
Pop
();
auto
*
var
=
scope_
->
FindVar
(
get_var_name
);
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
framework
::
SerializeToStream
(
oss
,
out_tensor_with_name
.
second
,
framework
::
SerializeToStream
(
oss
,
tensor
,
platform
::
CPUDeviceContext
());
platform
::
CPUDeviceContext
());
std
::
string
*
varname
=
out_var
->
mutable_varname
();
std
::
string
*
varname
=
out_var
->
mutable_varname
();
*
varname
=
out_tensor_with_name
.
first
;
*
varname
=
get_var_name
;
std
::
string
*
serialized
=
out_var
->
mutable_serialized
();
std
::
string
*
serialized
=
out_var
->
mutable_serialized
();
*
serialized
=
oss
.
str
();
*
serialized
=
oss
.
str
();
return
Status
::
OK
;
return
Status
::
OK
;
}
}
Status
SendRecvServerImpl
::
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
return
Status
::
OK
;
}
void
SendRecvServerImpl
::
Start
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
}
void
SendRecvServerImpl
::
Done
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace detail
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/detail/send_impl.cc
浏览文件 @
1e549563
...
@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
...
@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
return
true
;
return
true
;
}
}
bool
RPCClient
::
GetVariable
(
const
framework
::
Scope
&
scope
)
{
bool
RPCClient
::
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
outname
)
{
ClientContext
context
;
ClientContext
context
;
VariableMessage
msg
;
VariableMessage
call_msg
,
ret_
msg
;
VoidMessage
void_msg
;
call_msg
.
set_varname
(
outname
)
;
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
ctx
=
platform
::
CPUDeviceContext
();
Status
status
=
stub_
->
GetVariable
(
&
context
,
void_msg
,
&
msg
);
Status
status
=
stub_
->
GetVariable
(
&
context
,
call_msg
,
&
ret_
msg
);
if
(
!
status
.
ok
())
{
if
(
!
status
.
ok
())
{
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
LOG
(
ERROR
)
<<
"gRPC error: "
<<
status
.
error_message
();
return
false
;
return
false
;
}
}
std
::
istringstream
iss
(
msg
.
serialized
());
std
::
istringstream
iss
(
ret_
msg
.
serialized
());
auto
outname
=
msg
.
varname
();
framework
::
LoDTensor
ret_tensor
;
framework
::
LoDTensor
ret_tensor
;
framework
::
DeserializeFromStream
(
iss
,
&
ret_tensor
);
framework
::
DeserializeFromStream
(
iss
,
&
ret_tensor
);
auto
*
outvar
=
scope
.
FindVar
(
outname
);
auto
*
outvar
=
scope
.
FindVar
(
outname
);
...
...
paddle/operators/detail/send_recv.proto
浏览文件 @
1e549563
...
@@ -22,7 +22,9 @@ service SendRecvService {
...
@@ -22,7 +22,9 @@ service SendRecvService {
// TODO(typhoonzero): add streaming API
// TODO(typhoonzero): add streaming API
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
// Argument VariableMessage for GetVariable should only contain varname.
// Argument VariableMessage for GetVariable should only contain varname.
rpc
GetVariable
(
VoidMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
// wait for one execution of the program
rpc
Wait
(
VoidMessage
)
returns
(
VoidMessage
)
{}
}
}
// VariableMessage is serialized paddle variable message.
// VariableMessage is serialized paddle variable message.
...
...
paddle/operators/detail/send_recv_impl.h
浏览文件 @
1e549563
...
@@ -20,10 +20,6 @@
...
@@ -20,10 +20,6 @@
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/simple_block_queue.h"
// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
...
@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service {
...
@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service {
Status
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
Status
SendVariable
(
ServerContext
*
context
,
const
VariableMessage
*
in_var
,
VoidMessage
*
out_var
)
override
;
VoidMessage
*
out_var
)
override
;
Status
GetVariable
(
ServerContext
*
context
,
const
V
oid
Message
*
in_var
,
Status
GetVariable
(
ServerContext
*
context
,
const
V
ariable
Message
*
in_var
,
VariableMessage
*
out_var
)
override
;
VariableMessage
*
out_var
)
override
;
Status
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
override
;
void
Start
();
void
Done
();
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
};
const
TensorWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
const
TensorWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
TensorWithName
&
var
)
{
this
->
var_return_queue_
.
Push
(
var
);
}
private:
private:
// received variable from RPC, operators fetch variable from this queue.
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
TensorWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
TensorWithName
>
var_recv_queue_
;
// calculated variable should push to this queue.
framework
::
Scope
*
scope_
;
SimpleBlockQueue
<
TensorWithName
>
var_return_queue_
;
// condition of the sub program
std
::
mutex
mutex_
;
bool
done_
;
std
::
condition_variable
condition_
;
};
};
// RPCClient is a class to send tensors to pserver sub-network
// RPCClient is a class to send tensors to pserver sub-network
...
@@ -78,7 +80,7 @@ class RPCClient {
...
@@ -78,7 +80,7 @@ class RPCClient {
:
stub_
(
SendRecvService
::
NewStub
(
channel
))
{}
:
stub_
(
SendRecvService
::
NewStub
(
channel
))
{}
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
);
bool
SendVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
inname
);
bool
GetVariable
(
const
framework
::
Scope
&
scope
);
bool
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
outname
);
private:
private:
std
::
unique_ptr
<
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
SendRecvService
::
Stub
>
stub_
;
...
...
paddle/operators/recv_op.cc
浏览文件 @
1e549563
...
@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase {
...
@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
// FIXME(typhoonzero): no new scopes for every run.
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
rpc_service_
.
SetScope
(
&
recv_scope
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
size_t
param_count
=
param_list
.
size
();
size_t
param_count
=
param_list
.
size
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while
(
true
)
{
while
(
true
)
{
rpc_service_
.
Start
();
// Get from multiple trainers, we don't care about order in which
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
// the gradient arrives, just add suffix 0~n then average the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
...
@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase {
...
@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase {
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
}
for
(
size_t
i
=
0
;
i
<
param_count
;
++
i
)
{
//
for (size_t i = 0; i < param_count; ++i) {
auto
*
out_var
=
recv_scope
.
FindVar
(
param_list
[
i
]);
//
auto *out_var = recv_scope.FindVar(param_list[i]);
detail
::
TensorWithName
out
;
//
detail::TensorWithName out;
out
.
first
=
param_list
[
i
];
//
out.first = param_list[i];
out
.
second
=
out_var
->
Get
<
framework
::
LoDTensor
>
();
//
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_
->
Push
(
out
);
//
rpc_service_->Push(out);
}
//
}
}
// while(true)
}
// while(true)
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录