提交 ae0b57ba 编写于 作者: L linan17

update for save way=3

Change-Id: I3b0c3753bf1dea211bd41931db19cb2550f03f6e
上级 c00354af
...@@ -150,6 +150,8 @@ namespace feed { ...@@ -150,6 +150,8 @@ namespace feed {
return true; return true;
} }
return delta_id(epoch_id) % 24 == 0; return delta_id(epoch_id) % 24 == 0;
case ModelSaveWay::ModelSaveTrainCheckpointBase:
return is_last_epoch(epoch_id);
} }
return false; return false;
} }
...@@ -165,6 +167,7 @@ namespace feed { ...@@ -165,6 +167,7 @@ namespace feed {
case ModelSaveWay::ModelSaveInferenceBase: case ModelSaveWay::ModelSaveInferenceBase:
return _trainer_context->file_system->path_join(_inference_model_path, return _trainer_context->file_system->path_join(_inference_model_path,
string::format_string("%s/base", date.c_str())); string::format_string("%s/base", date.c_str()));
case ModelSaveWay::ModelSaveTrainCheckpointBase:
case ModelSaveWay::ModelSaveTrainCheckpoint: case ModelSaveWay::ModelSaveTrainCheckpoint:
return _trainer_context->file_system->path_join(_model_root_path, return _trainer_context->file_system->path_join(_model_root_path,
string::format_string("batch_model/%s", date_with_hour.c_str())); string::format_string("batch_model/%s", date_with_hour.c_str()));
......
...@@ -175,7 +175,11 @@ int LearnerProcess::run() { ...@@ -175,7 +175,11 @@ int LearnerProcess::run() {
{ {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); if (epoch_accessor->is_last_epoch(epoch_id)) {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
} else {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
}
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
if (epoch_accessor->is_last_epoch(epoch_id) && if (epoch_accessor->is_last_epoch(epoch_id) &&
environment->is_master_node(EnvironmentRole::WORKER)) { environment->is_master_node(EnvironmentRole::WORKER)) {
......
...@@ -26,7 +26,8 @@ const uint32_t SecondsPerDay = 24 * 3600; ...@@ -26,7 +26,8 @@ const uint32_t SecondsPerDay = 24 * 3600;
enum class ModelSaveWay { enum class ModelSaveWay {
ModelSaveTrainCheckpoint = 0, ModelSaveTrainCheckpoint = 0,
ModelSaveInferenceDelta = 1, ModelSaveInferenceDelta = 1,
ModelSaveInferenceBase = 2 ModelSaveInferenceBase = 2,
ModelSaveTrainCheckpointBase = 3,
}; };
enum class TrainerStatus { enum class TrainerStatus {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册