diff --git a/paddle/fluid/operators/distributed/large_scale_kv.h b/paddle/fluid/operators/distributed/large_scale_kv.h index cb59fbc66e18e6f08990e5f84a22fe425d042200..4c017e88732cfb05fa500034ae80b6c1ce041d6a 100644 --- a/paddle/fluid/operators/distributed/large_scale_kv.h +++ b/paddle/fluid/operators/distributed/large_scale_kv.h @@ -246,7 +246,7 @@ struct VALUE { std::vector names_; int count_; - bool seen_after_save_; + bool seen_after_last_save_; int unseen_days_; bool is_entry_; std::vector> values_; @@ -323,7 +323,7 @@ class ValueBlock { auto value = new VALUE(value_names_); value->set(values); - value->seen_after_save_ = true; + value->seen_after_last_save_ = true; value->count_ = count; values_[id] = value; } @@ -629,18 +629,22 @@ class SparseVariable { for (auto &block : shard_blocks_) { for (auto value : block->values_) { - bool id_need_save = false; - // save all params if (mode == 0) { - id_need_save = true; + ids.push_back(value.first); } else { - id_need_save = value.second.seen_after_save_; - } + bool id_need_save = false; + // save all params + if (mode == 1) { + id_need_save = true; + } else { + id_need_save = value.second.seen_after_last_save_; + } - if (id_need_save) { - ids.push_back(value.first); + if (id_need_save) { + ids.push_back(value.first); + } + value.second.seen_after_last_save_ = false; } - value.second.seen_after_save_ = false; } } diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index d00faac838504f5d68e9d44d9ffa9f25c7bf2ee5..575e5720d7765cc52b7e52dce3b44a3d6350f208 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -516,7 +516,7 @@ class Fleet(object): executor, dirname, feeded_var_names, target_vars, main_program, export_for_deployment) - def save_persistables(self, executor, dirname, main_program=None): + def save_persistables(self, executor, dirname, main_program=None, mode=1): """ saves all persistable variables from :code:`main_program` to @@ -557,7 +557,8 @@ class Fleet(object): """ - self._runtime_handle._save_persistables(executor, dirname, main_program) + self._runtime_handle._save_persistables(executor, dirname, main_program, + mode) def distributed_optimizer(self, optimizer, strategy=None): """