未验证 提交 990d6396 编写于 作者: G gongweibao 提交者: GitHub

Reuduce memory copy when communication between trainer and pserver. (#9271)

上级 b594251f
......@@ -18,12 +18,13 @@ import sys
import time
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import paddle.v2.fluid.profiler as profiler
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
import argparse
import functools
import os
from paddle.fluid import debuger
def str2bool(v):
......@@ -182,7 +183,6 @@ def main():
start_time = time.time()
num_samples = 0
train_pass_acc.reset()
with profiler.profiler("CPU", 'total') as prof:
for batch_id, data in enumerate(train_reader()):
ts = time.time()
img_data = np.array(
......@@ -254,9 +254,7 @@ def main():
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print("starting server side startup")
exe.run(pserver_startup)
print("starting parameter server...")
exe.run(pserver_prog)
elif training_role == "TRAINER":
# Parameter initialization
......
......@@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server):
return np.mean(test_accs)
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1,
log_device_placement=True)
config.gpu_options.allow_growth = True
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(args.task_index == 0),
hooks=hooks) as sess:
master=server.target,
is_chief=(args.task_index == 0),
hooks=hooks,
config=config) as sess:
iters, num_samples, start_time = 0, 0, 0.0
for pass_id in range(args.num_passes):
# train
......
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc)
cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
endif()
......@@ -23,9 +23,107 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
namespace grpc {
// A ZeroCopyInputStream that reads from grpc_byte_buffer
class GrpcBufferReader final
: public ::google::protobuf::io::ZeroCopyInputStream {
typedef void (CoreCodegenInterface::*OldReaderInitAPI)(
grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
typedef int (CoreCodegenInterface::*NewReaderInitAPI)(
grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
grpc_byte_buffer* buffer) {
(g_core_codegen_interface->*ptr)(reader, buffer);
}
void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
grpc_byte_buffer* buffer) {
int result = (g_core_codegen_interface->*ptr)(reader, buffer);
(void)result;
}
public:
explicit GrpcBufferReader(grpc_byte_buffer* buffer)
: byte_count_(0), backup_count_(0) {
ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_,
buffer);
}
~GrpcBufferReader() override {
g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_);
}
bool Next(const void** data, int* size) override {
if (backup_count_ > 0) {
*data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) -
backup_count_;
GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX);
*size = (int)backup_count_;
backup_count_ = 0;
return true;
}
if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_,
&slice_)) {
return false;
}
g_core_codegen_interface->grpc_slice_unref(slice_);
*data = GRPC_SLICE_START_PTR(slice_);
// On win x64, int is only 32bit
GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX);
byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_);
return true;
}
void BackUp(int count) override { backup_count_ = count; }
bool Skip(int count) override {
const void* data;
int size;
while (Next(&data, &size)) {
if (size >= count) {
BackUp(size - count);
return true;
}
// size < count;
count -= size;
}
// error or we have too large count;
return false;
}
::google::protobuf::int64 ByteCount() const override {
return byte_count_ - backup_count_;
}
private:
int64_t byte_count_;
int64_t backup_count_;
grpc_byte_buffer_reader reader_;
grpc_slice slice_;
};
}; // namespace grpc
namespace paddle {
namespace operators {
namespace detail {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource
......@@ -46,6 +144,42 @@ class GrpcByteBufferSource
::google::protobuf::int64 byte_count_;
};
class GrpcByteBufferSourceWrapper : public Source {
public:
GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) : source_(source) {}
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() override {
return source_;
}
private:
GrpcByteBufferSource* source_;
};
class GrpcByteSource : public Source {
public:
explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {}
~GrpcByteSource() override { DeleteStream(); }
typedef ::grpc::GrpcBufferReader Reader;
::google::protobuf::io::ZeroCopyInputStream* contents() override {
DeleteStream();
stream_ = new (&space_) Reader(buffer_);
return stream_;
}
private:
void DeleteStream() {
if (stream_) {
stream_->~Reader();
}
}
grpc_byte_buffer* buffer_; // Not owned
Reader* stream_ = nullptr; // Points into space_ if non-nullptr
char space_[sizeof(Reader)];
};
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
#include <sys/time.h>
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
......@@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage req;
SerializeToMessage(var_name_val, var, *p_ctx, &req);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
// varhandle
VarHandle var_h;
......@@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
auto call = std::move(s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
&cq_));
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
});
req_count_++;
......@@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
}
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name);
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
// const sendrecv::VariableMessage& ret_msg) {
const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = NULL;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, outvar);
}
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray(
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp);
}
bool RPCClient::AsyncGetVariable(const std::string& ep,
......@@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
auto call = std::move(s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_));
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
});
req_count_++;
......
......@@ -25,6 +25,11 @@ limitations under the License. */
#include <string>
#include <vector>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/grpc++.h>
#include <grpc++/support/byte_buffer.h>
#include <grpc++/support/slice.h>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -49,15 +54,11 @@ struct VarHandle {
}
};
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& msg);
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL;
}
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; }
virtual ~BaseProcessor() {}
......@@ -82,19 +83,18 @@ class BaseProcessor {
virtual void Process() = 0;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_;
VarHandle var_h_;
};
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
RequestSendCallBack;
class SendProcessor : public BaseProcessor {
public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
: BaseProcessor(ch), stub_g_(ch) {}
virtual ~SendProcessor() {}
......@@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor {
}
}
sendrecv::VoidMessage reply_;
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = NULL;
};
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
RequestGetCallBack;
class GetProcessor : public BaseProcessor {
public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
: BaseProcessor(ch), stub_g_(ch) {}
virtual ~GetProcessor() {}
......@@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor {
}
}
sendrecv::VariableMessage reply_;
::grpc::ByteBuffer reply_;
::grpc::GenericStub stub_g_;
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
class BatchBarrierProcessor : public BaseProcessor {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~BatchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~FetchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class RPCClient {
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
using grpc::ServerAsyncResponseWriter;
using ::grpc::ServerAsyncResponseWriter;
namespace paddle {
namespace operators {
......@@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH };
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {
explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
......@@ -42,55 +43,58 @@ class RequestBase {
}
protected:
grpc::ServerContext ctx_;
sendrecv::SendRecvService::AsyncService* service_;
grpc::ServerCompletionQueue* cq_;
::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_;
CallStatus status_;
const platform::DeviceContext* dev_ctx_;
};
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestSend final : public RequestBase {
public:
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq,
SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) {
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_,
this);
explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}
virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name));
responder_.Finish(reply_, grpc::Status::OK, this);
queue_->Push(std::make_pair(request_->Varname(), request_));
sendrecv::VoidMessage reply;
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
protected:
sendrecv::VariableMessage request_;
sendrecv::VoidMessage reply_;
SimpleBlockQueue<MessageWithName>* queue_;
std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq),
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
}
virtual ~RequestGet() {}
......@@ -101,24 +105,26 @@ class RequestGet final : public RequestBase {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
::grpc::ByteBuffer reply;
if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
}
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
MessageWithName msg_with_name =
// request name reply
std::make_pair(var_name, std::move(reply_));
if (var_name == FETCH_BARRIER_MESSAGE) {
sendrecv::VariableMessage msg;
MessageWithName msg_with_name = std::make_pair(var_name, msg);
queue_->Push(msg_with_name);
}
}
protected:
sendrecv::VariableMessage request_;
sendrecv::VariableMessage reply_;
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<MessageWithName>* queue_;
};
......@@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
}
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_);
......@@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
if (is_shut_down_) {
return;
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
&var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status();
}
......@@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
......
......@@ -14,28 +14,35 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <thread>
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
//#include <grpc/support/log.h>
namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef SimpleBlockQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase;
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
class AsyncGRPCServer final {
public:
explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
......@@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
void Push(const std::string &msg_name) {
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr));
}
void ShutDown();
protected:
void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
void HandleRequest(::grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
......@@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
private:
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
sendrecv::SendRecvService::AsyncService service_;
std::unique_ptr<grpc::Server> server_;
GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_;
std::string address_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<MessageWithName> var_get_queue_;
ReceivedQueue var_recv_queue_;
// condition of the sub program
std::mutex barrier_mutex_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h"
// NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC
// requests without too much copying of the tensor data.
namespace grpc {
class CompletionQueue;
class Channel;
class RpcService;
class ServerCompletionQueue;
class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> {
public:
static Status Serialize(
const paddle::operators::detail::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status();
}
static Status Deserialize(grpc_byte_buffer* buffer,
paddle::operators::detail::VariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload");
}
Status result = g_core_codegen_interface->ok();
if (result.ok()) {
paddle::operators::detail::GrpcByteSource source(buffer);
int ret = msg->Parse(&source);
if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
}
}
g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
return result;
}
};
} // namespace grpc
namespace paddle {
namespace operators {
namespace detail {
enum class GrpcMethod {
kSendVariable,
kGetVariable,
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
case GrpcMethod::kSendVariable:
return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
}
// Shouldn't be reached.
PADDLE_ENFORCE(false, "Invalid id: not found valid method name");
return nullptr;
}
class GrpcService final {
public:
class AsyncService : public ::grpc::Service {
public:
AsyncService() {
for (int i = 0; i < kGrpcNumMethods; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
GrpcMethodName(static_cast<GrpcMethod>(i)),
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
::grpc::Service::MarkMethodAsync(i);
}
}
virtual ~AsyncService() {}
// Make RequestAsyncUnary public for grpc_call.h
using ::grpc::Service::RequestAsyncUnary;
};
};
} // namespace detail
} // namespace operator
} // namespace paddle
......@@ -32,6 +32,9 @@ enum VarType {
SELECTED_ROWS = 1;
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
......@@ -45,7 +48,6 @@ message VariableMessage {
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
......@@ -64,3 +66,5 @@ message VariableMessage {
}
message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
......@@ -13,61 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include <sys/time.h>
#include <thread>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle {
namespace operators {
namespace detail {
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg) {
msg->set_varname(name);
std::ostringstream oss;
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break;
case framework::proto::VarType_Type_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx);
break;
default: {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
msg->set_serialized(oss.str());
}
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var) {
std::istringstream iss(msg.serialized());
switch (msg.type()) {
case sendrecv::VarType::LOD_TENSOR:
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
break;
case sendrecv::VarType::SELECTED_ROWS: {
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
ctx);
break;
}
default: {
PADDLE_THROW("Deserialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) {
......@@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.memory_size();
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
......@@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = tensor.data<void>();
......@@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
framework::Variable* var) {
sendrecv::VariableMessage meta;
GrpcByteBufferSource source;
source.Init(msg);
::google::protobuf::io::CodedInputStream input(&source);
// do zerocopy parsing
PADDLE_ENFORCE(meta.ParseFromCodedStream(&input));
PADDLE_ENFORCE(input.ConsumedEntireMessage());
// dims is needed by both tensor and selectedrows
std::vector<int> vecdims;
for (auto& d : meta.dims()) {
vecdims.push_back(d);
}
framework::DDim dims = framework::make_ddim(vecdims);
if (meta.type() == sendrecv::LOD_TENSOR) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta.data_type()));
framework::LoD lod;
for (int i = 0; i < meta.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) {
v.push_back(meta.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
// How to avoid copying and use the message buffer directly?
// Maybe need to find a way to release all memory except tensor content.
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()),
tensor_data, cpu,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size(), gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
memcpy(tensor_data,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size());
}
} else if (meta.type() == sendrecv::SELECTED_ROWS) {
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
int64_t* rows_data = slr->mutable_rows()->data();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta.data_type()));
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()),
tensor_data, cpu,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size(), gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
memcpy(tensor_data,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size());
}
// copy rows CPU data, GPU data will be copied lazly
memcpy(rows_data, reinterpret_cast<const void*>(meta.rows().data()),
meta.rows().size());
}
const framework::Scope* scope,
framework::Variable*& var) {
operators::detail::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
var = resp.GetVar();
}
} // namespace detail
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
......@@ -36,21 +37,14 @@ namespace detail {
typedef void (*DestroyCallback)(void*);
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
const framework::Scope* scope,
framework::Variable*& var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) {
......
......@@ -16,11 +16,13 @@ limitations under the License. */
#include <string>
#include <thread>
#include <google/protobuf/text_format.h>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
......@@ -31,19 +33,21 @@ namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestTensor(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
int tensor_numel = 2 * 10;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
......@@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) {
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
EXPECT_EQ(varmsg.type(), 1);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 31.9);
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
// deserialize zero-copy
framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto tensor2 = var2.Get<framework::LoDTensor>();
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope;
scope.Var("myvar");
operators::detail::TensorResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
}
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
......@@ -126,61 +135,83 @@ void RunSerdeTestSelectedRows(platform::Place place) {
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
EXPECT_FLOAT_EQ(tensor_data[i], 31.9);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
// message binary
std::string str;
varmsg.SerializeToString(&str);
// message bytebuffer
::grpc::Slice slices_2[1];
int num_slices = 1;
slices_2[0] = ::grpc::Slice(str.length());
memcpy(const_cast<uint8_t*>(slices_2[0].begin()), str.c_str(), str.length());
::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices);
// deserialize zero-copy
framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope;
scope.Var("myvar");
operators::detail::TensorResponse resp(&scope, &ctx);
if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0);
} else {
EXPECT_EQ(resp.Parse(bytebuffer2), 0);
}
auto* slr2 = var2.GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
TEST(SelectedRows, CPU) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
TEST(LodTensor, GPU) {
platform::CUDAPlace place;
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
}
TEST(SelectedRows, GPU) {
platform::CUDAPlace place;
RunSerdeTestSelectedRows(place);
TEST(LodTensor, CPU) {
platform::CPUPlace place;
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
}
TEST(Tensor, CPU) {
TEST(SelectedRows, CPU) {
platform::CPUPlace place;
RunSerdeTestTensor(place);
RunSerdeTestSelectedRows(place);
}
TEST(Tensor, GPU) {
TEST(SelectedRows, GPU) {
platform::CUDAPlace place;
RunSerdeTestTensor(place);
RunSerdeTestSelectedRows(place);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h"
#include <string.h>
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
enum WireType {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int size) {
const void* data = NULL;
int size_to_write = 0;
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (size > 0) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
p += size_to_write;
size -= size_to_write;
input->Skip(size_to_write);
}
gpu_dev_ctx.Wait();
#else
PADDLE_THROW("Unexpected branch");
#endif
return true;
}
char* p = reinterpret_cast<char*>(dest);
while (size > 0) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
// TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu;
memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
size -= size_to_write;
input->Skip(size_to_write);
}
return true;
}
bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
framework::LoD lod;
for (int i = 0; i < meta_.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
v.push_back(meta_.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
}
return true;
}
inline framework::DDim GetDims(
const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
std::vector<int> vecdims;
for (auto& d : dims) {
vecdims.push_back(d);
}
return framework::make_ddim(vecdims);
}
bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
}
return true;
}
bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
platform::CPUPlace cpu;
if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
return false;
}
return true;
}
bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
std::vector<int64_t>* lod) {
while (true) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input->ReadVarintSizeAsInt(&length)) {
return tag;
}
for (int i = 0; i < length; i++) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
int VariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint64_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input.ReadVarintSizeAsInt(&length)) {
return tag;
}
for (int i = 0; i < length; i++) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return false;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
"lod info should be got first!");
if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) {
return tag;
}
break;
}
if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) {
return tag;
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, length)) {
return tag;
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
}
}
return 0;
}
}; // namespace detail
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
namespace paddle {
namespace operators {
namespace detail {
class VariableResponse {
public:
VariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: scope_(scope), dev_ctx_(dev_ctx){};
virtual ~VariableResponse(){};
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(Source* source);
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
inline std::string Varname() { return meta_.varname(); }
// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
framework::DDim& dims, int length);
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length);
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
framework::DDim& dims, int length);
private:
const framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
// only Skeleton
sendrecv::VariableMessage meta_;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
......@@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase {
}
void Stop() override {
detail::MessageWithName term_msg;
term_msg.first = LISTEN_TERMINATE_MESSAGE;
rpc_service_->Push(term_msg);
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown();
server_thread_->join();
}
......@@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
size_t recv_var_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
......@@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase {
} else {
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto *var = recv_scope.FindVar(recv_var_name);
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
......
......@@ -16,7 +16,6 @@ import sys
import re
from graphviz import GraphPreviewGenerator
import proto.framework_pb2 as framework_pb2
import paddle.fluid.core as core
_vartype2str_ = [
"UNK",
......@@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False):
def is_var_backward(var_desc):
return "@GRAD" in var_desc.name
#print(type(block_desc))
if type(block_desc) is not framework_pb2.BlockDesc:
block_desc = framework_pb2.BlockDesc.FromString(
block_desc.serialize_to_string())
......
......@@ -20,6 +20,7 @@ from layer_helper import LayerHelper
from distributed_spliter import *
import math
from . import core
import debuger
class VarBlock:
......@@ -289,6 +290,7 @@ class DistributeTranspiler:
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
# step3
optimize_block = pserver_program.create_block(0)
# step 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册