Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3a6213f4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a6213f4
编写于
7月 20, 2018
作者:
G
gongweibao
提交者:
GitHub
7月 20, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change grpc interface to compatible with brpc. (#12164)
上级
b0630938
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
748 addition
and
550 deletion
+748
-550
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+30
-20
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc
+1
-1
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h
+1
-19
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+1
-0
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+3
-17
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+157
-0
paddle/fluid/operators/distributed/grpc_serde.h
paddle/fluid/operators/distributed/grpc_serde.h
+50
-0
paddle/fluid/operators/distributed/grpc_serde_test.cc
paddle/fluid/operators/distributed/grpc_serde_test.cc
+5
-3
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+11
-10
paddle/fluid/operators/distributed/grpc_service.h
paddle/fluid/operators/distributed/grpc_service.h
+5
-5
paddle/fluid/operators/distributed/grpc_variable_response.cc
paddle/fluid/operators/distributed/grpc_variable_response.cc
+308
-0
paddle/fluid/operators/distributed/grpc_variable_response.h
paddle/fluid/operators/distributed/grpc_variable_response.h
+58
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+17
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+2
-3
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+2
-1
paddle/fluid/operators/distributed/sendrecvop_utils.cc
paddle/fluid/operators/distributed/sendrecvop_utils.cc
+13
-131
paddle/fluid/operators/distributed/sendrecvop_utils.h
paddle/fluid/operators/distributed/sendrecvop_utils.h
+7
-10
paddle/fluid/operators/distributed/variable_response.cc
paddle/fluid/operators/distributed/variable_response.cc
+38
-313
paddle/fluid/operators/distributed/variable_response.h
paddle/fluid/operators/distributed/variable_response.h
+39
-17
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
3a6213f4
if
(
NOT WITH_DISTRIBUTE
)
return
()
endif
()
if
(
WITH_GRPC
)
set
(
cc_generic_services
"false"
)
else
()
set
(
cc_generic_services
"true"
)
endif
()
configure_file
(
send_recv.proto.in
${
CMAKE_CURRENT_SOURCE_DIR
}
/send_recv.proto @ONLY
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
selected_rows memory
)
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
cc_test
(
grpc_server_test SRCS rpc_server_test.cc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL
)
proto_desc lookup_table_op SERIAL
)
return
()
return
()
endif
()
endif
()
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
DEPS lod_tensor selected_rows memory
)
find_library
(
OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so
)
set
(
brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy
)
ADD_LIBRARY
(
crypto SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
OPENSSL_CRYPTO_LIBRARY_STATIC
}
)
find_library
(
OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so
)
cc_test
(
brpc_server_test SRCS rpc_server_test.cc
ADD_LIBRARY
(
ssl SHARED IMPORTED GLOBAL
)
DEPS
${
brpc_test_depends
}
SERIAL
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
OPENSSL_SSL_LIBRARY_STATIC
}
)
cc_test
(
brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
brpc protobuf leveldb gflags glog
DEPS
${
brpc_test_depends
}
SERIAL
)
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL
)
paddle/fluid/operators/distributed/bytebuffer_stream.cc
→
paddle/fluid/operators/distributed/
grpc_
bytebuffer_stream.cc
浏览文件 @
3a6213f4
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
// requests without too much copying of the tensor data.
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/
grpc_
bytebuffer_stream.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/distributed/bytebuffer_stream.h
→
paddle/fluid/operators/distributed/
grpc_
bytebuffer_stream.h
浏览文件 @
3a6213f4
...
@@ -24,6 +24,7 @@ limitations under the License. */
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h"
#include "grpc++/grpc++.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace
grpc
{
namespace
grpc
{
// A ZeroCopyInputStream that reads from grpc_byte_buffer
// A ZeroCopyInputStream that reads from grpc_byte_buffer
...
@@ -107,25 +108,6 @@ class GrpcBufferReader final
...
@@ -107,25 +108,6 @@ class GrpcBufferReader final
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class
Source
{
public:
virtual
~
Source
()
{}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
=
0
;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class
GrpcByteBufferSource
class
GrpcByteBufferSource
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
3a6213f4
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG
#include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
3a6213f4
...
@@ -38,7 +38,10 @@ limitations under the License. */
...
@@ -38,7 +38,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...
@@ -46,23 +49,6 @@ namespace paddle {
...
@@ -46,23 +49,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
struct
VarHandle
{
// RPC endpoint.
std
::
string
ep
;
const
platform
::
DeviceContext
*
ctx
;
const
framework
::
Scope
*
scope
;
// Variable name.
std
::
string
name
;
// RPC method name.
std
::
string
method
;
std
::
string
String
()
const
{
std
::
ostringstream
s
;
s
<<
method
<<
" name:["
<<
name
<<
"], ep:["
<<
ep
<<
"]"
;
return
s
.
str
();
}
};
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
class
BaseProcessor
{
class
BaseProcessor
{
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
0 → 100644
浏览文件 @
3a6213f4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
.
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
.
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
#ifdef PADDLE_WITH_CUDA
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
.
set_type
(
::
sendrecv
::
NCCL_ID
);
#endif
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Free
(
cuda_pinned
,
backing
);
};
#endif
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
NCCL_UNIQUE_ID_BYTES
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
e
.
WriteRawBytes
(
std
::
string
(
uid
.
internal
,
NCCL_UNIQUE_ID_BYTES
));
// for serialize NCCL_ID
::
grpc
::
Slice
slices
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
.
begin
()),
e
.
data
(),
e
.
size
());
::
grpc
::
ByteBuffer
tmp
(
&
slices
,
1
);
msg
->
Swap
(
&
tmp
);
return
;
}
#endif
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
slices
[
0
]
=
::
grpc
::
Slice
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
0
].
begin
()),
e
.
data
(),
e
.
size
());
slices
[
1
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
payload
,
payload_size
,
destroy_callback
,
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
slices
[
3
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
[](
void
*
backing
)
{},
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
num_slices
=
4
;
}
::
grpc
::
ByteBuffer
tmp
(
&
slices
[
0
],
num_slices
);
msg
->
Swap
(
&
tmp
);
}
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
)
{
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/grpc_serde.h
0 → 100644
浏览文件 @
3a6213f4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
typedef
void
(
*
DestroyCallback
)(
void
*
);
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_varname
=
std
::
string
());
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/grpc_serde_test.cc
浏览文件 @
3a6213f4
...
@@ -21,8 +21,10 @@ limitations under the License. */
...
@@ -21,8 +21,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
...
@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
...
@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
VariableResponse
resp
(
&
scope
,
&
ctx
);
operators
::
distributed
::
GRPC
VariableResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
framework
::
Variable
*
var2
=
resp
.
GetVar
();
...
@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy
// deserialize zero-copy
framework
::
Scope
scope
;
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
VariableResponse
resp
(
&
scope
,
&
ctx
);
operators
::
distributed
::
GRPC
VariableResponse
resp
(
&
scope
,
&
ctx
);
if
(
from_type
==
0
)
{
if
(
from_type
==
0
)
{
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
}
else
{
}
else
{
...
...
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
3a6213f4
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <limits>
#include <limits>
#include <string>
#include <string>
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
using
::
grpc
::
ServerAsyncResponseWriter
;
using
::
grpc
::
ServerAsyncResponseWriter
;
...
@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase {
...
@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPC
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync_mode
()));
!
request_handler
->
sync_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
...
@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase {
...
@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase {
protected:
protected:
sendrecv
::
VoidMessage
reply_
;
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
GRPC
VariableResponse
>
request_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
};
...
@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase {
...
@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase {
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
local_scope_
(
nullptr
)
{
local_scope_
(
nullptr
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPC
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
true
));
request_handler
->
dev_ctx
(),
true
));
int
method_id
=
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kPrefetchVariable
);
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kPrefetchVariable
);
...
@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase {
...
@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase {
}
}
protected:
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
GRPC
VariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
local_scope_
;
framework
::
Scope
*
local_scope_
;
...
@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPC
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
()));
request_handler
->
dev_ctx
()));
int
method_id
=
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kCheckpointNotify
);
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kCheckpointNotify
);
...
@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase {
}
}
protected:
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
GRPC
VariableResponse
>
request_
;
sendrecv
::
VoidMessage
reply_
;
sendrecv
::
VoidMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
};
...
...
paddle/fluid/operators/distributed/grpc_service.h
浏览文件 @
3a6213f4
...
@@ -23,8 +23,7 @@
...
@@ -23,8 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow
// NOTE: This method was originally created by tensorflow
...
@@ -42,17 +41,18 @@ class ServerContext;
...
@@ -42,17 +41,18 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
// Wire-format is identical to RecvVariableResponse.
template
<
>
template
<
>
class
SerializationTraits
<
paddle
::
operators
::
distributed
::
VariableResponse
>
{
class
SerializationTraits
<
paddle
::
operators
::
distributed
::
GRPCVariableResponse
>
{
public:
public:
static
Status
Serialize
(
static
Status
Serialize
(
const
paddle
::
operators
::
distributed
::
VariableResponse
&
msg
,
const
paddle
::
operators
::
distributed
::
GRPC
VariableResponse
&
msg
,
grpc_byte_buffer
**
bp
,
bool
*
own_buffer
)
{
grpc_byte_buffer
**
bp
,
bool
*
own_buffer
)
{
PADDLE_ENFORCE
(
false
,
"SerializationTraits::Serialize not implemented!"
);
PADDLE_ENFORCE
(
false
,
"SerializationTraits::Serialize not implemented!"
);
return
Status
();
return
Status
();
}
}
static
Status
Deserialize
(
static
Status
Deserialize
(
grpc_byte_buffer
*
buffer
,
grpc_byte_buffer
*
buffer
,
paddle
::
operators
::
distributed
::
VariableResponse
*
msg
,
paddle
::
operators
::
distributed
::
GRPC
VariableResponse
*
msg
,
int
max_message_size
=
INT_MAX
)
{
int
max_message_size
=
INT_MAX
)
{
if
(
buffer
==
nullptr
)
{
if
(
buffer
==
nullptr
)
{
return
Status
(
StatusCode
::
INTERNAL
,
"No payload"
);
return
Status
(
StatusCode
::
INTERNAL
,
"No payload"
);
...
...
paddle/fluid/operators/distributed/grpc_variable_response.cc
0 → 100644
浏览文件 @
3a6213f4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
enum
WireType
{
WIRETYPE_VARINT
=
0
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
};
inline
int
GetTagFieldNumber
(
uint32_t
tag
)
{
return
tag
>>
3
;
}
inline
WireType
GetTagWireType
(
uint32_t
tag
)
{
return
static_cast
<
WireType
>
(
tag
&
0x7
);
}
bool
ReadVarintSizeAsInt
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int
*
result
)
{
uint64_t
v
;
if
(
input
->
ReadVarint64
(
&
v
)
&&
v
<=
static_cast
<
uint64_t
>
(
INT_MAX
))
{
*
result
=
static_cast
<
int
>
(
v
);
return
true
;
}
else
{
return
false
;
}
}
int
GRPCVariableResponse
::
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
)
{
GrpcByteBufferSource
source
;
source
.
Init
(
byte_buffer
);
GrpcByteBufferSourceWrapper
r
(
&
source
);
return
Parse
(
&
r
);
}
bool
ParseLodData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
std
::
vector
<
int64_t
>*
lod
)
{
while
(
true
)
{
auto
p
=
input
->
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
return
(
tag
==
0
);
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage_LodData
::
kLodDataFieldNumber
:
{
uint64_t
v
;
if
(
wt
==
WIRETYPE_VARINT
)
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
}
lod
->
push_back
(
v
);
break
;
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
->
CurrentPosition
();
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
tag
;
}
lod
->
push_back
(
v
);
}
break
;
}
return
false
;
}
default:
{
return
false
;
}
}
}
return
true
;
}
int
GRPCVariableResponse
::
Parse
(
Source
*
source
)
{
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
::
google
::
protobuf
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
true
)
{
auto
p
=
input
.
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
if
(
tag
!=
0
)
{
return
-
1
;
}
return
0
;
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage
::
kVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint32_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_type
(
static_cast
<::
sendrecv
::
VarType
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint32_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_data_type
(
static_cast
<::
sendrecv
::
VariableMessage_Type
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDimsFieldNumber
:
{
// not packed
if
(
wt
==
WIRETYPE_VARINT
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
break
;
}
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_lod_level
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kLodFieldNumber
:
{
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
std
::
pair
<::
google
::
protobuf
::
io
::
CodedInputStream
::
Limit
,
int
>
p
=
input
.
IncrementRecursionDepthAndPushLimit
(
length
);
std
::
vector
<
int64_t
>
lod_data
;
if
(
p
.
second
<
0
||
!
ParseLodData
(
&
input
,
&
lod_data
))
{
return
tag
;
}
if
(
!
input
.
DecrementRecursionDepthAndPopLimit
(
p
.
first
))
{
return
tag
;
}
if
(
lod_data
.
size
()
==
0
)
{
break
;
}
auto
lod
=
meta_
.
add_lod
();
for
(
uint32_t
i
=
0
;
i
<
lod_data
.
size
();
i
++
)
{
lod
->
add_lod_data
(
lod_data
[
i
]);
}
break
;
}
case
sendrecv
::
VariableMessage
::
kSlrHeightFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_slr_height
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
:
{
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
ProcSerializedField
(
tag
,
&
input
,
num_bytes
))
{
return
tag
;
}
break
;
}
case
sendrecv
::
VariableMessage
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
}
break
;
}
case
sendrecv
::
VariableMessage
::
kOutVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_out_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kProfileFieldNumber
:
{
uint64_t
profiling
=
0
;
if
(
!
input
.
ReadVarint64
(
&
profiling
))
{
return
tag
;
}
meta_
.
set_profile
(
profiling
);
int64_t
listener_id
=
platform
::
ListenerId
();
if
(
listener_id
<=
0
)
{
break
;
}
if
(
profiling
==
platform
::
kEnableProfiler
&&
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
platform
::
IsProfileEnabled
())
{
// TODO(panyx0718): Should we allow to customize file dir.
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"/tmp/profile_ps_%lld"
,
listener_id
));
}
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
}
}
}
return
0
;
}
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/grpc_variable_response.h
0 → 100644
浏览文件 @
3a6213f4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
GRPCVariableResponse
:
public
VariableResponse
{
public:
GRPCVariableResponse
(
const
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
bool
create_scope
=
false
)
:
VariableResponse
(
scope
,
dev_ctx
,
create_scope
)
{}
virtual
~
GRPCVariableResponse
()
{}
int
Parse
(
Source
*
source
)
override
;
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
};
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
3a6213f4
...
@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
...
@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class
RPCServer
;
class
RPCServer
;
struct
VarHandle
{
// RPC endpoint.
std
::
string
ep
;
const
platform
::
DeviceContext
*
ctx
;
const
framework
::
Scope
*
scope
;
// Variable name.
std
::
string
name
;
// RPC method name.
std
::
string
method
;
std
::
string
String
()
const
{
std
::
ostringstream
s
;
s
<<
method
<<
" name:["
<<
name
<<
"], ep:["
<<
ep
<<
"]"
;
return
s
.
str
();
}
};
class
RequestHandler
{
class
RequestHandler
{
public:
public:
explicit
RequestHandler
(
bool
sync_mode
)
explicit
RequestHandler
(
bool
sync_mode
)
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
3a6213f4
...
@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Sync
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv
batch barrier message
"
;
VLOG
(
3
)
<<
"sync: recv
BATCH_BARRIER_MESSAGE
"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
BEGIN_PASS_MESSAGE
)
{
}
else
if
(
varname
==
BEGIN_PASS_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv begin pass message"
;
VLOG
(
3
)
<<
"sync: recv begin pass message"
;
...
@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
if
(
invar
==
nullptr
)
{
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
PADDLE_THROW
(
"sync: Can not find server side var"
);
return
false
;
return
false
;
}
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
...
...
paddle/fluid/operators/distributed/send_recv.proto
→
paddle/fluid/operators/distributed/send_recv.proto
.in
浏览文件 @
3a6213f4
/*
Copyright
(
c
)
2016
PaddlePaddle
Authors
.
All
Rights
Reserve
.
Licensed
under
/*
Copyright
(
c
)
2016
PaddlePaddle
Authors
.
All
Rights
Reserve
.
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
you
may
not
use
this
file
the
Apache
License
,
Version
2.0
(
the
"License"
);
you
may
not
use
this
file
except
in
compliance
with
the
License
.
except
in
compliance
with
the
License
.
...
@@ -14,7 +15,7 @@ limitations under the License. */
...
@@ -14,7 +15,7 @@ limitations under the License. */
syntax
=
"proto3"
;
syntax
=
"proto3"
;
package
sendrecv
;
package
sendrecv
;
// option cc_generic_services = true
;
option
cc_generic_services
=
@
cc_generic_services
@
;
service
SendRecvService
{
service
SendRecvService
{
//
For
parameter
server
round
-
robin
like
hashing
,
do
not
split
tensors
.
//
For
parameter
server
round
-
robin
like
hashing
,
do
not
split
tensors
.
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.cc
浏览文件 @
3a6213f4
...
@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#include <nccl.h>
#endif
#endif
#include <sys/time.h>
#include <sys/time.h>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -34,6 +28,11 @@ namespace distributed {
...
@@ -34,6 +28,11 @@ namespace distributed {
using
VarMsg
=
sendrecv
::
VariableMessage
;
using
VarMsg
=
sendrecv
::
VariableMessage
;
void
*
GetVarPayLoad
(
const
std
::
string
varname
,
int64_t
size
)
{
platform
::
CUDAPinnedPlace
cuda_pinned
;
return
memory
::
Alloc
(
cuda_pinned
,
size
);
}
void
GetTensorPayload
(
framework
::
Variable
*
var
,
void
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
void
**
payload
,
size_t
*
payload_size
)
{
...
@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var,
...
@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var,
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CUDAPinnedPlace
cuda_pinned
;
//
platform::CUDAPinnedPlace cuda_pinned;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
*
payload
=
memory
::
Alloc
(
cuda_pinned
,
copy_size
);
*
payload
=
GetVarPayLoad
(
request
->
varname
()
,
copy_size
);
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Copy
(
cuda_pinned
,
*
payload
,
memory
::
Copy
(
cuda_pinned
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
#endif
#endif
}
else
{
}
else
{
...
@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var,
...
@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var,
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
CUDAPinnedPlace
cuda_pinned
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload
=
memory
::
Alloc
(
cuda_pinned
,
copy_size
);
*
payload
=
GetVarPayLoad
(
request
->
varname
(),
copy_size
);
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Copy
(
cuda_pinned
,
*
payload
,
memory
::
Copy
(
cuda_pinned
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
...
@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var,
...
@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var,
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
}
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
.
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
.
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
#ifdef PADDLE_WITH_CUDA
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
.
set_type
(
::
sendrecv
::
NCCL_ID
);
#endif
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Free
(
cuda_pinned
,
backing
);
};
#endif
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
NCCL_UNIQUE_ID_BYTES
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
e
.
WriteRawBytes
(
std
::
string
(
uid
.
internal
,
NCCL_UNIQUE_ID_BYTES
));
// for serialize NCCL_ID
::
grpc
::
Slice
slices
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
.
begin
()),
e
.
data
(),
e
.
size
());
::
grpc
::
ByteBuffer
tmp
(
&
slices
,
1
);
msg
->
Swap
(
&
tmp
);
return
;
}
#endif
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
slices
[
0
]
=
::
grpc
::
Slice
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
0
].
begin
()),
e
.
data
(),
e
.
size
());
slices
[
1
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
payload
,
payload_size
,
destroy_callback
,
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
slices
[
3
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
[](
void
*
backing
)
{},
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
num_slices
=
4
;
}
::
grpc
::
ByteBuffer
tmp
(
&
slices
[
0
],
num_slices
);
msg
->
Swap
(
&
tmp
);
}
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
)
{
operators
::
distributed
::
VariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
}
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/sendrecvop_utils.h
浏览文件 @
3a6213f4
...
@@ -25,24 +25,21 @@ limitations under the License. */
...
@@ -25,24 +25,21 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
typedef
void
(
*
DestroyCallback
)(
void
*
)
;
using
VarMsg
=
sendrecv
::
VariableMessage
;
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
::
grpc
::
ByteBuffer
*
msg
,
void
**
payload
,
size_t
*
payload_size
);
const
std
::
string
&
out_varname
=
std
::
string
());
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
void
GetSelectedRowsPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
const
framework
::
Scope
*
scope
,
void
**
payload
,
size_t
*
payload_size
);
framework
::
Variable
**
var
);
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
...
...
paddle/fluid/operators/distributed/variable_response.cc
浏览文件 @
3a6213f4
...
@@ -13,50 +13,20 @@
...
@@ -13,50 +13,20 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include <string>
#include <utility>
#include <vector>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
enum
WireType
{
bool
VariableResponse
::
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
WIRETYPE_VARINT
=
0
,
const
platform
::
DeviceContext
&
dev_ctx
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
platform
::
Place
place
,
void
*
dest
,
};
int64_t
size
)
{
inline
int
GetTagFieldNumber
(
uint32_t
tag
)
{
return
tag
>>
3
;
}
inline
WireType
GetTagWireType
(
uint32_t
tag
)
{
return
static_cast
<
WireType
>
(
tag
&
0x7
);
}
bool
ReadVarintSizeAsInt
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int
*
result
)
{
uint64_t
v
;
if
(
input
->
ReadVarint64
(
&
v
)
&&
v
<=
static_cast
<
uint64_t
>
(
INT_MAX
))
{
*
result
=
static_cast
<
int
>
(
v
);
return
true
;
}
else
{
return
false
;
}
}
bool
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
int
size
)
{
const
void
*
data
=
NULL
;
const
void
*
data
=
NULL
;
int
size_to_write
=
0
;
int
size_to_write
=
0
;
int
length
=
size
;
int
64_t
length
=
size
;
int
total_written
=
0
;
int
total_written
=
0
;
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
platform
::
is_gpu_place
(
place
))
{
...
@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData(
...
@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData(
return
true
;
return
true
;
}
}
bool
ParseLodData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
bool
VariableResponse
::
ProcSerializedField
(
std
::
vector
<
int64_t
>*
lod
)
{
int
tag
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
while
(
true
)
{
int64_t
num_bytes
)
{
auto
p
=
input
->
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
return
(
tag
==
0
);
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage_LodData
::
kLodDataFieldNumber
:
{
uint64_t
v
;
if
(
wt
==
WIRETYPE_VARINT
)
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
}
lod
->
push_back
(
v
);
break
;
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
->
CurrentPosition
();
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
tag
;
}
lod
->
push_back
(
v
);
}
break
;
}
return
false
;
}
default:
{
return
false
;
}
}
}
return
true
;
}
int
VariableResponse
::
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
)
{
GrpcByteBufferSource
source
;
source
.
Init
(
byte_buffer
);
GrpcByteBufferSourceWrapper
r
(
&
source
);
return
Parse
(
&
r
);
}
int
VariableResponse
::
Parse
(
Source
*
source
)
{
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
::
google
::
protobuf
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
true
)
{
auto
p
=
input
.
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
if
(
tag
!=
0
)
{
return
-
1
;
}
return
0
;
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage
::
kVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint32_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_type
(
static_cast
<::
sendrecv
::
VarType
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint32_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_data_type
(
static_cast
<::
sendrecv
::
VariableMessage_Type
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDimsFieldNumber
:
{
// not packed
if
(
wt
==
WIRETYPE_VARINT
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
break
;
}
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_lod_level
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kLodFieldNumber
:
{
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
std
::
pair
<::
google
::
protobuf
::
io
::
CodedInputStream
::
Limit
,
int
>
p
=
input
.
IncrementRecursionDepthAndPushLimit
(
length
);
std
::
vector
<
int64_t
>
lod_data
;
if
(
p
.
second
<
0
||
!
ParseLodData
(
&
input
,
&
lod_data
))
{
return
tag
;
}
if
(
!
input
.
DecrementRecursionDepthAndPopLimit
(
p
.
first
))
{
return
false
;
}
if
(
lod_data
.
size
()
==
0
)
{
break
;
}
auto
lod
=
meta_
.
add_lod
();
for
(
uint32_t
i
=
0
;
i
<
lod_data
.
size
();
i
++
)
{
lod
->
add_lod_data
(
lod_data
[
i
]);
}
break
;
}
case
sendrecv
::
VariableMessage
::
kSlrHeightFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_slr_height
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
||
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
&&
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
&&
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
{
if
(
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
auto
*
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
var
=
scope_
->
FindVar
(
meta_
.
varname
());
if
(
var
!=
nullptr
)
{
if
(
var
!=
nullptr
)
{
ncclUniqueId
*
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
ncclUniqueId
*
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
if
(
!
ReadRaw
(
&
input
,
*
dev_ctx_
,
platform
::
CPUPlace
(),
id
->
internal
,
if
(
!
ReadRaw
(
input
,
*
dev_ctx_
,
platform
::
CPUPlace
(),
id
->
internal
,
num_bytes
))
{
num_bytes
))
{
return
tag
;
return
false
;
}
}
}
}
break
;
return
true
;
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA!"
);
PADDLE_THROW
(
"Not compiled with CUDA!"
);
return
false
;
#endif
#endif
}
}
framework
::
DDim
dims
=
GetDims
(
meta_
.
dims
());
framework
::
DDim
dims
=
GetDims
(
meta_
.
dims
());
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
false
;
return
tag
;
}
}
break
;
return
true
;
}
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
if
(
!
CopySelectRowsTensorData
(
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
false
;
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
}
break
;
}
case
sendrecv
::
VariableMessage
::
kOutVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_out_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kProfileFieldNumber
:
{
uint64_t
profiling
=
0
;
if
(
!
input
.
ReadVarint64
(
&
profiling
))
{
return
tag
;
}
meta_
.
set_profile
(
profiling
);
int64_t
listener_id
=
platform
::
ListenerId
();
if
(
listener_id
<=
0
)
{
break
;
}
if
(
profiling
==
platform
::
kEnableProfiler
&&
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
platform
::
IsProfileEnabled
())
{
// TODO(panyx0718): Should we allow to customize file dir.
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"/tmp/profile_ps_%lld"
,
listener_id
));
}
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
}
}
}
return
true
;
}
}
return
0
;
return
true
;
}
}
};
// namespace distributed
};
// namespace distributed
...
...
paddle/fluid/operators/distributed/variable_response.h
浏览文件 @
3a6213f4
...
@@ -22,18 +22,35 @@
...
@@ -22,18 +22,35 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/
bytebuffer_stream
.h"
#include "paddle/fluid/operators/distributed/
send_recv.pb
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class
Source
{
public:
virtual
~
Source
()
{}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
=
0
;
};
class
VariableResponse
{
class
VariableResponse
{
public:
public:
VariableResponse
(
const
framework
::
Scope
*
scope
,
VariableResponse
(
const
framework
::
Scope
*
scope
,
...
@@ -51,22 +68,19 @@ class VariableResponse {
...
@@ -51,22 +68,19 @@ class VariableResponse {
}
}
}
}
// return:
int
Parse
(
Source
*
source
,
const
sendrecv
::
VariableMessage
&
meta
)
{
// 0:ok.
meta_
=
meta
;
// -1: unkown error.
return
Parse
(
source
);
// other: number of error field.
}
int
Parse
(
Source
*
source
);
// return:
// return:
// 0:ok.
// 0:ok.
// -1: unkown error.
// -1: unkown error.
// other: number of error field.
// other: number of error field.
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
virtual
int
Parse
(
Source
*
source
)
=
0
;
const
framework
::
Scope
&
GetLocalScope
()
const
{
return
*
local_scope_
;
}
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
const
framework
::
Scope
&
GetLocalScope
()
const
{
return
*
local_scope_
;
}
inline
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
...
@@ -78,7 +92,11 @@ class VariableResponse {
...
@@ -78,7 +92,11 @@ class VariableResponse {
return
scope_
->
FindVar
(
meta_
.
varname
());
return
scope_
->
FindVar
(
meta_
.
varname
());
}
}
private:
protected:
bool
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
int64_t
size
);
bool
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
bool
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
);
const
framework
::
DDim
&
dims
,
int
length
);
...
@@ -90,12 +108,16 @@ class VariableResponse {
...
@@ -90,12 +108,16 @@ class VariableResponse {
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
);
const
framework
::
DDim
&
dims
,
int
length
);
private:
bool
ProcSerializedField
(
int
tag
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int64_t
num_bytes
);
protected:
const
framework
::
Scope
*
scope_
;
const
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
bool
create_scope_
=
false
;
bool
create_scope_
=
false
;
framework
::
Scope
*
local_scope_
=
nullptr
;
framework
::
Scope
*
local_scope_
=
nullptr
;
// only Skeleton
sendrecv
::
VariableMessage
meta_
;
sendrecv
::
VariableMessage
meta_
;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录