From 31673a92b5aeea02374bea941865b0d118d50cf9 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 11 Nov 2021 11:04:03 +0800 Subject: [PATCH] Get global cluster information (#37084) --- .../distributed/fleet_executor/fleet_executor.cc | 4 +++- .../fleet_executor/fleet_executor_desc.proto | 8 +++++++- python/paddle/fluid/executor.py | 12 ++++++++++++ .../fluid/tests/unittests/test_fleet_executor.py | 11 +++++++++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index b184ea8a716..a2cde8bdd51 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -20,7 +20,9 @@ namespace paddle { namespace distributed { FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { - // Initialize Executor + bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); + PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( + "Error occurs while parsing string to proto")); } FleetExecutor::~FleetExecutor() { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto index 3db8984b5dc..c817f743227 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto @@ -15,7 +15,13 @@ syntax = "proto2"; package paddle.distributed; +message RankInfo { + required int64 rank = 1; + required string ip_port = 2; +} + message FleetExecutorDesc { optional string grain = 1 [ default = "coarse" ]; - repeated string addrs = 2; // "ip:port" of all ranks + optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor + repeated RankInfo cluster_info = 3; } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index ad59ff65797..02c0806ff84 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1851,7 +1851,19 @@ class Executor(object): use_program_cache) from ..distributed.fleet.proto import fleet_executor_desc_pb2 from google.protobuf import text_format + cur_rank = os.getenv("PADDLE_TRAINER_ID") + trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS") fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() + if cur_rank and trainer_endpoints_str: + fleet_exe_desc.cur_rank = int(cur_rank) + trainer_endpoints = trainer_endpoints_str.split(',') + for rank, endpoint in enumerate(trainer_endpoints): + rank_info = fleet_executor_desc_pb2.RankInfo() + rank_info.rank = rank + rank_info.ip_port = endpoint + fleet_exe_desc.cluster_info.append(rank_info) + else: + logging.warning("Fleet Executor will run on single device only.") fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe.init(program._pipeline_opt["section_program"].desc) fleet_exe.run() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor.py index 1d042547e20..538f7bb8750 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import os import paddle import paddle.fluid as fluid @@ -38,6 +39,16 @@ class TestFleetExecutor(unittest.TestCase): for place in places: self.run_fleet_executor(place) + def test_dist_executor_on_multi_devices(self): + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:7000,127.0.0.1:7001,127.0.0.1:7002" + places = [fluid.CPUPlace()] + if fluid.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.run_fleet_executor(place) + if __name__ == "__main__": unittest.main() -- GitLab