diff --git a/paddle/fluid/operators/save_op.h b/paddle/fluid/operators/save_op.h index b41c70560812c57e89196525289e828c4a91e7f2..b59421cb9e08e343a507210316be0d9b06192c49 100644 --- a/paddle/fluid/operators/save_op.h +++ b/paddle/fluid/operators/save_op.h @@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel { void SaveSelectedRows(const framework::ExecutionContext &ctx, const platform::Place &place, const framework::Variable *var) const { - framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH); - auto file_path = ctx.Attr("file_path"); auto overwrite = ctx.Attr("overwrite"); std::string filename = file_path; + VLOG(4) << "SaveSelectedRows output file_path: " << file_path; + framework::Variable *out_put_var = ctx.scope().FindVar(LOOKUP_TABLE_PATH); if (out_put_var != nullptr) { auto *lt_var = out_put_var->GetMutable(); - filename = *lt_var; + if (lt_var->length() > 0) { + VLOG(4) << "SaveSelectedRows output var name: " << *lt_var; + filename = *lt_var; + } } if (FileExists(filename) && !overwrite) { diff --git a/python/paddle/fluid/contrib/utils/lookup_table_utils.py b/python/paddle/fluid/contrib/utils/lookup_table_utils.py index b15ee94f63512dcca91a8aab33d216db0fc24ed5..402850e70ab269e3b2e7d1993ad43686c02e9eb9 100644 --- a/python/paddle/fluid/contrib/utils/lookup_table_utils.py +++ b/python/paddle/fluid/contrib/utils/lookup_table_utils.py @@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program, }) sums.append(param_var) global_block.append_op( - type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={}) + type='merge_sparse_lookup_table', + inputs={"X": sums}, + outputs={'Out': emb_var}, + attrs={}) + global_block.append_op( + type='save', + inputs={"X": [emb_var]}, + outputs={}, + attrs={ + 'file_path': os.path.join(lookup_table_dirname, emb_var.name) + }) global_block.append_op(type='delete_var', inputs={'X': sums}) executor.run(convert_program) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 64afa7e856fb88e535265726d76214a621604e84..a0573881b79ba84d10f65a843996bc14812683c8 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -86,13 +86,21 @@ def is_persistable(var): def _clone_var_in_block_(block, var): assert isinstance(var, Variable) - return block.create_var( - name=var.name, - shape=var.shape, - dtype=var.dtype, - type=var.type, - lod_level=var.lod_level, - persistable=True) + if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR: + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + else: + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + persistable=True) def save_vars(executor,