未验证 提交 09fcf5f2 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9555 from jacquesqiao/improve-prefetch-on-server

Improve prefetch on server
......@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
......
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
#include <sys/time.h>
#include "paddle/fluid/operators/detail/grpc_client.h"
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
......@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
});
req_count_++;
......@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h,
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong());
proto.SerializeWithCachedSizesToArray(
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
::grpc::ByteBuffer tmp(&slice, 1);
result->Swap(&tmp);
}
......@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
});
req_count_++;
......@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
});
req_count_++;
......@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++;
}
......@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++;
}
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter;
namespace paddle {
......@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in";
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
......@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown();
cq_get_->Shutdown();
cq_prefetch_->Shutdown();
}
// This URL explains why shutdown is complicate:
......@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
void AsyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
......@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
void AsyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
......@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return;
}
RequestPrefetch* prefetch =
......@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
void* tag = NULL;
bool ok = false;
while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break;
}
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag;
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]";
TryToRegisterNewOne();
delete base;
continue;
......
......@@ -15,7 +15,8 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include <string>
#include <utility>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -93,6 +94,7 @@ class AsyncGRPCServer final {
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
// condition of the sub program
......
......@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_->RunSyncUpdate();
}
TEST(PREFETCH, CPU) {
......@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string var_name("tmp_0");
auto var = scope.Var(var_name);
auto tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize({10, 10});
std::string in_var_name("in");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, "");
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);
}
......@@ -80,7 +80,7 @@ enum class GrpcMethod {
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1;
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
......@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendREcvService/PrefetchVariable";
return "/sendrecv.SendRecvService/PrefetchVariable";
}
// Shouldn't be reached.
......@@ -117,5 +117,5 @@ class GrpcService final {
};
} // namespace detail
} // namespace operator
} // namespace operators
} // namespace paddle
......@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
......@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Executor executor(dev_place);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
rpc_service_->SetProgram(program);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
// Record received sparse variables, so that
......@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
<< "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册