未验证 提交 e4128757 编写于 作者: 石晓伟 提交者: GitHub

fix save interfaces of flatbuffers, test=develop (#4193)

* fix save interfaces of flatbuffers, test=develop

* fix the header file, test=develop
上级 557c0599
...@@ -35,11 +35,10 @@ std::vector<char> LoadFile(const std::string& path) { ...@@ -35,11 +35,10 @@ std::vector<char> LoadFile(const std::string& path) {
return buf; return buf;
} }
void SaveFile(const std::string& path, const void* src, size_t byte_size) { void SaveFile(const std::string& path, const std::vector<char>& cache) {
CHECK(src);
FILE* file = fopen(path.c_str(), "wb"); FILE* file = fopen(path.c_str(), "wb");
CHECK(file); CHECK(file);
CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size); CHECK(fwrite(cache.data(), sizeof(char), cache.size(), file) == cache.size());
fclose(file); fclose(file);
} }
......
...@@ -27,7 +27,7 @@ namespace lite { ...@@ -27,7 +27,7 @@ namespace lite {
namespace fbs { namespace fbs {
std::vector<char> LoadFile(const std::string& path); std::vector<char> LoadFile(const std::string& path);
void SaveFile(const std::string& path, const void* src, size_t byte_size); void SaveFile(const std::string& path, const std::vector<char>& cache);
void SetScopeWithCombinedParams(lite::Scope* scope, void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params); const CombinedParamsDescReadAPI& params);
......
...@@ -74,13 +74,8 @@ TEST(CombinedParamsDesc, Scope) { ...@@ -74,13 +74,8 @@ TEST(CombinedParamsDesc, Scope) {
}; };
check_params(combined_param); check_params(combined_param);
/* --------- Cache scope ---------- */
std::vector<char> cache;
cache.resize(combined_param.buf_size());
std::memcpy(cache.data(), combined_param.data(), combined_param.buf_size());
/* --------- View scope ---------- */ /* --------- View scope ---------- */
check_params(CombinedParamsDescView(std::move(cache))); check_params(CombinedParamsDescView(combined_param.data()));
} }
#endif // LITE_WITH_FLATBUFFERS_DESC #endif // LITE_WITH_FLATBUFFERS_DESC
......
...@@ -186,14 +186,12 @@ class CombinedParamsDesc : public CombinedParamsDescAPI { ...@@ -186,14 +186,12 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
return params_[params_.size() - 1].get(); return params_[params_.size() - 1].get();
} }
const void* data() { std::vector<char> data() {
SyncBuffer(); SyncBuffer();
return buf_.data(); std::vector<char> cache;
} cache.resize(buf_.size());
std::memcpy(cache.data(), buf_.data(), buf_.size());
size_t buf_size() { return cache;
SyncBuffer();
return buf_.size();
} }
private: private:
......
...@@ -137,14 +137,12 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -137,14 +137,12 @@ class ProgramDesc : public ProgramDescAPI {
desc_.version->version = version_in; desc_.version->version = version_in;
} }
const void* data() { std::vector<char> data() {
SyncBuffer(); SyncBuffer();
return buf_.data(); std::vector<char> cache;
} cache.resize(buf_.size());
std::memcpy(cache.data(), buf_.data(), buf_.size());
size_t buf_size() { return cache;
SyncBuffer();
return buf_.size();
} }
private: private:
......
...@@ -73,10 +73,7 @@ inline std::vector<char> GenerateProgramCache() { ...@@ -73,10 +73,7 @@ inline std::vector<char> GenerateProgramCache() {
op_b1.SetAttr<bool>("Attr1", true); op_b1.SetAttr<bool>("Attr1", true);
/* --------- Cache Program ---------- */ /* --------- Cache Program ---------- */
std::vector<char> cache; return program.data();
cache.resize(program.buf_size());
std::memcpy(cache.data(), program.data(), program.buf_size());
return cache;
} }
inline void CheckProgramCache(ProgramDesc* program) { inline void CheckProgramCache(ProgramDesc* program) {
......
...@@ -603,7 +603,7 @@ void SaveModelFbs(const std::string &model_dir, ...@@ -603,7 +603,7 @@ void SaveModelFbs(const std::string &model_dir,
const std::string prog_path = model_dir + "/model.fbs"; const std::string prog_path = model_dir + "/model.fbs";
fbs::ProgramDesc fbs_prog; fbs::ProgramDesc fbs_prog;
TransformProgramDescCppToAny(cpp_prog, &fbs_prog); TransformProgramDescCppToAny(cpp_prog, &fbs_prog);
fbs::SaveFile(prog_path, fbs_prog.data(), fbs_prog.buf_size()); fbs::SaveFile(prog_path, fbs_prog.data());
/* 2. Get param names from cpp::ProgramDesc */ /* 2. Get param names from cpp::ProgramDesc */
auto &main_block_desc = *cpp_prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *cpp_prog.GetBlock<cpp::BlockDesc>(0);
...@@ -621,7 +621,7 @@ void SaveModelFbs(const std::string &model_dir, ...@@ -621,7 +621,7 @@ void SaveModelFbs(const std::string &model_dir,
const std::string params_path = model_dir + "/params.fbs"; const std::string params_path = model_dir + "/params.fbs";
fbs::CombinedParamsDesc params_prog; fbs::CombinedParamsDesc params_prog;
fbs::SetCombinedParamsWithScope(exec_scope, unique_var_names, &params_prog); fbs::SetCombinedParamsWithScope(exec_scope, unique_var_names, &params_prog);
fbs::SaveFile(params_path, params_prog.data(), params_prog.buf_size()); fbs::SaveFile(params_path, params_prog.data());
} }
#endif // LITE_ON_TINY_PUBLISH #endif // LITE_ON_TINY_PUBLISH
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册