提交 bb17604b 编写于 作者: T tangwei12

bug fix

上级 b089b809
...@@ -240,7 +240,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -240,7 +240,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); req.set_notify_type(CHECKPOINT_SAVE_MESSAGE);
req.set_checkpoint_dir(dir); req.set_checkpoint_dir(dir);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
} }
......
...@@ -123,7 +123,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -123,7 +123,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) {} const std::string& out_var_name) {
executor_->RunPreparedContext(checkpoint_prepared_ctx_);
return true;
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
......
...@@ -96,7 +96,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -96,7 +96,7 @@ class SaveOp : public framework::OperatorBase {
} }
} }
SaveLodTensor(const string &filename, const platform::Place &place, SaveLodTensor(const std::string &filename, const platform::Place &place,
Variable *var) { Variable *var) {
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
...@@ -127,7 +127,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -127,7 +127,7 @@ class SaveOp : public framework::OperatorBase {
fout.close() fout.close()
} }
SaveSelectedRows(const string &filename, const platform::Place &place, SaveSelectedRows(const std::string &filename, const platform::Place &place,
Variable *var) { Variable *var) {
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
...@@ -141,7 +141,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -141,7 +141,7 @@ class SaveOp : public framework::OperatorBase {
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write", PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename); filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx); framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close() fout.close();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册