diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 8898438675687e06fc4389ddcd634dc04e8583bd..1dff3bfa3cbb9709e40e526523b5467b650268ff 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -240,7 +240,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); req.set_checkpoint_dir(dir); - auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq); + auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index ffad66700e589db48b0ab652fe28ffb2517aff25..de6ce72d4dcd30e53dfa770ce615335f3a2ecfc3 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -123,7 +123,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, framework::Variable** outvar, - const std::string& out_var_name) {} + const std::string& out_var_name) { + executor_->RunPreparedContext(checkpoint_prepared_ctx_); + return true; +} } // namespace detail } // namespace operators diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 410796eeb6cd1e13de2e2699f639033d8525f9ed..3d114538eb881bc526e51e1aa442f85e00c7647d 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -96,7 +96,7 @@ class SaveOp : public framework::OperatorBase { } } - SaveLodTensor(const string &filename, const platform::Place &place, + SaveLodTensor(const std::string &filename, const platform::Place &place, Variable *var) { auto &tensor = var->Get(); @@ -127,7 +127,7 @@ class SaveOp : public framework::OperatorBase { fout.close() } - SaveSelectedRows(const string &filename, const platform::Place &place, + SaveSelectedRows(const std::string &filename, const platform::Place &place, Variable *var) { auto &selectedRows = var->Get(); @@ -141,7 +141,7 @@ class SaveOp : public framework::OperatorBase { PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", filename); framework::SerializeToStream(fout, selectedRows, dev_ctx); - fout.close() + fout.close(); } };