提交 3d2c41e1 编写于 作者: X xiexionghang

fix xbox-base-dump && add checkpoint config

上级 57882670
...@@ -96,6 +96,7 @@ namespace feed { ...@@ -96,6 +96,7 @@ namespace feed {
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
_time_zone_seconds = config["time_zone_seconds"].as<int>(); _time_zone_seconds = config["time_zone_seconds"].as<int>();
_train_time_interval = config["train_time_interval"].as<int>(); _train_time_interval = config["train_time_interval"].as<int>();
_checkpoint_time_interval = config["checkpoint_time_interval"].as<int>(3600 * 18); // 默认每18小时dump一个
CHECK(_train_time_interval > 0 && (_train_time_interval % SecondsPerMin) == 0); CHECK(_train_time_interval > 0 && (_train_time_interval % SecondsPerMin) == 0);
_train_num_per_day = SecondsPerDay / _train_time_interval; _train_num_per_day = SecondsPerDay / _train_time_interval;
return EpochAccessor::initialize(config, context_ptr); return EpochAccessor::initialize(config, context_ptr);
...@@ -123,7 +124,7 @@ namespace feed { ...@@ -123,7 +124,7 @@ namespace feed {
bool TimelyEpochAccessor::is_last_epoch(uint64_t epoch_id) { bool TimelyEpochAccessor::is_last_epoch(uint64_t epoch_id) {
auto delta = delta_id(epoch_id); auto delta = delta_id(epoch_id);
return delta == _train_num_per_day; return delta == 0; // 最后一个delta恰好整除
} }
uint64_t TimelyEpochAccessor::epoch_time_interval() { uint64_t TimelyEpochAccessor::epoch_time_interval() {
...@@ -149,7 +150,7 @@ namespace feed { ...@@ -149,7 +150,7 @@ namespace feed {
if (is_last_epoch(epoch_id)) { if (is_last_epoch(epoch_id)) {
return true; return true;
} }
return delta_id(epoch_id) % 78 == 0; return delta_id(epoch_id) % (_checkpoint_time_interval / _train_time_interval) == 0;
case ModelSaveWay::ModelSaveTrainCheckpointBase: case ModelSaveWay::ModelSaveTrainCheckpointBase:
return is_last_epoch(epoch_id); return is_last_epoch(epoch_id);
} }
......
...@@ -103,6 +103,7 @@ private: ...@@ -103,6 +103,7 @@ private:
uint32_t _time_zone_seconds; // 相对UTC时差(秒) uint32_t _time_zone_seconds; // 相对UTC时差(秒)
uint32_t _train_time_interval; // 训练时间间隔(秒) uint32_t _train_time_interval; // 训练时间间隔(秒)
uint32_t _train_num_per_day; // 天级训练总轮数 uint32_t _train_num_per_day; // 天级训练总轮数
uint32_t _checkpoint_time_interval; // 每隔n秒,dump出CheckPoint
}; };
} // namespace feed } // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册