未验证 提交 0d53f5af 编写于 作者: M mapingshuo 提交者: GitHub

avoid saving shared params repeatedly (#3561) (#3823)

* avoid saving shared params repeatedly (#3561)

* test=develop
Co-authored-by: Nzhupengyang <zhu_py@qq.com>
Co-authored-by: Nhuzhiqiang <912790387@qq.com>
上级 089b2c38
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <fstream> #include <fstream>
#include <limits> #include <limits>
#include <set> #include <set>
#include <unordered_set>
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/core/variable.h" #include "lite/core/variable.h"
...@@ -528,12 +529,16 @@ void SaveCombinedParamsNaive(const std::string &path, ...@@ -528,12 +529,16 @@ void SaveCombinedParamsNaive(const std::string &path,
auto prog = cpp_prog; auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly
std::unordered_set<std::string> unique_var_names;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i); auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable() ||
unique_var_names.count(var.Name()) > 0)
continue; continue;
naive_buffer::ParamDesc param_desc(desc.AddParam()); naive_buffer::ParamDesc param_desc(desc.AddParam());
SetParamInfoNaive(&param_desc, exec_scope, var.Name()); SetParamInfoNaive(&param_desc, exec_scope, var.Name());
unique_var_names.emplace(var.Name());
} }
pt_desc.Save(); pt_desc.Save();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册