未验证 提交 31673a92 编写于 作者: L LiYuRio 提交者: GitHub

Get global cluster information (#37084)

上级 6c183a8e
......@@ -20,7 +20,9 @@ namespace paddle {
namespace distributed {
FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
// Initialize Executor
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
}
FleetExecutor::~FleetExecutor() {
......
......@@ -15,7 +15,13 @@
syntax = "proto2";
package paddle.distributed;
message RankInfo {
required int64 rank = 1;
required string ip_port = 2;
}
message FleetExecutorDesc {
optional string grain = 1 [ default = "coarse" ];
repeated string addrs = 2; // "ip:port" of all ranks
optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 3;
}
......@@ -1851,7 +1851,19 @@ class Executor(object):
use_program_cache)
from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format
cur_rank = os.getenv("PADDLE_TRAINER_ID")
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS")
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
if cur_rank and trainer_endpoints_str:
fleet_exe_desc.cur_rank = int(cur_rank)
trainer_endpoints = trainer_endpoints_str.split(',')
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
else:
logging.warning("Fleet Executor will run on single device only.")
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
fleet_exe.init(program._pipeline_opt["section_program"].desc)
fleet_exe.run()
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import unittest
import os
import paddle
import paddle.fluid as fluid
......@@ -38,6 +39,16 @@ class TestFleetExecutor(unittest.TestCase):
for place in places:
self.run_fleet_executor(place)
def test_dist_executor_on_multi_devices(self):
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:7000,127.0.0.1:7001,127.0.0.1:7002"
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.run_fleet_executor(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册