From 48b7b543213d2f1584efca610d6d50ccb1ee56e0 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 24 Apr 2018 14:52:36 +0800 Subject: [PATCH] Refine code. --- paddle/fluid/framework/init.cc | 10 ++++------ paddle/fluid/framework/init.h | 2 +- paddle/fluid/inference/io.cc | 6 +----- paddle/fluid/inference/io.h | 4 +--- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/init.cc b/paddle/fluid/framework/init.cc index 3ce37041cb..642b892105 100644 --- a/paddle/fluid/framework/init.cc +++ b/paddle/fluid/framework/init.cc @@ -20,7 +20,6 @@ limitations under the License. */ #include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" @@ -35,10 +34,8 @@ std::once_flag p2p_init_flag; using paddle::platform::DeviceContextPool; -void Init(int argc, char **argv) { - std::call_once(gflags_init_flag, - [&]() { google::ParseCommandLineFlags(&argc, &argv, true); }); - +void Init(std::vector &argv) { + InitGflags(argv); // init devices std::vector devices; std::string token; @@ -51,6 +48,7 @@ void Init(int argc, char **argv) { void InitGflags(std::vector &argv) { std::call_once(gflags_init_flag, [&]() { + argv.push_back("dummy"); int argc = argv.size(); char **arr = new char *[argv.size()]; std::string line; @@ -151,7 +149,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { #endif for (size_t i = 0; i < devices.size(); ++i) { - if (devices[i] >= count) { + if (devices[i] >= count || devices[i] < 0) { LOG(WARNING) << "Invalid devices id."; continue; } diff --git a/paddle/fluid/framework/init.h b/paddle/fluid/framework/init.h index 38604d232c..cf792f18b7 100644 --- a/paddle/fluid/framework/init.h +++ b/paddle/fluid/framework/init.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void Init(int argc, char **argv); +void Init(std::vector &argv); void InitGflags(std::vector &argv); diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 74068d9dbe..9c37e0178a 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -24,11 +24,7 @@ limitations under the License. */ namespace paddle { namespace inference { -// Temporarily add this function for exposing framework::InitDevices() when -// linking the inference shared library. -void Init(bool init_p2p) { framework::InitDevices(init_p2p); } - -void Init(int argc, char** argv) { framework::Init(argc, argv); } +void Init(std::vector &argv) { framework::Init(argv); } void ReadBinaryFile(const std::string& filename, std::string* contents) { std::ifstream fin(filename, std::ios::in | std::ios::binary); diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index 988b8aebbe..799693b0c5 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -25,9 +25,7 @@ limitations under the License. */ namespace paddle { namespace inference { -void Init(bool init_p2p); - -void Init(int argc, char** argv); +void Init(std::vector &argv); void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, -- GitLab