提交 59580a7f 编写于 作者: T tangwei12

bug fix

上级 3972ba32
...@@ -130,12 +130,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -130,12 +130,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id != -1, checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); "when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name; << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
return true; return true;
} }
......
...@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase { ...@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase {
std::string filename = lt_var->data(); std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool // get device context from pool
......
...@@ -680,8 +680,6 @@ def load_inference_model(dirname, ...@@ -680,8 +680,6 @@ def load_inference_model(dirname,
executor, executor,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
training_role=None,
role_id=None,
pserver_endpoints=None): pserver_endpoints=None):
""" """
Load inference model from a directory Load inference model from a directory
...@@ -733,9 +731,6 @@ def load_inference_model(dirname, ...@@ -733,9 +731,6 @@ def load_inference_model(dirname,
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
if training_role == "PSERVER":
_load_lookup_table_vars(executor, dirname, program, role_id)
if model_filename is not None: if model_filename is not None:
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
else: else:
...@@ -800,7 +795,7 @@ def _save_lookup_tables_by_notify(executor, dirname, lookup_table, ...@@ -800,7 +795,7 @@ def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_notify_block = pserver_notify_program.global_block() pserver_notify_block = pserver_notify_program.global_block()
attrs = {} attrs = {}
attrs['epmap'] = pserver_endpoints.split(",") attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table attrs['lookup_table'] = lookup_table
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册