diff --git a/paddle/fluid/operators/async_listen_and_serv_op.cc b/paddle/fluid/operators/async_listen_and_serv_op.cc index ec0ddedf3d13b319c54b00de593fa43ac945ee9f..bdb20240f69449d44c2ec89d2f3cb0d206f0b9b2 100644 --- a/paddle/fluid/operators/async_listen_and_serv_op.cc +++ b/paddle/fluid/operators/async_listen_and_serv_op.cc @@ -96,6 +96,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, block_list.push_back(blkid); } } + PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(), + "grad num should be equal to optimize block num"); auto optimize_prepared = executor.Prepare(*program, block_list); std::unordered_mapSetScope(&recv_scope); rpc_service_->SetDevCtx(&dev_ctx); - // TODO(qiao) set proper fields for table lookup and update + + // set proper fields for table lookup and update rpc_service_->SetExecutor(&executor); VLOG(3) << "prefetch block id is " << prefetch_block->ID(); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); @@ -142,6 +145,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, } AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(), &recv_scope); + // TODO(qiao): explain why if (var->IsType()) { var->GetMutable()->mutable_rows()->clear(); } diff --git a/paddle/fluid/operators/detail/async_grpc_server.cc b/paddle/fluid/operators/detail/async_grpc_server.cc index 7e45bc6b2f5dda6bd9cb91e6b860acf877bae67a..b9a228e1d2fc6d4b6c2ad0d449c19769b8eadc83 100644 --- a/paddle/fluid/operators/detail/async_grpc_server.cc +++ b/paddle/fluid/operators/detail/async_grpc_server.cc @@ -91,7 +91,7 @@ class RequestGet final : public RequestBase { framework::Scope* scope, const platform::DeviceContext* dev_ctx) : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) { - int method_id = static_cast(detail::GrpcMethod::kGetVariable); + auto method_id = static_cast(detail::GrpcMethod::kGetVariable); service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, cq_, this); }