提交 99078d30 编写于 作者: S seiriosPlus

add mode for save

上级 f70096a2
...@@ -246,7 +246,7 @@ struct VALUE { ...@@ -246,7 +246,7 @@ struct VALUE {
std::vector<std::string> names_; std::vector<std::string> names_;
int count_; int count_;
bool seen_after_save_; bool seen_after_last_save_;
int unseen_days_; int unseen_days_;
bool is_entry_; bool is_entry_;
std::vector<std::vector<float>> values_; std::vector<std::vector<float>> values_;
...@@ -323,7 +323,7 @@ class ValueBlock { ...@@ -323,7 +323,7 @@ class ValueBlock {
auto value = new VALUE(value_names_); auto value = new VALUE(value_names_);
value->set(values); value->set(values);
value->seen_after_save_ = true; value->seen_after_last_save_ = true;
value->count_ = count; value->count_ = count;
values_[id] = value; values_[id] = value;
} }
...@@ -629,18 +629,22 @@ class SparseVariable { ...@@ -629,18 +629,22 @@ class SparseVariable {
for (auto &block : shard_blocks_) { for (auto &block : shard_blocks_) {
for (auto value : block->values_) { for (auto value : block->values_) {
bool id_need_save = false;
// save all params
if (mode == 0) { if (mode == 0) {
id_need_save = true; ids.push_back(value.first);
} else { } else {
id_need_save = value.second.seen_after_save_; bool id_need_save = false;
} // save all params
if (mode == 1) {
id_need_save = true;
} else {
id_need_save = value.second.seen_after_last_save_;
}
if (id_need_save) { if (id_need_save) {
ids.push_back(value.first); ids.push_back(value.first);
}
value.second.seen_after_last_save_ = false;
} }
value.second.seen_after_save_ = false;
} }
} }
......
...@@ -516,7 +516,7 @@ class Fleet(object): ...@@ -516,7 +516,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program, executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment) export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None): def save_persistables(self, executor, dirname, main_program=None, mode=1):
""" """
saves all persistable variables from :code:`main_program` to saves all persistable variables from :code:`main_program` to
...@@ -557,7 +557,8 @@ class Fleet(object): ...@@ -557,7 +557,8 @@ class Fleet(object):
""" """
self._runtime_handle._save_persistables(executor, dirname, main_program) self._runtime_handle._save_persistables(executor, dirname, main_program,
mode)
def distributed_optimizer(self, optimizer, strategy=None): def distributed_optimizer(self, optimizer, strategy=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册