提交 b138d29c 编写于 作者: X Xin Pan

Avoid init_p2p all the times

上级 09b4a1a3
......@@ -105,7 +105,7 @@ static void BuildVar(const std::string& param_name,
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
InitDevices();
InitDevices(true);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
......
......@@ -64,7 +64,7 @@ void InitP2P(int count) {
#endif
}
void InitDevices() {
void InitDevices(bool init_p2p) {
/*Init all avaiable devices by default */
std::vector<platform::Place> places;
......@@ -85,7 +85,9 @@ void InitDevices() {
for (int i = 0; i < count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
InitP2P(count);
if (init_p2p) {
InitP2P(count);
}
platform::DeviceContextPool::Init(places);
}
......
......@@ -24,7 +24,7 @@ void InitGflags(std::vector<std::string> &argv);
void InitGLOG(const std::string &prog_name);
void InitDevices();
void InitDevices(bool init_p2p);
} // namespace framework
} // namespace paddle
......@@ -21,7 +21,7 @@ TEST(InitDevices, CPU) {
using paddle::platform::DeviceContextPool;
#ifndef PADDLE_WITH_CUDA
InitDevices();
InitDevices(true);
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U);
#endif
......@@ -33,7 +33,7 @@ TEST(InitDevices, CUDA) {
#ifdef PADDLE_WITH_CUDA
int count = paddle::platform::GetCUDADeviceCount();
InitDevices();
InitDevices(true);
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U + static_cast<unsigned>(count));
#endif
......
......@@ -30,7 +30,7 @@ __global__ void test(size_t* a, int size) {
}
TEST(LoD, data) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
......@@ -46,7 +46,7 @@ TEST(LoD, data) {
}
TEST(LoDTensor, LoDInGPU) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::LoDTensor lod_tensor;
paddle::platform::CUDAPlace place(0);
......
......@@ -72,7 +72,7 @@ REGISTER_OP_WITHOUT_GRADIENT(test_operator,
paddle::framework::OpWithoutKernelCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
......@@ -198,7 +198,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
......@@ -228,7 +228,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
......@@ -269,7 +269,7 @@ class OperatorClone : public paddle::framework::OperatorBase {
};
TEST(Operator, Clone) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
......
......@@ -423,7 +423,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("init_devices", &framework::InitDevices);
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
#ifdef PADDLE_WITH_CUDA
......
......@@ -41,6 +41,6 @@ int main(int argc, char** argv) {
paddle::memory::Used(paddle::platform::CUDAPlace(0));
#endif
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
return RUN_ALL_TESTS();
}
......@@ -84,6 +84,8 @@ def __bootstrap__():
import core
import os
in_test = 'unittest' in sys.modules
try:
num_threads = int(os.getenv('OMP_NUM_THREADS', '1'))
except ValueError:
......@@ -108,8 +110,11 @@ def __bootstrap__():
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
core.init_devices()
# don't init_p2p when in unittest to save time.
core.init_devices(not in_test)
# TODO(panyx0718): Avoid doing complex initialization logic in __init__.py.
# Consider paddle.init(args) or paddle.main(args)
layers.monkey_patch_variable()
__bootstrap__()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册