diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index b184ea8a7160192ce40633d76160c70acdf13fa3..a2cde8bdd51ad3d65879f69fcf46bef3bb4da762 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -20,7 +20,9 @@ namespace paddle { namespace distributed { FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { - // Initialize Executor + bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); + PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( + "Error occurs while parsing string to proto")); } FleetExecutor::~FleetExecutor() { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto index 3db8984b5dcff1a4714fbb5352d615d59340d8b2..c817f7432271e628ff45f21b50ef6c857313ee02 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto @@ -15,7 +15,13 @@ syntax = "proto2"; package paddle.distributed; +message RankInfo { + required int64 rank = 1; + required string ip_port = 2; +} + message FleetExecutorDesc { optional string grain = 1 [ default = "coarse" ]; - repeated string addrs = 2; // "ip:port" of all ranks + optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor + repeated RankInfo cluster_info = 3; } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index ad59ff657971c21659c4f1d9965a4024ff177ea5..02c0806ff849997ec28e3b34beeb5bd3a6952a11 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1851,7 +1851,19 @@ class Executor(object): use_program_cache) from ..distributed.fleet.proto import fleet_executor_desc_pb2 from google.protobuf import text_format + cur_rank = os.getenv("PADDLE_TRAINER_ID") + trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS") fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() + if cur_rank and trainer_endpoints_str: + fleet_exe_desc.cur_rank = int(cur_rank) + trainer_endpoints = trainer_endpoints_str.split(',') + for rank, endpoint in enumerate(trainer_endpoints): + rank_info = fleet_executor_desc_pb2.RankInfo() + rank_info.rank = rank + rank_info.ip_port = endpoint + fleet_exe_desc.cluster_info.append(rank_info) + else: + logging.warning("Fleet Executor will run on single device only.") fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe.init(program._pipeline_opt["section_program"].desc) fleet_exe.run() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor.py index 1d042547e2067a195b713d9901e4bebd18ace906..538f7bb8750f0f2cb14fae23cc111fac165b4404 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import os import paddle import paddle.fluid as fluid @@ -38,6 +39,16 @@ class TestFleetExecutor(unittest.TestCase): for place in places: self.run_fleet_executor(place) + def test_dist_executor_on_multi_devices(self): + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:7000,127.0.0.1:7001,127.0.0.1:7002" + places = [fluid.CPUPlace()] + if fluid.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.run_fleet_executor(place) + if __name__ == "__main__": unittest.main()