提交 c3580eae 编写于 作者: Y Yancey1989

Add prefetch interface on server side

上级 95658767
...@@ -2,7 +2,8 @@ if(WITH_DISTRIBUTE) ...@@ -2,7 +2,8 @@ if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(test_serde.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc) cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
endif() endif()
...@@ -150,7 +150,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -150,7 +150,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, (void*)s);
}); });
......
...@@ -128,6 +128,47 @@ class RequestGet final : public RequestBase { ...@@ -128,6 +128,47 @@ class RequestGet final : public RequestBase {
SimpleBlockQueue<MessageWithName>* queue_; SimpleBlockQueue<MessageWithName>* queue_;
}; };
class RequestPrefetch final : public RequestBase {
public:
explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program, int blkid)
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
blkid_(blkid) {
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
}
virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual void Process() {
// prefetch process...
::grpc::ByteBuffer relay;
// TODO(Yancey1989): execute the Block which containers prefetch ops
responder_.Finish(relay, ::grpc::Status::OK, this);
status_ = FINISH;
}
protected:
sendrecv::VariableMessage request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
int blkid_;
};
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
int fetch_barriers = 0; int fetch_barriers = 0;
while (fetch_barriers < count) { while (fetch_barriers < count) {
...@@ -147,6 +188,7 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -147,6 +188,7 @@ void AsyncGRPCServer::RunSyncUpdate() {
cq_send_ = builder.AddCompletionQueue(); cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue(); cq_get_ = builder.AddCompletionQueue();
cq_prefetch_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl; LOG(INFO) << "Server listening on " << address_ << std::endl;
...@@ -155,6 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -155,6 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
std::function<void()> get_register = std::function<void()> get_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
std::function<void()> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
t_send_.reset( t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
...@@ -163,11 +207,14 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -163,11 +207,14 @@ void AsyncGRPCServer::RunSyncUpdate() {
t_get_.reset( t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register))); cq_get_.get(), "cq_get", get_register)));
t_prefetch_.reset(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
// wait server // wait server
server_->Wait(); server_->Wait();
t_send_->join(); t_send_->join();
t_get_->join(); t_get_->join();
t_prefetch_->join();
} }
void AsyncGRPCServer::ShutdownQueue() { void AsyncGRPCServer::ShutdownQueue() {
...@@ -203,6 +250,18 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -203,6 +250,18 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "Create RequestGet status:" << get->Status(); VLOG(4) << "Create RequestGet status:" << get->Status();
} }
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
return;
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_blk_id_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
// FIXME(typhoonzero): change cq_name to enum. // FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
std::string cq_name, std::string cq_name,
......
...@@ -17,7 +17,9 @@ limitations under the License. */ ...@@ -17,7 +17,9 @@ limitations under the License. */
#include <grpc++/grpc++.h> #include <grpc++/grpc++.h>
#include <thread> #include <thread>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
...@@ -53,6 +55,12 @@ class AsyncGRPCServer final { ...@@ -53,6 +55,12 @@ class AsyncGRPCServer final {
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetPrefetchBlkdId(int blkid) { prefetch_blk_id_ = blkid; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const std::string &msg_name) { void Push(const std::string &msg_name) {
...@@ -66,6 +74,7 @@ class AsyncGRPCServer final { ...@@ -66,6 +74,7 @@ class AsyncGRPCServer final {
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne();
void TryToRegisterNewPrefetchOne();
void ShutdownQueue(); void ShutdownQueue();
private: private:
...@@ -73,6 +82,7 @@ class AsyncGRPCServer final { ...@@ -73,6 +82,7 @@ class AsyncGRPCServer final {
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
GrpcService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
...@@ -92,6 +102,11 @@ class AsyncGRPCServer final { ...@@ -92,6 +102,11 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_send_; std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_; std::unique_ptr<std::thread> t_get_;
std::unique_ptr<std::thread> t_prefetch_;
int prefetch_blk_id_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
}; };
}; // namespace detail }; // namespace detail
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
}
TEST(PREFETCH, CPU) {
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889");
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string var_name("tmp_0");
auto var = scope.Var(var_name);
auto tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize({10, 10});
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, "");
server_thread.join();
rpc_service_.reset(nullptr);
}
...@@ -76,6 +76,7 @@ namespace detail { ...@@ -76,6 +76,7 @@ namespace detail {
enum class GrpcMethod { enum class GrpcMethod {
kSendVariable, kSendVariable,
kGetVariable, kGetVariable,
kPrefetchVariable,
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
...@@ -87,6 +88,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -87,6 +88,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/SendVariable"; return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable: case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendREcvService/PrefetchVariable";
} }
// Shouldn't be reached. // Shouldn't be reached.
......
...@@ -21,6 +21,8 @@ service SendRecvService { ...@@ -21,6 +21,8 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {} rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname. // Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// Prefetch variable by Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
} }
// VariableMessage is serialized paddle variable message. // VariableMessage is serialized paddle variable message.
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "cuda_runtime.h" #include "cuda_runtime.h"
#endif
#include "gtest/gtest.h" #include "gtest/gtest.h"
TEST(Event, CpuElapsedTime) { TEST(Event, CpuElapsedTime) {
...@@ -159,6 +161,7 @@ TEST(RecordEvent, RecordEvent) { ...@@ -159,6 +161,7 @@ TEST(RecordEvent, RecordEvent) {
DisableProfiler(EventSortingKey::kTotal, "/tmp/profiler"); DisableProfiler(EventSortingKey::kTotal, "/tmp/profiler");
} }
#ifdef PADDLE_WITH_CUDA
TEST(TMP, stream_wait) { TEST(TMP, stream_wait) {
cudaStream_t stream; cudaStream_t stream;
cudaStreamCreate(&stream); cudaStreamCreate(&stream);
...@@ -166,3 +169,4 @@ TEST(TMP, stream_wait) { ...@@ -166,3 +169,4 @@ TEST(TMP, stream_wait) {
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
} }
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册