未验证 提交 990d6396 编写于 作者: G gongweibao 提交者: GitHub

Reuduce memory copy when communication between trainer and pserver. (#9271)

上级 b594251f
...@@ -18,12 +18,13 @@ import sys ...@@ -18,12 +18,13 @@ import sys
import time import time
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.fluid as fluid
import paddle.v2.fluid.core as core import paddle.fluid.core as core
import paddle.v2.fluid.profiler as profiler import paddle.fluid.profiler as profiler
import argparse import argparse
import functools import functools
import os import os
from paddle.fluid import debuger
def str2bool(v): def str2bool(v):
...@@ -182,28 +183,27 @@ def main(): ...@@ -182,28 +183,27 @@ def main():
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
train_pass_acc.reset() train_pass_acc.reset()
with profiler.profiler("CPU", 'total') as prof: for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(train_reader()): ts = time.time()
ts = time.time() img_data = np.array(
img_data = np.array( map(lambda x: x[0].reshape(data_shape), data)).astype(
map(lambda x: x[0].reshape(data_shape), data)).astype( "float32")
"float32") y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1])
y_data = y_data.reshape([-1, 1])
loss, acc, b_size = exe.run(
loss, acc, b_size = exe.run( trainer_prog,
trainer_prog, feed={"pixel": img_data,
feed={"pixel": img_data, "label": y_data},
"label": y_data}, fetch_list=[avg_cost, batch_acc, batch_size])
fetch_list=[avg_cost, batch_acc, batch_size]) iters += 1
iters += 1 num_samples += len(data)
num_samples += len(data) train_pass_acc.add(value=acc, weight=b_size)
train_pass_acc.add(value=acc, weight=b_size) print(
print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" % (pass_id, iters, loss, acc,
% (pass_id, iters, loss, acc, len(data) / (time.time() - ts))
len(data) / (time.time() - ts)) ) # The accuracy is the accumulation of batches, but not the current batch.
) # The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed = time.time() - start_time pass_elapsed = time.time() - start_time
pass_train_acc = train_pass_acc.eval() pass_train_acc = train_pass_acc.eval()
...@@ -254,9 +254,7 @@ def main(): ...@@ -254,9 +254,7 @@ def main():
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog) pserver_prog)
print("starting server side startup")
exe.run(pserver_startup) exe.run(pserver_startup)
print("starting parameter server...")
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
# Parameter initialization # Parameter initialization
......
...@@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server): ...@@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server):
return np.mean(test_accs) return np.mean(test_accs)
config = tf.ConfigProto( config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1,
log_device_placement=True)
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
hooks = [tf.train.StopAtStepHook(last_step=1000000)] hooks = [tf.train.StopAtStepHook(last_step=1000000)]
with tf.train.MonitoredTrainingSession( with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(args.task_index == 0), master=server.target,
hooks=hooks) as sess: is_chief=(args.task_index == 0),
hooks=hooks,
config=config) as sess:
iters, num_samples, start_time = 0, 0, 0.0 iters, num_samples, start_time = 0, 0, 0.0
for pass_id in range(args.num_passes): for pass_id in range(args.num_passes):
# train # train
......
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
endif() endif()
...@@ -23,9 +23,107 @@ limitations under the License. */ ...@@ -23,9 +23,107 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
namespace grpc {
// A ZeroCopyInputStream that reads from grpc_byte_buffer
class GrpcBufferReader final
: public ::google::protobuf::io::ZeroCopyInputStream {
typedef void (CoreCodegenInterface::*OldReaderInitAPI)(
grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
typedef int (CoreCodegenInterface::*NewReaderInitAPI)(
grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
grpc_byte_buffer* buffer) {
(g_core_codegen_interface->*ptr)(reader, buffer);
}
void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
grpc_byte_buffer* buffer) {
int result = (g_core_codegen_interface->*ptr)(reader, buffer);
(void)result;
}
public:
explicit GrpcBufferReader(grpc_byte_buffer* buffer)
: byte_count_(0), backup_count_(0) {
ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_,
buffer);
}
~GrpcBufferReader() override {
g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_);
}
bool Next(const void** data, int* size) override {
if (backup_count_ > 0) {
*data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) -
backup_count_;
GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX);
*size = (int)backup_count_;
backup_count_ = 0;
return true;
}
if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_,
&slice_)) {
return false;
}
g_core_codegen_interface->grpc_slice_unref(slice_);
*data = GRPC_SLICE_START_PTR(slice_);
// On win x64, int is only 32bit
GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX);
byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_);
return true;
}
void BackUp(int count) override { backup_count_ = count; }
bool Skip(int count) override {
const void* data;
int size;
while (Next(&data, &size)) {
if (size >= count) {
BackUp(size - count);
return true;
}
// size < count;
count -= size;
}
// error or we have too large count;
return false;
}
::google::protobuf::int64 ByteCount() const override {
return byte_count_ - backup_count_;
}
private:
int64_t byte_count_;
int64_t backup_count_;
grpc_byte_buffer_reader reader_;
grpc_slice slice_;
};
}; // namespace grpc
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer. // A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource class GrpcByteBufferSource
...@@ -46,6 +144,42 @@ class GrpcByteBufferSource ...@@ -46,6 +144,42 @@ class GrpcByteBufferSource
::google::protobuf::int64 byte_count_; ::google::protobuf::int64 byte_count_;
}; };
class GrpcByteBufferSourceWrapper : public Source {
public:
GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) : source_(source) {}
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() override {
return source_;
}
private:
GrpcByteBufferSource* source_;
};
class GrpcByteSource : public Source {
public:
explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {}
~GrpcByteSource() override { DeleteStream(); }
typedef ::grpc::GrpcBufferReader Reader;
::google::protobuf::io::ZeroCopyInputStream* contents() override {
DeleteStream();
stream_ = new (&space_) Reader(buffer_);
return stream_;
}
private:
void DeleteStream() {
if (stream_) {
stream_->~Reader();
}
}
grpc_byte_buffer* buffer_; // Not owned
Reader* stream_ = nullptr; // Points into space_ if non-nullptr
char space_[sizeof(Reader)];
};
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "grpc_client.h" #include "grpc_client.h"
#include <sys/time.h>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
...@@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage req;
SerializeToMessage(var_name_val, var, *p_ctx, &req); ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
// varhandle // varhandle
VarHandle var_h; VarHandle var_h;
...@@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
s->Prepare(var_h, time_out); s->Prepare(var_h, time_out);
s->response_call_back_ = NULL; s->response_call_back_ = NULL;
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto call = std::move(s->stub_g_.PrepareUnaryCall(
rpc->Finish(&s->reply_, &s->status_, (void*)s); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
&cq_));
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
}); });
req_count_++; req_count_++;
...@@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
} }
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) { // const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name); const ::grpc::ByteBuffer& ret_msg) {
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar); framework::Variable* outvar = NULL;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, outvar);
}
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray(
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp);
} }
bool RPCClient::AsyncGetVariable(const std::string& ep, bool RPCClient::AsyncGetVariable(const std::string& ep,
...@@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
s->Prepare(var_h, time_out); s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); ::grpc::ByteBuffer buf;
rpc->Finish(&s->reply_, &s->status_, (void*)s); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
auto call = std::move(s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_));
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
}); });
req_count_++; req_count_++;
......
...@@ -25,6 +25,11 @@ limitations under the License. */ ...@@ -25,6 +25,11 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include <grpc++/generic/generic_stub.h>
#include <grpc++/grpc++.h>
#include <grpc++/support/byte_buffer.h>
#include <grpc++/support/slice.h>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -49,15 +54,11 @@ struct VarHandle { ...@@ -49,15 +54,11 @@ struct VarHandle {
} }
}; };
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
const sendrecv::VariableMessage& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { context_ = NULL; }
stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
...@@ -82,19 +83,18 @@ class BaseProcessor { ...@@ -82,19 +83,18 @@ class BaseProcessor {
virtual void Process() = 0; virtual void Process() = 0;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
std::unique_ptr<grpc::ClientContext> context_; std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_; grpc::Status status_;
VarHandle var_h_; VarHandle var_h_;
}; };
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
RequestSendCallBack; RequestSendCallBack;
class SendProcessor : public BaseProcessor { class SendProcessor : public BaseProcessor {
public: public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {} : BaseProcessor(ch), stub_g_(ch) {}
virtual ~SendProcessor() {} virtual ~SendProcessor() {}
...@@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor { ...@@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor {
} }
} }
sendrecv::VoidMessage reply_; ::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
RequestSendCallBack response_call_back_ = NULL; RequestSendCallBack response_call_back_ = NULL;
}; };
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
RequestGetCallBack; RequestGetCallBack;
class GetProcessor : public BaseProcessor { class GetProcessor : public BaseProcessor {
public: public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {} : BaseProcessor(ch), stub_g_(ch) {}
virtual ~GetProcessor() {} virtual ~GetProcessor() {}
...@@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor { ...@@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor {
} }
} }
sendrecv::VariableMessage reply_; ::grpc::ByteBuffer reply_;
::grpc::GenericStub stub_g_;
RequestGetCallBack response_call_back_ = ProcGetResponse; RequestGetCallBack response_call_back_ = ProcGetResponse;
}; };
class BatchBarrierProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {} : BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~BatchBarrierProcessor() {} virtual ~BatchBarrierProcessor() {}
virtual void Process() {} virtual void Process() {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class FetchBarrierProcessor : public BaseProcessor { class FetchBarrierProcessor : public BaseProcessor {
public: public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {} : BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~FetchBarrierProcessor() {} virtual ~FetchBarrierProcessor() {}
virtual void Process() {} virtual void Process() {}
sendrecv::VariableMessage reply_; sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class RPCClient { class RPCClient {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
using grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH }; ...@@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH };
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server // https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase { class RequestBase {
public: public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, explicit RequestBase(GrpcService::AsyncService* service,
grpc::ServerCompletionQueue* cq) ::grpc::ServerCompletionQueue* cq,
: service_(service), cq_(cq), status_(PROCESS) { const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_); PADDLE_ENFORCE(cq_);
} }
virtual ~RequestBase() {} virtual ~RequestBase() {}
...@@ -42,55 +43,58 @@ class RequestBase { ...@@ -42,55 +43,58 @@ class RequestBase {
} }
protected: protected:
grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
sendrecv::SendRecvService::AsyncService* service_; GrpcService::AsyncService* service_;
grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
CallStatus status_; CallStatus status_;
const platform::DeviceContext* dev_ctx_;
}; };
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestSend final : public RequestBase { class RequestSend final : public RequestBase {
public: public:
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
SimpleBlockQueue<MessageWithName>* queue) framework::Scope* scope, ReceivedQueue* queue,
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) { const platform::DeviceContext* dev_ctx)
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_, : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
this); request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
} }
virtual ~RequestSend() {} virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_.varname(); } virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() { virtual void Process() {
MessageWithName msg_with_name = queue_->Push(std::make_pair(request_->Varname(), request_));
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name)); sendrecv::VoidMessage reply;
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
} }
protected: protected:
sendrecv::VariableMessage request_; std::shared_ptr<VariableResponse> request_;
sendrecv::VoidMessage reply_; ReceivedQueue* queue_;
SimpleBlockQueue<MessageWithName>* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, explicit RequestGet(GrpcService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<MessageWithName>* queue) SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq), : RequestBase(service, cq, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
dev_ctx_(dev_ctx),
queue_(queue) { queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
} }
virtual ~RequestGet() {} virtual ~RequestGet() {}
...@@ -101,24 +105,26 @@ class RequestGet final : public RequestBase { ...@@ -101,24 +105,26 @@ class RequestGet final : public RequestBase {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name); auto* var = scope_->FindVar(var_name);
::grpc::ByteBuffer reply;
if (var_name != FETCH_BARRIER_MESSAGE) { if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToMessage(var_name, var, *dev_ctx_, &reply_); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
} }
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
MessageWithName msg_with_name =
// request name reply if (var_name == FETCH_BARRIER_MESSAGE) {
std::make_pair(var_name, std::move(reply_)); sendrecv::VariableMessage msg;
queue_->Push(msg_with_name); MessageWithName msg_with_name = std::make_pair(var_name, msg);
queue_->Push(msg_with_name);
}
} }
protected: protected:
sendrecv::VariableMessage request_; sendrecv::VariableMessage request_;
sendrecv::VariableMessage reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_; framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<MessageWithName>* queue_; SimpleBlockQueue<MessageWithName>* queue_;
}; };
...@@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) { ...@@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
} }
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder; ::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_); builder.RegisterService(&service_);
...@@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
if (is_shut_down_) { if (is_shut_down_) {
return; return;
} }
RequestSend* send = RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); &var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
} }
...@@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
} }
// FIXME(typhoonzero): change cq_name to enum. // FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne(); TryToRegisterNewOne();
......
...@@ -14,28 +14,35 @@ limitations under the License. */ ...@@ -14,28 +14,35 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h" #include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include <grpc++/grpc++.h> #include "paddle/fluid/operators/detail/grpc_service.h"
#include <grpc/support/log.h>
#include <thread> //#include <grpc/support/log.h>
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef SimpleBlockQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName; typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase; class RequestBase;
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { class AsyncGRPCServer final {
public: public:
explicit AsyncGRPCServer(const std::string &address) : address_(address) {} explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
...@@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
const MessageWithName Get() { return this->var_recv_queue_.Pop(); } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } void Push(const std::string &msg_name) {
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr));
}
void ShutDown(); void ShutDown();
protected: protected:
void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name, void HandleRequest(::grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne();
...@@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
private: private:
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
sendrecv::SendRecvService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
std::string address_; std::string address_;
framework::Scope *scope_; framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<MessageWithName> var_get_queue_; SimpleBlockQueue<MessageWithName> var_get_queue_;
ReceivedQueue var_recv_queue_;
// condition of the sub program // condition of the sub program
std::mutex barrier_mutex_; std::mutex barrier_mutex_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h"
// NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC
// requests without too much copying of the tensor data.
namespace grpc {
class CompletionQueue;
class Channel;
class RpcService;
class ServerCompletionQueue;
class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> {
public:
static Status Serialize(
const paddle::operators::detail::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status();
}
static Status Deserialize(grpc_byte_buffer* buffer,
paddle::operators::detail::VariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload");
}
Status result = g_core_codegen_interface->ok();
if (result.ok()) {
paddle::operators::detail::GrpcByteSource source(buffer);
int ret = msg->Parse(&source);
if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
}
}
g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
return result;
}
};
} // namespace grpc
namespace paddle {
namespace operators {
namespace detail {
enum class GrpcMethod {
kSendVariable,
kGetVariable,
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
case GrpcMethod::kSendVariable:
return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
}
// Shouldn't be reached.
PADDLE_ENFORCE(false, "Invalid id: not found valid method name");
return nullptr;
}
class GrpcService final {
public:
class AsyncService : public ::grpc::Service {
public:
AsyncService() {
for (int i = 0; i < kGrpcNumMethods; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
GrpcMethodName(static_cast<GrpcMethod>(i)),
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
::grpc::Service::MarkMethodAsync(i);
}
}
virtual ~AsyncService() {}
// Make RequestAsyncUnary public for grpc_call.h
using ::grpc::Service::RequestAsyncUnary;
};
};
} // namespace detail
} // namespace operator
} // namespace paddle
...@@ -32,6 +32,9 @@ enum VarType { ...@@ -32,6 +32,9 @@ enum VarType {
SELECTED_ROWS = 1; SELECTED_ROWS = 1;
} }
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage { message VariableMessage {
enum Type { enum Type {
// Pod Types // Pod Types
...@@ -45,7 +48,6 @@ message VariableMessage { ...@@ -45,7 +48,6 @@ message VariableMessage {
} }
message LodData { repeated int64 lod_data = 1; } message LodData { repeated int64 lod_data = 1; }
string varname = 1; string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType // TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2; VarType type = 2;
...@@ -64,3 +66,5 @@ message VariableMessage { ...@@ -64,3 +66,5 @@ message VariableMessage {
} }
message VoidMessage {} message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
...@@ -13,61 +13,19 @@ See the License for the specific language governing permissions and ...@@ -13,61 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include <sys/time.h>
#include <thread>
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h" #include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg) {
msg->set_varname(name);
std::ostringstream oss;
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break;
case framework::proto::VarType_Type_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx);
break;
default: {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
msg->set_serialized(oss.str());
}
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var) {
std::istringstream iss(msg.serialized());
switch (msg.type()) {
case sendrecv::VarType::LOD_TENSOR:
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
break;
case sendrecv::VarType::SELECTED_ROWS: {
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
ctx);
break;
}
default: {
PADDLE_THROW("Deserialize does not support type: %s",
typeid(var->Type()).name());
break;
}
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) { ::grpc::ByteBuffer* msg) {
...@@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast<const platform::CUDADeviceContext&>(ctx); static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.memory_size(); auto copy_size = tensor.memory_size();
payload = memory::Alloc(cpu, copy_size); payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload, memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()), boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), reinterpret_cast<const void*>(tensor.data<void>()),
...@@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform::CPUPlace cpu; platform::CPUPlace cpu;
memory::Free(cpu, backing); memory::Free(cpu, backing);
}; };
#endif #endif
} else { } else {
payload = tensor.data<void>(); payload = tensor.data<void>();
...@@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
framework::Variable* var) { const framework::Scope* scope,
sendrecv::VariableMessage meta; framework::Variable*& var) {
GrpcByteBufferSource source; operators::detail::VariableResponse resp(scope, &ctx);
source.Init(msg); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
::google::protobuf::io::CodedInputStream input(&source); var = resp.GetVar();
// do zerocopy parsing
PADDLE_ENFORCE(meta.ParseFromCodedStream(&input));
PADDLE_ENFORCE(input.ConsumedEntireMessage());
// dims is needed by both tensor and selectedrows
std::vector<int> vecdims;
for (auto& d : meta.dims()) {
vecdims.push_back(d);
}
framework::DDim dims = framework::make_ddim(vecdims);
if (meta.type() == sendrecv::LOD_TENSOR) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta.data_type()));
framework::LoD lod;
for (int i = 0; i < meta.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) {
v.push_back(meta.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
// How to avoid copying and use the message buffer directly?
// Maybe need to find a way to release all memory except tensor content.
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()),
tensor_data, cpu,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size(), gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
memcpy(tensor_data,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size());
}
} else if (meta.type() == sendrecv::SELECTED_ROWS) {
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
int64_t* rows_data = slr->mutable_rows()->data();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta.data_type()));
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()),
tensor_data, cpu,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size(), gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
memcpy(tensor_data,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size());
}
// copy rows CPU data, GPU data will be copied lazly
memcpy(rows_data, reinterpret_cast<const void*>(meta.rows().data()),
meta.rows().size());
}
} }
} // namespace detail } // namespace detail
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
...@@ -36,21 +37,14 @@ namespace detail { ...@@ -36,21 +37,14 @@ namespace detail {
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg); ::grpc::ByteBuffer* msg);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
framework::Variable* var); const framework::Scope* scope,
framework::Variable*& var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
......
...@@ -16,11 +16,13 @@ limitations under the License. */ ...@@ -16,11 +16,13 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> #include <thread>
#include <google/protobuf/text_format.h>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -31,19 +33,21 @@ namespace operators = paddle::operators; ...@@ -31,19 +33,21 @@ namespace operators = paddle::operators;
namespace math = paddle::operators::math; namespace math = paddle::operators::math;
namespace memory = paddle::memory; namespace memory = paddle::memory;
void RunSerdeTestTensor(platform::Place place) { void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
tensor->mutable_data<float>(place); tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9); int tensor_numel = 2 * 10;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
...@@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) { ...@@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) {
for (const auto& s : slices) { for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size()); tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
} }
sendrecv::VariableMessage varmsg; sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp)); EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar"); EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0); EXPECT_EQ(varmsg.type(), 1);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
const float* tensor_data = const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data()); reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) { for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 31.9); EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
} }
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
// deserialize zero-copy // deserialize zero-copy
framework::Variable var2; // framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto tensor2 = var2.Get<framework::LoDTensor>(); framework::Scope scope;
scope.Var("myvar");
operators::detail::TensorResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr; float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor; framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor); framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>(); tensor_data2 = tmp_tensor.data<float>();
} else { } else {
tensor_data2 = const_cast<float*>(tensor2.data<float>()); tensor_data2 = const_cast<float*>(tensor2->data<float>());
} }
const int64_t* rows_data2 = rows2->data();
EXPECT_EQ(varmsg.lod_level(), 1); for (int i = 0; i < tensor_numel; ++i) {
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); }
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); EXPECT_EQ(rows_data2[0], 3);
for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); EXPECT_EQ(rows_data2[1], 10);
} }
void RunSerdeTestSelectedRows(platform::Place place) { void RunTestLodTensor(platform::Place place, int from_type = 0) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer // serialize var to ByteBuffer
framework::Variable var; framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>(); auto* tensor = var.GetMutable<framework::LoDTensor>();
auto* tensor = slr->mutable_value(); tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
auto* rows = slr->mutable_rows(); framework::LoD lod;
tensor->Resize(framework::make_ddim({2, 10})); lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
tensor->mutable_data<float>(place); tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10; math::set_constant(ctx, tensor, 31.9);
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
...@@ -126,43 +135,75 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -126,43 +135,75 @@ void RunSerdeTestSelectedRows(platform::Place place) {
} }
sendrecv::VariableMessage varmsg; sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp)); EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar"); EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1); EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
const float* tensor_data = const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data()); reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) { for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7); EXPECT_FLOAT_EQ(tensor_data[i], 31.9);
} }
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10); // message binary
std::string str;
varmsg.SerializeToString(&str);
// message bytebuffer
::grpc::Slice slices_2[1];
int num_slices = 1;
slices_2[0] = ::grpc::Slice(str.length());
memcpy(const_cast<uint8_t*>(slices_2[0].begin()), str.c_str(), str.length());
::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices);
// deserialize zero-copy // deserialize zero-copy
framework::Variable var2; framework::Scope scope;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); scope.Var("myvar");
operators::detail::TensorResponse resp(&scope, &ctx);
if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0);
} else {
EXPECT_EQ(resp.Parse(bytebuffer2), 0);
}
auto* slr2 = var2.GetMutable<framework::SelectedRows>(); framework::Variable* var2 = resp.GetVar();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows(); auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr; float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor; framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor); framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>(); tensor_data2 = tmp_tensor.data<float>();
} else { } else {
tensor_data2 = const_cast<float*>(tensor2->data<float>()); tensor_data2 = const_cast<float*>(tensor2.data<float>());
} }
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) { EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
} EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(rows_data2[0], 3); EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
EXPECT_EQ(rows_data2[1], 10); for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
TEST(LodTensor, GPU) {
platform::CUDAPlace place;
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
}
TEST(LodTensor, CPU) {
platform::CPUPlace place;
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
} }
TEST(SelectedRows, CPU) { TEST(SelectedRows, CPU) {
...@@ -174,13 +215,3 @@ TEST(SelectedRows, GPU) { ...@@ -174,13 +215,3 @@ TEST(SelectedRows, GPU) {
platform::CUDAPlace place; platform::CUDAPlace place;
RunSerdeTestSelectedRows(place); RunSerdeTestSelectedRows(place);
} }
TEST(Tensor, CPU) {
platform::CPUPlace place;
RunSerdeTestTensor(place);
}
TEST(Tensor, GPU) {
platform::CUDAPlace place;
RunSerdeTestTensor(place);
}
\ No newline at end of file
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h"
#include <string.h>
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
enum WireType {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int size) {
const void* data = NULL;
int size_to_write = 0;
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (size > 0) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
p += size_to_write;
size -= size_to_write;
input->Skip(size_to_write);
}
gpu_dev_ctx.Wait();
#else
PADDLE_THROW("Unexpected branch");
#endif
return true;
}
char* p = reinterpret_cast<char*>(dest);
while (size > 0) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
// TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu;
memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
size -= size_to_write;
input->Skip(size_to_write);
}
return true;
}
bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
framework::LoD lod;
for (int i = 0; i < meta_.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
v.push_back(meta_.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
}
return true;
}
inline framework::DDim GetDims(
const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
std::vector<int> vecdims;
for (auto& d : dims) {
vecdims.push_back(d);
}
return framework::make_ddim(vecdims);
}
bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
}
return true;
}
bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
platform::CPUPlace cpu;
if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
return false;
}
return true;
}
bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
std::vector<int64_t>* lod) {
while (true) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input->ReadVarintSizeAsInt(&length)) {
return tag;
}
for (int i = 0; i < length; i++) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
int VariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint64_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input.ReadVarintSizeAsInt(&length)) {
return tag;
}
for (int i = 0; i < length; i++) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return false;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
"lod info should be got first!");
if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) {
return tag;
}
break;
}
if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) {
return tag;
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, length)) {
return tag;
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
}
}
return 0;
}
}; // namespace detail
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
namespace paddle {
namespace operators {
namespace detail {
class VariableResponse {
public:
VariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: scope_(scope), dev_ctx_(dev_ctx){};
virtual ~VariableResponse(){};
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(Source* source);
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
inline std::string Varname() { return meta_.varname(); }
// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
framework::DDim& dims, int length);
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length);
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
framework::DDim& dims, int length);
private:
const framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
// only Skeleton
sendrecv::VariableMessage meta_;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
...@@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase {
} }
void Stop() override { void Stop() override {
detail::MessageWithName term_msg; rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
term_msg.first = LISTEN_TERMINATE_MESSAGE;
rpc_service_->Push(term_msg);
rpc_service_->ShutDown(); rpc_service_->ShutDown();
server_thread_->join(); server_thread_->join();
} }
...@@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
size_t recv_var_cnt = 0; size_t recv_var_cnt = 0;
int batch_barrier = 0; int batch_barrier = 0;
while (batch_barrier != fan_in) { while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first; auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit"; LOG(INFO) << "received terminate message and exit";
...@@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase {
} else { } else {
VLOG(3) << "received grad: " << recv_var_name; VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++; recv_var_cnt++;
auto *var = recv_scope.FindVar(recv_var_name); auto var = v.second->GetVar();
if (var == nullptr) { if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name; LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var"); PADDLE_THROW("Can not find server side var");
} }
detail::DeserializeFromMessage(v.second, dev_ctx, var);
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var); sparse_vars.push_back(var);
} }
......
...@@ -16,7 +16,6 @@ import sys ...@@ -16,7 +16,6 @@ import sys
import re import re
from graphviz import GraphPreviewGenerator from graphviz import GraphPreviewGenerator
import proto.framework_pb2 as framework_pb2 import proto.framework_pb2 as framework_pb2
import paddle.fluid.core as core
_vartype2str_ = [ _vartype2str_ = [
"UNK", "UNK",
...@@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False): ...@@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False):
def is_var_backward(var_desc): def is_var_backward(var_desc):
return "@GRAD" in var_desc.name return "@GRAD" in var_desc.name
#print(type(block_desc))
if type(block_desc) is not framework_pb2.BlockDesc: if type(block_desc) is not framework_pb2.BlockDesc:
block_desc = framework_pb2.BlockDesc.FromString( block_desc = framework_pb2.BlockDesc.FromString(
block_desc.serialize_to_string()) block_desc.serialize_to_string())
......
...@@ -20,6 +20,7 @@ from layer_helper import LayerHelper ...@@ -20,6 +20,7 @@ from layer_helper import LayerHelper
from distributed_spliter import * from distributed_spliter import *
import math import math
from . import core from . import core
import debuger
class VarBlock: class VarBlock:
...@@ -289,6 +290,7 @@ class DistributeTranspiler: ...@@ -289,6 +290,7 @@ class DistributeTranspiler:
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
recv_inputs.append(var) recv_inputs.append(var)
# step3 # step3
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# step 4 # step 4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册