Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
be853853
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be853853
编写于
4月 08, 2018
作者:
Y
Yancey
提交者:
GitHub
4月 08, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9593 from Yancey1989/prefech_prog_on_server
run prefetch prog on server
上级
7d397250
974b253e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
147 addition
and
28 deletion
+147
-28
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+3
-2
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+5
-1
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+1
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+1
-1
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+18
-9
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+5
-0
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+89
-11
paddle/fluid/operators/detail/send_recv.proto
paddle/fluid/operators/detail/send_recv.proto
+3
-1
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+5
-1
paddle/fluid/operators/detail/sendrecvop_utils.h
paddle/fluid/operators/detail/sendrecvop_utils.h
+2
-1
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+14
-0
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+1
-0
未找到文件。
paddle/fluid/framework/scope.cc
浏览文件 @
be853853
...
@@ -15,7 +15,6 @@ limitations under the License. */
...
@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include <set>
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
...
@@ -39,6 +38,7 @@ Scope::~Scope() {
...
@@ -39,6 +38,7 @@ Scope::~Scope() {
}
}
Scope
&
Scope
::
NewScope
()
const
{
Scope
&
Scope
::
NewScope
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
kids_
.
push_back
(
new
Scope
(
this
));
kids_
.
push_back
(
new
Scope
(
this
));
return
*
kids_
.
back
();
return
*
kids_
.
back
();
}
}
...
@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
...
@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}
}
void
Scope
::
DeleteScope
(
Scope
*
scope
)
{
void
Scope
::
DeleteScope
(
Scope
*
scope
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
std
::
find
(
this
->
kids_
.
begin
(),
this
->
kids_
.
end
(),
scope
);
auto
it
=
std
::
find
(
this
->
kids_
.
begin
(),
this
->
kids_
.
end
(),
scope
);
PADDLE_ENFORCE
(
it
!=
this
->
kids_
.
end
(),
"Cannot find %p as kid scope"
,
scope
);
PADDLE_ENFORCE
(
it
!=
this
->
kids_
.
end
(),
"Cannot find %p as kid scope"
,
scope
);
this
->
kids_
.
erase
(
it
);
this
->
kids_
.
erase
(
it
);
...
@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
...
@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
}
}
}
}
void
Scope
::
EraseVars
(
std
::
vector
<
std
::
string
>&
var_names
)
{
void
Scope
::
EraseVars
(
const
std
::
vector
<
std
::
string
>&
var_names
)
{
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
std
::
set
<
std
::
string
>
var_set
(
var_names
.
begin
(),
var_names
.
end
());
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
for
(
auto
it
=
vars_
.
begin
();
it
!=
vars_
.
end
();)
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
if
(
var_set
.
find
(
it
->
first
)
!=
var_set
.
end
())
{
...
...
paddle/fluid/framework/scope.h
浏览文件 @
be853853
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <list>
#include <list>
#include <mutex> // NOLINT
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -51,7 +52,7 @@ class Scope {
...
@@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name.
/// Create a variable with a scope-unique name.
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
Variable
*
Var
(
std
::
string
*
name
=
nullptr
);
void
EraseVars
(
std
::
vector
<
std
::
string
>&
var_names
);
void
EraseVars
(
const
std
::
vector
<
std
::
string
>&
var_names
);
/// Find a variable in the scope or any of its ancestors. Returns
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// nullptr if cannot find.
...
@@ -88,6 +89,9 @@ class Scope {
...
@@ -88,6 +89,9 @@ class Scope {
Scope
const
*
parent_
{
nullptr
};
Scope
const
*
parent_
{
nullptr
};
DISABLE_COPY_AND_ASSIGN
(
Scope
);
DISABLE_COPY_AND_ASSIGN
(
Scope
);
private:
mutable
std
::
mutex
mutex_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
be853853
...
@@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
...
@@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc
)
cares zlib protobuf sendrecvop_grpc
)
cc_test
(
grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
cc_test
(
grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
executor proto_desc lookup_table_op
)
endif
()
endif
()
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
be853853
...
@@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
::
grpc
::
ByteBuffer
req
;
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
);
SerializeToByteBuffer
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
out_var_name_val
);
// var handle
// var handle
VarHandle
var_h
;
VarHandle
var_h
;
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
be853853
...
@@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
...
@@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
int
blkid
)
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
dev_ctx
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
scope_
(
scope
),
executor_
(
executor
),
executor_
(
executor
),
program_
(
program
),
program_
(
program
),
blkid_
(
blkid
)
{
prefetch_ctx_
(
prefetch_ctx
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq
_
,
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder
_
,
cq_
,
this
);
cq_
,
cq_
,
this
);
}
}
virtual
~
RequestPrefetch
()
{}
virtual
~
RequestPrefetch
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
v
arname
();
}
virtual
std
::
string
GetReqName
()
{
return
request_
->
V
arname
();
}
virtual
void
Process
()
{
virtual
void
Process
()
{
// prefetch process...
// prefetch process...
::
grpc
::
ByteBuffer
reply
;
::
grpc
::
ByteBuffer
reply
;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG
(
3
)
<<
"RequestPrefetch Process in"
;
std
::
string
var_name
=
request_
->
OutVarname
();
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
&
scope_
->
NewScope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
scope_
,
false
,
false
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
status_
=
FINISH
;
}
}
protected:
protected:
s
endrecv
::
VariableMessage
request_
;
s
td
::
shared_ptr
<
VariableResponse
>
request_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
framework
::
Executor
*
executor_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
blkid_
;
int
blkid_
;
};
};
...
@@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
...
@@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
}
}
RequestPrefetch
*
prefetch
=
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
scope_
,
dev_ctx_
,
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_
blk_id
_
);
executor_
,
program_
,
prefetch_
ctx
_
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
}
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
be853853
...
@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
...
@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
framework
::
ExecutorPrepareContext
*
prepared
)
{
prefetch_ctx_
=
prepared
;
}
int
GetSelectedPort
()
{
return
selected_port_
;
}
int
GetSelectedPort
()
{
return
selected_port_
;
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
...
@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
...
@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
int
prefetch_blk_id_
;
int
prefetch_blk_id_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
int
selected_port_
;
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
be853853
...
@@ -20,43 +20,121 @@ limitations under the License. */
...
@@ -20,43 +20,121 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
framework
=
paddle
::
framework
;
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
platform
=
paddle
::
platform
;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
detail
=
paddle
::
operators
::
detail
;
USE_OP
(
lookup_table
);
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
auto
root_block
=
program
->
MutableBlock
(
0
);
auto
*
block
=
program
->
AppendBlock
(
*
root_block
);
framework
::
VariableNameMap
input
({{
"W"
,
{
"w"
}},
{
"Ids"
,
{
"ids"
}}});
framework
::
VariableNameMap
output
({{
"Output"
,
{
"out"
}}});
auto
op
=
block
->
AppendOp
();
op
->
SetType
(
"lookup_table"
);
op
->
SetInput
(
"W"
,
{
"w"
});
op
->
SetInput
(
"Ids"
,
{
"ids"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
auto
&
out
=
*
root_block
->
Var
(
"out"
);
out
.
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
out
.
SetShape
({
10
,
10
});
return
block
;
}
void
CreateVarsOnScope
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
)
{
auto
w_var
=
scope
->
Var
(
"w"
);
w_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
out_var
=
scope
->
Var
(
"out"
);
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
ids_var
=
scope
->
Var
(
"ids"
);
ids_var
->
GetMutable
<
framework
::
SelectedRows
>
();
}
void
InitTensorsOnClient
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
int64_t
rows_numel
)
{
CreateVarsOnScope
(
scope
,
place
);
auto
ids_var
=
scope
->
Var
(
"ids"
)
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
rows
=
ids_var
->
mutable_rows
();
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
rows
->
push_back
(
i
*
2
);
ids_var
->
mutable_value
()
->
Resize
({
rows_numel
,
1
});
ids_var
->
mutable_value
()
->
mutable_data
<
float
>
(
*
place
);
}
void
InitTensorsOnServer
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
int64_t
rows_numel
)
{
CreateVarsOnScope
(
scope
,
place
);
auto
w
=
scope
->
Var
(
"w"
)
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
rows
=
w
->
mutable_rows
();
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
rows
->
push_back
(
i
);
auto
w_value
=
w
->
mutable_value
();
w_value
->
Resize
({
rows_numel
,
10
});
auto
ptr
=
w_value
->
mutable_data
<
float
>
(
*
place
);
for
(
int64_t
i
=
0
;
i
<
w_value
->
numel
();
++
i
)
{
ptr
[
i
]
=
static_cast
<
float
>
(
i
/
10
);
}
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
framework
::
Executor
exe
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
auto
*
block
=
AppendPrefetchBlcok
(
&
program
);
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
rpc_service_
->
SetProgram
(
&
program
);
rpc_service_
->
SetPrefetchPreparedCtx
(
prepared
.
get
());
rpc_service_
->
SetDevCtx
(
&
ctx
);
rpc_service_
->
SetScope
(
&
scope
);
rpc_service_
->
SetExecutor
(
&
exe
);
rpc_service_
->
RunSyncUpdate
();
rpc_service_
->
RunSyncUpdate
();
}
}
TEST
(
PREFETCH
,
CPU
)
{
TEST
(
PREFETCH
,
CPU
)
{
// start up a server instance backend
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std
::
thread
server_thread
(
StartServer
,
"127.0.0.1:8889"
);
std
::
thread
server_thread
(
StartServer
,
"127.0.0.1:8889"
);
sleep
(
2
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
// create var on local scope
// create var on local scope
std
::
string
in_var_name
(
"in"
);
int64_t
rows_numel
=
5
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
std
::
string
out_var_name
(
"out"
);
auto
*
in_var
=
scope
.
Var
(
in_var_name
);
auto
*
in_tensor
=
in_var
->
GetMutable
<
framework
::
LoDTensor
>
();
in_tensor
->
Resize
({
10
,
10
});
VLOG
(
3
)
<<
"before mutable_data"
;
in_tensor
->
mutable_data
<
int
>
(
place
);
scope
.
Var
(
out_var_name
);
VLOG
(
3
)
<<
"before fetch"
;
detail
::
RPCClient
client
;
detail
::
RPCClient
client
;
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
out_var_name
);
client
.
Wait
();
client
.
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
rpc_service_
->
ShutDown
();
rpc_service_
->
ShutDown
();
server_thread
.
join
();
server_thread
.
join
();
rpc_service_
.
reset
(
nullptr
);
rpc_service_
.
reset
(
nullptr
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
}
}
}
paddle/fluid/operators/detail/send_recv.proto
浏览文件 @
be853853
...
@@ -21,7 +21,7 @@ service SendRecvService {
...
@@ -21,7 +21,7 @@ service SendRecvService {
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
// Argument VariableMessage for GetVariable should only contain varname.
// Argument VariableMessage for GetVariable should only contain varname.
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
//
Prefetch variable by
Ids
//
pre-fetch variable by given variable name and
Ids
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
}
}
...
@@ -67,6 +67,8 @@ message VariableMessage {
...
@@ -67,6 +67,8 @@ message VariableMessage {
bytes
serialized
=
8
;
bytes
serialized
=
8
;
// selected_rows data
// selected_rows data
bytes
rows
=
9
;
bytes
rows
=
9
;
// Look up table block execution output variable name.
string
out_varname
=
10
;
}
}
message
VoidMessage
{}
message
VoidMessage
{}
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
be853853
...
@@ -30,7 +30,8 @@ namespace detail {
...
@@ -30,7 +30,8 @@ namespace detail {
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
)
{
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
using
VarMsg
=
sendrecv
::
VariableMessage
;
using
VarMsg
=
sendrecv
::
VariableMessage
;
sendrecv
::
VariableMessage
request
;
sendrecv
::
VariableMessage
request
;
std
::
string
header
;
std
::
string
header
;
...
@@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
}
}
if
(
!
out_name
.
empty
())
{
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
}
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/detail/sendrecvop_utils.h
浏览文件 @
be853853
...
@@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);
...
@@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
);
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_varname
=
std
::
string
());
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
be853853
...
@@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) {
...
@@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) {
}
}
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kOutVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_out_varname
(
temp
);
break
;
}
default:
{
default:
{
// Unknown tag, return unknown error.
// Unknown tag, return unknown error.
...
...
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
be853853
...
@@ -55,6 +55,7 @@ class VariableResponse {
...
@@ -55,6 +55,7 @@ class VariableResponse {
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
// should call parse first.
// should call parse first.
framework
::
Variable
*
GetVar
()
{
return
scope_
->
FindVar
(
meta_
.
varname
());
}
framework
::
Variable
*
GetVar
()
{
return
scope_
->
FindVar
(
meta_
.
varname
());
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录