提交 6b3d9625 编写于 作者: S songhao 提交者: tangwei12

fix some bug when merge sparse embedding parameters, test=develop (#18223)

1. fix the bug that out_put_var in SaveSelectedRows would be empty string
2. use merge_sparse_lookup_table to replace sum op for load_persistables_for_inference
3. fix the bug in _clone_var_in_block_ when the var is SELECTED_ROWS.
上级 3f8031e2
...@@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel<T> {
void SaveSelectedRows(const framework::ExecutionContext &ctx, void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place, const platform::Place &place,
const framework::Variable *var) const { const framework::Variable *var) const {
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
auto file_path = ctx.Attr<std::string>("file_path"); auto file_path = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite"); auto overwrite = ctx.Attr<bool>("overwrite");
std::string filename = file_path; std::string filename = file_path;
VLOG(4) << "SaveSelectedRows output file_path: " << file_path;
framework::Variable *out_put_var = ctx.scope().FindVar(LOOKUP_TABLE_PATH);
if (out_put_var != nullptr) { if (out_put_var != nullptr) {
auto *lt_var = out_put_var->GetMutable<std::string>(); auto *lt_var = out_put_var->GetMutable<std::string>();
filename = *lt_var; if (lt_var->length() > 0) {
VLOG(4) << "SaveSelectedRows output var name: " << *lt_var;
filename = *lt_var;
}
} }
if (FileExists(filename) && !overwrite) { if (FileExists(filename) && !overwrite) {
......
...@@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program, ...@@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program,
}) })
sums.append(param_var) sums.append(param_var)
global_block.append_op( global_block.append_op(
type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={}) type='merge_sparse_lookup_table',
inputs={"X": sums},
outputs={'Out': emb_var},
attrs={})
global_block.append_op(
type='save',
inputs={"X": [emb_var]},
outputs={},
attrs={
'file_path': os.path.join(lookup_table_dirname, emb_var.name)
})
global_block.append_op(type='delete_var', inputs={'X': sums}) global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program) executor.run(convert_program)
......
...@@ -86,13 +86,21 @@ def is_persistable(var): ...@@ -86,13 +86,21 @@ def is_persistable(var):
def _clone_var_in_block_(block, var): def _clone_var_in_block_(block, var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
return block.create_var( if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
name=var.name, return block.create_var(
shape=var.shape, name=var.name,
dtype=var.dtype, shape=var.shape,
type=var.type, dtype=var.dtype,
lod_level=var.lod_level, type=var.type,
persistable=True) lod_level=var.lod_level,
persistable=True)
else:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
def save_vars(executor, def save_vars(executor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册