From 3d2c41e1542014042babce1bc6b1b753f96faadb Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Thu, 12 Sep 2019 10:16:31 +0800 Subject: [PATCH] fix xbox-base-dump && add checkpoint config --- .../train/custom_trainer/feed/accessor/epoch_accessor.cc | 5 +++-- .../train/custom_trainer/feed/accessor/epoch_accessor.h | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc index e2e9565f..cd73f918 100755 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -96,6 +96,7 @@ namespace feed { std::shared_ptr context_ptr) { _time_zone_seconds = config["time_zone_seconds"].as(); _train_time_interval = config["train_time_interval"].as(); + _checkpoint_time_interval = config["checkpoint_time_interval"].as(3600 * 18); // 默认每18小时dump一个 CHECK(_train_time_interval > 0 && (_train_time_interval % SecondsPerMin) == 0); _train_num_per_day = SecondsPerDay / _train_time_interval; return EpochAccessor::initialize(config, context_ptr); @@ -123,7 +124,7 @@ namespace feed { bool TimelyEpochAccessor::is_last_epoch(uint64_t epoch_id) { auto delta = delta_id(epoch_id); - return delta == _train_num_per_day; + return delta == 0; // 最后一个delta恰好整除 } uint64_t TimelyEpochAccessor::epoch_time_interval() { @@ -149,7 +150,7 @@ namespace feed { if (is_last_epoch(epoch_id)) { return true; } - return delta_id(epoch_id) % 78 == 0; + return delta_id(epoch_id) % (_checkpoint_time_interval / _train_time_interval) == 0; case ModelSaveWay::ModelSaveTrainCheckpointBase: return is_last_epoch(epoch_id); } diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h index 9c9f2058..53df71f3 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h @@ -103,6 +103,7 @@ private: uint32_t _time_zone_seconds; // 相对UTC时差(秒) uint32_t _train_time_interval; // 训练时间间隔(秒) uint32_t _train_num_per_day; // 天级训练总轮数 + uint32_t _checkpoint_time_interval; // 每隔n秒,dump出CheckPoint }; } // namespace feed -- GitLab