提交 6b3d9625 编写于 作者: S songhao 提交者: tangwei12

fix some bug when merge sparse embedding parameters, test=develop (#18223)

1. fix the bug that out_put_var in SaveSelectedRows would be empty string
2. use merge_sparse_lookup_table to replace sum op for load_persistables_for_inference
3. fix the bug in _clone_var_in_block_ when the var is SELECTED_ROWS.
上级 3f8031e2
......@@ -102,17 +102,20 @@ class SaveOpKernel : public framework::OpKernel<T> {
void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var) const {
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
auto file_path = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
std::string filename = file_path;
VLOG(4) << "SaveSelectedRows output file_path: " << file_path;
framework::Variable *out_put_var = ctx.scope().FindVar(LOOKUP_TABLE_PATH);
if (out_put_var != nullptr) {
auto *lt_var = out_put_var->GetMutable<std::string>();
if (lt_var->length() > 0) {
VLOG(4) << "SaveSelectedRows output var name: " << *lt_var;
filename = *lt_var;
}
}
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
......
......@@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program,
})
sums.append(param_var)
global_block.append_op(
type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={})
type='merge_sparse_lookup_table',
inputs={"X": sums},
outputs={'Out': emb_var},
attrs={})
global_block.append_op(
type='save',
inputs={"X": [emb_var]},
outputs={},
attrs={
'file_path': os.path.join(lookup_table_dirname, emb_var.name)
})
global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program)
......
......@@ -86,6 +86,7 @@ def is_persistable(var):
def _clone_var_in_block_(block, var):
assert isinstance(var, Variable)
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
return block.create_var(
name=var.name,
shape=var.shape,
......@@ -93,6 +94,13 @@ def _clone_var_in_block_(block, var):
type=var.type,
lod_level=var.lod_level,
persistable=True)
else:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
def save_vars(executor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册