提交 3199d72d 编写于 作者: S seiriosPlus

add mode for save

上级 99078d30
...@@ -637,13 +637,13 @@ class SparseVariable { ...@@ -637,13 +637,13 @@ class SparseVariable {
if (mode == 1) { if (mode == 1) {
id_need_save = true; id_need_save = true;
} else { } else {
id_need_save = value.second.seen_after_last_save_; id_need_save = value.second->seen_after_last_save_;
} }
if (id_need_save) { if (id_need_save) {
ids.push_back(value.first); ids.push_back(value.first);
} }
value.second.seen_after_last_save_ = false; value.second->seen_after_last_save_ = false;
} }
} }
} }
...@@ -661,7 +661,7 @@ class SparseVariable { ...@@ -661,7 +661,7 @@ class SparseVariable {
auto *slr = var->GetMutable<framework::SelectedRows>(); auto *slr = var->GetMutable<framework::SelectedRows>();
auto *src_t = slr->mutable_value(); auto *src_t = slr->mutable_value();
src_t->Resize({ids.size(), dim}); src_t->Resize({static_cast<int64_t>(ids.size()), dim});
auto *value = src_t->mutable_data<float>(place); auto *value = src_t->mutable_data<float>(place);
dims.push_back(dim); dims.push_back(dim);
...@@ -669,12 +669,11 @@ class SparseVariable { ...@@ -669,12 +669,11 @@ class SparseVariable {
tensors.push_back(value); tensors.push_back(value);
} }
std::vector<std::vector<std::vector<float> *>> *values; std::vector<std::vector<std::vector<float> *>> values;
Get(ids, variables, values); Get(ids, valuenames, &values);
int64_t offset = 0; int64_t offset = 0;
for (auto *value : values) { for (auto &vss : values) {
auto vss = value;
for (int i = 0; i < static_cast<int>(vss.size()); i++) { for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i]; auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(), std::memcpy(tensors[i] + offset * dims[i], vs->data(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册