diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index a003995ae3f8e111881b4681554aa8eb17b60cc1..e8bf53e160e7382122c3c2f92a152fea058a032e 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -48,7 +48,14 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, void AllReduceOpHandle::RunImpl() { platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second); +// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR, +// this is a distributed or inter-process call, find a better way. +#ifdef PADDLE_WITH_CUDA + if (NoDummyInputSize() == 1 && + local_scopes_[0]->FindLocalVar(NCCL_ID_VARNAME) == nullptr) { +#else if (NoDummyInputSize() == 1) { +#endif return; // No need to all reduce when GPU count = 1; } else { // Wait input done diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 523f9eadf2d7e2e08504c5920372fb7cdb0d7aba..1e1b945f63cf480308c05ffc0f9a3b9f0b0da02b 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -62,6 +62,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { auto multi_devices_pass = AppendPass("multi_devices_pass"); multi_devices_pass->SetNotOwned("strategy", &strategy_); + multi_devices_pass->Set("num_trainers", + new int(strategy_.num_trainers_)); // Add a graph print pass to record a graph with device info. if (!strategy_.debug_graphviz_path_.empty()) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 03f5f2e73a8d3e1cd1816f47a92ccfdb9bba4850..cbae5321d9a3fdcaebca2bd9111fff6a10b235a8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -133,6 +133,7 @@ static const char kPlaces[] = "places"; static const char kParams[] = "params"; static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; +static const char kNumTrainers[] = "num_trainers"; void MultiDevSSAGraphBuilder::Init() const { all_vars_.clear(); @@ -299,6 +300,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; + int num_trainers = Get(kNumTrainers); + for (auto &node : nodes) { if (node->IsVar() && node->Var()) { all_vars_.emplace(node->Name(), node->Var()); @@ -383,7 +386,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateComputationalOps(&result, node, places_.size()); } - if (!is_forwarding && places_.size() > 1) { + if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. if (static_cast(boost::get(node->Op()->GetAttr( @@ -895,4 +898,5 @@ REGISTER_PASS(multi_devices_pass, .RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kParams) .RequirePassAttr(paddle::framework::details::kLocalScopes) - .RequirePassAttr(paddle::framework::details::kStrategy); + .RequirePassAttr(paddle::framework::details::kStrategy) + .RequirePassAttr(paddle::framework::details::kNumTrainers); diff --git a/paddle/fluid/memory/allocation/legacy_allocator.cc b/paddle/fluid/memory/allocation/legacy_allocator.cc index 05b9a2cc08387c545ec45e38351dcd6707e20372..64aa63ffe9705d75e70c8d9d9cbc433dd6358596 100644 --- a/paddle/fluid/memory/allocation/legacy_allocator.cc +++ b/paddle/fluid/memory/allocation/legacy_allocator.cc @@ -14,11 +14,13 @@ #include "paddle/fluid/memory/allocation/legacy_allocator.h" #include +#include #include "glog/logging.h" #include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/split.h" DEFINE_bool(init_allocated_mem, false, "It is a mistake that the values of the memory allocated by " @@ -110,19 +112,21 @@ size_t Used(const platform::CPUPlace &place) { BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) { static std::once_flag init_flag; static detail::BuddyAllocator **a_arr = nullptr; + static std::vector devices; std::call_once(init_flag, [gpu_id]() { - int gpu_num = platform::GetCUDADeviceCount(); - PADDLE_ENFORCE(gpu_id < gpu_num, "gpu_id:%d should < gpu_num:%d", gpu_id, - gpu_num); + devices = platform::GetSelectedDevices(); + int gpu_num = devices.size(); a_arr = new BuddyAllocator *[gpu_num]; - for (int i = 0; i < gpu_num; i++) { + for (size_t i = 0; i < devices.size(); ++i) { + int dev_id = devices[i]; a_arr[i] = nullptr; - platform::SetDeviceId(i); - a_arr[i] = new BuddyAllocator( - std::unique_ptr(new detail::GPUAllocator(i)), - platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()); + platform::SetDeviceId(dev_id); + a_arr[i] = new BuddyAllocator(std::unique_ptr( + new detail::GPUAllocator(dev_id)), + platform::GpuMinChunkSize(), + platform::GpuMaxChunkSize()); VLOG(10) << "\n\nNOTE: each GPU device use " << FLAGS_fraction_of_gpu_memory_to_use * 100 @@ -134,7 +138,9 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) { }); platform::SetDeviceId(gpu_id); - return a_arr[gpu_id]; + auto pos = std::distance(devices.begin(), + std::find(devices.begin(), devices.end(), gpu_id)); + return a_arr[pos]; } #endif diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 6954e4c6a9df8dea01ec2b0f193965d835503b17..ca89d91aadb2d3e9005e6dd06cef124428d7e250 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/split.h" #ifndef _WIN32 constexpr static float fraction_of_gpu_memory_to_use = 0.92f; @@ -45,6 +46,15 @@ DEFINE_bool( "input and output must be half precision) and recurrent neural networks " "(RNNs)."); +DEFINE_string(selected_gpus, "", + "A list of device ids separated by comma, like: 0,1,2,3. " + "This option is useful when doing multi process training and " + "each process have only one device (GPU). If you want to use " + "all visible devices, set this to empty string. NOTE: the " + "reason of doing this is that we want to use P2P communication" + "between GPU devices, use CUDA_VISIBLE_DEVICES can only use" + "share-memory only."); + namespace paddle { namespace platform { @@ -121,6 +131,24 @@ int GetCurrentDeviceId() { return device_id; } +//! Get a list of device ids from environment variable or use all. +std::vector GetSelectedDevices() { + // use user specified GPUs in single-node multi-process mode. + std::vector devices; + if (!FLAGS_selected_gpus.empty()) { + auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ','); + for (auto id : devices_str) { + devices.push_back(atoi(id.c_str())); + } + } else { + int count = GetCUDADeviceCount(); + for (int i = 0; i < count; ++i) { + devices.push_back(i); + } + } + return devices; +} + void SetDeviceId(int id) { // TODO(qijun): find a better way to cache the cuda device count PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 6a0b3c8e02d49068c2dbe14c7feea7e139947694..1e1ab2503f53fe20bbe62c48f65d8535947f1aa8 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include namespace paddle { namespace platform { @@ -47,6 +48,9 @@ int GetCUDAMaxThreadsPerMultiProcessor(int i); //! Get the current GPU device id in system. int GetCurrentDeviceId(); +//! Get a list of device ids from environment variable or use all. +std::vector GetSelectedDevices(); + //! Set the GPU device id for next execution. void SetDeviceId(int device_id); diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 258779ba51026d0cc418257a37b78f346fa48efa..51b46450e419c6641b415a58b7551f7dd56627b2 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/string/split.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" #endif @@ -82,10 +83,8 @@ void InitDevices(bool init_p2p) { std::vector devices; #ifdef PADDLE_WITH_CUDA try { - int count = platform::GetCUDADeviceCount(); - for (int i = 0; i < count; ++i) { - devices.push_back(i); - } + // use user specified GPUs in single-node multi-process mode. + devices = platform::GetSelectedDevices(); } catch (const std::exception &exp) { LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; } @@ -95,20 +94,15 @@ void InitDevices(bool init_p2p) { void InitDevices(bool init_p2p, const std::vector devices) { std::vector places; - int count = 0; -#ifdef PADDLE_WITH_CUDA - try { - count = platform::GetCUDADeviceCount(); - } catch (const std::exception &exp) { - LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; - } -#endif for (size_t i = 0; i < devices.size(); ++i) { - if (devices[i] >= count || devices[i] < 0) { + // In multi process multi gpu mode, we may have gpuid = 7 + // but count = 1. + if (devices[i] < 0) { LOG(WARNING) << "Invalid devices id."; continue; } + places.emplace_back(platform::CUDAPlace(devices[i])); } if (init_p2p) { diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index fc903b548c70e9b72c6121dd24c014973e3cd1d4..7c539d25f6dd02fc09aa1234d7bf0164b54a610f 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -97,7 +97,7 @@ struct NCCLContextMap { order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - if (places.size() <= 1) { + if (places.size() <= 1 && num_trainers == 1) { return; } std::unique_ptr comms(new ncclComm_t[order_.size()]); @@ -111,12 +111,19 @@ struct NCCLContextMap { { int nranks = num_trainers * order_.size(); NCCLGroupGuard gurad; - for (auto &gpu_id : order_) { - int rank = trainer_id * order_.size() + gpu_id; - VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks; + for (size_t i = 0; i < order_.size(); ++i) { + int gpu_id = order_[i]; + int rank; + if (order_.size() > 1) { + rank = trainer_id * order_.size() + i; + } else { + rank = trainer_id; + } + VLOG(30) << "init nccl rank: " << rank << " nranks: " << nranks + << "gpu id: " << gpu_id; PADDLE_ENFORCE(cudaSetDevice(gpu_id)); PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( - comms.get() + gpu_id, nranks, *nccl_id, rank)); + comms.get() + i, nranks, *nccl_id, rank)); } } } diff --git a/paddle/fluid/string/CMakeLists.txt b/paddle/fluid/string/CMakeLists.txt index 8572dc1e8e543b552e3ed5a180ec942faf90a624..169a925d12328e7d1df744635445b5674c19b125 100644 --- a/paddle/fluid/string/CMakeLists.txt +++ b/paddle/fluid/string/CMakeLists.txt @@ -3,3 +3,4 @@ cc_library(pretty_log SRCS pretty_log.cc) cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags) cc_test(to_string_test SRCS to_string_test.cc) +cc_test(split_test SRCS split_test.cc) diff --git a/paddle/fluid/string/split.h b/paddle/fluid/string/split.h new file mode 100644 index 0000000000000000000000000000000000000000..ccb96b8a9cb68f03acbca592a2149ba5001f34d2 --- /dev/null +++ b/paddle/fluid/string/split.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +namespace paddle { +namespace string { + +static inline std::vector Split(std::string const& original, + char separator) { + std::vector results; + std::string token; + std::istringstream is(original); + while (std::getline(is, token, separator)) { + if (!token.empty()) { + results.push_back(token); + } + } + return results; +} + +} // namespace string +} // namespace paddle diff --git a/paddle/fluid/string/split_test.cc b/paddle/fluid/string/split_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c85dc1eed40dbe25d922c0f4810a747d1bd2d60f --- /dev/null +++ b/paddle/fluid/string/split_test.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/string/split.h" + +#include + +#include "gtest/gtest.h" + +TEST(StringSplit, StringSplit) { + std::string to_split = "0,1,2,3,4,5"; + int i = 0; + for (auto s : paddle::string::Split(to_split, ',')) { + EXPECT_EQ(atoi(s.c_str()), i); + i++; + } +} diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index a1ffbf42622bcda44ec038edf20811fb0032891f..2a53519188e7454b54424cfdd4a713ae37a2326b 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -147,7 +147,7 @@ def __bootstrap__(): read_env_flags += [ 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', - 'cudnn_exhaustive_search' + 'cudnn_exhaustive_search', 'selected_gpus' ] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index bdcd045341212d6cf9dbfbc3cebc72f320e37e9d..dc27a8eabb5e3671250958b67ea7212ad9bdf08b 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -95,7 +95,14 @@ class ParallelExecutor(object): self._places = [] self._act_places = [] if use_cuda: - for i in six.moves.range(core.get_cuda_device_count()): + gpus = [] + gpus_env = os.getenv("FLAGS_selected_gpus") + if gpus_env: + gpus = [int(s) for s in gpus_env.split(",")] + else: + for i in six.moves.range(core.get_cuda_device_count()): + gpus.append(i) + for i in gpus: p = core.Place() self._act_places.append(core.CUDAPlace(i)) p.set_place(self._act_places[-1])