diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc index e2e9565f60c7e60770ebfaedc4b7c9d7dd8a0134..cd73f918eaa44fe645a2efc0314bc577e6845864 100755 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -96,6 +96,7 @@ namespace feed { std::shared_ptr context_ptr) { _time_zone_seconds = config["time_zone_seconds"].as(); _train_time_interval = config["train_time_interval"].as(); + _checkpoint_time_interval = config["checkpoint_time_interval"].as(3600 * 18); // 默认每18小时dump一个 CHECK(_train_time_interval > 0 && (_train_time_interval % SecondsPerMin) == 0); _train_num_per_day = SecondsPerDay / _train_time_interval; return EpochAccessor::initialize(config, context_ptr); @@ -123,7 +124,7 @@ namespace feed { bool TimelyEpochAccessor::is_last_epoch(uint64_t epoch_id) { auto delta = delta_id(epoch_id); - return delta == _train_num_per_day; + return delta == 0; // 最后一个delta恰好整除 } uint64_t TimelyEpochAccessor::epoch_time_interval() { @@ -149,7 +150,7 @@ namespace feed { if (is_last_epoch(epoch_id)) { return true; } - return delta_id(epoch_id) % 78 == 0; + return delta_id(epoch_id) % (_checkpoint_time_interval / _train_time_interval) == 0; case ModelSaveWay::ModelSaveTrainCheckpointBase: return is_last_epoch(epoch_id); } diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h index 9c9f2058b0c30a16a0bf5943bbc30cff71dded79..53df71f3712d5194def0645d482ecac1bdd0121a 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h @@ -103,6 +103,7 @@ private: uint32_t _time_zone_seconds; // 相对UTC时差(秒) uint32_t _train_time_interval; // 训练时间间隔(秒) uint32_t _train_num_per_day; // 天级训练总轮数 + uint32_t _checkpoint_time_interval; // 每隔n秒,dump出CheckPoint }; } // namespace feed