diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 26bef375cb3493b4b6a428677986e005654b6be3..407fa5ef5aea51626b2ac372c096156be0aaa144 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_server.h" +#include using ::grpc::ServerAsyncResponseWriter; @@ -156,6 +157,8 @@ class RequestPrefetch final : public RequestBase { ::grpc::ByteBuffer relay; // TODO(Yancey1989): execute the Block which containers prefetch ops + VLOG(3) << "RequestPrefetch Process in"; + responder_.Finish(relay, ::grpc::Status::OK, this); status_ = FINISH; } @@ -251,6 +254,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { } void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { + VLOG(4) << "TryToRegisterNewPrefetchOne in"; std::unique_lock lock(cq_mutex_); if (is_shut_down_) { return; @@ -287,8 +291,8 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I if (!ok) { - LOG(WARNING) << cq_name << " recv no regular event:argument name" - << base->GetReqName(); + LOG(WARNING) << cq_name << " recv no regular event:argument name[" + << base->GetReqName() << "]"; TryToRegisterNewOne(); delete base; continue; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 577374810696c039b8794fc151083ca7ddf43a10..1ad62863a1a98c28cb08f47dfa8a5bfae463ba91 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -28,6 +28,7 @@ std::unique_ptr rpc_service_; void StartServer(const std::string& endpoint) { rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + rpc_service_->RunSyncUpdate(); } TEST(PREFETCH, CPU) { @@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) { platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope - std::string var_name("tmp_0"); - auto var = scope.Var(var_name); - auto tensor = var->GetMutable(); - tensor->Resize({10, 10}); + std::string in_var_name("in"); + std::string out_var_name("out"); + auto* in_var = scope.Var(in_var_name); + auto* in_tensor = in_var->GetMutable(); + in_tensor->Resize({10, 10}); + VLOG(3) << "before mutable_data"; + in_tensor->mutable_data(place); + scope.Var(out_var_name); + + VLOG(3) << "before fetch"; detail::RPCClient client; - client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, ""); + client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, + out_var_name); + client.Wait(); + + rpc_service_->ShutDown(); server_thread.join(); rpc_service_.reset(nullptr); } diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index 879e21933b452363c3fccacffb4d16ac1bfd6020..1ec8cf11c5167ae69edd7b30d7d5581518c0e823 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -80,7 +80,7 @@ enum class GrpcMethod { }; static const int kGrpcNumMethods = - static_cast(GrpcMethod::kGetVariable) + 1; + static_cast(GrpcMethod::kPrefetchVariable) + 1; inline const char* GrpcMethodName(GrpcMethod id) { switch (id) { diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index d5eae2be79f95c78f66ca348261a3460790dca4a..c9455fd35cf9ae396e4848ad817313b1693c09a8 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -112,6 +112,10 @@ class ListenAndServOp : public framework::OperatorBase { framework::Executor executor(dev_place); + rpc_service_->SetExecutor(&executor); + rpc_service_->SetPrefetchBlkdId(0); + rpc_service_->SetProgram(program); + // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; // Record received sparse variables, so that