Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3a6213f4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a6213f4
编写于
7月 20, 2018
作者:
G
gongweibao
提交者:
GitHub
7月 20, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change grpc interface to compatible with brpc. (#12164)
上级
b0630938
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
748 addition
and
550 deletion
+748
-550
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+30
-20
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc
+1
-1
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h
paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h
+1
-19
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+1
-0
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+3
-17
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+157
-0
paddle/fluid/operators/distributed/grpc_serde.h
paddle/fluid/operators/distributed/grpc_serde.h
+50
-0
paddle/fluid/operators/distributed/grpc_serde_test.cc
paddle/fluid/operators/distributed/grpc_serde_test.cc
+5
-3
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+11
-10
paddle/fluid/operators/distributed/grpc_service.h
paddle/fluid/operators/distributed/grpc_service.h
+5
-5
paddle/fluid/operators/distributed/grpc_variable_response.cc
paddle/fluid/operators/distributed/grpc_variable_response.cc
+308
-0
paddle/fluid/operators/distributed/grpc_variable_response.h
paddle/fluid/operators/distributed/grpc_variable_response.h
+58
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+17
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+2
-3
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+2
-1
paddle/fluid/operators/distributed/sendrecvop_utils.cc
paddle/fluid/operators/distributed/sendrecvop_utils.cc
+13
-131
paddle/fluid/operators/distributed/sendrecvop_utils.h
paddle/fluid/operators/distributed/sendrecvop_utils.h
+7
-10
paddle/fluid/operators/distributed/variable_response.cc
paddle/fluid/operators/distributed/variable_response.cc
+38
-313
paddle/fluid/operators/distributed/variable_response.h
paddle/fluid/operators/distributed/variable_response.h
+39
-17
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
3a6213f4
if
(
NOT WITH_DISTRIBUTE
)
return
()
endif
()
if
(
WITH_GRPC
)
set
(
cc_generic_services
"false"
)
else
()
set
(
cc_generic_services
"true"
)
endif
()
configure_file
(
send_recv.proto.in
${
CMAKE_CURRENT_SOURCE_DIR
}
/send_recv.proto @ONLY
)
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
grpc_library
(
sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
proto_desc lookup_table_op SERIAL
)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
grpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL
)
return
()
endif
()
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
find_library
(
OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so
)
ADD_LIBRARY
(
crypto SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
OPENSSL_CRYPTO_LIBRARY_STATIC
}
)
set
(
brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy
)
find_library
(
OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so
)
ADD_LIBRARY
(
ssl SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
OPENSSL_SSL_LIBRARY_STATIC
}
)
cc_test
(
brpc_server_test SRCS rpc_server_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
cc_test
(
brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
brpc protobuf leveldb gflags glog
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL
)
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
paddle/fluid/operators/distributed/bytebuffer_stream.cc
→
paddle/fluid/operators/distributed/
grpc_
bytebuffer_stream.cc
浏览文件 @
3a6213f4
...
...
@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/
grpc_
bytebuffer_stream.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/distributed/bytebuffer_stream.h
→
paddle/fluid/operators/distributed/
grpc_
bytebuffer_stream.h
浏览文件 @
3a6213f4
...
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace
grpc
{
// A ZeroCopyInputStream that reads from grpc_byte_buffer
...
...
@@ -107,25 +108,6 @@ class GrpcBufferReader final
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class
Source
{
public:
virtual
~
Source
()
{}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
=
0
;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class
GrpcByteBufferSource
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
3a6213f4
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h"
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
3a6213f4
...
...
@@ -38,7 +38,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...
...
@@ -46,23 +49,6 @@ namespace paddle {
namespace
operators
{
namespace
distributed
{
struct
VarHandle
{
// RPC endpoint.
std
::
string
ep
;
const
platform
::
DeviceContext
*
ctx
;
const
framework
::
Scope
*
scope
;
// Variable name.
std
::
string
name
;
// RPC method name.
std
::
string
method
;
std
::
string
String
()
const
{
std
::
ostringstream
s
;
s
<<
method
<<
" name:["
<<
name
<<
"], ep:["
<<
ep
<<
"]"
;
return
s
.
str
();
}
};
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
class
BaseProcessor
{
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
0 → 100644
浏览文件 @
3a6213f4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
.
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
.
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
#ifdef PADDLE_WITH_CUDA
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
.
set_type
(
::
sendrecv
::
NCCL_ID
);
#endif
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Free
(
cuda_pinned
,
backing
);
};
#endif
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
NCCL_UNIQUE_ID_BYTES
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
e
.
WriteRawBytes
(
std
::
string
(
uid
.
internal
,
NCCL_UNIQUE_ID_BYTES
));
// for serialize NCCL_ID
::
grpc
::
Slice
slices
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
.
begin
()),
e
.
data
(),
e
.
size
());
::
grpc
::
ByteBuffer
tmp
(
&
slices
,
1
);
msg
->
Swap
(
&
tmp
);
return
;
}
#endif
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
slices
[
0
]
=
::
grpc
::
Slice
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
0
].
begin
()),
e
.
data
(),
e
.
size
());
slices
[
1
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
payload
,
payload_size
,
destroy_callback
,
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
slices
[
3
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
[](
void
*
backing
)
{},
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
num_slices
=
4
;
}
::
grpc
::
ByteBuffer
tmp
(
&
slices
[
0
],
num_slices
);
msg
->
Swap
(
&
tmp
);
}
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
)
{
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/grpc_serde.h
0 → 100644
浏览文件 @
3a6213f4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
typedef
void
(
*
DestroyCallback
)(
void
*
);
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_varname
=
std
::
string
());
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/grpc_serde_test.cc
浏览文件 @
3a6213f4
...
...
@@ -21,8 +21,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
...
...
@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
VariableResponse
resp
(
&
scope
,
&
ctx
);
operators
::
distributed
::
GRPC
VariableResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
...
...
@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
VariableResponse
resp
(
&
scope
,
&
ctx
);
operators
::
distributed
::
GRPC
VariableResponse
resp
(
&
scope
,
&
ctx
);
if
(
from_type
==
0
)
{
EXPECT_EQ
(
resp
.
Parse
(
msg
),
0
);
}
else
{
...
...
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
3a6213f4
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <limits>
#include <string>
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_server.h"
using
::
grpc
::
ServerAsyncResponseWriter
;
...
...
@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPC
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
...
...
@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase {
protected:
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
GRPC
VariableResponse
>
request_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
...
...
@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase {
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
),
local_scope_
(
nullptr
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPC
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
true
));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kPrefetchVariable
);
...
...
@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase {
}
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
GRPC
VariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
local_scope_
;
...
...
@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPC
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kCheckpointNotify
);
...
...
@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase {
}
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
GRPC
VariableResponse
>
request_
;
sendrecv
::
VoidMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
...
...
paddle/fluid/operators/distributed/grpc_service.h
浏览文件 @
3a6213f4
...
...
@@ -23,8 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow
...
...
@@ -42,17 +41,18 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template
<
>
class
SerializationTraits
<
paddle
::
operators
::
distributed
::
VariableResponse
>
{
class
SerializationTraits
<
paddle
::
operators
::
distributed
::
GRPCVariableResponse
>
{
public:
static
Status
Serialize
(
const
paddle
::
operators
::
distributed
::
VariableResponse
&
msg
,
const
paddle
::
operators
::
distributed
::
GRPC
VariableResponse
&
msg
,
grpc_byte_buffer
**
bp
,
bool
*
own_buffer
)
{
PADDLE_ENFORCE
(
false
,
"SerializationTraits::Serialize not implemented!"
);
return
Status
();
}
static
Status
Deserialize
(
grpc_byte_buffer
*
buffer
,
paddle
::
operators
::
distributed
::
VariableResponse
*
msg
,
paddle
::
operators
::
distributed
::
GRPC
VariableResponse
*
msg
,
int
max_message_size
=
INT_MAX
)
{
if
(
buffer
==
nullptr
)
{
return
Status
(
StatusCode
::
INTERNAL
,
"No payload"
);
...
...
paddle/fluid/operators/distributed/grpc_variable_response.cc
0 → 100644
浏览文件 @
3a6213f4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
enum
WireType
{
WIRETYPE_VARINT
=
0
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
};
inline
int
GetTagFieldNumber
(
uint32_t
tag
)
{
return
tag
>>
3
;
}
inline
WireType
GetTagWireType
(
uint32_t
tag
)
{
return
static_cast
<
WireType
>
(
tag
&
0x7
);
}
bool
ReadVarintSizeAsInt
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int
*
result
)
{
uint64_t
v
;
if
(
input
->
ReadVarint64
(
&
v
)
&&
v
<=
static_cast
<
uint64_t
>
(
INT_MAX
))
{
*
result
=
static_cast
<
int
>
(
v
);
return
true
;
}
else
{
return
false
;
}
}
int
GRPCVariableResponse
::
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
)
{
GrpcByteBufferSource
source
;
source
.
Init
(
byte_buffer
);
GrpcByteBufferSourceWrapper
r
(
&
source
);
return
Parse
(
&
r
);
}
bool
ParseLodData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
std
::
vector
<
int64_t
>*
lod
)
{
while
(
true
)
{
auto
p
=
input
->
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
return
(
tag
==
0
);
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage_LodData
::
kLodDataFieldNumber
:
{
uint64_t
v
;
if
(
wt
==
WIRETYPE_VARINT
)
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
}
lod
->
push_back
(
v
);
break
;
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
->
CurrentPosition
();
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
tag
;
}
lod
->
push_back
(
v
);
}
break
;
}
return
false
;
}
default:
{
return
false
;
}
}
}
return
true
;
}
int
GRPCVariableResponse
::
Parse
(
Source
*
source
)
{
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
::
google
::
protobuf
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
true
)
{
auto
p
=
input
.
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
if
(
tag
!=
0
)
{
return
-
1
;
}
return
0
;
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage
::
kVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint32_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_type
(
static_cast
<::
sendrecv
::
VarType
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint32_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_data_type
(
static_cast
<::
sendrecv
::
VariableMessage_Type
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDimsFieldNumber
:
{
// not packed
if
(
wt
==
WIRETYPE_VARINT
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
break
;
}
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_lod_level
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kLodFieldNumber
:
{
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
std
::
pair
<::
google
::
protobuf
::
io
::
CodedInputStream
::
Limit
,
int
>
p
=
input
.
IncrementRecursionDepthAndPushLimit
(
length
);
std
::
vector
<
int64_t
>
lod_data
;
if
(
p
.
second
<
0
||
!
ParseLodData
(
&
input
,
&
lod_data
))
{
return
tag
;
}
if
(
!
input
.
DecrementRecursionDepthAndPopLimit
(
p
.
first
))
{
return
tag
;
}
if
(
lod_data
.
size
()
==
0
)
{
break
;
}
auto
lod
=
meta_
.
add_lod
();
for
(
uint32_t
i
=
0
;
i
<
lod_data
.
size
();
i
++
)
{
lod
->
add_lod_data
(
lod_data
[
i
]);
}
break
;
}
case
sendrecv
::
VariableMessage
::
kSlrHeightFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_slr_height
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
:
{
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
ProcSerializedField
(
tag
,
&
input
,
num_bytes
))
{
return
tag
;
}
break
;
}
case
sendrecv
::
VariableMessage
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
}
break
;
}
case
sendrecv
::
VariableMessage
::
kOutVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_out_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kProfileFieldNumber
:
{
uint64_t
profiling
=
0
;
if
(
!
input
.
ReadVarint64
(
&
profiling
))
{
return
tag
;
}
meta_
.
set_profile
(
profiling
);
int64_t
listener_id
=
platform
::
ListenerId
();
if
(
listener_id
<=
0
)
{
break
;
}
if
(
profiling
==
platform
::
kEnableProfiler
&&
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
platform
::
IsProfileEnabled
())
{
// TODO(panyx0718): Should we allow to customize file dir.
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"/tmp/profile_ps_%lld"
,
listener_id
));
}
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
}
}
}
return
0
;
}
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/grpc_variable_response.h
0 → 100644
浏览文件 @
3a6213f4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
GRPCVariableResponse
:
public
VariableResponse
{
public:
GRPCVariableResponse
(
const
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
bool
create_scope
=
false
)
:
VariableResponse
(
scope
,
dev_ctx
,
create_scope
)
{}
virtual
~
GRPCVariableResponse
()
{}
int
Parse
(
Source
*
source
)
override
;
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
};
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
3a6213f4
...
...
@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class
RPCServer
;
struct
VarHandle
{
// RPC endpoint.
std
::
string
ep
;
const
platform
::
DeviceContext
*
ctx
;
const
framework
::
Scope
*
scope
;
// Variable name.
std
::
string
name
;
// RPC method name.
std
::
string
method
;
std
::
string
String
()
const
{
std
::
ostringstream
s
;
s
<<
method
<<
" name:["
<<
name
<<
"], ep:["
<<
ep
<<
"]"
;
return
s
.
str
();
}
};
class
RequestHandler
{
public:
explicit
RequestHandler
(
bool
sync_mode
)
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
3a6213f4
...
...
@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv
batch barrier message
"
;
VLOG
(
3
)
<<
"sync: recv
BATCH_BARRIER_MESSAGE
"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
BEGIN_PASS_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv begin pass message"
;
...
...
@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
PADDLE_THROW
(
"sync: Can not find server side var"
);
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
...
...
paddle/fluid/operators/distributed/send_recv.proto
→
paddle/fluid/operators/distributed/send_recv.proto
.in
浏览文件 @
3a6213f4
/*
Copyright
(
c
)
2016
PaddlePaddle
Authors
.
All
Rights
Reserve
.
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
you
may
not
use
this
file
except
in
compliance
with
the
License
.
...
...
@@ -14,7 +15,7 @@ limitations under the License. */
syntax
=
"proto3"
;
package
sendrecv
;
// option cc_generic_services = true
;
option
cc_generic_services
=
@
cc_generic_services
@
;
service
SendRecvService
{
//
For
parameter
server
round
-
robin
like
hashing
,
do
not
split
tensors
.
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.cc
浏览文件 @
3a6213f4
...
...
@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -34,6 +28,11 @@ namespace distributed {
using
VarMsg
=
sendrecv
::
VariableMessage
;
void
*
GetVarPayLoad
(
const
std
::
string
varname
,
int64_t
size
)
{
platform
::
CUDAPinnedPlace
cuda_pinned
;
return
memory
::
Alloc
(
cuda_pinned
,
size
);
}
void
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
...
...
@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var,
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CUDAPinnedPlace
cuda_pinned
;
//
platform::CUDAPinnedPlace cuda_pinned;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
*
payload
=
memory
::
Alloc
(
cuda_pinned
,
copy_size
);
*
payload
=
GetVarPayLoad
(
request
->
varname
()
,
copy_size
);
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Copy
(
cuda_pinned
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
#endif
}
else
{
...
...
@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var,
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
CUDAPinnedPlace
cuda_pinned
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload
=
memory
::
Alloc
(
cuda_pinned
,
copy_size
);
*
payload
=
GetVarPayLoad
(
request
->
varname
(),
copy_size
);
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Copy
(
cuda_pinned
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
...
...
@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var,
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
.
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
.
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
#ifdef PADDLE_WITH_CUDA
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
.
set_type
(
::
sendrecv
::
NCCL_ID
);
#endif
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CUDAPinnedPlace
cuda_pinned
;
memory
::
Free
(
cuda_pinned
,
backing
);
};
#endif
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
NCCL_UNIQUE_ID_BYTES
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
e
.
WriteRawBytes
(
std
::
string
(
uid
.
internal
,
NCCL_UNIQUE_ID_BYTES
));
// for serialize NCCL_ID
::
grpc
::
Slice
slices
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
.
begin
()),
e
.
data
(),
e
.
size
());
::
grpc
::
ByteBuffer
tmp
(
&
slices
,
1
);
msg
->
Swap
(
&
tmp
);
return
;
}
#endif
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
slices
[
0
]
=
::
grpc
::
Slice
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
0
].
begin
()),
e
.
data
(),
e
.
size
());
slices
[
1
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
payload
,
payload_size
,
destroy_callback
,
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
slices
[
3
]
=
::
grpc
::
Slice
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
[](
void
*
backing
)
{},
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
num_slices
=
4
;
}
::
grpc
::
ByteBuffer
tmp
(
&
slices
[
0
],
num_slices
);
msg
->
Swap
(
&
tmp
);
}
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
)
{
operators
::
distributed
::
VariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/sendrecvop_utils.h
浏览文件 @
3a6213f4
...
...
@@ -25,24 +25,21 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
typedef
void
(
*
DestroyCallback
)(
void
*
)
;
using
VarMsg
=
sendrecv
::
VariableMessage
;
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_varname
=
std
::
string
());
void
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
);
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
);
void
GetSelectedRowsPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
);
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
...
...
paddle/fluid/operators/distributed/variable_response.cc
浏览文件 @
3a6213f4
...
...
@@ -13,50 +13,20 @@
// limitations under the License.
#include "paddle/fluid/operators/distributed/variable_response.h"
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
enum
WireType
{
WIRETYPE_VARINT
=
0
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
};
inline
int
GetTagFieldNumber
(
uint32_t
tag
)
{
return
tag
>>
3
;
}
inline
WireType
GetTagWireType
(
uint32_t
tag
)
{
return
static_cast
<
WireType
>
(
tag
&
0x7
);
}
bool
ReadVarintSizeAsInt
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int
*
result
)
{
uint64_t
v
;
if
(
input
->
ReadVarint64
(
&
v
)
&&
v
<=
static_cast
<
uint64_t
>
(
INT_MAX
))
{
*
result
=
static_cast
<
int
>
(
v
);
return
true
;
}
else
{
return
false
;
}
}
bool
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
int
size
)
{
bool
VariableResponse
::
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
int64_t
size
)
{
const
void
*
data
=
NULL
;
int
size_to_write
=
0
;
int
length
=
size
;
int
64_t
length
=
size
;
int
total_written
=
0
;
if
(
platform
::
is_gpu_place
(
place
))
{
...
...
@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData(
return
true
;
}
bool
ParseLodData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
std
::
vector
<
int64_t
>*
lod
)
{
while
(
true
)
{
auto
p
=
input
->
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
return
(
tag
==
0
);
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage_LodData
::
kLodDataFieldNumber
:
{
uint64_t
v
;
if
(
wt
==
WIRETYPE_VARINT
)
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
}
lod
->
push_back
(
v
);
break
;
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
->
CurrentPosition
();
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
tag
;
}
lod
->
push_back
(
v
);
}
break
;
}
return
false
;
}
default:
{
return
false
;
}
}
}
return
true
;
}
int
VariableResponse
::
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
)
{
GrpcByteBufferSource
source
;
source
.
Init
(
byte_buffer
);
GrpcByteBufferSourceWrapper
r
(
&
source
);
return
Parse
(
&
r
);
}
int
VariableResponse
::
Parse
(
Source
*
source
)
{
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
::
google
::
protobuf
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
true
)
{
auto
p
=
input
.
ReadTagWithCutoff
(
127
);
int
tag
=
GetTagFieldNumber
(
p
.
first
);
WireType
wt
=
GetTagWireType
(
p
.
first
);
if
(
!
p
.
second
)
{
if
(
tag
!=
0
)
{
return
-
1
;
}
return
0
;
}
switch
(
tag
)
{
case
sendrecv
::
VariableMessage
::
kVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint32_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_type
(
static_cast
<::
sendrecv
::
VarType
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint32_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint32
(
&
v
))
{
return
tag
;
}
meta_
.
set_data_type
(
static_cast
<::
sendrecv
::
VariableMessage_Type
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kDimsFieldNumber
:
{
// not packed
if
(
wt
==
WIRETYPE_VARINT
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
break
;
}
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
}
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
add_dims
(
v
);
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_lod_level
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kLodFieldNumber
:
{
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
std
::
pair
<::
google
::
protobuf
::
io
::
CodedInputStream
::
Limit
,
int
>
p
=
input
.
IncrementRecursionDepthAndPushLimit
(
length
);
std
::
vector
<
int64_t
>
lod_data
;
if
(
p
.
second
<
0
||
!
ParseLodData
(
&
input
,
&
lod_data
))
{
return
tag
;
}
if
(
!
input
.
DecrementRecursionDepthAndPopLimit
(
p
.
first
))
{
return
false
;
}
if
(
lod_data
.
size
()
==
0
)
{
break
;
}
auto
lod
=
meta_
.
add_lod
();
for
(
uint32_t
i
=
0
;
i
<
lod_data
.
size
();
i
++
)
{
lod
->
add_lod_data
(
lod_data
[
i
]);
}
break
;
}
case
sendrecv
::
VariableMessage
::
kSlrHeightFieldNumber
:
{
uint64_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
}
meta_
.
set_slr_height
(
static_cast
<
int64_t
>
(
v
));
break
;
}
case
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
:
{
bool
VariableResponse
::
ProcSerializedField
(
int
tag
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int64_t
num_bytes
)
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
||
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
{
#ifdef PADDLE_WITH_CUDA
auto
*
var
=
scope_
->
FindVar
(
meta_
.
varname
());
if
(
var
!=
nullptr
)
{
ncclUniqueId
*
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
if
(
!
ReadRaw
(
&
input
,
*
dev_ctx_
,
platform
::
CPUPlace
(),
id
->
internal
,
if
(
!
ReadRaw
(
input
,
*
dev_ctx_
,
platform
::
CPUPlace
(),
id
->
internal
,
num_bytes
))
{
return
tag
;
return
false
;
}
}
break
;
return
true
;
#else
PADDLE_THROW
(
"Not compiled with CUDA!"
);
return
false
;
#endif
}
framework
::
DDim
dims
=
GetDims
(
meta_
.
dims
());
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
false
;
}
break
;
return
true
;
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
}
break
;
}
return
tag
;
}
case
sendrecv
::
VariableMessage
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
}
break
;
}
case
sendrecv
::
VariableMessage
::
kOutVarnameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_out_varname
(
temp
);
break
;
}
case
sendrecv
::
VariableMessage
::
kProfileFieldNumber
:
{
uint64_t
profiling
=
0
;
if
(
!
input
.
ReadVarint64
(
&
profiling
))
{
return
tag
;
}
meta_
.
set_profile
(
profiling
);
int64_t
listener_id
=
platform
::
ListenerId
();
if
(
listener_id
<=
0
)
{
break
;
}
if
(
profiling
==
platform
::
kEnableProfiler
&&
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
platform
::
IsProfileEnabled
())
{
// TODO(panyx0718): Should we allow to customize file dir.
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"/tmp/profile_ps_%lld"
,
listener_id
));
}
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
}
if
(
!
CopySelectRowsTensorData
(
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
false
;
}
return
true
;
}
return
0
;
return
true
;
}
};
// namespace distributed
...
...
paddle/fluid/operators/distributed/variable_response.h
浏览文件 @
3a6213f4
...
...
@@ -22,18 +22,35 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/
bytebuffer_stream
.h"
#include "paddle/fluid/operators/distributed/
send_recv.pb
.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class
Source
{
public:
virtual
~
Source
()
{}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
=
0
;
};
class
VariableResponse
{
public:
VariableResponse
(
const
framework
::
Scope
*
scope
,
...
...
@@ -51,22 +68,19 @@ class VariableResponse {
}
}
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int
Parse
(
Source
*
source
);
int
Parse
(
Source
*
source
,
const
sendrecv
::
VariableMessage
&
meta
)
{
meta_
=
meta
;
return
Parse
(
source
);
}
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int
Parse
(
const
::
grpc
::
ByteBuffer
&
byte_buffer
);
const
framework
::
Scope
&
GetLocalScope
()
const
{
return
*
local_scope_
;
}
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
virtual
int
Parse
(
Source
*
source
)
=
0
;
inline
const
framework
::
Scope
&
GetLocalScope
()
const
{
return
*
local_scope_
;
}
inline
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
...
...
@@ -78,7 +92,11 @@ class VariableResponse {
return
scope_
->
FindVar
(
meta_
.
varname
());
}
private:
protected:
bool
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
int64_t
size
);
bool
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
);
...
...
@@ -90,12 +108,16 @@ class VariableResponse {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
);
private:
bool
ProcSerializedField
(
int
tag
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
int64_t
num_bytes
);
protected:
const
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
bool
create_scope_
=
false
;
framework
::
Scope
*
local_scope_
=
nullptr
;
// only Skeleton
sendrecv
::
VariableMessage
meta_
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录