未验证 提交 0434b828 编写于 作者: L LiYuRio 提交者: GitHub

Make tcp store as a global instance (#55956)

* make tcp store a global instance

* fix windows compile error
上级 02e6347d
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
...@@ -109,6 +110,9 @@ void BindTCPStore(py::module *m) { ...@@ -109,6 +110,9 @@ void BindTCPStore(py::module *m) {
py::arg("world_size"), py::arg("world_size"),
py::arg("timeout") = 900, py::arg("timeout") = 900,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m->def("create_or_get_global_tcp_store",
&phi::distributed::CreateOrGetGlobalTCPStore);
} }
} // namespace pybind } // namespace pybind
......
...@@ -3,14 +3,8 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto) ...@@ -3,14 +3,8 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "") set(DISTRIBUTED_SRCS "")
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
list( list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
APPEND reshard_split_functor.cc r_to_s_reshard_function.cc)
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_utils.cc
r_to_s_reshard_function.cc)
endif() endif()
collect_srcs( collect_srcs(
...@@ -20,4 +14,5 @@ collect_srcs( ...@@ -20,4 +14,5 @@ collect_srcs(
process_mesh.cc process_mesh.cc
dist_attr.cc dist_attr.cc
dist_mapper.cc dist_mapper.cc
reshard_utils.cc
${DISTRIBUTED_SRCS}) ${DISTRIBUTED_SRCS})
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
namespace phi { namespace phi {
namespace distributed { namespace distributed {
using auto_parallel::str_split;
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) { bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(), return std::any_of(dims_mapping.begin(),
...@@ -33,15 +34,6 @@ bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) { ...@@ -33,15 +34,6 @@ bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) {
[](int64_t value) { return value == -1; }); [](int64_t value) { return value == -1; });
} }
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape(); const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids(); const auto& process_ids = process_mesh.process_ids();
...@@ -80,5 +72,66 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping( ...@@ -80,5 +72,66 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
return split_axis_to_mesh_axis; return split_axis_to_mesh_axis;
} }
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}
namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}
PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}
} // namespace
std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}
uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);
static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -16,8 +16,12 @@ ...@@ -16,8 +16,12 @@
#include <cstdint> #include <cstdint>
#include <map> #include <map>
#include <memory>
#include <string>
#include <vector> #include <vector>
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -31,8 +35,6 @@ bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping); ...@@ -31,8 +35,6 @@ bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);
bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping); bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping);
int64_t GetCurGlobalRank();
// Get the coordinate of cur rank in process mesh. For example, the process mesh // Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will // is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1]. // return [2, 0]; if the current rank is 3, then will return [1, 1].
...@@ -46,5 +48,15 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh); ...@@ -46,5 +48,15 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping( std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping); const std::vector<int64_t>& dims_mapping);
int64_t GetCurGlobalRank();
std::string GetMasterAddr();
int64_t GetGlobalWorldSize();
uint16_t GetMasterPort();
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import os
import paddle import paddle
...@@ -320,32 +319,18 @@ def is_available(): ...@@ -320,32 +319,18 @@ def is_available():
def _init_parallel_env(backend): def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None) store = core.create_or_get_global_tcp_store()
if master_endpoint is None: global_env = _get_global_env()
master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0] rank = global_env.rank
assert ( world_size = global_env.world_size
master_endpoint is not None dev_id = global_env.device_id
), "Please set PADDLE_MASTER enviroment variable."
if master_endpoint: if backend == "gloo":
master_addr = master_endpoint.split(":")[0] core.CommContextManager.create_gloo_comm_context(
master_port = int(master_endpoint.split(":")[1]) store, "0", rank, world_size
global_env = _get_global_env() )
rank = global_env.rank elif backend == "nccl":
world_size = global_env.world_size core.CommContextManager.set_cuda_device_id(dev_id)
dev_id = global_env.device_id core.CommContextManager.create_nccl_comm_context(
is_master = rank == 0 store, "0", rank, world_size
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
) )
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册