Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9de18095
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9de18095
编写于
1月 15, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fluid distributed on CUDA place
上级
cb6b468e
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
23 addition
and
13 deletion
+23
-13
paddle/framework/tensor_util.h
paddle/framework/tensor_util.h
+2
-3
paddle/operators/detail/grpc_server.cc
paddle/operators/detail/grpc_server.cc
+9
-4
paddle/operators/detail/grpc_server.h
paddle/operators/detail/grpc_server.h
+4
-1
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+4
-3
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+4
-2
未找到文件。
paddle/framework/tensor_util.h
浏览文件 @
9de18095
...
@@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
...
@@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
desc
.
data_type
(),
desc
.
data_type
(),
DeserializedDataFunctor
(
&
buf
,
&
cpu_tensor
,
ctx
.
GetPlace
()));
DeserializedDataFunctor
(
&
buf
,
&
cpu_tensor
,
ctx
.
GetPlace
()));
is
.
read
(
static_cast
<
char
*>
(
buf
),
cpu_tensor
.
memory_size
());
is
.
read
(
static_cast
<
char
*>
(
buf
),
cpu_tensor
.
memory_size
());
auto
cpu_place
=
new
platform
::
CPUPlace
();
auto
dst_place
=
dev_ctx
.
GetPlace
();
framework
::
Copy
(
cpu_tensor
,
*
cpu_place
,
dev_ctx
,
tensor
);
framework
::
Copy
(
cpu_tensor
,
dst_place
,
dev_ctx
,
tensor
);
delete
cpu_place
;
#else
#else
PADDLE_THROW
(
"Unexpected branch"
);
PADDLE_THROW
(
"Unexpected branch"
);
#endif
#endif
...
...
paddle/operators/detail/grpc_server.cc
浏览文件 @
9de18095
...
@@ -74,8 +74,12 @@ class RequestSend final : public RequestBase {
...
@@ -74,8 +74,12 @@ class RequestSend final : public RequestBase {
class
RequestGet
final
:
public
RequestBase
{
class
RequestGet
final
:
public
RequestBase
{
public:
public:
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
)
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
:
RequestBase
(
service
,
cq
),
responder_
(
&
ctx_
),
scope_
(
scope
)
{
const
platform
::
DeviceContext
*
dev_ctx
)
:
RequestBase
(
service
,
cq
),
responder_
(
&
ctx_
),
scope_
(
scope
),
dev_ctx_
(
dev_ctx
)
{
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
}
...
@@ -85,7 +89,7 @@ class RequestGet final : public RequestBase {
...
@@ -85,7 +89,7 @@ class RequestGet final : public RequestBase {
// proc request.
// proc request.
std
::
string
var_name
=
request_
.
varname
();
std
::
string
var_name
=
request_
.
varname
();
auto
*
var
=
scope_
->
FindVar
(
var_name
);
auto
*
var
=
scope_
->
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
platform
::
CPUDeviceContext
()
,
&
reply_
);
SerializeToMessage
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
// TODO(gongwb): check var's info.
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
}
}
...
@@ -95,6 +99,7 @@ class RequestGet final : public RequestBase {
...
@@ -95,6 +99,7 @@ class RequestGet final : public RequestBase {
sendrecv
::
VariableMessage
reply_
;
sendrecv
::
VariableMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
};
};
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
...
@@ -155,7 +160,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
...
@@ -155,7 +160,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if
(
is_shut_down_
)
{
if
(
is_shut_down_
)
{
return
;
return
;
}
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
);
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
);
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
}
}
...
...
paddle/operators/detail/grpc_server.h
浏览文件 @
9de18095
...
@@ -37,7 +37,7 @@ class RequestBase;
...
@@ -37,7 +37,7 @@ class RequestBase;
class
AsyncGRPCServer
final
:
public
sendrecv
::
SendRecvService
::
Service
{
class
AsyncGRPCServer
final
:
public
sendrecv
::
SendRecvService
::
Service
{
public:
public:
explicit
AsyncGRPCServer
(
std
::
string
address
)
{
address_
=
address
;
}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{
}
void
RunSyncUpdate
();
void
RunSyncUpdate
();
...
@@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
const
MessageWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
const
MessageWithName
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
MessageWithName
&
msg
)
{
this
->
var_recv_queue_
.
Push
(
msg
);
}
void
Push
(
const
MessageWithName
&
msg
)
{
this
->
var_recv_queue_
.
Push
(
msg
);
}
...
@@ -74,6 +76,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -74,6 +76,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std
::
string
address_
;
std
::
string
address_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
...
...
paddle/operators/recv_op.cc
浏览文件 @
9de18095
...
@@ -87,7 +87,11 @@ class RecvOp : public framework::OperatorBase {
...
@@ -87,7 +87,11 @@ class RecvOp : public framework::OperatorBase {
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
// FIXME(typhoonzero): no new scopes for every run.
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
...
@@ -134,9 +138,6 @@ class RecvOp : public framework::OperatorBase {
...
@@ -134,9 +138,6 @@ class RecvOp : public framework::OperatorBase {
}
}
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
}
...
...
paddle/operators/send_op.cc
浏览文件 @
9de18095
...
@@ -33,13 +33,15 @@ class SendOp : public framework::OperatorBase {
...
@@ -33,13 +33,15 @@ class SendOp : public framework::OperatorBase {
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
auto
ins
=
Inputs
(
"X"
);
auto
ins
=
Inputs
(
"X"
);
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
// FIXME(gongwb): DeviceContext?
// FIXME(gongwb): DeviceContext?
auto
ctx
=
platform
::
CPUDeviceContext
();
// auto ctx = platform::CPUDeviceContext();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录