From e84f353e1a4b7f1d00eaeb3f0c9814f1f7a433d3 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 19 Apr 2018 13:18:48 +0800 Subject: [PATCH] optimize --- paddle/fluid/operators/async_listen_and_serv_op.cc | 6 +++++- paddle/fluid/operators/detail/async_grpc_server.cc | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/async_listen_and_serv_op.cc b/paddle/fluid/operators/async_listen_and_serv_op.cc index ec0ddedf3d1..bdb20240f69 100644 --- a/paddle/fluid/operators/async_listen_and_serv_op.cc +++ b/paddle/fluid/operators/async_listen_and_serv_op.cc @@ -96,6 +96,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, block_list.push_back(blkid); } } + PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(), + "grad num should be equal to optimize block num"); auto optimize_prepared = executor.Prepare(*program, block_list); std::unordered_mapSetScope(&recv_scope); rpc_service_->SetDevCtx(&dev_ctx); - // TODO(qiao) set proper fields for table lookup and update + + // set proper fields for table lookup and update rpc_service_->SetExecutor(&executor); VLOG(3) << "prefetch block id is " << prefetch_block->ID(); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); @@ -142,6 +145,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, } AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(), &recv_scope); + // TODO(qiao): explain why if (var->IsType()) { var->GetMutable()->mutable_rows()->clear(); } diff --git a/paddle/fluid/operators/detail/async_grpc_server.cc b/paddle/fluid/operators/detail/async_grpc_server.cc index 7e45bc6b2f5..b9a228e1d2f 100644 --- a/paddle/fluid/operators/detail/async_grpc_server.cc +++ b/paddle/fluid/operators/detail/async_grpc_server.cc @@ -91,7 +91,7 @@ class RequestGet final : public RequestBase { framework::Scope* scope, const platform::DeviceContext* dev_ctx) : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) { - int method_id = static_cast(detail::GrpcMethod::kGetVariable); + auto method_id = static_cast(detail::GrpcMethod::kGetVariable); service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, cq_, this); } -- GitLab