提交 af0a6a14 编写于 作者: T tangwei12

checkpoint notify

上级 ae12281d
...@@ -126,11 +126,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -126,11 +126,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) { const std::string& out_var_name) {
auto lt_varname = string::Sprintf("%s.path", varname); auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>();
auto *lt_var = scope->FindVar(lt_varname)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update " << lt_varname << " to: " << out_var_name; VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true; return true;
} }
......
...@@ -182,9 +182,32 @@ This operator will serialize and write a tensor/selected rows variable to file o ...@@ -182,9 +182,32 @@ This operator will serialize and write a tensor/selected rows variable to file o
} }
}; };
} // namespace operators class SaveOpVarTypeInference : public framework::VarTypeInference {
} // namespace paddle public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("loopup_table_path").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
}
}
// namespace operators
// namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); REGISTER_OPERATOR(save, ops::SaveOp,
paddle::framework::EmptyGradOpMaker,
ops::SaveOpProtoMaker,
ops::SaveOpVarTypeInference,
ops::SaveOpShapeInference);
...@@ -838,7 +838,7 @@ class DistributeTranspiler: ...@@ -838,7 +838,7 @@ class DistributeTranspiler:
""" """
import os import os
pserver_program.global_block().create_var(name="%s.path"%self.table_name, persistable=True, type=core.VarDesc.VarType.RAW) pserver_program.global_block().create_var(name="loopup_table_path", persistable=True, type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx) checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op( checkpoint_save_block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册