From 6b3d96254d03c514be375a13d9f95686cc964a5c Mon Sep 17 00:00:00 2001 From: songhao Date: Fri, 21 Jun 2019 15:56:20 +0800 Subject: [PATCH] fix some bug when merge sparse embedding parameters, test=develop (#18223) 1. fix the bug that out_put_var in SaveSelectedRows would be empty string 2. use merge_sparse_lookup_table to replace sum op for load_persistables_for_inference 3. fix the bug in _clone_var_in_block_ when the var is SELECTED_ROWS. --- paddle/fluid/operators/save_op.h | 9 +++++--- .../fluid/contrib/utils/lookup_table_utils.py | 12 +++++++++- python/paddle/fluid/io.py | 22 +++++++++++++------ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/save_op.h b/paddle/fluid/operators/save_op.h index b41c7056081..b59421cb9e0 100644 --- a/paddle/fluid/operators/save_op.h +++ b/paddle/fluid/operators/save_op.h @@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel { void SaveSelectedRows(const framework::ExecutionContext &ctx, const platform::Place &place, const framework::Variable *var) const { - framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH); - auto file_path = ctx.Attr("file_path"); auto overwrite = ctx.Attr("overwrite"); std::string filename = file_path; + VLOG(4) << "SaveSelectedRows output file_path: " << file_path; + framework::Variable *out_put_var = ctx.scope().FindVar(LOOKUP_TABLE_PATH); if (out_put_var != nullptr) { auto *lt_var = out_put_var->GetMutable(); - filename = *lt_var; + if (lt_var->length() > 0) { + VLOG(4) << "SaveSelectedRows output var name: " << *lt_var; + filename = *lt_var; + } } if (FileExists(filename) && !overwrite) { diff --git a/python/paddle/fluid/contrib/utils/lookup_table_utils.py b/python/paddle/fluid/contrib/utils/lookup_table_utils.py index b15ee94f635..402850e70ab 100644 --- a/python/paddle/fluid/contrib/utils/lookup_table_utils.py +++ b/python/paddle/fluid/contrib/utils/lookup_table_utils.py @@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program, }) sums.append(param_var) global_block.append_op( - type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={}) + type='merge_sparse_lookup_table', + inputs={"X": sums}, + outputs={'Out': emb_var}, + attrs={}) + global_block.append_op( + type='save', + inputs={"X": [emb_var]}, + outputs={}, + attrs={ + 'file_path': os.path.join(lookup_table_dirname, emb_var.name) + }) global_block.append_op(type='delete_var', inputs={'X': sums}) executor.run(convert_program) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 64afa7e856f..a0573881b79 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -86,13 +86,21 @@ def is_persistable(var): def _clone_var_in_block_(block, var): assert isinstance(var, Variable) - return block.create_var( - name=var.name, - shape=var.shape, - dtype=var.dtype, - type=var.type, - lod_level=var.lod_level, - persistable=True) + if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR: + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + else: + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + persistable=True) def save_vars(executor, -- GitLab