提交 48b7b543 编写于 作者: W wanghaoshuang

Refine code.

上级 1bdea0a8
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/piece.h"
...@@ -35,10 +34,8 @@ std::once_flag p2p_init_flag; ...@@ -35,10 +34,8 @@ std::once_flag p2p_init_flag;
using paddle::platform::DeviceContextPool; using paddle::platform::DeviceContextPool;
void Init(int argc, char **argv) { void Init(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, InitGflags(argv);
[&]() { google::ParseCommandLineFlags(&argc, &argv, true); });
// init devices // init devices
std::vector<int> devices; std::vector<int> devices;
std::string token; std::string token;
...@@ -51,6 +48,7 @@ void Init(int argc, char **argv) { ...@@ -51,6 +48,7 @@ void Init(int argc, char **argv) {
void InitGflags(std::vector<std::string> &argv) { void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
argv.push_back("dummy");
int argc = argv.size(); int argc = argv.size();
char **arr = new char *[argv.size()]; char **arr = new char *[argv.size()];
std::string line; std::string line;
...@@ -151,7 +149,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -151,7 +149,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif #endif
for (size_t i = 0; i < devices.size(); ++i) { for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i] >= count) { if (devices[i] >= count || devices[i] < 0) {
LOG(WARNING) << "Invalid devices id."; LOG(WARNING) << "Invalid devices id.";
continue; continue;
} }
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Init(int argc, char **argv); void Init(std::vector<std::string> &argv);
void InitGflags(std::vector<std::string> &argv); void InitGflags(std::vector<std::string> &argv);
......
...@@ -24,11 +24,7 @@ limitations under the License. */ ...@@ -24,11 +24,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Temporarily add this function for exposing framework::InitDevices() when void Init(std::vector<std::string> &argv) { framework::Init(argv); }
// linking the inference shared library.
void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
void Init(int argc, char** argv) { framework::Init(argc, argv); }
void ReadBinaryFile(const std::string& filename, std::string* contents) { void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
......
...@@ -25,9 +25,7 @@ limitations under the License. */ ...@@ -25,9 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
void Init(bool init_p2p); void Init(std::vector<std::string> &argv);
void Init(int argc, char** argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope, void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program, const framework::ProgramDesc& main_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册