未验证 提交 31673a92 编写于 作者: L LiYuRio 提交者: GitHub

Get global cluster information (#37084)

上级 6c183a8e
...@@ -20,7 +20,9 @@ namespace paddle { ...@@ -20,7 +20,9 @@ namespace paddle {
namespace distributed { namespace distributed {
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
// Initialize Executor bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
} }
FleetExecutor::~FleetExecutor() { FleetExecutor::~FleetExecutor() {
......
...@@ -15,7 +15,13 @@ ...@@ -15,7 +15,13 @@
syntax = "proto2"; syntax = "proto2";
package paddle.distributed; package paddle.distributed;
message RankInfo {
required int64 rank = 1;
required string ip_port = 2;
}
message FleetExecutorDesc { message FleetExecutorDesc {
optional string grain = 1 [ default = "coarse" ]; optional string grain = 1 [ default = "coarse" ];
repeated string addrs = 2; // "ip:port" of all ranks optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 3;
} }
...@@ -1851,7 +1851,19 @@ class Executor(object): ...@@ -1851,7 +1851,19 @@ class Executor(object):
use_program_cache) use_program_cache)
from ..distributed.fleet.proto import fleet_executor_desc_pb2 from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format from google.protobuf import text_format
cur_rank = os.getenv("PADDLE_TRAINER_ID")
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS")
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
if cur_rank and trainer_endpoints_str:
fleet_exe_desc.cur_rank = int(cur_rank)
trainer_endpoints = trainer_endpoints_str.split(',')
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
else:
logging.warning("Fleet Executor will run on single device only.")
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
fleet_exe.init(program._pipeline_opt["section_program"].desc) fleet_exe.init(program._pipeline_opt["section_program"].desc)
fleet_exe.run() fleet_exe.run()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import os
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -38,6 +39,16 @@ class TestFleetExecutor(unittest.TestCase): ...@@ -38,6 +39,16 @@ class TestFleetExecutor(unittest.TestCase):
for place in places: for place in places:
self.run_fleet_executor(place) self.run_fleet_executor(place)
def test_dist_executor_on_multi_devices(self):
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:7000,127.0.0.1:7001,127.0.0.1:7002"
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.run_fleet_executor(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册