提交 e84f353e 编写于 作者: Q qiaolongfei

optimize

上级 1a438287
...@@ -96,6 +96,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -96,6 +96,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
block_list.push_back(blkid); block_list.push_back(blkid);
} }
} }
PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(),
"grad num should be equal to optimize block num");
auto optimize_prepared = executor.Prepare(*program, block_list); auto optimize_prepared = executor.Prepare(*program, block_list);
std::unordered_map<std::string, std::unordered_map<std::string,
...@@ -107,7 +109,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -107,7 +109,8 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
// set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor); rpc_service_->SetExecutor(&executor);
VLOG(3) << "prefetch block id is " << prefetch_block->ID(); VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
...@@ -142,6 +145,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -142,6 +145,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
} }
AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(), AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(),
&recv_scope); &recv_scope);
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
......
...@@ -91,7 +91,7 @@ class RequestGet final : public RequestBase { ...@@ -91,7 +91,7 @@ class RequestGet final : public RequestBase {
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) { : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) {
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this); cq_, this);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册