未验证 提交 3a6213f4 编写于 作者: G gongweibao 提交者: GitHub

Change grpc interface to compatible with brpc. (#12164)

上级 b0630938
if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC)
set(cc_generic_services "false")
else()
set(cc_generic_services "true")
endif()
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
selected_rows memory) PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(grpc_serde_test SRCS grpc_serde_test.cc
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc cc_test(grpc_server_test SRCS rpc_server_test.cc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
proto_desc lookup_table_op SERIAL)
return() return()
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows memory) DEPS lod_tensor selected_rows memory)
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so) set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so) cc_test(brpc_server_test SRCS rpc_server_test.cc
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL) DEPS ${brpc_test_depends} SERIAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc cc_test(brpc_serde_test SRCS brpc_serde_test.cc
brpc protobuf leveldb gflags glog DEPS ${brpc_test_depends} SERIAL)
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC // file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data. // requests without too much copying of the tensor data.
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace grpc { namespace grpc {
// A ZeroCopyInputStream that reads from grpc_byte_buffer // A ZeroCopyInputStream that reads from grpc_byte_buffer
...@@ -107,25 +108,6 @@ class GrpcBufferReader final ...@@ -107,25 +108,6 @@ class GrpcBufferReader final
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer. // A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource class GrpcByteBufferSource
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG #include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
......
...@@ -38,7 +38,10 @@ limitations under the License. */ ...@@ -38,7 +38,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...@@ -46,23 +49,6 @@ namespace paddle { ...@@ -46,23 +49,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
struct VarHandle {
// RPC endpoint.
std::string ep;
const platform::DeviceContext* ctx;
const framework::Scope* scope;
// Variable name.
std::string name;
// RPC method name.
std::string method;
std::string String() const {
std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]";
return s.str();
}
};
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
void* payload = nullptr;
size_t payload_size;
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CUDAPinnedPlace cuda_pinned;
memory::Free(cuda_pinned, backing);
};
#endif
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
#endif
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
slices[3] = ::grpc::Slice(
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var);
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -21,8 +21,10 @@ limitations under the License. */ ...@@ -21,8 +21,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::distributed::VariableResponse resp(&scope, &ctx); operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar(); framework::Variable* var2 = resp.GetVar();
...@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy // deserialize zero-copy
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::distributed::VariableResponse resp(&scope, &ctx); operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
if (from_type == 0) { if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
} else { } else {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
...@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase { ...@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), request_handler->dev_ctx(),
!request_handler->sync_mode())); !request_handler->sync_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
...@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase { ...@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase {
protected: protected:
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::shared_ptr<VariableResponse> request_; std::shared_ptr<GRPCVariableResponse> request_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
...@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase {
: RequestBase(service, cq, request_handler, req_id), : RequestBase(service, cq, request_handler, req_id),
responder_(&ctx_), responder_(&ctx_),
local_scope_(nullptr) { local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true)); request_handler->dev_ctx(), true));
int method_id = int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable); static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
...@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase {
} }
protected: protected:
std::shared_ptr<VariableResponse> request_; std::shared_ptr<GRPCVariableResponse> request_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* local_scope_; framework::Scope* local_scope_;
...@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx())); request_handler->dev_ctx()));
int method_id = int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify); static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
...@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase {
} }
protected: protected:
std::shared_ptr<VariableResponse> request_; std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
......
...@@ -23,8 +23,7 @@ ...@@ -23,8 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h> #include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h> #include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow // NOTE: This method was originally created by tensorflow
...@@ -42,17 +41,18 @@ class ServerContext; ...@@ -42,17 +41,18 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse. // Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse. // Wire-format is identical to RecvVariableResponse.
template <> template <>
class SerializationTraits<paddle::operators::distributed::VariableResponse> { class SerializationTraits<
paddle::operators::distributed::GRPCVariableResponse> {
public: public:
static Status Serialize( static Status Serialize(
const paddle::operators::distributed::VariableResponse& msg, const paddle::operators::distributed::GRPCVariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) { grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status(); return Status();
} }
static Status Deserialize( static Status Deserialize(
grpc_byte_buffer* buffer, grpc_byte_buffer* buffer,
paddle::operators::distributed::VariableResponse* msg, paddle::operators::distributed::GRPCVariableResponse* msg,
int max_message_size = INT_MAX) { int max_message_size = INT_MAX) {
if (buffer == nullptr) { if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload"); return Status(StatusCode::INTERNAL, "No payload");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
enum WireType {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
int GRPCVariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
std::vector<int64_t>* lod) {
while (true) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return tag;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int GRPCVariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return tag;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!ProcSerializedField(tag, &input, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
case sendrecv::VariableMessage::kProfileFieldNumber: {
uint64_t profiling = 0;
if (!input.ReadVarint64(&profiling)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
}
}
return 0;
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class GRPCVariableResponse : public VariableResponse {
public:
GRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~GRPCVariableResponse() {}
int Parse(Source* source) override;
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
...@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; ...@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class RPCServer; class RPCServer;
struct VarHandle {
// RPC endpoint.
std::string ep;
const platform::DeviceContext* ctx;
const framework::Scope* scope;
// Variable name.
std::string name;
// RPC method name.
std::string method;
std::string String() const {
std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]";
return s.str();
}
};
class RequestHandler { class RequestHandler {
public: public:
explicit RequestHandler(bool sync_mode) explicit RequestHandler(bool sync_mode)
......
...@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Sync // Sync
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == BEGIN_PASS_MESSAGE) { } else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv begin pass message"; VLOG(3) << "sync: recv begin pass message";
...@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG(3) << "sync: processing received var: " << varname; VLOG(3) << "sync: processing received var: " << varname;
if (invar == nullptr) { if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname; LOG(FATAL) << "sync: Can not find server side var: " << varname;
PADDLE_THROW("sync: Can not find server side var");
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) { if (invar->IsType<framework::SelectedRows>()) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License. except in compliance with the License.
...@@ -14,7 +15,7 @@ limitations under the License. */ ...@@ -14,7 +15,7 @@ limitations under the License. */
syntax = "proto3"; syntax = "proto3";
package sendrecv; package sendrecv;
// option cc_generic_services = true; option cc_generic_services = @cc_generic_services@;
service SendRecvService { service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors. // For parameter server round-robin like hashing, do not split tensors.
......
...@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <sys/time.h> #include <sys/time.h>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,6 +28,11 @@ namespace distributed { ...@@ -34,6 +28,11 @@ namespace distributed {
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
void* GetVarPayLoad(const std::string varname, int64_t size) {
platform::CUDAPinnedPlace cuda_pinned;
return memory::Alloc(cuda_pinned, size);
}
void GetTensorPayload(framework::Variable* var, void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request, const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) { void** payload, size_t* payload_size) {
...@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var, ...@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var,
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CUDAPinnedPlace cuda_pinned; // platform::CUDAPinnedPlace cuda_pinned;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx); auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
*payload = memory::Alloc(cuda_pinned, copy_size); *payload = GetVarPayLoad(request->varname(), copy_size);
platform::CUDAPinnedPlace cuda_pinned;
memory::Copy(cuda_pinned, *payload, memory::Copy(cuda_pinned, *payload,
boost::get<platform::CUDAPlace>(tensor.place()), boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size, reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream()); gpu_dev_ctx.stream());
ctx.Wait(); ctx.Wait();
#endif #endif
} else { } else {
...@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var, ...@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var,
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CUDAPinnedPlace cuda_pinned;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx); auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type()); auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
*payload = memory::Alloc(cuda_pinned, copy_size); *payload = GetVarPayLoad(request->varname(), copy_size);
platform::CUDAPinnedPlace cuda_pinned;
memory::Copy(cuda_pinned, *payload, memory::Copy(cuda_pinned, *payload,
boost::get<platform::CUDAPlace>(tensor->place()), boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()), copy_size, reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
...@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var, ...@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var,
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); *payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
} }
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
void* payload = nullptr;
size_t payload_size;
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CUDAPinnedPlace cuda_pinned;
memory::Free(cuda_pinned, backing);
};
#endif
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
#endif
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
slices[3] = ::grpc::Slice(
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::distributed::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -25,24 +25,21 @@ limitations under the License. */ ...@@ -25,24 +25,21 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
typedef void (*DestroyCallback)(void*); using VarMsg = sendrecv::VariableMessage;
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx, VarMsg* request,
::grpc::ByteBuffer* msg, void** payload, size_t* payload_size);
const std::string& out_varname = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx, VarMsg* request,
const framework::Scope* scope, void** payload, size_t* payload_size);
framework::Variable** var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
......
...@@ -13,50 +13,20 @@ ...@@ -13,50 +13,20 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include <string>
#include <utility>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
enum WireType { bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
WIRETYPE_VARINT = 0, const platform::DeviceContext& dev_ctx,
WIRETYPE_LENGTH_DELIMITED = 2, platform::Place place, void* dest,
}; int64_t size) {
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int size) {
const void* data = NULL; const void* data = NULL;
int size_to_write = 0; int size_to_write = 0;
int length = size; int64_t length = size;
int total_written = 0; int total_written = 0;
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
...@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData( ...@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData(
return true; return true;
} }
bool ParseLodData(::google::protobuf::io::CodedInputStream* input, bool VariableResponse::ProcSerializedField(
std::vector<int64_t>* lod) { int tag, ::google::protobuf::io::CodedInputStream* input,
while (true) { int64_t num_bytes) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return tag;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
int VariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return false;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR || meta_.type() == sendrecv::LOD_TENSOR ||
meta_.type() == sendrecv::NCCL_ID) && meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (meta_.type() == sendrecv::NCCL_ID) { if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto* var = scope_->FindVar(meta_.varname()); auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) { if (var != nullptr) {
ncclUniqueId* id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
num_bytes)) { num_bytes)) {
return tag; return false;
} }
} }
break; return true;
#else #else
PADDLE_THROW("Not compiled with CUDA!"); PADDLE_THROW("Not compiled with CUDA!");
return false;
#endif #endif
} }
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
"lod info should be got first!"); if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) { return false;
return tag;
} }
break; return true;
} }
if (meta_.type() == sendrecv::SELECTED_ROWS) { if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) { if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
return tag; return false;
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
case sendrecv::VariableMessage::kProfileFieldNumber: {
uint64_t profiling = 0;
if (!input.ReadVarint64(&profiling)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
} }
return true;
} }
return 0; return true;
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -22,18 +22,35 @@ ...@@ -22,18 +22,35 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
class VariableResponse { class VariableResponse {
public: public:
VariableResponse(const framework::Scope* scope, VariableResponse(const framework::Scope* scope,
...@@ -51,22 +68,19 @@ class VariableResponse { ...@@ -51,22 +68,19 @@ class VariableResponse {
} }
} }
// return: int Parse(Source* source, const sendrecv::VariableMessage& meta) {
// 0:ok. meta_ = meta;
// -1: unkown error. return Parse(source);
// other: number of error field. }
int Parse(Source* source);
// return: // return:
// 0:ok. // 0:ok.
// -1: unkown error. // -1: unkown error.
// other: number of error field. // other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer); virtual int Parse(Source* source) = 0;
const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline const framework::Scope& GetLocalScope() const { return *local_scope_; }
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
...@@ -78,7 +92,11 @@ class VariableResponse { ...@@ -78,7 +92,11 @@ class VariableResponse {
return scope_->FindVar(meta_.varname()); return scope_->FindVar(meta_.varname());
} }
private: protected:
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int64_t size);
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::DDim& dims, int length); const framework::DDim& dims, int length);
...@@ -90,12 +108,16 @@ class VariableResponse { ...@@ -90,12 +108,16 @@ class VariableResponse {
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::DDim& dims, int length); const framework::DDim& dims, int length);
private: bool ProcSerializedField(int tag,
::google::protobuf::io::CodedInputStream* input,
int64_t num_bytes);
protected:
const framework::Scope* scope_; const framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
bool create_scope_ = false; bool create_scope_ = false;
framework::Scope* local_scope_ = nullptr; framework::Scope* local_scope_ = nullptr;
// only Skeleton
sendrecv::VariableMessage meta_; sendrecv::VariableMessage meta_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册