diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index bf75be074664da4e4a9c3d8f4c84b93af4dffdb8..ab68a5248cf7f4b87b822fc5574228636e531768 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -334,45 +334,48 @@ class PaddleCloudRoleMaker(RoleMakerBase): def generate_role(self): if not self._role_is_generated: if not self._is_collective: - self.port = os.getenv("PADDLE_PORT", - "6174") # port of current server - self.pserver_ips = os.getenv("PADDLE_PSERVERS", - "") # ip of server - - if "," in self.port: - ports = self.port.split(",") - else: - ports = [self.port for i in self.pserver_ips.split(",")] - eplist = [] - # note that, we usually assign the same port to different ips - # if we run parameter server training in local mode - # port should be different in environment variables - for i, ip in enumerate(self.pserver_ips.split(",")): - eplist.append(':'.join([ip, ports[i]])) - self.endpoints = ",".join(eplist) - self._trainers_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) - # ip of current node, either a worker or a pserver - current_ip = os.getenv("POD_IP", "") - if current_ip == "": - self._current_endpoint = os.getenv("CURRENT_ENDPOINT") - else: - self._current_endpoint = current_ip + ports[0] - self.role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") - # for trainer, only POD_IP and current trainer id is needed - # we usually do not need to know other trainer ips - self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) - self.eplist = eplist - self.endpoints = self.endpoints.split(",") - self._server_endpoints = self.endpoints - self._worker_endpoints = self.endpoints - if self.role.upper() == "PSERVER": - # current endpoint index among all pservers - self._current_id = self.endpoints.index( - self._current_endpoint) - self._role = Role.SERVER - else: - self._current_id = self.trainer_id - self._role = Role.WORKER + try: + port = os.environ["PADDLE_PORT"] + pserver_ips = os.environ["PADDLE_PSERVERS"].split(",") + if "," in port: + ports = port.split(",") + else: + ports = [port] * len(pserver_ips) + eplist = [] + # note that, we usually assign the same port to different ips + # if we run parameter server training in local mode + # port should be different in environment variables + for i, ip in enumerate(pserver_ips): + eplist.append(':'.join([ip, ports[i]])) + + trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"]) + training_role = os.environ["TRAINING_ROLE"] + + if training_role not in ["TRAINER", "PSERVER"]: + raise ValueError( + "TRAINING_ROLE must be PSERVER or TRAINER") + + if training_role == "TRAINER": + role = Role.WORKER + current_id = int(os.environ["PADDLE_TRAINER_ID"]) + elif training_role == "PSERVER": + role = Role.SERVER + cur_ip = os.environ["POD_IP"] + cur_idx = pserver_ips.index(cur_ip) + current_id = eplist.index(":".join( + [cur_ip, ports[cur_idx]])) + else: + raise ValueError( + "TRAINING_ROLE must be PSERVER or TRAINER") + except ValueError as ve: + raise ValueError( + "something wrong with PaddleCloud, please check environment" + ) + + self._trainers_num = trainers_num + self._server_endpoints = eplist + self._role = role + self._current_id = current_id else: self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._training_role = os.getenv("PADDLE_TRAINING_ROLE", diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ee00a3188d1f4efb50c9cea70d3a55ff311ed0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest + +import paddle.fluid.incubate.fleet.base.role_maker as role_maker + + +class TestCloudRoleMaker(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_PORT"] = "36001" + os.environ["PADDLE_PSERVERS"] = "127.0.0.1,127.0.0.2" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + + def test_tr_rolemaker(self): + os.environ["TRAINING_ROLE"] = "TRAINER" + os.environ["PADDLE_TRAINER_ID"] = "0" + + ro = role_maker.PaddleCloudRoleMaker(is_collective=False) + ro.generate_role() + + self.assertTrue(ro.is_worker()) + self.assertFalse(ro.is_server()) + self.assertEqual(ro.worker_num(), 2) + + def test_ps_rolemaker(self): + os.environ["TRAINING_ROLE"] = "PSERVER" + os.environ["POD_IP"] = "127.0.0.1" + + ro = role_maker.PaddleCloudRoleMaker(is_collective=False) + ro.generate_role() + self.assertFalse(ro.is_worker()) + self.assertTrue(ro.is_server()) + self.assertEqual(ro.worker_num(), 2) + + def test_traing_role(self): + os.environ["TRAINING_ROLE"] = "TEST" + ro = role_maker.PaddleCloudRoleMaker(is_collective=False) + self.assertRaises(ValueError, ro.generate_role) + + +if __name__ == "__main__": + unittest.main()