未验证 提交 0434b828 编写于 作者: L LiYuRio 提交者: GitHub

Make tcp store as a global instance (#55956)

* make tcp store a global instance

* fix windows compile error
上级 02e6347d
......@@ -28,6 +28,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
......@@ -109,6 +110,9 @@ void BindTCPStore(py::module *m) {
py::arg("world_size"),
py::arg("timeout") = 900,
py::call_guard<py::gil_scoped_release>());
m->def("create_or_get_global_tcp_store",
&phi::distributed::CreateOrGetGlobalTCPStore);
}
} // namespace pybind
......
......@@ -3,14 +3,8 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")
if(WITH_DISTRIBUTE)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_utils.cc
r_to_s_reshard_function.cc)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
reshard_split_functor.cc r_to_s_reshard_function.cc)
endif()
collect_srcs(
......@@ -20,4 +14,5 @@ collect_srcs(
process_mesh.cc
dist_attr.cc
dist_mapper.cc
reshard_utils.cc
${DISTRIBUTED_SRCS})
......@@ -20,6 +20,7 @@
namespace phi {
namespace distributed {
using auto_parallel::str_split;
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(),
......@@ -33,15 +34,6 @@ bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) {
[](int64_t value) { return value == -1; });
}
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids();
......@@ -80,5 +72,66 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
return split_axis_to_mesh_axis;
}
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}
namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}
PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}
} // namespace
std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}
uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);
static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}
} // namespace distributed
} // namespace phi
......@@ -16,8 +16,12 @@
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -31,8 +35,6 @@ bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);
bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping);
int64_t GetCurGlobalRank();
// Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1].
......@@ -46,5 +48,15 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);
int64_t GetCurGlobalRank();
std::string GetMasterAddr();
int64_t GetGlobalWorldSize();
uint16_t GetMasterPort();
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();
} // namespace distributed
} // namespace phi
......@@ -13,7 +13,6 @@
# limitations under the License.
import datetime
import os
import paddle
......@@ -320,32 +319,18 @@ def is_available():
def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint is None:
master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
assert (
master_endpoint is not None
), "Please set PADDLE_MASTER enviroment variable."
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
is_master = rank == 0
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
store = core.create_or_get_global_tcp_store()
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册