提交 580f55fa 编写于 作者: Y Yancey1989

update by comment

上级 6edfae42
...@@ -167,9 +167,8 @@ void ListenAndServOp::RunSyncLoop( ...@@ -167,9 +167,8 @@ void ListenAndServOp::RunSyncLoop(
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
// reset received sparse vars to avoid reuse it in the next mini-batch
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx, ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
!rpc_service_->NeedResetAllVars()); rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
...@@ -179,7 +178,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -179,7 +178,7 @@ void ListenAndServOp::RunSyncLoop(
void ListenAndServOp::ResetReceivedVars( void ListenAndServOp::ResetReceivedVars(
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope, const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const { platform::DeviceContext *dev_ctx, bool reset_all) const {
for (auto &varname : recv_varnames) { for (auto &varname : recv_varnames) {
auto var = recv_scope->FindVar(varname); auto var = recv_scope->FindVar(varname);
if (var == nullptr) { if (var == nullptr) {
...@@ -187,9 +186,11 @@ void ListenAndServOp::ResetReceivedVars( ...@@ -187,9 +186,11 @@ void ListenAndServOp::ResetReceivedVars(
continue; continue;
} }
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
if (!only_sparse_vars) { if (UNLIKELY(reset_all)) {
VLOG(3) << "reset dense var: " << varname;
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(), math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0)); static_cast<float>(0));
......
...@@ -70,7 +70,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -70,7 +70,7 @@ class ListenAndServOp : public framework::OperatorBase {
void ResetReceivedVars(const std::vector<std::string>& recv_varnames, void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
framework::Scope* recv_scope, framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
bool only_sparse_vars = true) const; bool reset_all = false) const;
protected: protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册