未验证 提交 a186b53d 编写于 作者: Q QI JUN 提交者: GitHub

add init_gflags interface (#5193)

* add init_gflags interface

* refine code

* follow comments
上级 6c8dce9c
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/pybind/protobuf.h" #include "paddle/pybind/protobuf.h"
#include <mutex> // for call_once
#include "gflags/gflags.h"
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_method.h"
...@@ -45,6 +47,24 @@ static size_t UniqueIntegerGenerator() { ...@@ -45,6 +47,24 @@ static size_t UniqueIntegerGenerator() {
return generator.fetch_add(1); return generator.fetch_add(1);
} }
std::once_flag gflags_init_flag;
// TODO(qijun) move init gflags to init.cc
void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() {
int argc = argv.size();
char **arr = new char *[argv.size()];
std::string line;
for (size_t i = 0; i < argv.size(); i++) {
arr[i] = &argv[i][0];
line += argv[i];
line += ' ';
}
google::ParseCommandLineFlags(&argc, &arr, true);
VLOG(1) << "Init commandline: " << line;
});
}
bool IsCompileGPU() { bool IsCompileGPU() {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
return false; return false;
...@@ -483,6 +503,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -483,6 +503,7 @@ All parameter, weight, gradient are variables in Paddle.
}); });
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("init_gflags", InitGflags);
m.def("is_compile_gpu", IsCompileGPU); m.def("is_compile_gpu", IsCompileGPU);
m.def("set_feed_variable", framework::SetFeedVariable); m.def("set_feed_variable", framework::SetFeedVariable);
......
import sys
import core
__all__ = ['proto'] __all__ = ['proto']
argv = []
if core.is_compile_gpu():
argv = list(sys.argv) + [
"--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"
]
else:
argv = list(sys.argv) + ["--tryfromenv=use_pinned_memory"]
core.init_gflags(argv)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册