From 3a6213f493e287339f7cfa86755e9d4ba4ee5703 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Fri, 20 Jul 2018 11:18:08 +0800 Subject: [PATCH] Change grpc interface to compatible with brpc. (#12164) --- .../operators/distributed/CMakeLists.txt | 50 ++- ...er_stream.cc => grpc_bytebuffer_stream.cc} | 2 +- ...ffer_stream.h => grpc_bytebuffer_stream.h} | 20 +- .../operators/distributed/grpc_client.cc | 1 + .../fluid/operators/distributed/grpc_client.h | 20 +- .../fluid/operators/distributed/grpc_serde.cc | 157 ++++++++ .../fluid/operators/distributed/grpc_serde.h | 50 +++ .../operators/distributed/grpc_serde_test.cc | 8 +- .../operators/distributed/grpc_server.cc | 21 +- .../operators/distributed/grpc_service.h | 10 +- .../distributed/grpc_variable_response.cc | 308 +++++++++++++++ .../distributed/grpc_variable_response.h | 58 +++ .../operators/distributed/request_handler.h | 17 + .../distributed/request_handler_impl.cc | 5 +- .../{send_recv.proto => send_recv.proto.in} | 3 +- .../operators/distributed/sendrecvop_utils.cc | 144 +------ .../operators/distributed/sendrecvop_utils.h | 17 +- .../distributed/variable_response.cc | 351 ++---------------- .../operators/distributed/variable_response.h | 56 ++- 19 files changed, 748 insertions(+), 550 deletions(-) rename paddle/fluid/operators/distributed/{bytebuffer_stream.cc => grpc_bytebuffer_stream.cc} (96%) rename paddle/fluid/operators/distributed/{bytebuffer_stream.h => grpc_bytebuffer_stream.h} (87%) create mode 100644 paddle/fluid/operators/distributed/grpc_serde.cc create mode 100644 paddle/fluid/operators/distributed/grpc_serde.h create mode 100644 paddle/fluid/operators/distributed/grpc_variable_response.cc create mode 100644 paddle/fluid/operators/distributed/grpc_variable_response.h rename paddle/fluid/operators/distributed/{send_recv.proto => send_recv.proto.in} (97%) diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 675ca36774b..6555b8101a9 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -1,33 +1,43 @@ +if(NOT WITH_DISTRIBUTE) + return() +endif() + +if(WITH_GRPC) + set(cc_generic_services "false") +else() + set(cc_generic_services "true") +endif() +configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY) + if(WITH_GRPC) - grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc - request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor - selected_rows memory) + grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc + request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc + PROTO send_recv.proto + DEPS lod_tensor selected_rows memory) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr - cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) - cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc - grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor - proto_desc lookup_table_op SERIAL) + cc_test(grpc_serde_test SRCS grpc_serde_test.cc + DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) + cc_test(grpc_server_test SRCS rpc_server_test.cc + DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL) return() endif() set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") -set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc + +set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc + brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc + brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc PROTO send_recv.proto DEPS lod_tensor selected_rows memory) -find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so) -ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC}) - +set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy) -find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so) -ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC}) +cc_test(brpc_server_test SRCS rpc_server_test.cc + DEPS ${brpc_test_depends} SERIAL) -cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc - brpc protobuf leveldb gflags glog - protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL) +cc_test(brpc_serde_test SRCS brpc_serde_test.cc + DEPS ${brpc_test_depends} SERIAL) diff --git a/paddle/fluid/operators/distributed/bytebuffer_stream.cc b/paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc similarity index 96% rename from paddle/fluid/operators/distributed/bytebuffer_stream.cc rename to paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc index 6e91b447db8..d192f54ee0c 100644 --- a/paddle/fluid/operators/distributed/bytebuffer_stream.cc +++ b/paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc @@ -17,7 +17,7 @@ limitations under the License. */ // file and did some modifications so that we can send gRPC // requests without too much copying of the tensor data. -#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/distributed/bytebuffer_stream.h b/paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h similarity index 87% rename from paddle/fluid/operators/distributed/bytebuffer_stream.h rename to paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h index e7de172c79c..e9074574cdd 100644 --- a/paddle/fluid/operators/distributed/bytebuffer_stream.h +++ b/paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "grpc++/grpc++.h" +#include "paddle/fluid/operators/distributed/variable_response.h" namespace grpc { // A ZeroCopyInputStream that reads from grpc_byte_buffer @@ -107,25 +108,6 @@ class GrpcBufferReader final namespace paddle { namespace operators { namespace distributed { -// Source provides a way for a particular RPC implementation to provide -// received data to ParseFrom. -class Source { - public: - virtual ~Source() {} - - // Return the stream that contains the data to be parsed. - // Note that this method might be invoked more than once if - // ParseFrom needs to fall back to a more expensive parsing method. - // Every call must return a stream pointing at the beginning of - // the serialized RecvTensorResponse. - // - // Note that a subsequent call to contents() invalidates previous - // results of contents(). - // - // Ownership of the returned stream is retained by the Source and - // should not be deleted by the caller. - virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0; -}; // A ZeroCopyInputStream that reads from a grpc::ByteBuffer. class GrpcByteBufferSource diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 4d60801b6a6..52c4bc1e796 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "glog/logging.h" // For VLOG #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/distributed/grpc_serde.h" #include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/platform/profiler.h" diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index d03a3e56aed..11de84d9e26 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -38,7 +38,10 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN @@ -46,23 +49,6 @@ namespace paddle { namespace operators { namespace distributed { -struct VarHandle { - // RPC endpoint. - std::string ep; - const platform::DeviceContext* ctx; - const framework::Scope* scope; - // Variable name. - std::string name; - // RPC method name. - std::string method; - - std::string String() const { - std::ostringstream s; - s << method << " name:[" << name << "], ep:[" << ep << "]"; - return s.str(); - } -}; - void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { diff --git a/paddle/fluid/operators/distributed/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc_serde.cc new file mode 100644 index 00000000000..3f8796713a6 --- /dev/null +++ b/paddle/fluid/operators/distributed/grpc_serde.cc @@ -0,0 +1,157 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_CUDA +#include +#endif +#include +#include // NOLINT + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/grpc_serde.h" +#include "paddle/fluid/operators/distributed/grpc_variable_response.h" +#include "paddle/fluid/operators/distributed/proto_encoder_helper.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { +namespace distributed { + +void SerializeToByteBuffer(const std::string& name, framework::Variable* var, + const platform::DeviceContext& ctx, + ::grpc::ByteBuffer* msg, + const std::string& out_name) { + // Default DestroyCallback does nothing, When using GPU + // the CPU buffer need to be freed. + DestroyCallback destroy_callback = [](void* backing) {}; + VarMsg request; + void* payload = nullptr; + size_t payload_size; + + request.set_varname(name); + // Note: normally the profiler is enabled in 1 trainer, hence only + // 1 trainer returns true for ShouldSendProfileState(). It tells PS + // servers the trainer's profiling state so that PS can follow the + // trainer. + if (platform::ShouldSendProfileState()) { + if (platform::IsProfileEnabled()) { + request.set_profile(platform::kEnableProfiler); + } else { + request.set_profile(platform::kDisableProfiler); + } + } + if (!out_name.empty()) { + request.set_out_varname(out_name); + } + if (var->IsType()) { + request.set_type(::sendrecv::LOD_TENSOR); + GetTensorPayload(var, ctx, &request, &payload, &payload_size); + } else if (var->IsType()) { + request.set_type(::sendrecv::SELECTED_ROWS); + GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size); +#ifdef PADDLE_WITH_CUDA + } else if (var->IsType()) { + request.set_type(::sendrecv::NCCL_ID); +#endif + } else { + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); + } + + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + // GPU data is copied to CPU buffer when sending, + // free the buffer when possible. + destroy_callback = [](void* backing) { + platform::CUDAPinnedPlace cuda_pinned; + memory::Free(cuda_pinned, backing); + }; +#endif + } + + std::string header; + request.AppendToString(&header); + auto buffer = std::unique_ptr(new char[1024]); + void* buf = buffer.get(); + ProtoEncodeHelper e(static_cast(buf), 1024); + e.WriteRawBytes(std::string(header.data(), header.size())); +// NCCLID is copied directly to the message, return bytebuffer +// with only one slice if serializing NCCLID. +#ifdef PADDLE_WITH_CUDA + if (var->IsType()) { + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, + NCCL_UNIQUE_ID_BYTES); + const ncclUniqueId& uid = var->Get(); + e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES)); + + // for serialize NCCL_ID + ::grpc::Slice slices(e.size()); + memcpy(const_cast(slices.begin()), e.data(), e.size()); + ::grpc::ByteBuffer tmp(&slices, 1); + msg->Swap(&tmp); + return; + } +#endif + + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + // steal reference of tensor data + ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows + int num_slices = 2; // only SelectedRows have rows buffer + slices[0] = ::grpc::Slice(e.size()); + memcpy(const_cast(slices[0].begin()), e.data(), e.size()); + slices[1] = ::grpc::Slice( + grpc_slice_new_with_user_data(payload, payload_size, destroy_callback, + static_cast(payload)), + ::grpc::Slice::STEAL_REF); + + if (var->IsType()) { + auto* slr = var->GetMutable(); + ProtoEncodeHelper e2(static_cast(buf), 128); + size_t rows_memory_size = + slr->rows().size() * framework::SizeOfType(typeid(int64_t)); + e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); + slices[2] = ::grpc::Slice(e2.size()); + memcpy(const_cast(slices[2].begin()), e2.data(), e2.size()); + + slices[3] = ::grpc::Slice( + grpc_slice_new_with_user_data( + const_cast( + reinterpret_cast(slr->rows().data())), + rows_memory_size, [](void* backing) {}, + const_cast( + reinterpret_cast(slr->rows().data()))), + ::grpc::Slice::STEAL_REF); + num_slices = 4; + } + + ::grpc::ByteBuffer tmp(&slices[0], num_slices); + msg->Swap(&tmp); +} + +void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + const framework::Scope* scope, + framework::Variable** var) { + operators::distributed::GRPCVariableResponse resp(scope, &ctx); + PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); + *var = resp.GetVar(); +} + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc_serde.h b/paddle/fluid/operators/distributed/grpc_serde.h new file mode 100644 index 00000000000..450c41dcd6b --- /dev/null +++ b/paddle/fluid/operators/distributed/grpc_serde.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" + +#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" + +namespace paddle { +namespace operators { +namespace distributed { + +typedef void (*DestroyCallback)(void*); + +void SerializeToByteBuffer(const std::string& name, framework::Variable* var, + const platform::DeviceContext& ctx, + ::grpc::ByteBuffer* msg, + const std::string& out_varname = std::string()); + +void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + const framework::Scope* scope, + framework::Variable** var); + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc_serde_test.cc b/paddle/fluid/operators/distributed/grpc_serde_test.cc index 3d107b533bc..96ea05e74ed 100644 --- a/paddle/fluid/operators/distributed/grpc_serde_test.cc +++ b/paddle/fluid/operators/distributed/grpc_serde_test.cc @@ -21,8 +21,10 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/detail/macros.h" +#include "paddle/fluid/operators/distributed/grpc_serde.h" +#include "paddle/fluid/operators/distributed/grpc_variable_response.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h" -#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" @@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::distributed::VariableResponse resp(&scope, &ctx); + operators::distributed::GRPCVariableResponse resp(&scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::distributed::VariableResponse resp(&scope, &ctx); + operators::distributed::GRPCVariableResponse resp(&scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index f35e268f6ad..8edb00276df 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/operators/distributed/grpc_serde.h" #include "paddle/fluid/operators/distributed/grpc_server.h" using ::grpc::ServerAsyncResponseWriter; @@ -84,9 +85,9 @@ class RequestSend final : public RequestBase { ::grpc::ServerCompletionQueue* cq, RequestHandler* request_handler, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - request_.reset(new VariableResponse(request_handler->scope(), - request_handler->dev_ctx(), - !request_handler->sync_mode())); + request_.reset(new GRPCVariableResponse(request_handler->scope(), + request_handler->dev_ctx(), + !request_handler->sync_mode())); int method_id = static_cast(distributed::GrpcMethod::kSendVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, @@ -109,7 +110,7 @@ class RequestSend final : public RequestBase { protected: sendrecv::VoidMessage reply_; - std::shared_ptr request_; + std::shared_ptr request_; ServerAsyncResponseWriter responder_; }; @@ -161,8 +162,8 @@ class RequestPrefetch final : public RequestBase { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_), local_scope_(nullptr) { - request_.reset(new VariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); + request_.reset(new GRPCVariableResponse(request_handler->scope(), + request_handler->dev_ctx(), true)); int method_id = static_cast(distributed::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary( @@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase { } protected: - std::shared_ptr request_; + std::shared_ptr request_; ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* local_scope_; @@ -206,8 +207,8 @@ class RequestCheckpointNotify final : public RequestBase { ::grpc::ServerCompletionQueue* cq, RequestHandler* request_handler, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - request_.reset(new VariableResponse(request_handler->scope(), - request_handler->dev_ctx())); + request_.reset(new GRPCVariableResponse(request_handler->scope(), + request_handler->dev_ctx())); int method_id = static_cast(distributed::GrpcMethod::kCheckpointNotify); service_->RequestAsyncUnary( @@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase { } protected: - std::shared_ptr request_; + std::shared_ptr request_; sendrecv::VoidMessage reply_; ServerAsyncResponseWriter responder_; }; diff --git a/paddle/fluid/operators/distributed/grpc_service.h b/paddle/fluid/operators/distributed/grpc_service.h index cdc4e7b7927..9ae9a31a003 100644 --- a/paddle/fluid/operators/distributed/grpc_service.h +++ b/paddle/fluid/operators/distributed/grpc_service.h @@ -23,8 +23,7 @@ #include #include #include -#include "paddle/fluid/operators/distributed/variable_response.h" - +#include "paddle/fluid/operators/distributed/grpc_variable_response.h" #include "paddle/fluid/platform/profiler.h" // NOTE: This method was originally created by tensorflow @@ -42,17 +41,18 @@ class ServerContext; // Support parsing/unparsing of tensorflow::VariableResponse. // Wire-format is identical to RecvVariableResponse. template <> -class SerializationTraits { +class SerializationTraits< + paddle::operators::distributed::GRPCVariableResponse> { public: static Status Serialize( - const paddle::operators::distributed::VariableResponse& msg, + const paddle::operators::distributed::GRPCVariableResponse& msg, grpc_byte_buffer** bp, bool* own_buffer) { PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); return Status(); } static Status Deserialize( grpc_byte_buffer* buffer, - paddle::operators::distributed::VariableResponse* msg, + paddle::operators::distributed::GRPCVariableResponse* msg, int max_message_size = INT_MAX) { if (buffer == nullptr) { return Status(StatusCode::INTERNAL, "No payload"); diff --git a/paddle/fluid/operators/distributed/grpc_variable_response.cc b/paddle/fluid/operators/distributed/grpc_variable_response.cc new file mode 100644 index 00000000000..34d47f3ec0f --- /dev/null +++ b/paddle/fluid/operators/distributed/grpc_variable_response.cc @@ -0,0 +1,308 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#ifdef PADDLE_WITH_CUDA +#include +#endif + +#include "paddle/fluid/operators/distributed/grpc_variable_response.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { +namespace distributed { + +enum WireType { + WIRETYPE_VARINT = 0, + WIRETYPE_LENGTH_DELIMITED = 2, +}; + +inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; } + +inline WireType GetTagWireType(uint32_t tag) { + return static_cast(tag & 0x7); +} + +bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input, + int* result) { + uint64_t v; + if (input->ReadVarint64(&v) && v <= static_cast(INT_MAX)) { + *result = static_cast(v); + return true; + } else { + return false; + } +} + +int GRPCVariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { + GrpcByteBufferSource source; + source.Init(byte_buffer); + GrpcByteBufferSourceWrapper r(&source); + + return Parse(&r); +} + +bool ParseLodData(::google::protobuf::io::CodedInputStream* input, + std::vector* lod) { + while (true) { + auto p = input->ReadTagWithCutoff(127); + int tag = GetTagFieldNumber(p.first); + WireType wt = GetTagWireType(p.first); + + if (!p.second) { + return (tag == 0); + } + + switch (tag) { + case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: { + uint64_t v; + if (wt == WIRETYPE_VARINT) { + if (!input->ReadVarint64(&v)) { + return false; + } + lod->push_back(v); + break; + } + + if (wt == WIRETYPE_LENGTH_DELIMITED) { + int num_bytes = 0; + if (!input->ReadVarintSizeAsInt(&num_bytes)) { + return tag; + } + int start_pos = input->CurrentPosition(); + while (input->CurrentPosition() - start_pos < num_bytes) { + uint64_t v; + if (!input->ReadVarint64(&v)) { + return tag; + } + lod->push_back(v); + } + break; + } + + return false; + } + default: { return false; } + } + } + + return true; +} + +int GRPCVariableResponse::Parse(Source* source) { + ::google::protobuf::io::ZeroCopyInputStream* input_stream = + source->contents(); + ::google::protobuf::io::CodedInputStream input(input_stream); + input.SetTotalBytesLimit(INT_MAX, INT_MAX); + + while (true) { + auto p = input.ReadTagWithCutoff(127); + int tag = GetTagFieldNumber(p.first); + WireType wt = GetTagWireType(p.first); + if (!p.second) { + if (tag != 0) { + return -1; + } + return 0; + } + + switch (tag) { + case sendrecv::VariableMessage::kVarnameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_varname(temp); + break; + } + case sendrecv::VariableMessage::kTypeFieldNumber: { + uint32_t v; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) { + return tag; + } + + meta_.set_type(static_cast<::sendrecv::VarType>(v)); + break; + } + case sendrecv::VariableMessage::kDataTypeFieldNumber: { + uint32_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) { + return tag; + } + + meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v)); + break; + } + case sendrecv::VariableMessage::kDimsFieldNumber: { + // not packed + if (wt == WIRETYPE_VARINT) { + uint64_t v; + if (!input.ReadVarint64(&v)) { + return tag; + } + meta_.add_dims(v); + break; + } + + // packed + if (wt == WIRETYPE_LENGTH_DELIMITED) { + int num_bytes = 0; + if (!input.ReadVarintSizeAsInt(&num_bytes)) { + return tag; + } + int start_pos = input.CurrentPosition(); + while (input.CurrentPosition() - start_pos < num_bytes) { + uint64_t v; + if (!input.ReadVarint64(&v)) { + return tag; + } + meta_.add_dims(v); + } + break; + } + return tag; + } + case sendrecv::VariableMessage::kLodLevelFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + meta_.set_lod_level(static_cast(v)); + break; + } + case sendrecv::VariableMessage::kLodFieldNumber: { + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p = + input.IncrementRecursionDepthAndPushLimit(length); + + std::vector lod_data; + if (p.second < 0 || !ParseLodData(&input, &lod_data)) { + return tag; + } + + if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { + return tag; + } + + if (lod_data.size() == 0) { + break; + } + + auto lod = meta_.add_lod(); + for (uint32_t i = 0; i < lod_data.size(); i++) { + lod->add_lod_data(lod_data[i]); + } + break; + } + case sendrecv::VariableMessage::kSlrHeightFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + meta_.set_slr_height(static_cast(v)); + break; + } + case sendrecv::VariableMessage::kSerializedFieldNumber: { + int num_bytes = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &num_bytes)) { + return tag; + } + + if (!ProcSerializedField(tag, &input, num_bytes)) { + return tag; + } + + break; + } + case sendrecv::VariableMessage::kRowsFieldNumber: { + PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || + meta_.type() == sendrecv::LOD_TENSOR) && + meta_.varname() != "", + "meta info should be got first!"); + + int num_bytes = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &num_bytes)) { + return tag; + } + + if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) { + return tag; + } + break; + } + case sendrecv::VariableMessage::kOutVarnameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_out_varname(temp); + break; + } + case sendrecv::VariableMessage::kProfileFieldNumber: { + uint64_t profiling = 0; + if (!input.ReadVarint64(&profiling)) { + return tag; + } + meta_.set_profile(profiling); + int64_t listener_id = platform::ListenerId(); + if (listener_id <= 0) { + break; + } + if (profiling == platform::kEnableProfiler && + !platform::IsProfileEnabled()) { + platform::EnableProfiler(platform::ProfilerState::kCPU); + } else if (profiling == platform::kDisableProfiler && + platform::IsProfileEnabled()) { + // TODO(panyx0718): Should we allow to customize file dir. + platform::DisableProfiler( + platform::EventSortingKey::kDefault, + string::Sprintf("/tmp/profile_ps_%lld", listener_id)); + } + break; + } + default: { + // Unknown tag, return unknown error. + return -1; + } + } + } + + return 0; +} + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc_variable_response.h b/paddle/fluid/operators/distributed/grpc_variable_response.h new file mode 100644 index 00000000000..89df07c92cd --- /dev/null +++ b/paddle/fluid/operators/distributed/grpc_variable_response.h @@ -0,0 +1,58 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type.h" + +#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/variable_response.h" + +namespace paddle { +namespace operators { +namespace distributed { + +class GRPCVariableResponse : public VariableResponse { + public: + GRPCVariableResponse(const framework::Scope* scope, + const platform::DeviceContext* dev_ctx, + bool create_scope = false) + : VariableResponse(scope, dev_ctx, create_scope) {} + + virtual ~GRPCVariableResponse() {} + + int Parse(Source* source) override; + + // return: + // 0:ok. + // -1: unkown error. + // other: number of error field. + int Parse(const ::grpc::ByteBuffer& byte_buffer); +}; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 271306d5d20..3d61171dff9 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; class RPCServer; +struct VarHandle { + // RPC endpoint. + std::string ep; + const platform::DeviceContext* ctx; + const framework::Scope* scope; + // Variable name. + std::string name; + // RPC method name. + std::string method; + + std::string String() const { + std::ostringstream s; + s << method << " name:[" << name << "], ep:[" << ep << "]"; + return s.str(); + } +}; + class RequestHandler { public: explicit RequestHandler(bool sync_mode) diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 5e6bff20f5f..f1f84072d47 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname, // Sync if (varname == BATCH_BARRIER_MESSAGE) { - VLOG(3) << "sync: recv batch barrier message"; + VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; rpc_server_->IncreaseBatchBarrier(kRequestSend); } else if (varname == BEGIN_PASS_MESSAGE) { VLOG(3) << "sync: recv begin pass message"; @@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname, VLOG(3) << "sync: processing received var: " << varname; if (invar == nullptr) { - LOG(ERROR) << "sync: Can not find server side var: " << varname; - PADDLE_THROW("sync: Can not find server side var"); + LOG(FATAL) << "sync: Can not find server side var: " << varname; return false; } if (invar->IsType()) { diff --git a/paddle/fluid/operators/distributed/send_recv.proto b/paddle/fluid/operators/distributed/send_recv.proto.in similarity index 97% rename from paddle/fluid/operators/distributed/send_recv.proto rename to paddle/fluid/operators/distributed/send_recv.proto.in index e0902320cff..8b0a09abe1d 100644 --- a/paddle/fluid/operators/distributed/send_recv.proto +++ b/paddle/fluid/operators/distributed/send_recv.proto.in @@ -1,3 +1,4 @@ + /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +15,7 @@ limitations under the License. */ syntax = "proto3"; package sendrecv; -// option cc_generic_services = true; +option cc_generic_services = @cc_generic_services@; service SendRecvService { // For parameter server round-robin like hashing, do not split tensors. diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.cc b/paddle/fluid/operators/distributed/sendrecvop_utils.cc index 98129d9f101..98a5dcbbb87 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.cc @@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" - #ifdef PADDLE_WITH_CUDA #include #endif #include #include // NOLINT -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" -#include "paddle/fluid/operators/distributed/proto_encoder_helper.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/variable_response.h" -#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -34,6 +28,11 @@ namespace distributed { using VarMsg = sendrecv::VariableMessage; +void* GetVarPayLoad(const std::string varname, int64_t size) { + platform::CUDAPinnedPlace cuda_pinned; + return memory::Alloc(cuda_pinned, size); +} + void GetTensorPayload(framework::Variable* var, const platform::DeviceContext& ctx, VarMsg* request, void** payload, size_t* payload_size) { @@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var, if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); - platform::CUDAPinnedPlace cuda_pinned; + // platform::CUDAPinnedPlace cuda_pinned; auto& gpu_dev_ctx = static_cast(ctx); auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); - *payload = memory::Alloc(cuda_pinned, copy_size); + *payload = GetVarPayLoad(request->varname(), copy_size); + platform::CUDAPinnedPlace cuda_pinned; memory::Copy(cuda_pinned, *payload, boost::get(tensor.place()), reinterpret_cast(tensor.data()), copy_size, gpu_dev_ctx.stream()); + ctx.Wait(); #endif } else { @@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var, auto* tensor = slr->mutable_value(); if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA - platform::CUDAPinnedPlace cuda_pinned; auto& gpu_dev_ctx = static_cast(ctx); auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type()); - *payload = memory::Alloc(cuda_pinned, copy_size); + *payload = GetVarPayLoad(request->varname(), copy_size); + + platform::CUDAPinnedPlace cuda_pinned; memory::Copy(cuda_pinned, *payload, boost::get(tensor->place()), reinterpret_cast(tensor->data()), copy_size, @@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var, *payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); } -void SerializeToByteBuffer(const std::string& name, framework::Variable* var, - const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg, - const std::string& out_name) { - // Default DestroyCallback does nothing, When using GPU - // the CPU buffer need to be freed. - DestroyCallback destroy_callback = [](void* backing) {}; - VarMsg request; - void* payload = nullptr; - size_t payload_size; - - request.set_varname(name); - // Note: normally the profiler is enabled in 1 trainer, hence only - // 1 trainer returns true for ShouldSendProfileState(). It tells PS - // servers the trainer's profiling state so that PS can follow the - // trainer. - if (platform::ShouldSendProfileState()) { - if (platform::IsProfileEnabled()) { - request.set_profile(platform::kEnableProfiler); - } else { - request.set_profile(platform::kDisableProfiler); - } - } - if (!out_name.empty()) { - request.set_out_varname(out_name); - } - if (var->IsType()) { - request.set_type(::sendrecv::LOD_TENSOR); - GetTensorPayload(var, ctx, &request, &payload, &payload_size); - } else if (var->IsType()) { - request.set_type(::sendrecv::SELECTED_ROWS); - GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size); -#ifdef PADDLE_WITH_CUDA - } else if (var->IsType()) { - request.set_type(::sendrecv::NCCL_ID); -#endif - } else { - PADDLE_THROW("Serialize does not support type: %s", - typeid(var->Type()).name()); - } - - if (platform::is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA - // GPU data is copied to CPU buffer when sending, - // free the buffer when possible. - destroy_callback = [](void* backing) { - platform::CUDAPinnedPlace cuda_pinned; - memory::Free(cuda_pinned, backing); - }; -#endif - } - - std::string header; - request.AppendToString(&header); - auto buffer = std::unique_ptr(new char[1024]); - void* buf = buffer.get(); - ProtoEncodeHelper e(static_cast(buf), 1024); - e.WriteRawBytes(std::string(header.data(), header.size())); -// NCCLID is copied directly to the message, return bytebuffer -// with only one slice if serializing NCCLID. -#ifdef PADDLE_WITH_CUDA - if (var->IsType()) { - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, - NCCL_UNIQUE_ID_BYTES); - const ncclUniqueId& uid = var->Get(); - e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES)); - - // for serialize NCCL_ID - ::grpc::Slice slices(e.size()); - memcpy(const_cast(slices.begin()), e.data(), e.size()); - ::grpc::ByteBuffer tmp(&slices, 1); - msg->Swap(&tmp); - return; - } -#endif - - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); - // steal reference of tensor data - ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows - int num_slices = 2; // only SelectedRows have rows buffer - slices[0] = ::grpc::Slice(e.size()); - memcpy(const_cast(slices[0].begin()), e.data(), e.size()); - slices[1] = ::grpc::Slice( - grpc_slice_new_with_user_data(payload, payload_size, destroy_callback, - static_cast(payload)), - ::grpc::Slice::STEAL_REF); - - if (var->IsType()) { - auto* slr = var->GetMutable(); - ProtoEncodeHelper e2(static_cast(buf), 128); - size_t rows_memory_size = - slr->rows().size() * framework::SizeOfType(typeid(int64_t)); - e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); - slices[2] = ::grpc::Slice(e2.size()); - memcpy(const_cast(slices[2].begin()), e2.data(), e2.size()); - - slices[3] = ::grpc::Slice( - grpc_slice_new_with_user_data( - const_cast( - reinterpret_cast(slr->rows().data())), - rows_memory_size, [](void* backing) {}, - const_cast( - reinterpret_cast(slr->rows().data()))), - ::grpc::Slice::STEAL_REF); - num_slices = 4; - } - - ::grpc::ByteBuffer tmp(&slices[0], num_slices); - msg->Swap(&tmp); -} - -void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var) { - operators::distributed::VariableResponse resp(scope, &ctx); - PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); - *var = resp.GetVar(); -} - } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.h b/paddle/fluid/operators/distributed/sendrecvop_utils.h index fe25e73fa60..4d08d3c77af 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.h +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.h @@ -25,24 +25,21 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h" namespace paddle { namespace operators { namespace distributed { -typedef void (*DestroyCallback)(void*); +using VarMsg = sendrecv::VariableMessage; -void SerializeToByteBuffer(const std::string& name, framework::Variable* var, - const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg, - const std::string& out_varname = std::string()); +void GetTensorPayload(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* request, + void** payload, size_t* payload_size); -void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, - const platform::DeviceContext& ctx, - const framework::Scope* scope, - framework::Variable** var); +void GetSelectedRowsPayload(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* request, + void** payload, size_t* payload_size); inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { switch (type) { diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 45832c60bf9..466bce18af7 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,50 +13,20 @@ // limitations under the License. #include "paddle/fluid/operators/distributed/variable_response.h" - -#include -#include #include -#ifdef PADDLE_WITH_CUDA -#include -#endif -#include "paddle/fluid/platform/profiler.h" - -#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h" namespace paddle { namespace operators { namespace distributed { -enum WireType { - WIRETYPE_VARINT = 0, - WIRETYPE_LENGTH_DELIMITED = 2, -}; - -inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; } - -inline WireType GetTagWireType(uint32_t tag) { - return static_cast(tag & 0x7); -} - -bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input, - int* result) { - uint64_t v; - if (input->ReadVarint64(&v) && v <= static_cast(INT_MAX)) { - *result = static_cast(v); - return true; - } else { - return false; - } -} - -bool ReadRaw(::google::protobuf::io::CodedInputStream* input, - const platform::DeviceContext& dev_ctx, platform::Place place, - void* dest, int size) { +bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& dev_ctx, + platform::Place place, void* dest, + int64_t size) { const void* data = NULL; int size_to_write = 0; - int length = size; + int64_t length = size; int total_written = 0; if (platform::is_gpu_place(place)) { @@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData( return true; } -bool ParseLodData(::google::protobuf::io::CodedInputStream* input, - std::vector* lod) { - while (true) { - auto p = input->ReadTagWithCutoff(127); - int tag = GetTagFieldNumber(p.first); - WireType wt = GetTagWireType(p.first); - - if (!p.second) { - return (tag == 0); - } - - switch (tag) { - case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: { - uint64_t v; - if (wt == WIRETYPE_VARINT) { - if (!input->ReadVarint64(&v)) { - return false; - } - lod->push_back(v); - break; - } - - if (wt == WIRETYPE_LENGTH_DELIMITED) { - int num_bytes = 0; - if (!input->ReadVarintSizeAsInt(&num_bytes)) { - return tag; - } - int start_pos = input->CurrentPosition(); - while (input->CurrentPosition() - start_pos < num_bytes) { - uint64_t v; - if (!input->ReadVarint64(&v)) { - return tag; - } - lod->push_back(v); - } - break; - } +bool VariableResponse::ProcSerializedField( + int tag, ::google::protobuf::io::CodedInputStream* input, + int64_t num_bytes) { + PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || + meta_.type() == sendrecv::LOD_TENSOR || + meta_.type() == sendrecv::NCCL_ID) && + meta_.varname() != "", + "meta info should be got first!"); + if (meta_.type() == sendrecv::NCCL_ID) { +#ifdef PADDLE_WITH_CUDA + auto* var = scope_->FindVar(meta_.varname()); + if (var != nullptr) { + ncclUniqueId* id = var->GetMutable(); + if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal, + num_bytes)) { return false; } - default: { return false; } } - } - - return true; -} - -int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { - GrpcByteBufferSource source; - source.Init(byte_buffer); - GrpcByteBufferSourceWrapper r(&source); - - return Parse(&r); -} - -int VariableResponse::Parse(Source* source) { - ::google::protobuf::io::ZeroCopyInputStream* input_stream = - source->contents(); - ::google::protobuf::io::CodedInputStream input(input_stream); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); - - while (true) { - auto p = input.ReadTagWithCutoff(127); - int tag = GetTagFieldNumber(p.first); - WireType wt = GetTagWireType(p.first); - if (!p.second) { - if (tag != 0) { - return -1; - } - return 0; - } - - switch (tag) { - case sendrecv::VariableMessage::kVarnameFieldNumber: { - uint32_t length; - if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { - return tag; - } - - std::string temp; - if (!input.ReadString(&temp, length)) { - return tag; - } - - meta_.set_varname(temp); - break; - } - case sendrecv::VariableMessage::kTypeFieldNumber: { - uint32_t v; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) { - return tag; - } - - meta_.set_type(static_cast<::sendrecv::VarType>(v)); - break; - } - case sendrecv::VariableMessage::kDataTypeFieldNumber: { - uint32_t v = 0; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) { - return tag; - } - - meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v)); - break; - } - case sendrecv::VariableMessage::kDimsFieldNumber: { - // not packed - if (wt == WIRETYPE_VARINT) { - uint64_t v; - if (!input.ReadVarint64(&v)) { - return tag; - } - meta_.add_dims(v); - break; - } - - // packed - if (wt == WIRETYPE_LENGTH_DELIMITED) { - int num_bytes = 0; - if (!input.ReadVarintSizeAsInt(&num_bytes)) { - return tag; - } - int start_pos = input.CurrentPosition(); - while (input.CurrentPosition() - start_pos < num_bytes) { - uint64_t v; - if (!input.ReadVarint64(&v)) { - return tag; - } - meta_.add_dims(v); - } - break; - } - return tag; - } - case sendrecv::VariableMessage::kLodLevelFieldNumber: { - uint64_t v = 0; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { - return tag; - } - meta_.set_lod_level(static_cast(v)); - break; - } - case sendrecv::VariableMessage::kLodFieldNumber: { - int length = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &length)) { - return tag; - } - - std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p = - input.IncrementRecursionDepthAndPushLimit(length); - - std::vector lod_data; - if (p.second < 0 || !ParseLodData(&input, &lod_data)) { - return tag; - } - - if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { - return false; - } - - if (lod_data.size() == 0) { - break; - } - - auto lod = meta_.add_lod(); - for (uint32_t i = 0; i < lod_data.size(); i++) { - lod->add_lod_data(lod_data[i]); - } - break; - } - case sendrecv::VariableMessage::kSlrHeightFieldNumber: { - uint64_t v = 0; - if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { - return tag; - } - meta_.set_slr_height(static_cast(v)); - break; - } - case sendrecv::VariableMessage::kSerializedFieldNumber: { - PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || - meta_.type() == sendrecv::LOD_TENSOR || - meta_.type() == sendrecv::NCCL_ID) && - meta_.varname() != "", - "meta info should be got first!"); - - int num_bytes = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &num_bytes)) { - return tag; - } - - if (meta_.type() == sendrecv::NCCL_ID) { -#ifdef PADDLE_WITH_CUDA - auto* var = scope_->FindVar(meta_.varname()); - if (var != nullptr) { - ncclUniqueId* id = var->GetMutable(); - if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, - num_bytes)) { - return tag; - } - } - break; + return true; #else - PADDLE_THROW("Not compiled with CUDA!"); + PADDLE_THROW("Not compiled with CUDA!"); + return false; #endif - } - - framework::DDim dims = GetDims(meta_.dims()); - if (meta_.type() == sendrecv::LOD_TENSOR) { - PADDLE_ENFORCE(meta_.lod_size() >= 0, - "lod info should be got first!"); - if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) { - return tag; - } - break; - } - - if (meta_.type() == sendrecv::SELECTED_ROWS) { - if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) { - return tag; - } - break; - } - - return tag; - } - case sendrecv::VariableMessage::kRowsFieldNumber: { - PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || - meta_.type() == sendrecv::LOD_TENSOR) && - meta_.varname() != "", - "meta info should be got first!"); - - int num_bytes = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &num_bytes)) { - return tag; - } - - if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) { - return tag; - } - break; - } - case sendrecv::VariableMessage::kOutVarnameFieldNumber: { - uint32_t length; - if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { - return tag; - } + } - std::string temp; - if (!input.ReadString(&temp, length)) { - return tag; - } + framework::DDim dims = GetDims(meta_.dims()); + if (meta_.type() == sendrecv::LOD_TENSOR) { + PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); + if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { + return false; + } + return true; + } - meta_.set_out_varname(temp); - break; - } - case sendrecv::VariableMessage::kProfileFieldNumber: { - uint64_t profiling = 0; - if (!input.ReadVarint64(&profiling)) { - return tag; - } - meta_.set_profile(profiling); - int64_t listener_id = platform::ListenerId(); - if (listener_id <= 0) { - break; - } - if (profiling == platform::kEnableProfiler && - !platform::IsProfileEnabled()) { - platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (profiling == platform::kDisableProfiler && - platform::IsProfileEnabled()) { - // TODO(panyx0718): Should we allow to customize file dir. - platform::DisableProfiler( - platform::EventSortingKey::kDefault, - string::Sprintf("/tmp/profile_ps_%lld", listener_id)); - } - break; - } - default: { - // Unknown tag, return unknown error. - return -1; - } + if (meta_.type() == sendrecv::SELECTED_ROWS) { + if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) { + return false; } + return true; } - return 0; + return true; } }; // namespace distributed diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index 1db4a0a5226..6aec52ca00f 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -22,18 +22,35 @@ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" -#include "paddle/fluid/operators/distributed/send_recv.pb.h" - #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" namespace paddle { namespace operators { namespace distributed { +// Source provides a way for a particular RPC implementation to provide +// received data to ParseFrom. +class Source { + public: + virtual ~Source() {} + + // Return the stream that contains the data to be parsed. + // Note that this method might be invoked more than once if + // ParseFrom needs to fall back to a more expensive parsing method. + // Every call must return a stream pointing at the beginning of + // the serialized RecvTensorResponse. + // + // Note that a subsequent call to contents() invalidates previous + // results of contents(). + // + // Ownership of the returned stream is retained by the Source and + // should not be deleted by the caller. + virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0; +}; + class VariableResponse { public: VariableResponse(const framework::Scope* scope, @@ -51,22 +68,19 @@ class VariableResponse { } } - // return: - // 0:ok. - // -1: unkown error. - // other: number of error field. - int Parse(Source* source); + int Parse(Source* source, const sendrecv::VariableMessage& meta) { + meta_ = meta; + return Parse(source); + } // return: // 0:ok. // -1: unkown error. // other: number of error field. - int Parse(const ::grpc::ByteBuffer& byte_buffer); - - const framework::Scope& GetLocalScope() const { return *local_scope_; } - - framework::Scope* GetMutableLocalScope() const { return local_scope_; } + virtual int Parse(Source* source) = 0; + inline const framework::Scope& GetLocalScope() const { return *local_scope_; } + inline framework::Scope* GetMutableLocalScope() const { return local_scope_; } inline std::string Varname() const { return meta_.varname(); } inline std::string OutVarname() const { return meta_.out_varname(); } @@ -78,7 +92,11 @@ class VariableResponse { return scope_->FindVar(meta_.varname()); } - private: + protected: + bool ReadRaw(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& dev_ctx, platform::Place place, + void* dest, int64_t size); + bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length); @@ -90,12 +108,16 @@ class VariableResponse { const platform::DeviceContext& ctx, const framework::DDim& dims, int length); - private: + bool ProcSerializedField(int tag, + ::google::protobuf::io::CodedInputStream* input, + int64_t num_bytes); + + protected: const framework::Scope* scope_; const platform::DeviceContext* dev_ctx_; bool create_scope_ = false; framework::Scope* local_scope_ = nullptr; - // only Skeleton + sendrecv::VariableMessage meta_; }; -- GitLab