Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
990d6396
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
990d6396
编写于
3月 22, 2018
作者:
G
gongweibao
提交者:
GitHub
3月 22, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reuduce memory copy when communication between trainer and pserver. (#9271)
上级
b594251f
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
1021 addition
and
322 deletion
+1021
-322
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+25
-27
benchmark/cluster/vgg16/vgg16_tf.py
benchmark/cluster/vgg16/vgg16_tf.py
+7
-3
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+4
-2
paddle/fluid/operators/detail/bytebuffer_stream.h
paddle/fluid/operators/detail/bytebuffer_stream.h
+134
-0
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+30
-9
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+23
-15
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+49
-43
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+23
-13
paddle/fluid/operators/detail/grpc_service.h
paddle/fluid/operators/detail/grpc_service.h
+118
-0
paddle/fluid/operators/detail/send_recv.proto
paddle/fluid/operators/detail/send_recv.proto
+5
-1
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+10
-119
paddle/fluid/operators/detail/sendrecvop_utils.h
paddle/fluid/operators/detail/sendrecvop_utils.h
+3
-9
paddle/fluid/operators/detail/test_serde.cc
paddle/fluid/operators/detail/test_serde.cc
+104
-73
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+400
-0
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+81
-0
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+3
-6
python/paddle/fluid/debuger.py
python/paddle/fluid/debuger.py
+0
-2
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+2
-0
未找到文件。
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
990d6396
...
@@ -18,12 +18,13 @@ import sys
...
@@ -18,12 +18,13 @@ import sys
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
paddle.
v2.
fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.
v2.
fluid.core
as
core
import
paddle.fluid.core
as
core
import
paddle.
v2.
fluid.profiler
as
profiler
import
paddle.fluid.profiler
as
profiler
import
argparse
import
argparse
import
functools
import
functools
import
os
import
os
from
paddle.fluid
import
debuger
def
str2bool
(
v
):
def
str2bool
(
v
):
...
@@ -182,28 +183,27 @@ def main():
...
@@ -182,28 +183,27 @@ def main():
start_time
=
time
.
time
()
start_time
=
time
.
time
()
num_samples
=
0
num_samples
=
0
train_pass_acc
.
reset
()
train_pass_acc
.
reset
()
with
profiler
.
profiler
(
"CPU"
,
'total'
)
as
prof
:
for
batch_id
,
data
in
enumerate
(
train_reader
()):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
ts
=
time
.
time
()
ts
=
time
.
time
()
img_data
=
np
.
array
(
img_data
=
np
.
array
(
map
(
lambda
x
:
x
[
0
].
reshape
(
data_shape
),
data
)).
astype
(
map
(
lambda
x
:
x
[
0
].
reshape
(
data_shape
),
data
)).
astype
(
"float32"
)
"float32"
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int64"
)
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int64"
)
y_data
=
y_data
.
reshape
([
-
1
,
1
])
y_data
=
y_data
.
reshape
([
-
1
,
1
])
loss
,
acc
,
b_size
=
exe
.
run
(
loss
,
acc
,
b_size
=
exe
.
run
(
trainer_prog
,
trainer_prog
,
feed
=
{
"pixel"
:
img_data
,
feed
=
{
"pixel"
:
img_data
,
"label"
:
y_data
},
"label"
:
y_data
},
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size
])
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size
])
iters
+=
1
iters
+=
1
num_samples
+=
len
(
data
)
num_samples
+=
len
(
data
)
train_pass_acc
.
add
(
value
=
acc
,
weight
=
b_size
)
train_pass_acc
.
add
(
value
=
acc
,
weight
=
b_size
)
print
(
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
len
(
data
)
/
(
time
.
time
()
-
ts
))
)
# The accuracy is the accumulation of batches, but not the current batch.
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed
=
time
.
time
()
-
start_time
pass_elapsed
=
time
.
time
()
-
start_time
pass_train_acc
=
train_pass_acc
.
eval
()
pass_train_acc
=
train_pass_acc
.
eval
()
...
@@ -254,9 +254,7 @@ def main():
...
@@ -254,9 +254,7 @@ def main():
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
pserver_prog
)
print
(
"starting server side startup"
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_startup
)
print
(
"starting parameter server..."
)
exe
.
run
(
pserver_prog
)
exe
.
run
(
pserver_prog
)
elif
training_role
==
"TRAINER"
:
elif
training_role
==
"TRAINER"
:
# Parameter initialization
# Parameter initialization
...
...
benchmark/cluster/vgg16/vgg16_tf.py
浏览文件 @
990d6396
...
@@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server):
...
@@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server):
return
np
.
mean
(
test_accs
)
return
np
.
mean
(
test_accs
)
config
=
tf
.
ConfigProto
(
config
=
tf
.
ConfigProto
(
intra_op_parallelism_threads
=
1
,
inter_op_parallelism_threads
=
1
)
intra_op_parallelism_threads
=
1
,
inter_op_parallelism_threads
=
1
,
log_device_placement
=
True
)
config
.
gpu_options
.
allow_growth
=
True
config
.
gpu_options
.
allow_growth
=
True
hooks
=
[
tf
.
train
.
StopAtStepHook
(
last_step
=
1000000
)]
hooks
=
[
tf
.
train
.
StopAtStepHook
(
last_step
=
1000000
)]
with
tf
.
train
.
MonitoredTrainingSession
(
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
server
.
target
,
is_chief
=
(
args
.
task_index
==
0
),
master
=
server
.
target
,
hooks
=
hooks
)
as
sess
:
is_chief
=
(
args
.
task_index
==
0
),
hooks
=
hooks
,
config
=
config
)
as
sess
:
iters
,
num_samples
,
start_time
=
0
,
0
,
0.0
iters
,
num_samples
,
start_time
=
0
,
0
,
0.0
for
pass_id
in
range
(
args
.
num_passes
):
for
pass_id
in
range
(
args
.
num_passes
):
# train
# train
...
...
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
990d6396
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
test_serde.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
test_serde.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc
)
cc_test
(
serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc
)
endif
()
endif
()
paddle/fluid/operators/detail/bytebuffer_stream.h
浏览文件 @
990d6396
...
@@ -23,9 +23,107 @@ limitations under the License. */
...
@@ -23,9 +23,107 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
namespace
grpc
{
// A ZeroCopyInputStream that reads from grpc_byte_buffer
class
GrpcBufferReader
final
:
public
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
{
typedef
void
(
CoreCodegenInterface
::*
OldReaderInitAPI
)(
grpc_byte_buffer_reader
*
reader
,
grpc_byte_buffer
*
buffer
);
typedef
int
(
CoreCodegenInterface
::*
NewReaderInitAPI
)(
grpc_byte_buffer_reader
*
reader
,
grpc_byte_buffer
*
buffer
);
void
ReaderInit
(
OldReaderInitAPI
ptr
,
grpc_byte_buffer_reader
*
reader
,
grpc_byte_buffer
*
buffer
)
{
(
g_core_codegen_interface
->*
ptr
)(
reader
,
buffer
);
}
void
ReaderInit
(
NewReaderInitAPI
ptr
,
grpc_byte_buffer_reader
*
reader
,
grpc_byte_buffer
*
buffer
)
{
int
result
=
(
g_core_codegen_interface
->*
ptr
)(
reader
,
buffer
);
(
void
)
result
;
}
public:
explicit
GrpcBufferReader
(
grpc_byte_buffer
*
buffer
)
:
byte_count_
(
0
),
backup_count_
(
0
)
{
ReaderInit
(
&
CoreCodegenInterface
::
grpc_byte_buffer_reader_init
,
&
reader_
,
buffer
);
}
~
GrpcBufferReader
()
override
{
g_core_codegen_interface
->
grpc_byte_buffer_reader_destroy
(
&
reader_
);
}
bool
Next
(
const
void
**
data
,
int
*
size
)
override
{
if
(
backup_count_
>
0
)
{
*
data
=
GRPC_SLICE_START_PTR
(
slice_
)
+
GRPC_SLICE_LENGTH
(
slice_
)
-
backup_count_
;
GPR_CODEGEN_ASSERT
(
backup_count_
<=
INT_MAX
);
*
size
=
(
int
)
backup_count_
;
backup_count_
=
0
;
return
true
;
}
if
(
!
g_core_codegen_interface
->
grpc_byte_buffer_reader_next
(
&
reader_
,
&
slice_
))
{
return
false
;
}
g_core_codegen_interface
->
grpc_slice_unref
(
slice_
);
*
data
=
GRPC_SLICE_START_PTR
(
slice_
);
// On win x64, int is only 32bit
GPR_CODEGEN_ASSERT
(
GRPC_SLICE_LENGTH
(
slice_
)
<=
INT_MAX
);
byte_count_
+=
*
size
=
(
int
)
GRPC_SLICE_LENGTH
(
slice_
);
return
true
;
}
void
BackUp
(
int
count
)
override
{
backup_count_
=
count
;
}
bool
Skip
(
int
count
)
override
{
const
void
*
data
;
int
size
;
while
(
Next
(
&
data
,
&
size
))
{
if
(
size
>=
count
)
{
BackUp
(
size
-
count
);
return
true
;
}
// size < count;
count
-=
size
;
}
// error or we have too large count;
return
false
;
}
::
google
::
protobuf
::
int64
ByteCount
()
const
override
{
return
byte_count_
-
backup_count_
;
}
private:
int64_t
byte_count_
;
int64_t
backup_count_
;
grpc_byte_buffer_reader
reader_
;
grpc_slice
slice_
;
};
};
// namespace grpc
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class
Source
{
public:
virtual
~
Source
()
{}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
=
0
;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class
GrpcByteBufferSource
class
GrpcByteBufferSource
...
@@ -46,6 +144,42 @@ class GrpcByteBufferSource
...
@@ -46,6 +144,42 @@ class GrpcByteBufferSource
::
google
::
protobuf
::
int64
byte_count_
;
::
google
::
protobuf
::
int64
byte_count_
;
};
};
class
GrpcByteBufferSourceWrapper
:
public
Source
{
public:
GrpcByteBufferSourceWrapper
(
GrpcByteBufferSource
*
source
)
:
source_
(
source
)
{}
virtual
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
override
{
return
source_
;
}
private:
GrpcByteBufferSource
*
source_
;
};
class
GrpcByteSource
:
public
Source
{
public:
explicit
GrpcByteSource
(
grpc_byte_buffer
*
buffer
)
:
buffer_
(
buffer
)
{}
~
GrpcByteSource
()
override
{
DeleteStream
();
}
typedef
::
grpc
::
GrpcBufferReader
Reader
;
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
override
{
DeleteStream
();
stream_
=
new
(
&
space_
)
Reader
(
buffer_
);
return
stream_
;
}
private:
void
DeleteStream
()
{
if
(
stream_
)
{
stream_
->~
Reader
();
}
}
grpc_byte_buffer
*
buffer_
;
// Not owned
Reader
*
stream_
=
nullptr
;
// Points into space_ if non-nullptr
char
space_
[
sizeof
(
Reader
)];
};
}
// namespace detail
}
// namespace detail
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
990d6396
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "grpc_client.h"
#include "grpc_client.h"
#include <sys/time.h>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
...
@@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
framework
::
Async
([
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch
,
this
]
{
framework
::
Async
([
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch
,
this
]
{
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
sendrecv
::
VariableMessage
req
;
SerializeToMessage
(
var_name_val
,
var
,
*
p_ctx
,
&
req
);
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
var_name_val
,
var
,
*
p_ctx
,
&
req
);
// varhandle
// varhandle
VarHandle
var_h
;
VarHandle
var_h
;
...
@@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
s
->
Prepare
(
var_h
,
time_out
);
s
->
Prepare
(
var_h
,
time_out
);
s
->
response_call_back_
=
NULL
;
s
->
response_call_back_
=
NULL
;
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
auto
call
=
std
::
move
(
s
->
stub_g_
.
PrepareUnaryCall
(
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendVariable"
,
req
,
&
cq_
));
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
});
});
req_count_
++
;
req_count_
++
;
...
@@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
}
}
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
sendrecv
::
VariableMessage
&
ret_msg
)
{
// const sendrecv::VariableMessage& ret_msg) {
auto
*
outvar
=
var_h
.
scope
->
FindVar
(
var_h
.
name
);
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
DeserializeFromMessage
(
ret_msg
,
*
var_h
.
ctx
,
outvar
);
framework
::
Variable
*
outvar
=
NULL
;
DeserializeFromByteBuffer
(
ret_msg
,
*
var_h
.
ctx
,
var_h
.
scope
,
outvar
);
}
template
<
typename
T
>
void
RequestToByteBuffer
(
const
T
&
proto
,
::
grpc
::
ByteBuffer
*
result
)
{
::
grpc
::
Slice
slice
(
proto
.
ByteSizeLong
());
proto
.
SerializeWithCachedSizesToArray
(
const_cast
<
uint8_t
*>
(
reinterpret_cast
<
const
uint8_t
*>
(
slice
.
begin
())));
::
grpc
::
ByteBuffer
tmp
(
&
slice
,
1
);
result
->
Swap
(
&
tmp
);
}
}
bool
RPCClient
::
AsyncGetVariable
(
const
std
::
string
&
ep
,
bool
RPCClient
::
AsyncGetVariable
(
const
std
::
string
&
ep
,
...
@@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
s
->
Prepare
(
var_h
,
time_out
);
s
->
Prepare
(
var_h
,
time_out
);
s
->
response_call_back_
=
ProcGetResponse
;
s
->
response_call_back_
=
ProcGetResponse
;
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
::
grpc
::
ByteBuffer
buf
;
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
auto
call
=
std
::
move
(
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/GetVariable"
,
buf
,
&
cq_
));
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
});
});
req_count_
++
;
req_count_
++
;
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
990d6396
...
@@ -25,6 +25,11 @@ limitations under the License. */
...
@@ -25,6 +25,11 @@ limitations under the License. */
#include <string>
#include <string>
#include <vector>
#include <vector>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/grpc++.h>
#include <grpc++/support/byte_buffer.h>
#include <grpc++/support/slice.h>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -49,15 +54,11 @@ struct VarHandle {
...
@@ -49,15 +54,11 @@ struct VarHandle {
}
}
};
};
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
const
sendrecv
::
VariableMessage
&
msg
);
class
BaseProcessor
{
class
BaseProcessor
{
public:
public:
explicit
BaseProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
explicit
BaseProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
context_
=
NULL
;
}
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
context_
=
NULL
;
}
virtual
~
BaseProcessor
()
{}
virtual
~
BaseProcessor
()
{}
...
@@ -82,19 +83,18 @@ class BaseProcessor {
...
@@ -82,19 +83,18 @@ class BaseProcessor {
virtual
void
Process
()
=
0
;
virtual
void
Process
()
=
0
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
grpc
::
ClientContext
>
context_
;
std
::
unique_ptr
<
grpc
::
ClientContext
>
context_
;
grpc
::
Status
status_
;
grpc
::
Status
status_
;
VarHandle
var_h_
;
VarHandle
var_h_
;
};
};
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
sendrecv
::
VoidMessage
&
)
>
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
::
grpc
::
ByteBuffer
&
)
>
RequestSendCallBack
;
RequestSendCallBack
;
class
SendProcessor
:
public
BaseProcessor
{
class
SendProcessor
:
public
BaseProcessor
{
public:
public:
explicit
SendProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
explicit
SendProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
:
BaseProcessor
(
ch
)
,
stub_g_
(
ch
)
{}
virtual
~
SendProcessor
()
{}
virtual
~
SendProcessor
()
{}
...
@@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor {
...
@@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor {
}
}
}
}
sendrecv
::
VoidMessage
reply_
;
::
grpc
::
GenericStub
stub_g_
;
::
grpc
::
ByteBuffer
reply_
;
RequestSendCallBack
response_call_back_
=
NULL
;
RequestSendCallBack
response_call_back_
=
NULL
;
};
};
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
sendrecv
::
VariableMessage
&
)
>
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
::
grpc
::
ByteBuffer
&
)
>
RequestGetCallBack
;
RequestGetCallBack
;
class
GetProcessor
:
public
BaseProcessor
{
class
GetProcessor
:
public
BaseProcessor
{
public:
public:
explicit
GetProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
explicit
GetProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
:
BaseProcessor
(
ch
)
,
stub_g_
(
ch
)
{}
virtual
~
GetProcessor
()
{}
virtual
~
GetProcessor
()
{}
...
@@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor {
...
@@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor {
}
}
}
}
sendrecv
::
VariableMessage
reply_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
GenericStub
stub_g_
;
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
};
};
class
BatchBarrierProcessor
:
public
BaseProcessor
{
class
BatchBarrierProcessor
:
public
BaseProcessor
{
public:
public:
explicit
BatchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
explicit
BatchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
:
BaseProcessor
(
ch
)
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
}
virtual
~
BatchBarrierProcessor
()
{}
virtual
~
BatchBarrierProcessor
()
{}
virtual
void
Process
()
{}
virtual
void
Process
()
{}
sendrecv
::
VoidMessage
reply_
;
sendrecv
::
VoidMessage
reply_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
};
class
FetchBarrierProcessor
:
public
BaseProcessor
{
class
FetchBarrierProcessor
:
public
BaseProcessor
{
public:
public:
explicit
FetchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
explicit
FetchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
:
BaseProcessor
(
ch
)
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
}
virtual
~
FetchBarrierProcessor
()
{}
virtual
~
FetchBarrierProcessor
()
{}
virtual
void
Process
()
{}
virtual
void
Process
()
{}
sendrecv
::
VariableMessage
reply_
;
sendrecv
::
VariableMessage
reply_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
};
class
RPCClient
{
class
RPCClient
{
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
990d6396
...
@@ -14,7 +14,7 @@ limitations under the License. */
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
using
grpc
::
ServerAsyncResponseWriter
;
using
::
grpc
::
ServerAsyncResponseWriter
;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH };
...
@@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH };
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class
RequestBase
{
class
RequestBase
{
public:
public:
explicit
RequestBase
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
)
::
grpc
::
ServerCompletionQueue
*
cq
,
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
)
{
const
platform
::
DeviceContext
*
dev_ctx
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
PADDLE_ENFORCE
(
cq_
);
PADDLE_ENFORCE
(
cq_
);
}
}
virtual
~
RequestBase
()
{}
virtual
~
RequestBase
()
{}
...
@@ -42,55 +43,58 @@ class RequestBase {
...
@@ -42,55 +43,58 @@ class RequestBase {
}
}
protected:
protected:
grpc
::
ServerContext
ctx_
;
::
grpc
::
ServerContext
ctx_
;
sendrecv
::
SendRecv
Service
::
AsyncService
*
service_
;
Grpc
Service
::
AsyncService
*
service_
;
grpc
::
ServerCompletionQueue
*
cq_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
CallStatus
status_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
};
};
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestSend
final
:
public
RequestBase
{
class
RequestSend
final
:
public
RequestBase
{
public:
public:
explicit
RequestSend
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
SimpleBlockQueue
<
MessageWithName
>*
queue
)
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
:
RequestBase
(
service
,
cq
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
const
platform
::
DeviceContext
*
dev_ctx
)
service_
->
RequestSendVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
:
RequestBase
(
service
,
cq
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
this
);
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
}
}
virtual
~
RequestSend
()
{}
virtual
~
RequestSend
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
v
arname
();
}
virtual
std
::
string
GetReqName
()
{
return
request_
->
V
arname
();
}
virtual
void
Process
()
{
virtual
void
Process
()
{
MessageWithName
msg_with_name
=
queue_
->
Push
(
std
::
make_pair
(
request_
->
Varname
(),
request_
));
std
::
make_pair
(
request_
.
varname
(),
std
::
move
(
request_
));
queue_
->
Push
(
std
::
move
(
msg_with_name
))
;
sendrecv
::
VoidMessage
reply
;
responder_
.
Finish
(
reply
_
,
grpc
::
Status
::
OK
,
this
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
status_
=
FINISH
;
}
}
protected:
protected:
sendrecv
::
VariableMessage
request_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
sendrecv
::
VoidMessage
reply_
;
ReceivedQueue
*
queue_
;
SimpleBlockQueue
<
MessageWithName
>*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
};
class
RequestGet
final
:
public
RequestBase
{
class
RequestGet
final
:
public
RequestBase
{
public:
public:
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
const
platform
::
DeviceContext
*
dev_ctx
,
SimpleBlockQueue
<
MessageWithName
>*
queue
)
SimpleBlockQueue
<
MessageWithName
>*
queue
)
:
RequestBase
(
service
,
cq
),
:
RequestBase
(
service
,
cq
,
dev_ctx
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
scope_
(
scope
),
dev_ctx_
(
dev_ctx
),
queue_
(
queue
)
{
queue_
(
queue
)
{
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
}
virtual
~
RequestGet
()
{}
virtual
~
RequestGet
()
{}
...
@@ -101,24 +105,26 @@ class RequestGet final : public RequestBase {
...
@@ -101,24 +105,26 @@ class RequestGet final : public RequestBase {
// proc request.
// proc request.
std
::
string
var_name
=
request_
.
varname
();
std
::
string
var_name
=
request_
.
varname
();
auto
*
var
=
scope_
->
FindVar
(
var_name
);
auto
*
var
=
scope_
->
FindVar
(
var_name
);
::
grpc
::
ByteBuffer
reply
;
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
SerializeTo
Message
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
SerializeTo
ByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
}
}
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply
_
,
grpc
::
Status
::
OK
,
this
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
status_
=
FINISH
;
MessageWithName
msg_with_name
=
// request name reply
if
(
var_name
==
FETCH_BARRIER_MESSAGE
)
{
std
::
make_pair
(
var_name
,
std
::
move
(
reply_
));
sendrecv
::
VariableMessage
msg
;
queue_
->
Push
(
msg_with_name
);
MessageWithName
msg_with_name
=
std
::
make_pair
(
var_name
,
msg
);
queue_
->
Push
(
msg_with_name
);
}
}
}
protected:
protected:
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VariableMessage
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
SimpleBlockQueue
<
MessageWithName
>*
queue_
;
SimpleBlockQueue
<
MessageWithName
>*
queue_
;
};
};
...
@@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
...
@@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
}
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
grpc
::
ServerBuilder
builder
;
::
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
grpc
::
InsecureServerCredentials
());
builder
.
AddListeningPort
(
address_
,
::
grpc
::
InsecureServerCredentials
());
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
RegisterService
(
&
service_
);
builder
.
RegisterService
(
&
service_
);
...
@@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
...
@@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
if
(
is_shut_down_
)
{
if
(
is_shut_down_
)
{
return
;
return
;
}
}
RequestSend
*
send
=
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
scope_
,
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
&
var_recv_queue
_
);
&
var_recv_queue_
,
dev_ctx
_
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
}
...
@@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
...
@@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
}
// FIXME(typhoonzero): change cq_name to enum.
// FIXME(typhoonzero): change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
grpc
::
ServerCompletionQueue
*
cq
,
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
TryToRegisterNewOne
();
TryToRegisterNewOne
();
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
990d6396
...
@@ -14,28 +14,35 @@ limitations under the License. */
...
@@ -14,28 +14,35 @@ limitations under the License. */
#pragma once
#pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
#include "paddle/fluid/operators/detail/grpc_service.h"
#include <grpc/support/log.h>
#include <thread>
//#include <grpc/support/log.h>
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
VariableResponse
>>
ReceivedMessage
;
typedef
SimpleBlockQueue
<
ReceivedMessage
>
ReceivedQueue
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
RequestBase
;
class
AsyncGRPCServer
final
:
public
sendrecv
::
SendRecvService
::
Service
{
class
AsyncGRPCServer
final
{
public:
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{}
...
@@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
const
MessageWithNam
e
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
const
ReceivedMessag
e
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
void
Push
(
const
MessageWithName
&
msg
)
{
this
->
var_recv_queue_
.
Push
(
msg
);
}
void
Push
(
const
std
::
string
&
msg_name
)
{
this
->
var_recv_queue_
.
Push
(
std
::
make_pair
(
msg_name
,
nullptr
));
}
void
ShutDown
();
void
ShutDown
();
protected:
protected:
void
HandleRequest
(
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
std
::
function
<
void
()
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewGetOne
();
void
TryToRegisterNewGetOne
();
...
@@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
private:
private:
std
::
mutex
cq_mutex_
;
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<
::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<
::
grpc
::
ServerCompletionQueue
>
cq_get_
;
sendrecv
::
SendRecv
Service
::
AsyncService
service_
;
Grpc
Service
::
AsyncService
service_
;
std
::
unique_ptr
<
grpc
::
Server
>
server_
;
std
::
unique_ptr
<
::
grpc
::
Server
>
server_
;
std
::
string
address_
;
std
::
string
address_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
MessageWithName
>
var_get_queue_
;
SimpleBlockQueue
<
MessageWithName
>
var_get_queue_
;
ReceivedQueue
var_recv_queue_
;
// condition of the sub program
// condition of the sub program
std
::
mutex
barrier_mutex_
;
std
::
mutex
barrier_mutex_
;
...
...
paddle/fluid/operators/detail/grpc_service.h
0 → 100644
浏览文件 @
990d6396
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h"
// NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC
// requests without too much copying of the tensor data.
namespace
grpc
{
class
CompletionQueue
;
class
Channel
;
class
RpcService
;
class
ServerCompletionQueue
;
class
ServerContext
;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template
<
>
class
SerializationTraits
<
paddle
::
operators
::
detail
::
VariableResponse
>
{
public:
static
Status
Serialize
(
const
paddle
::
operators
::
detail
::
VariableResponse
&
msg
,
grpc_byte_buffer
**
bp
,
bool
*
own_buffer
)
{
PADDLE_ENFORCE
(
false
,
"SerializationTraits::Serialize not implemented!"
);
return
Status
();
}
static
Status
Deserialize
(
grpc_byte_buffer
*
buffer
,
paddle
::
operators
::
detail
::
VariableResponse
*
msg
,
int
max_message_size
=
INT_MAX
)
{
if
(
buffer
==
nullptr
)
{
return
Status
(
StatusCode
::
INTERNAL
,
"No payload"
);
}
Status
result
=
g_core_codegen_interface
->
ok
();
if
(
result
.
ok
())
{
paddle
::
operators
::
detail
::
GrpcByteSource
source
(
buffer
);
int
ret
=
msg
->
Parse
(
&
source
);
if
(
ret
!=
0
)
{
result
=
Status
(
StatusCode
::
INTERNAL
,
"VariableResponse parse error"
);
}
}
g_core_codegen_interface
->
grpc_byte_buffer_destroy
(
buffer
);
return
result
;
}
};
}
// namespace grpc
namespace
paddle
{
namespace
operators
{
namespace
detail
{
enum
class
GrpcMethod
{
kSendVariable
,
kGetVariable
,
};
static
const
int
kGrpcNumMethods
=
static_cast
<
int
>
(
GrpcMethod
::
kGetVariable
)
+
1
;
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
switch
(
id
)
{
case
GrpcMethod
::
kSendVariable
:
return
"/sendrecv.SendRecvService/SendVariable"
;
case
GrpcMethod
::
kGetVariable
:
return
"/sendrecv.SendRecvService/GetVariable"
;
}
// Shouldn't be reached.
PADDLE_ENFORCE
(
false
,
"Invalid id: not found valid method name"
);
return
nullptr
;
}
class
GrpcService
final
{
public:
class
AsyncService
:
public
::
grpc
::
Service
{
public:
AsyncService
()
{
for
(
int
i
=
0
;
i
<
kGrpcNumMethods
;
++
i
)
{
AddMethod
(
new
::
grpc
::
internal
::
RpcServiceMethod
(
GrpcMethodName
(
static_cast
<
GrpcMethod
>
(
i
)),
::
grpc
::
internal
::
RpcMethod
::
NORMAL_RPC
,
nullptr
));
::
grpc
::
Service
::
MarkMethodAsync
(
i
);
}
}
virtual
~
AsyncService
()
{}
// Make RequestAsyncUnary public for grpc_call.h
using
::
grpc
::
Service
::
RequestAsyncUnary
;
};
};
}
// namespace detail
}
// namespace operator
}
// namespace paddle
paddle/fluid/operators/detail/send_recv.proto
浏览文件 @
990d6396
...
@@ -32,6 +32,9 @@ enum VarType {
...
@@ -32,6 +32,9 @@ enum VarType {
SELECTED_ROWS
=
1
;
SELECTED_ROWS
=
1
;
}
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message
VariableMessage
{
message
VariableMessage
{
enum
Type
{
enum
Type
{
// Pod Types
// Pod Types
...
@@ -45,7 +48,6 @@ message VariableMessage {
...
@@ -45,7 +48,6 @@ message VariableMessage {
}
}
message
LodData
{
repeated
int64
lod_data
=
1
;
}
message
LodData
{
repeated
int64
lod_data
=
1
;
}
string
varname
=
1
;
string
varname
=
1
;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType
type
=
2
;
VarType
type
=
2
;
...
@@ -64,3 +66,5 @@ message VariableMessage {
...
@@ -64,3 +66,5 @@ message VariableMessage {
}
}
message
VoidMessage
{}
message
VoidMessage
{}
message
TestMessage
{
int64
test_1
=
1
;
}
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
990d6396
...
@@ -13,61 +13,19 @@ See the License for the specific language governing permissions and
...
@@ -13,61 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include <sys/time.h>
#include <thread>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
void
SerializeToMessage
(
const
std
::
string
&
name
,
const
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
sendrecv
::
VariableMessage
*
msg
)
{
msg
->
set_varname
(
name
);
std
::
ostringstream
oss
;
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
msg
->
set_type
(
sendrecv
::
VarType
::
LOD_TENSOR
);
framework
::
SerializeToStream
(
oss
,
var
->
Get
<
framework
::
LoDTensor
>
(),
ctx
);
break
;
case
framework
::
proto
::
VarType_Type_SELECTED_ROWS
:
msg
->
set_type
(
sendrecv
::
VarType
::
SELECTED_ROWS
);
framework
::
SerializeToStream
(
oss
,
var
->
Get
<
framework
::
SelectedRows
>
(),
ctx
);
break
;
default:
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
break
;
}
}
msg
->
set_serialized
(
oss
.
str
());
}
void
DeserializeFromMessage
(
const
sendrecv
::
VariableMessage
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Variable
*
var
)
{
std
::
istringstream
iss
(
msg
.
serialized
());
switch
(
msg
.
type
())
{
case
sendrecv
::
VarType
::
LOD_TENSOR
:
DeserializeFromStream
(
iss
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
ctx
);
break
;
case
sendrecv
::
VarType
::
SELECTED_ROWS
:
{
DeserializeFromStream
(
iss
,
var
->
GetMutable
<
framework
::
SelectedRows
>
(),
ctx
);
break
;
}
default:
{
PADDLE_THROW
(
"Deserialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
break
;
}
}
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
)
{
::
grpc
::
ByteBuffer
*
msg
)
{
...
@@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
memory_size
();
auto
copy_size
=
tensor
.
memory_size
();
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
...
@@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
memory
::
Free
(
cpu
,
backing
);
};
};
#endif
#endif
}
else
{
}
else
{
payload
=
tensor
.
data
<
void
>
();
payload
=
tensor
.
data
<
void
>
();
...
@@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Variable
*
var
)
{
const
framework
::
Scope
*
scope
,
sendrecv
::
VariableMessage
meta
;
framework
::
Variable
*&
var
)
{
GrpcByteBufferSource
source
;
operators
::
detail
::
VariableResponse
resp
(
scope
,
&
ctx
);
source
.
Init
(
msg
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
::
google
::
protobuf
::
io
::
CodedInputStream
input
(
&
source
);
var
=
resp
.
GetVar
();
// do zerocopy parsing
PADDLE_ENFORCE
(
meta
.
ParseFromCodedStream
(
&
input
));
PADDLE_ENFORCE
(
input
.
ConsumedEntireMessage
());
// dims is needed by both tensor and selectedrows
std
::
vector
<
int
>
vecdims
;
for
(
auto
&
d
:
meta
.
dims
())
{
vecdims
.
push_back
(
d
);
}
framework
::
DDim
dims
=
framework
::
make_ddim
(
vecdims
);
if
(
meta
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
dims
);
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
paddle
::
operators
::
detail
::
ToTypeIndex
(
meta
.
data_type
()));
framework
::
LoD
lod
;
for
(
int
i
=
0
;
i
<
meta
.
lod_level
();
++
i
)
{
framework
::
Vector
<
size_t
>
v
;
for
(
int
j
=
0
;
j
<
meta
.
lod
(
i
).
lod_data_size
();
++
j
)
{
v
.
push_back
(
meta
.
lod
(
i
).
lod_data
(
j
));
}
lod
.
push_back
(
v
);
}
tensor
->
set_lod
(
lod
);
// How to avoid copying and use the message buffer directly?
// Maybe need to find a way to release all memory except tensor content.
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
tensor_data
,
cpu
,
reinterpret_cast
<
const
void
*>
(
meta
.
serialized
().
data
()),
meta
.
serialized
().
size
(),
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
#endif
}
else
{
memcpy
(
tensor_data
,
reinterpret_cast
<
const
void
*>
(
meta
.
serialized
().
data
()),
meta
.
serialized
().
size
());
}
}
else
if
(
meta
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor
=
slr
->
mutable_value
();
int64_t
*
rows_data
=
slr
->
mutable_rows
()
->
data
();
tensor
->
Resize
(
dims
);
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
paddle
::
operators
::
detail
::
ToTypeIndex
(
meta
.
data_type
()));
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
tensor_data
,
cpu
,
reinterpret_cast
<
const
void
*>
(
meta
.
serialized
().
data
()),
meta
.
serialized
().
size
(),
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
#endif
}
else
{
memcpy
(
tensor_data
,
reinterpret_cast
<
const
void
*>
(
meta
.
serialized
().
data
()),
meta
.
serialized
().
size
());
}
// copy rows CPU data, GPU data will be copied lazly
memcpy
(
rows_data
,
reinterpret_cast
<
const
void
*>
(
meta
.
rows
().
data
()),
meta
.
rows
().
size
());
}
}
}
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/detail/sendrecvop_utils.h
浏览文件 @
990d6396
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
...
@@ -36,21 +37,14 @@ namespace detail {
...
@@ -36,21 +37,14 @@ namespace detail {
typedef
void
(
*
DestroyCallback
)(
void
*
);
typedef
void
(
*
DestroyCallback
)(
void
*
);
void
SerializeToMessage
(
const
std
::
string
&
name
,
const
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
sendrecv
::
VariableMessage
*
msg
);
void
DeserializeFromMessage
(
const
sendrecv
::
VariableMessage
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Variable
*
var
);
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
);
::
grpc
::
ByteBuffer
*
msg
);
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Variable
*
var
);
const
framework
::
Scope
*
scope
,
framework
::
Variable
*&
var
);
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
...
...
paddle/fluid/operators/detail/test_serde.cc
浏览文件 @
990d6396
...
@@ -16,11 +16,13 @@ limitations under the License. */
...
@@ -16,11 +16,13 @@ limitations under the License. */
#include <string>
#include <string>
#include <thread>
#include <thread>
#include <google/protobuf/text_format.h>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
...
@@ -31,19 +33,21 @@ namespace operators = paddle::operators;
...
@@ -31,19 +33,21 @@ namespace operators = paddle::operators;
namespace
math
=
paddle
::
operators
::
math
;
namespace
math
=
paddle
::
operators
::
math
;
namespace
memory
=
paddle
::
memory
;
namespace
memory
=
paddle
::
memory
;
void
RunSerdeTestTensor
(
platform
::
Place
place
)
{
void
RunSerdeTestSelectedRows
(
platform
::
Place
place
)
{
// serialize var to ByteBuffer
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
4
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
int
tensor_numel
=
4
*
8
*
4
*
2
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
// serialize var to ByteBuffer
framework
::
Variable
var
;
auto
*
slr
=
var
.
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
tensor
->
Resize
(
framework
::
make_ddim
({
2
,
10
}));
tensor
->
mutable_data
<
float
>
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
math
::
set_constant
(
ctx
,
tensor
,
31.9
);
int
tensor_numel
=
2
*
10
;
math
::
set_constant
(
ctx
,
tensor
,
32.7
);
rows
->
push_back
(
3
);
rows
->
push_back
(
10
);
::
grpc
::
ByteBuffer
msg
;
::
grpc
::
ByteBuffer
msg
;
operators
::
detail
::
SerializeToByteBuffer
(
"myvar"
,
&
var
,
ctx
,
&
msg
);
operators
::
detail
::
SerializeToByteBuffer
(
"myvar"
,
&
var
,
ctx
,
&
msg
);
...
@@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) {
...
@@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) {
for
(
const
auto
&
s
:
slices
)
{
for
(
const
auto
&
s
:
slices
)
{
tmp
.
append
(
reinterpret_cast
<
const
char
*>
(
s
.
begin
()),
s
.
size
());
tmp
.
append
(
reinterpret_cast
<
const
char
*>
(
s
.
begin
()),
s
.
size
());
}
}
sendrecv
::
VariableMessage
varmsg
;
sendrecv
::
VariableMessage
varmsg
;
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
type
(),
1
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
varmsg
.
lod_level
(),
1
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
0
),
1
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
1
),
3
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
2
),
8
);
const
float
*
tensor_data
=
const
float
*
tensor_data
=
reinterpret_cast
<
const
float
*>
(
varmsg
.
serialized
().
data
());
reinterpret_cast
<
const
float
*>
(
varmsg
.
serialized
().
data
());
const
int64_t
*
rows_data
=
reinterpret_cast
<
const
int64_t
*>
(
varmsg
.
rows
().
data
());
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_FLOAT_EQ
(
tensor_data
[
i
],
3
1.9
);
EXPECT_FLOAT_EQ
(
tensor_data
[
i
],
3
2.7
);
}
}
EXPECT_EQ
(
rows_data
[
0
],
3
);
EXPECT_EQ
(
rows_data
[
1
],
10
);
// deserialize zero-copy
// deserialize zero-copy
framework
::
Variable
var2
;
// framework::Variable var2;
operators
::
detail
::
DeserializeFromByteBuffer
(
msg
,
ctx
,
&
var2
);
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto
tensor2
=
var2
.
Get
<
framework
::
LoDTensor
>
();
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
detail
::
TensorResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
auto
*
slr2
=
var2
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor2
=
slr2
->
mutable_value
();
auto
*
rows2
=
slr2
->
mutable_rows
();
float
*
tensor_data2
=
nullptr
;
float
*
tensor_data2
=
nullptr
;
framework
::
Tensor
tmp_tensor
;
framework
::
Tensor
tmp_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
framework
::
TensorCopy
(
tensor2
,
cpu
,
&
tmp_tensor
);
framework
::
TensorCopy
(
*
tensor2
,
cpu
,
&
tmp_tensor
);
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
}
else
{
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
.
data
<
float
>
());
tensor_data2
=
const_cast
<
float
*>
(
tensor2
->
data
<
float
>
());
}
}
const
int64_t
*
rows_data2
=
rows2
->
data
();
EXPECT_EQ
(
varmsg
.
lod_level
(),
1
);
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
0
),
1
);
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
32.7
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
1
),
3
);
}
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
2
),
8
);
EXPECT_EQ
(
rows_data2
[
0
],
3
);
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
31.9
);
EXPECT_EQ
(
rows_data2
[
1
],
10
);
}
}
void
RunSerdeTestSelectedRows
(
platform
::
Place
place
)
{
void
RunTestLodTensor
(
platform
::
Place
place
,
int
from_type
=
0
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// serialize var to ByteBuffer
// serialize var to ByteBuffer
framework
::
Variable
var
;
framework
::
Variable
var
;
auto
*
slr
=
var
.
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
slr
->
mutable_value
();
tensor
->
Resize
(
framework
::
make_ddim
({
4
,
8
,
4
,
2
}));
auto
*
rows
=
slr
->
mutable_rows
();
framework
::
LoD
lod
;
tensor
->
Resize
(
framework
::
make_ddim
({
2
,
10
}));
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
int
tensor_numel
=
4
*
8
*
4
*
2
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
int
tensor_numel
=
2
*
10
;
math
::
set_constant
(
ctx
,
tensor
,
31.9
);
math
::
set_constant
(
ctx
,
tensor
,
32.7
);
rows
->
push_back
(
3
);
rows
->
push_back
(
10
);
::
grpc
::
ByteBuffer
msg
;
::
grpc
::
ByteBuffer
msg
;
operators
::
detail
::
SerializeToByteBuffer
(
"myvar"
,
&
var
,
ctx
,
&
msg
);
operators
::
detail
::
SerializeToByteBuffer
(
"myvar"
,
&
var
,
ctx
,
&
msg
);
...
@@ -126,43 +135,75 @@ void RunSerdeTestSelectedRows(platform::Place place) {
...
@@ -126,43 +135,75 @@ void RunSerdeTestSelectedRows(platform::Place place) {
}
}
sendrecv
::
VariableMessage
varmsg
;
sendrecv
::
VariableMessage
varmsg
;
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
type
(),
1
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
varmsg
.
lod_level
(),
1
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
0
),
1
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
1
),
3
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
2
),
8
);
const
float
*
tensor_data
=
const
float
*
tensor_data
=
reinterpret_cast
<
const
float
*>
(
varmsg
.
serialized
().
data
());
reinterpret_cast
<
const
float
*>
(
varmsg
.
serialized
().
data
());
const
int64_t
*
rows_data
=
reinterpret_cast
<
const
int64_t
*>
(
varmsg
.
rows
().
data
());
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_FLOAT_EQ
(
tensor_data
[
i
],
3
2.7
);
EXPECT_FLOAT_EQ
(
tensor_data
[
i
],
3
1.9
);
}
}
EXPECT_EQ
(
rows_data
[
0
],
3
);
EXPECT_EQ
(
rows_data
[
1
],
10
);
// message binary
std
::
string
str
;
varmsg
.
SerializeToString
(
&
str
);
// message bytebuffer
::
grpc
::
Slice
slices_2
[
1
];
int
num_slices
=
1
;
slices_2
[
0
]
=
::
grpc
::
Slice
(
str
.
length
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices_2
[
0
].
begin
()),
str
.
c_str
(),
str
.
length
());
::
grpc
::
ByteBuffer
bytebuffer2
(
&
slices_2
[
0
],
num_slices
);
// deserialize zero-copy
// deserialize zero-copy
framework
::
Variable
var2
;
framework
::
Scope
scope
;
operators
::
detail
::
DeserializeFromByteBuffer
(
msg
,
ctx
,
&
var2
);
scope
.
Var
(
"myvar"
);
operators
::
detail
::
TensorResponse
resp
(
&
scope
,
&
ctx
);
if
(
from_type
==
0
)
{
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
}
else
{
EXPECT_EQ
(
resp
.
Parse
(
bytebuffer2
),
0
);
}
auto
*
slr2
=
var2
.
GetMutable
<
framework
::
SelectedRows
>
();
framework
::
Variable
*
var2
=
resp
.
GetVar
();
auto
*
tensor2
=
slr2
->
mutable_value
();
auto
*
rows2
=
slr2
->
mutable_rows
();
auto
tensor2
=
var2
->
Get
<
framework
::
LoDTensor
>
();
float
*
tensor_data2
=
nullptr
;
float
*
tensor_data2
=
nullptr
;
framework
::
Tensor
tmp_tensor
;
framework
::
Tensor
tmp_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
framework
::
TensorCopy
(
*
tensor2
,
cpu
,
&
tmp_tensor
);
framework
::
TensorCopy
(
tensor2
,
cpu
,
&
tmp_tensor
);
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
}
else
{
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
->
data
<
float
>
());
tensor_data2
=
const_cast
<
float
*>
(
tensor2
.
data
<
float
>
());
}
}
const
int64_t
*
rows_data2
=
rows2
->
data
();
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_EQ
(
varmsg
.
lod_level
(),
1
);
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
32.7
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
0
),
1
);
}
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
1
),
3
);
EXPECT_EQ
(
rows_data2
[
0
],
3
);
EXPECT_EQ
(
varmsg
.
lod
(
0
).
lod_data
(
2
),
8
);
EXPECT_EQ
(
rows_data2
[
1
],
10
);
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
31.9
);
}
TEST
(
LodTensor
,
GPU
)
{
platform
::
CUDAPlace
place
;
RunTestLodTensor
(
place
);
RunTestLodTensor
(
place
,
1
);
}
TEST
(
LodTensor
,
CPU
)
{
platform
::
CPUPlace
place
;
RunTestLodTensor
(
place
);
RunTestLodTensor
(
place
,
1
);
}
}
TEST
(
SelectedRows
,
CPU
)
{
TEST
(
SelectedRows
,
CPU
)
{
...
@@ -174,13 +215,3 @@ TEST(SelectedRows, GPU) {
...
@@ -174,13 +215,3 @@ TEST(SelectedRows, GPU) {
platform
::
CUDAPlace
place
;
platform
::
CUDAPlace
place
;
RunSerdeTestSelectedRows
(
place
);
RunSerdeTestSelectedRows
(
place
);
}
}
TEST
(
Tensor
,
CPU
)
{
platform
::
CPUPlace
place
;
RunSerdeTestTensor
(
place
);
}
TEST
(
Tensor
,
GPU
)
{
platform
::
CUDAPlace
place
;
RunSerdeTestTensor
(
place
);
}
\ No newline at end of file
paddle/fluid/operators/detail/variable_response.cc
0 → 100644
浏览文件 @
990d6396
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h"
#include <string.h>
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
enum
WireType
{
WIRETYPE_VARINT
=
0
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
};
inline
int
GetTagFieldNumber
(
uint32_t
tag
)
{
return
tag
>>
3
;
}
inline
WireType
GetTagWireType
(
uint32_t
tag
)
{
return
static_cast
<
WireType
>
(
tag
&
0x7
);
}
bool
ReadVarintSizeAsInt
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int
*
result
)
{
uint64_t
v
;
if
(
input
->
ReadVarint64
(
&
v
)
&&
v
<=
static_cast
<
uint64_t
>
(
INT_MAX
))
{
*
result
=
static_cast
<
int
>
(
v
);
return
true
;
}
else
{
return
false
;
}
}
bool
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
int
size
)
{
const
void
*
data
=
NULL
;
int
size_to_write
=
0
;
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_WITH_CUDA
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
platform
::
CPUPlace
cpu
;
char
*
p
=
reinterpret_cast
<
char
*>
(
dest
);
while
(
size
>
0
)
{
if
(
!
input
->
GetDirectBufferPointer
(
&
data
,
&
size_to_write
))
{
return
false
;
}
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
reinterpret_cast
<
void
*>
(
p
),
cpu
,
data
,
size_to_write
,
gpu_dev_ctx
.
stream
());
p
+=
size_to_write
;
size
-=
size_to_write
;
input
->
Skip
(
size_to_write
);
}
gpu_dev_ctx
.
Wait
();
#else
PADDLE_THROW
(
"Unexpected branch"
);
#endif
return
true
;
}
char
*
p
=
reinterpret_cast
<
char
*>
(
dest
);
while
(
size
>
0
)
{
if
(
!
input
->
GetDirectBufferPointer
(
&
data
,
&
size_to_write
))
{
return
false
;
}
// TODO(gongwb): can we avoid copy?
platform
::
CPUPlace
cpu
;
memory
::
Copy
(
cpu
,
reinterpret_cast
<
void
*>
(
p
),
cpu
,
data
,
size_to_write
);
p
+=
size_to_write
;
size
-=
size_to_write
;
input
->
Skip
(
size_to_write
);
}
return
true
;
}
bool
VariableResponse
::
CopyLodTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
)
{
auto
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
dims
);
framework
::
LoD
lod
;
for
(
int
i
=
0
;
i
<
meta_
.
lod_level
();
++
i
)
{
framework
::
Vector
<
size_t
>
v
;
for
(
int
j
=
0
;
j
<
meta_
.
lod
(
i
).
lod_data_size
();
++
j
)
{
v
.
push_back
(
meta_
.
lod
(
i
).
lod_data
(
j
));
}
lod
.
push_back
(
v
);
}
tensor
->
set_lod
(
lod
);
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
ToTypeIndex
(
meta_
.
data_type
()));
if
(
!
ReadRaw
(
input
,
ctx
,
tensor
->
place
(),
tensor_data
,
length
))
{
return
false
;
}
return
true
;
}
inline
framework
::
DDim
GetDims
(
const
::
google
::
protobuf
::
RepeatedField
<::
google
::
protobuf
::
int64
>&
dims
)
{
std
::
vector
<
int
>
vecdims
;
for
(
auto
&
d
:
dims
)
{
vecdims
.
push_back
(
d
);
}
return
framework
::
make_ddim
(
vecdims
);
}
bool
VariableResponse
::
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
)
{
auto
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor
=
slr
->
mutable_value
();
tensor
->
Resize
(
dims
);
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
paddle
::
operators
::
detail
::
ToTypeIndex
(
meta_
.
data_type
()));
if
(
!
ReadRaw
(
input
,
ctx
,
tensor
->
place
(),
tensor_data
,
length
))
{
return
false
;
}
return
true
;
}
bool
VariableResponse
::
CopySelectRowsData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
auto
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
int64_t
*
rows_data
=
slr
->
mutable_rows
()
->
data
();
// copy rows CPU data, GPU data will be copied lazily.
platform
::
CPUPlace
cpu
;
if
(
!
ReadRaw
(
input
,
ctx
,
cpu
,
rows_data
,
length
))
{
return
false
;
}
return
true
;
}
bool
ParseLodData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
std
::
vector
<
int64_t
>*
lod
)
{
while
(
true
)
{
auto
p
=
input
->
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
return
(
tag
==
0
);
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage_LodData
::
kLodDataFieldNumber
:
{
uint64_t
v
;
if
(
wt
==
WIRETYPE_VARINT
)
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
}
lod
->
push_back
(
v
);
break
;
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
length
))
{
return
tag
;
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
}
lod
->
push_back
(
v
);
}
break
;
}
return
false
;
}
default:
{
return
false
;
}
}
}
return
true
;
}
int
VariableResponse
::
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
)
{
GrpcByteBufferSource
source
;
source
.
Init
(
byte_buffer
);
GrpcByteBufferSourceWrapper
r
(
&
source
);
return
Parse
(
&
r
);
}
int
VariableResponse
::
Parse
(
Source
*
source
)
{
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
::
google
::
protobuf
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
true
)
{
auto
p
=
input
.
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
if
(
tag
!=
0
)
{
return
-
1
;
}
return
0
;
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage
::
kVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint64_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_type
(
static_cast
<::
sendrecv
::
VarType
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_data_type
(
static_cast
<::
sendrecv
::
VariableMessage_Type
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDimsFieldNumber
:
{
// not packed
if
(
wt
==
WIRETYPE_VARINT
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
break
;
}
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
length
))
{
return
tag
;
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_lod_level
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kLodFieldNumber
:
{
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
std
::
pair
<::
google
::
protobuf
::
io
::
CodedInputStream
::
Limit
,
int
>
p
=
input
.
IncrementRecursionDepthAndPushLimit
(
length
);
std
::
vector
<
int64_t
>
lod_data
;
if
(
p
.
second
<
0
||
!
ParseLodData
(
&
input
,
&
lod_data
))
{
return
tag
;
}
if
(
!
input
.
DecrementRecursionDepthAndPopLimit
(
p
.
first
))
{
return
false
;
}
if
(
lod_data
.
size
()
==
0
)
{
break
;
}
auto
lod
=
meta_
.
add_lod
();
for
(
uint32_t
i
=
0
;
i
<
lod_data
.
size
();
i
++
)
{
lod
->
add_lod_data
(
lod_data
[
i
]);
}
break
;
}
case
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
framework
::
DDim
dims
=
GetDims
(
meta_
.
dims
());
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
return
tag
;
}
break
;
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
return
tag
;
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
length
))
{
return
tag
;
}
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
}
}
}
return
0
;
}
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/variable_response.h
0 → 100644
浏览文件 @
990d6396
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
VariableResponse
{
public:
VariableResponse
(
const
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
scope_
(
scope
),
dev_ctx_
(
dev_ctx
){};
virtual
~
VariableResponse
(){};
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int
Parse
(
Source
*
source
);
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
// should call parse first.
framework
::
Variable
*
GetVar
()
{
return
scope_
->
FindVar
(
meta_
.
varname
());
}
private:
bool
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
);
bool
CopySelectRowsData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
int
length
);
bool
CopyLodTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
);
private:
const
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// only Skeleton
sendrecv
::
VariableMessage
meta_
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
990d6396
...
@@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
void
Stop
()
override
{
void
Stop
()
override
{
detail
::
MessageWithName
term_msg
;
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
term_msg
.
first
=
LISTEN_TERMINATE_MESSAGE
;
rpc_service_
->
Push
(
term_msg
);
rpc_service_
->
ShutDown
();
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
server_thread_
->
join
();
}
}
...
@@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
size_t
recv_var_cnt
=
0
;
size_t
recv_var_cnt
=
0
;
int
batch_barrier
=
0
;
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
while
(
batch_barrier
!=
fan_in
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
LOG
(
INFO
)
<<
"received terminate message and exit"
;
...
@@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase {
}
else
{
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
recv_var_cnt
++
;
recv_var_cnt
++
;
auto
*
var
=
recv_scope
.
FindVar
(
recv_var_name
);
auto
var
=
v
.
second
->
GetVar
(
);
if
(
var
==
nullptr
)
{
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
PADDLE_THROW
(
"Can not find server side var"
);
}
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
sparse_vars
.
push_back
(
var
);
}
}
...
...
python/paddle/fluid/debuger.py
浏览文件 @
990d6396
...
@@ -16,7 +16,6 @@ import sys
...
@@ -16,7 +16,6 @@ import sys
import
re
import
re
from
graphviz
import
GraphPreviewGenerator
from
graphviz
import
GraphPreviewGenerator
import
proto.framework_pb2
as
framework_pb2
import
proto.framework_pb2
as
framework_pb2
import
paddle.fluid.core
as
core
_vartype2str_
=
[
_vartype2str_
=
[
"UNK"
,
"UNK"
,
...
@@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False):
...
@@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False):
def
is_var_backward
(
var_desc
):
def
is_var_backward
(
var_desc
):
return
"@GRAD"
in
var_desc
.
name
return
"@GRAD"
in
var_desc
.
name
#print(type(block_desc))
if
type
(
block_desc
)
is
not
framework_pb2
.
BlockDesc
:
if
type
(
block_desc
)
is
not
framework_pb2
.
BlockDesc
:
block_desc
=
framework_pb2
.
BlockDesc
.
FromString
(
block_desc
=
framework_pb2
.
BlockDesc
.
FromString
(
block_desc
.
serialize_to_string
())
block_desc
.
serialize_to_string
())
...
...
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
990d6396
...
@@ -20,6 +20,7 @@ from layer_helper import LayerHelper
...
@@ -20,6 +20,7 @@ from layer_helper import LayerHelper
from
distributed_spliter
import
*
from
distributed_spliter
import
*
import
math
import
math
from
.
import
core
from
.
import
core
import
debuger
class
VarBlock
:
class
VarBlock
:
...
@@ -289,6 +290,7 @@ class DistributeTranspiler:
...
@@ -289,6 +290,7 @@ class DistributeTranspiler:
dtype
=
v
.
dtype
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
shape
=
v
.
shape
)
recv_inputs
.
append
(
var
)
recv_inputs
.
append
(
var
)
# step3
# step3
optimize_block
=
pserver_program
.
create_block
(
0
)
optimize_block
=
pserver_program
.
create_block
(
0
)
# step 4
# step 4
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录