提交 e84f353e 编写于 作者: Q qiaolongfei

optimize

上级 1a438287
......@@ -96,6 +96,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
block_list.push_back(blkid);
}
}
PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(),
"grad num should be equal to optimize block num");
auto optimize_prepared = executor.Prepare(*program, block_list);
std::unordered_map<std::string,
......@@ -107,7 +109,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
// set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
......@@ -142,6 +145,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
}
AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(),
&recv_scope);
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
......
......@@ -91,7 +91,7 @@ class RequestGet final : public RequestBase {
framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) {
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册