提交 580f55fa 编写于 作者: Y Yancey1989

update by comment

上级 6edfae42
......@@ -167,9 +167,8 @@ void ListenAndServOp::RunSyncLoop(
recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
// reset received sparse vars to avoid reuse it in the next mini-batch
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
!rpc_service_->NeedResetAllVars());
rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
......@@ -179,7 +178,7 @@ void ListenAndServOp::RunSyncLoop(
void ListenAndServOp::ResetReceivedVars(
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const {
platform::DeviceContext *dev_ctx, bool reset_all) const {
for (auto &varname : recv_varnames) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
......@@ -187,9 +186,11 @@ void ListenAndServOp::ResetReceivedVars(
continue;
}
if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
if (!only_sparse_vars) {
if (UNLIKELY(reset_all)) {
VLOG(3) << "reset dense var: " << varname;
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
......
......@@ -70,7 +70,7 @@ class ListenAndServOp : public framework::OperatorBase {
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool only_sparse_vars = true) const;
bool reset_all = false) const;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册