diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 3b58019db6e55fa8198d2f77731095c6cf356266..78d2f16746cf478c4424df929bd1f62b08f8a67c 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/io.h" +#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/feed_fetch_type.h" @@ -27,14 +28,14 @@ namespace inference { // linking the inference shared library. void Init(bool init_p2p) { framework::InitDevices(init_p2p); } -void ReadBinaryFile(const std::string& filename, std::string& contents) { +void ReadBinaryFile(const std::string& filename, std::string* contents) { std::ifstream fin(filename, std::ios::in | std::ios::binary); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); fin.seekg(0, std::ios::end); - contents.clear(); - contents.resize(fin.tellg()); + contents->clear(); + contents->resize(fin.tellg()); fin.seekg(0, std::ios::beg); - fin.read(&contents[0], contents.size()); + fin.read(&(contents->at(0)), contents->size()); fin.close(); } @@ -47,7 +48,7 @@ bool IsPersistable(const framework::VarDesc* var) { return false; } -void LoadPersistables(framework::Executor& executor, framework::Scope& scope, +void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, const std::string& param_filename) { @@ -92,18 +93,18 @@ void LoadPersistables(framework::Executor& executor, framework::Scope& scope, op->CheckAttrs(); } - executor.Run(*load_program, &scope, 0, true, true); + executor->Run(*load_program, scope, 0, true, true); delete load_program; } -std::unique_ptr Load(framework::Executor& executor, - framework::Scope& scope, +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, const std::string& dirname) { std::string model_filename = dirname + "/__model__"; std::string program_desc_str; VLOG(3) << "loading model from " << model_filename; - ReadBinaryFile(model_filename, program_desc_str); + ReadBinaryFile(model_filename, &program_desc_str); std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); @@ -113,11 +114,11 @@ std::unique_ptr Load(framework::Executor& executor, } std::unique_ptr Load( - framework::Executor& executor, framework::Scope& scope, + framework::Executor* executor, framework::Scope* scope, const std::string& prog_filename, const std::string& param_filename) { std::string model_filename = prog_filename; std::string program_desc_str; - ReadBinaryFile(model_filename, program_desc_str); + ReadBinaryFile(model_filename, &program_desc_str); std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index 756c936b33ad55e2994542b171b945e248ba2e21..ba3e45099ae7c1626bf11d9527d4fa4c7f772fec 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -27,17 +27,17 @@ namespace inference { void Init(bool init_p2p); -void LoadPersistables(framework::Executor& executor, framework::Scope& scope, +void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, const std::string& param_filename); -std::unique_ptr Load(framework::Executor& executor, - framework::Scope& scope, +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, const std::string& dirname); -std::unique_ptr Load(framework::Executor& executor, - framework::Scope& scope, +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, const std::string& prog_filename, const std::string& param_filename); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index c3a8d0889c6a6dd9591837ccc523da56f8d13661..117472599f7c4874ab05e29c6ecb46fd61d0db9c 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -133,12 +133,12 @@ void TestInference(const std::string& dirname, std::string prog_filename = "__model_combined__"; std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load( - executor, *scope, dirname + "/" + prog_filename, + &executor, scope, dirname + "/" + prog_filename, dirname + "/" + param_filename); } else { // Parameters are saved in separate files sited in the specified // `dirname`. - inference_program = paddle::inference::Load(executor, *scope, dirname); + inference_program = paddle::inference::Load(&executor, scope, dirname); } } // Disable the profiler and print the timing information