From a186b53dfbc46963904f790077244a10ea1cb60d Mon Sep 17 00:00:00 2001 From: QI JUN Date: Mon, 30 Oct 2017 10:37:44 -0700 Subject: [PATCH] add init_gflags interface (#5193) * add init_gflags interface * refine code * follow comments --- paddle/pybind/pybind.cc | 21 +++++++++++++++++++++ python/paddle/v2/framework/__init__.py | 10 ++++++++++ 2 files changed, 31 insertions(+) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index bf6e122642..4baff895da 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" +#include // for call_once +#include "gflags/gflags.h" #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" #include "paddle/framework/feed_fetch_method.h" @@ -45,6 +47,24 @@ static size_t UniqueIntegerGenerator() { return generator.fetch_add(1); } +std::once_flag gflags_init_flag; + +// TODO(qijun) move init gflags to init.cc +void InitGflags(std::vector &argv) { + std::call_once(gflags_init_flag, [&]() { + int argc = argv.size(); + char **arr = new char *[argv.size()]; + std::string line; + for (size_t i = 0; i < argv.size(); i++) { + arr[i] = &argv[i][0]; + line += argv[i]; + line += ' '; + } + google::ParseCommandLineFlags(&argc, &arr, true); + VLOG(1) << "Init commandline: " << line; + }); +} + bool IsCompileGPU() { #ifndef PADDLE_WITH_CUDA return false; @@ -483,6 +503,7 @@ All parameter, weight, gradient are variables in Paddle. }); m.def("unique_integer", UniqueIntegerGenerator); + m.def("init_gflags", InitGflags); m.def("is_compile_gpu", IsCompileGPU); m.def("set_feed_variable", framework::SetFeedVariable); diff --git a/python/paddle/v2/framework/__init__.py b/python/paddle/v2/framework/__init__.py index c942373c66..5df612bf35 100644 --- a/python/paddle/v2/framework/__init__.py +++ b/python/paddle/v2/framework/__init__.py @@ -1 +1,11 @@ +import sys +import core __all__ = ['proto'] +argv = [] +if core.is_compile_gpu(): + argv = list(sys.argv) + [ + "--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory" + ] +else: + argv = list(sys.argv) + ["--tryfromenv=use_pinned_memory"] +core.init_gflags(argv) -- GitLab