未验证 提交 09fcf5f2 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9555 from jacquesqiao/improve-prefetch-on-server

Improve prefetch on server
...@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE) ...@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc) cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
......
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include <sys/time.h>
#include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
...@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h, ...@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h,
template <typename T> template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong()); ::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray( proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
::grpc::ByteBuffer tmp(&slice, 1); ::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp); result->Swap(&tmp);
} }
...@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++; req_count_++;
} }
...@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++; req_count_++;
} }
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
...@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply; ::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops // TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in";
responder_.Finish(reply, ::grpc::Status::OK, this); responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
} }
...@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() { ...@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown(); cq_send_->Shutdown();
cq_get_->Shutdown(); cq_get_->Shutdown();
cq_prefetch_->Shutdown();
} }
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
...@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() { ...@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
void AsyncGRPCServer::TryToRegisterNewSendOne() { void AsyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
...@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
void AsyncGRPCServer::TryToRegisterNewGetOne() { void AsyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
...@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return; return;
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
...@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
void* tag = NULL; void* tag = NULL;
bool ok = false; bool ok = false;
while (true) { while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!"; LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op // FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag; RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) { if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name" LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName(); << base->GetReqName() << "]";
TryToRegisterNewOne(); TryToRegisterNewOne();
delete base; delete base;
continue; continue;
......
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h> #include <grpc++/grpc++.h>
#include <thread> #include <string>
#include <utility>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -93,6 +94,7 @@ class AsyncGRPCServer final { ...@@ -93,6 +94,7 @@ class AsyncGRPCServer final {
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_get_queue_; SimpleBlockQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_; ReceivedQueue var_recv_queue_;
// condition of the sub program // condition of the sub program
......
...@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_; ...@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void StartServer(const std::string& endpoint) { void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_->RunSyncUpdate();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
...@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) { ...@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope // create var on local scope
std::string var_name("tmp_0"); std::string in_var_name("in");
auto var = scope.Var(var_name); std::string out_var_name("out");
auto tensor = var->GetMutable<framework::LoDTensor>(); auto* in_var = scope.Var(in_var_name);
tensor->Resize({10, 10}); auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client; detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, ""); client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
rpc_service_->ShutDown();
server_thread.join(); server_thread.join();
rpc_service_.reset(nullptr); rpc_service_.reset(nullptr);
} }
...@@ -80,7 +80,7 @@ enum class GrpcMethod { ...@@ -80,7 +80,7 @@ enum class GrpcMethod {
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1; static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) { inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) { switch (id) {
...@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
case GrpcMethod::kGetVariable: case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable: case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendREcvService/PrefetchVariable"; return "/sendrecv.SendRecvService/PrefetchVariable";
} }
// Shouldn't be reached. // Shouldn't be reached.
...@@ -117,5 +117,5 @@ class GrpcService final { ...@@ -117,5 +117,5 @@ class GrpcService final {
}; };
} // namespace detail } // namespace detail
} // namespace operator } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <sys/stat.h>
#include <ostream> #include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
rpc_service_->SetProgram(program);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
// Record received sparse variables, so that // Record received sparse variables, so that
...@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts; VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
<< "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册