提交 99078d30 编写于 作者: S seiriosPlus

add mode for save

上级 f70096a2
......@@ -246,7 +246,7 @@ struct VALUE {
std::vector<std::string> names_;
int count_;
bool seen_after_save_;
bool seen_after_last_save_;
int unseen_days_;
bool is_entry_;
std::vector<std::vector<float>> values_;
......@@ -323,7 +323,7 @@ class ValueBlock {
auto value = new VALUE(value_names_);
value->set(values);
value->seen_after_save_ = true;
value->seen_after_last_save_ = true;
value->count_ = count;
values_[id] = value;
}
......@@ -629,18 +629,22 @@ class SparseVariable {
for (auto &block : shard_blocks_) {
for (auto value : block->values_) {
if (mode == 0) {
ids.push_back(value.first);
} else {
bool id_need_save = false;
// save all params
if (mode == 0) {
if (mode == 1) {
id_need_save = true;
} else {
id_need_save = value.second.seen_after_save_;
id_need_save = value.second.seen_after_last_save_;
}
if (id_need_save) {
ids.push_back(value.first);
}
value.second.seen_after_save_ = false;
value.second.seen_after_last_save_ = false;
}
}
}
......
......@@ -516,7 +516,7 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None):
def save_persistables(self, executor, dirname, main_program=None, mode=1):
"""
saves all persistable variables from :code:`main_program` to
......@@ -557,7 +557,8 @@ class Fleet(object):
"""
self._runtime_handle._save_persistables(executor, dirname, main_program)
self._runtime_handle._save_persistables(executor, dirname, main_program,
mode)
def distributed_optimizer(self, optimizer, strategy=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册