diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 317d3900ef50def47dcfd7708c42ce8e6a937146..76877cdfae741e159c2604670afea4137c26f6f4 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include #include +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/store/tcp_store.h" @@ -109,6 +110,9 @@ void BindTCPStore(py::module *m) { py::arg("world_size"), py::arg("timeout") = 900, py::call_guard()); + + m->def("create_or_get_global_tcp_store", + &phi::distributed::CreateOrGetGlobalTCPStore); } } // namespace pybind diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index d4af259a5906cd2b297f8caadea2cfcd5e154b3a..503fbebe2595751d05ba7b8d0b63350cd80cbbb2 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -3,14 +3,8 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto) set(DISTRIBUTED_SRCS "") if(WITH_DISTRIBUTE) - list( - APPEND - DISTRIBUTED_SRCS - dist_tensor.cc - reshard_function.cc - reshard_split_functor.cc - reshard_utils.cc - r_to_s_reshard_function.cc) + list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc + reshard_split_functor.cc r_to_s_reshard_function.cc) endif() collect_srcs( @@ -20,4 +14,5 @@ collect_srcs( process_mesh.cc dist_attr.cc dist_mapper.cc + reshard_utils.cc ${DISTRIBUTED_SRCS}) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index b777b53c2304384e8eedb33d6206a638f86efce9..cdbe9d981715fc9a8f77c0eb9a9f874671c1ec07 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -20,6 +20,7 @@ namespace phi { namespace distributed { +using auto_parallel::str_split; bool IsDimsMappingShard(const std::vector& dims_mapping) { return std::any_of(dims_mapping.begin(), @@ -33,15 +34,6 @@ bool IsDimsMappingReplicated(const std::vector& dims_mapping) { [](int64_t value) { return value == -1; }); } -int64_t GetCurGlobalRank() { - const char* cur_rank = std::getenv("PADDLE_TRAINER_ID"); - PADDLE_ENFORCE_NOT_NULL( - cur_rank, - phi::errors::NotFound( - "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); - return std::atoi(cur_rank); -} - std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { const auto& process_shape = process_mesh.shape(); const auto& process_ids = process_mesh.process_ids(); @@ -80,5 +72,66 @@ std::map GetSplitAxisWithDimsMapping( return split_axis_to_mesh_axis; } +int64_t GetCurGlobalRank() { + const char* cur_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + cur_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + return std::atoi(cur_rank); +} + +int64_t GetGlobalWorldSize() { + const char* world_size = std::getenv("PADDLE_TRAINERS_NUM"); + PADDLE_ENFORCE_NOT_NULL( + world_size, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINERS_NUM' cannot be found.")); + return std::atoi(world_size); +} + +namespace { +std::string GetMasterEndpoint() { + const char* master_endpoint = std::getenv("PADDLE_MASTER"); + if (!master_endpoint) { + const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS"); + PADDLE_ENFORCE_NOT_NULL( + trainer_endpoints, + phi::errors::NotFound("The environment variable " + "'PADDLE_TRAINER_ENDPOINTS' cannot be found.")); + return str_split(trainer_endpoints, ",")[0]; + } + + PADDLE_ENFORCE_NOT_NULL( + master_endpoint, + phi::errors::NotFound( + "The environment variable 'PADDLE_MASTER' cannot be found.")); + return master_endpoint; +} + +} // namespace + +std::string GetMasterAddr() { + std::string master_endpoint = GetMasterEndpoint(); + return str_split(master_endpoint, ":")[0]; +} + +uint16_t GetMasterPort() { + std::string master_endpoint = GetMasterEndpoint(); + return std::stoi(str_split(master_endpoint, ":")[1]); +} + +std::shared_ptr CreateOrGetGlobalTCPStore() { + std::string host = GetMasterAddr(); + uint16_t port = GetMasterPort(); + int64_t cur_rank = GetCurGlobalRank(); + int64_t world_size = GetGlobalWorldSize(); + bool is_master = (cur_rank == 0); + + static std::shared_ptr store = + std::make_shared(host, port, is_master, world_size); + return store; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index dceaa5150a6b0a4835939536c5e86882d44e3dca..52403d0c560267b2f88615badd7e4bd10c64429e 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -16,8 +16,12 @@ #include #include +#include +#include #include +#include "paddle/phi/core/distributed/store/tcp_store.h" + namespace phi { namespace distributed { namespace auto_parallel { @@ -31,8 +35,6 @@ bool IsDimsMappingShard(const std::vector& dims_mapping); bool IsDimsMappingReplicated(const std::vector& dims_mapping); -int64_t GetCurGlobalRank(); - // Get the coordinate of cur rank in process mesh. For example, the process mesh // is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will // return [2, 0]; if the current rank is 3, then will return [1, 1]. @@ -46,5 +48,15 @@ std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh); std::map GetSplitAxisWithDimsMapping( const std::vector& dims_mapping); +int64_t GetCurGlobalRank(); + +std::string GetMasterAddr(); + +int64_t GetGlobalWorldSize(); + +uint16_t GetMasterPort(); + +std::shared_ptr CreateOrGetGlobalTCPStore(); + } // namespace distributed } // namespace phi diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 96725ed7cad41c2359ee10125e0e36af2379f5d0..a14d67429e91789a489a315796d4aea30039ca57 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -13,7 +13,6 @@ # limitations under the License. import datetime -import os import paddle @@ -320,32 +319,18 @@ def is_available(): def _init_parallel_env(backend): - master_endpoint = os.getenv("PADDLE_MASTER", None) - if master_endpoint is None: - master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0] - assert ( - master_endpoint is not None - ), "Please set PADDLE_MASTER enviroment variable." - if master_endpoint: - master_addr = master_endpoint.split(":")[0] - master_port = int(master_endpoint.split(":")[1]) - global_env = _get_global_env() - rank = global_env.rank - world_size = global_env.world_size - dev_id = global_env.device_id - is_master = rank == 0 - store = core.TCPStore( - master_addr, - master_port, - is_master, - world_size, + store = core.create_or_get_global_tcp_store() + global_env = _get_global_env() + rank = global_env.rank + world_size = global_env.world_size + dev_id = global_env.device_id + + if backend == "gloo": + core.CommContextManager.create_gloo_comm_context( + store, "0", rank, world_size + ) + elif backend == "nccl": + core.CommContextManager.set_cuda_device_id(dev_id) + core.CommContextManager.create_nccl_comm_context( + store, "0", rank, world_size ) - if backend == "gloo": - core.CommContextManager.create_gloo_comm_context( - store, "0", rank, world_size - ) - elif backend == "nccl": - core.CommContextManager.set_cuda_device_id(dev_id) - core.CommContextManager.create_nccl_comm_context( - store, "0", rank, world_size - )