提交 229964e4 编写于 作者: X xiexionghang

add force-dump in startup

上级 19b00fb2
......@@ -16,6 +16,7 @@ namespace feed {
int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
int ret = Process::initialize(context_ptr);
auto& config = _context_ptr->trainer_config;
_startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
if (config["executor"]) {
_executors.resize(config["executor"].size());
for (size_t i = 0; i < _executors.size(); ++i) {
......@@ -26,7 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return 0;
}
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get();
......@@ -34,7 +35,7 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0;
}
if (!epoch_accessor->need_save_model(epoch_id, way)) {
if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
return 0;
}
paddle::platform::Timer timer;
......@@ -112,8 +113,8 @@ int LearnerProcess::run() {
CHECK(load_model(epoch_id) == 0);
environment->barrier(EnvironmentRole::WORKER);
//判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
//判断是否先dump出base TODO
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
environment->barrier(EnvironmentRole::WORKER);
while (true) {
......
......@@ -20,10 +20,11 @@ public:
protected:
// 加载所有模型
virtual int load_model(uint64_t epoch_id);
// 同步保存所有模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false);
private:
bool _startup_dump_inference_base; //启动立即dump base
std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册