From 59580a7f691f8301779c40e9514897d1fa8842dc Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 14 Aug 2018 14:38:46 +0800 Subject: [PATCH] bug fix --- paddle/fluid/operators/distributed/request_handler_impl.cc | 4 ++-- paddle/fluid/operators/save_op.cc | 2 ++ python/paddle/fluid/io.py | 7 +------ 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index de1a50315..2af3a8229 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -130,12 +130,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, checkpoint_notify_id != -1, "when checkpoint_notify_id = -1, there should be no RPC invoke."); - auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); + auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " << out_var_name; - executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); + executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_); return true; } diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 201a51130..85de37416 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase { std::string filename = lt_var->data(); VLOG(4) << "SaveSelectedRows get File name: " << filename; + MkDirRecursively(DirName(filename).c_str()); + auto &selectedRows = var->Get(); // get device context from pool diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 87c91475b..362577a72 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -680,8 +680,6 @@ def load_inference_model(dirname, executor, model_filename=None, params_filename=None, - training_role=None, - role_id=None, pserver_endpoints=None): """ Load inference model from a directory @@ -733,9 +731,6 @@ def load_inference_model(dirname, if not os.path.isdir(dirname): raise ValueError("There is no directory named '%s'", dirname) - if training_role == "PSERVER": - _load_lookup_table_vars(executor, dirname, program, role_id) - if model_filename is not None: model_filename = os.path.basename(model_filename) else: @@ -800,7 +795,7 @@ def _save_lookup_tables_by_notify(executor, dirname, lookup_table, pserver_notify_block = pserver_notify_program.global_block() attrs = {} - attrs['epmap'] = pserver_endpoints.split(",") + attrs['epmap'] = pserver_endpoints attrs['dir'] = dirname attrs['lookup_table'] = lookup_table -- GitLab