role_maker.py 38.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defination of Role Makers."""
15
import os
16
import time
17
import numpy as np
18
import warnings
19
from multiprocessing import Process, Manager
20

21
import paddle
22
import paddle.fluid as fluid
23
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
24

25 26
__all__ = []

27 28 29 30

class Role:
    WORKER = 1
    SERVER = 2
31
    HETER_WORKER = 3
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    ALL = 4


class Gloo(object):
    """
    Gloo is a universal class for barrier and collective communication
    """

    class RENDEZVOUS:
        HDFS = 1
        FILE = 2
        HTTP = 3

    def __init__(self):
        self._worker_comm = None
        self._server_comm = None
        self._nodes_comm = None

        self._comm_world = ["worker", "server", "all"]
        self._err_init = "gloo is not initialized, will not communicator with other nodes"
        self._err_type = "gloo initialized error, please check arguments"
        self._err_world = "argument error, comm_world must in {}".format(
            self._comm_world)

        self._is_initialized = False
        self._init_timeout_seconds = 3600
        self._run_timeout_seconds = 9999999

        self._rendezvous = None
        self._role = None
        self._iface = None

        self._role_id = -1
        self._worker_num = -1
        self._server_num = -1
        self._need_init_all = False

    def init(self,
             rendezvous,
             role,
             role_id,
             worker_num,
             server_num,
             need_init_all=False,
             kwargs=None):

        self._rendezvous = rendezvous
        self._role = role
        self._role_id = role_id
        self._worker_num = worker_num
        self._server_num = server_num
        self._need_init_all = need_init_all
84
        self._iface = ""
85 86
        self._prefix = kwargs.get("store.prefix", "")

87
        http_server = None
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        if self._rendezvous == Gloo.RENDEZVOUS.HDFS:
            dfs_name = kwargs.get("dfs.name", "")
            dfs_ugi = kwargs.get("dfs.ugi", "")
            dfs_path = kwargs.get("dfs.path", "")

            if not dfs_name or not dfs_ugi or not dfs_path:
                raise ValueError(self._err_type)
            self._init_dfs(dfs_name, dfs_ugi, dfs_path, self._prefix)

        elif self._rendezvous == Gloo.RENDEZVOUS.FILE:
            fs_path = kwargs.get("dfs.path", "")

            if not fs_path:
                raise ValueError(self._err_type)
            self._init_fs(fs_path, self._prefix)

        elif self._rendezvous == Gloo.RENDEZVOUS.HTTP:
            ip = kwargs.get("http.host", "")
            port = kwargs.get("http.port", "")
107 108
            start_http_server = kwargs.get("start_http_server", False)
            http_server_d = kwargs.get("http_server_d")
109 110 111

            if not ip or not port:
                raise ValueError(self._err_type)
112 113
            http_server = self._init_http(ip, port, self._prefix,
                                          start_http_server, http_server_d)
114 115 116 117
        else:
            raise ValueError(self._err_type)

        self._is_initialized = True
118
        self._http_server = http_server
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173

    def _init_fs(self, fs_path, prefix):
        def init(rank, nodes, role):
            gloo = fluid.core.Gloo()
            gloo.set_rank(rank)
            gloo.set_size(nodes)
            gloo.set_prefix(prefix)
            gloo.set_iface(self._iface)
            gloo.set_timeout_seconds(self._init_timeout_seconds,
                                     self._run_timeout_seconds)
            gloo.set_hdfs_store(os.path.join(fs_path, role), "", "")
            gloo.init()
            return gloo

        if self._role == Role.WORKER:
            rank, nodes = self._get_rank_nodes(Role.WORKER)
            gloo = init(rank, nodes, "WORKER")
            self._worker_comm = gloo
        else:
            rank, nodes = self._get_rank_nodes(Role.SERVER)
            gloo = init(rank, nodes, "SERVER")
            self._server_comm = gloo

        if self._need_init_all:
            rank, nodes = self._get_rank_nodes(Role.ALL)
            gloo = init(rank, nodes, "ALL")
            self._nodes_comm = gloo

    def _init_dfs(self, dfs_name, dfs_ugi, dfs_path, prefix):
        def init(rank, nodes, role):
            gloo = fluid.core.Gloo()
            gloo.set_rank(rank)
            gloo.set_size(nodes)
            gloo.set_prefix(prefix)
            gloo.set_iface(self._iface)
            gloo.set_timeout_seconds(self._init_timeout_seconds,
                                     self._run_timeout_seconds)
            gloo.set_hdfs_store(os.path.join(dfs_path, role), dfs_name, dfs_ugi)
            gloo.init()
            return gloo

        if self._role == Role.WORKER:
            rank, nodes = self._get_rank_nodes(Role.WORKER)
            gloo = init(rank, nodes, "WORKER")
            self._worker_comm = gloo
        else:
            rank, nodes = self._get_rank_nodes(Role.SERVER)
            gloo = init(rank, nodes, "SERVER")
            self._server_comm = gloo

        if self._need_init_all:
            rank, nodes = self._get_rank_nodes(Role.ALL)
            gloo = init(rank, nodes, "ALL")
            self._nodes_comm = gloo

174
    def _init_http(self, ip, port, prefix, start_http_server, http_server_d):
175
        def __start_kv_server(http_server_d, size_d):
176
            print("start http_server: {}, {}".format(port, size_d))
177 178 179 180
            from paddle.distributed.fleet.utils.http_server import KVServer
            http_server = KVServer(port, size_d)
            http_server.start()
            wait_seconds = 5
L
lilong12 已提交
181 182
            while http_server_d.get("running",
                                    False) or not http_server.should_stop():
183 184 185
                time.sleep(wait_seconds)
            http_server.stop()

186
        def init_kv_server(http_server_d):
187 188 189
            worker_key = prefix + '_' + 'worker'
            size_d = {worker_key: self._worker_num, }
            print("worker_key:{}, size: {}".format(worker_key, size_d))
190

191
            http_server_d["running"] = True
192 193
            # child process for http server
            _http_server = Process(
194
                target=__start_kv_server, args=(http_server_d, size_d))
195 196 197 198
            _http_server.daemon = True
            # set running status to True
            # start child process
            _http_server.start()
199
            return _http_server
200 201 202 203 204 205 206 207 208

        def init(rank, nodes, role):
            gloo = fluid.core.Gloo()
            gloo.set_rank(rank)
            gloo.set_size(nodes)
            gloo.set_prefix(prefix)
            gloo.set_iface(self._iface)
            gloo.set_timeout_seconds(self._init_timeout_seconds,
                                     self._run_timeout_seconds)
209
            gloo.set_http_store(ip, port, 'worker')
210 211 212
            ep = ":".join([ip, str(port)])
            wait_server_ready([ep])
            gloo.init()
213 214 215 216
            return gloo

        port = int(port)

217
        if start_http_server:
218
            print("to start http_server")
219
            http_server = init_kv_server(http_server_d)
220 221 222 223 224

        if self._role == Role.WORKER:
            rank, nodes = self._get_rank_nodes(Role.WORKER)
            gloo = init(rank, nodes, "WORKER")
            self._worker_comm = gloo
L
lilong12 已提交
225
        # TODO (sandyhouse): initialize gloo for server and all
226

227 228 229
        if start_http_server:
            http_server_d["running"] = False
            http_server.join()
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264

    def _get_rank_nodes(self, role):
        nodes = 0
        rank = -1

        if role == Role.WORKER:
            nodes = self._worker_num
            rank = self._role_id
        elif role == Role.SERVER:
            nodes = self._server_num
            rank = self._role_id
        elif role == Role.ALL:
            nodes = self._worker_num + self._server_num

            if self._role == Role.WORKER:
                rank = self._role_id
            else:
                rank = self._worker_num + self._role_id
        else:
            ValueError(self._err_type)

        return rank, nodes

    def __get_default_iface(self):
        """
        get default physical interface
        """
        default1 = self.__get_default_iface_from_gateway()
        default2 = self.__get_default_iface_from_interfaces()
        return default2 if default1 == "lo" else default1

    def __get_default_iface_from_gateway(self):
        """
        get default physical interface
        """
1
123malin 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
        res = os.popen("route -A inet").read().strip().split("\n")

        gateway_idx = None
        iface_idx = None
        for item in res:
            item = item.split()
            if "Gateway" in item and "Iface" in item:
                gateway_idx = item.index("Gateway")
                iface_idx = item.index("Iface")
            elif gateway_idx != None and iface_idx != None:
                gateway = None
                if len(item) > gateway_idx:
                    gateway = item[gateway_idx]
                if gateway and gateway != '*' and gateway != "0.0.0.0" and len(
                        item) > iface_idx:
                    return item[iface_idx]
281 282 283 284 285 286
        return "lo"

    def __get_default_iface_from_interfaces(self):
        """
        get default physical interface
        """
1
123malin 已提交
287 288 289 290 291
        res = os.popen("ip -f inet addr | awk NR%3==1").read().strip().split(
            "\n")
        for item in res:
            if "BROADCAST" in item:
                return item.split(":")[1].strip()
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
        return "lo"

    def barrier(self, comm_world):
        """
        dummy barrier, do nothing
        """
        if not self._is_initialized:
            warnings.warn(self._err_init)
            return

        if comm_world not in self._comm_world:
            raise ValueError(self._err_world)

        if comm_world == "worker":
            self._worker_comm.barrier()
        elif comm_world == "server":
            self._server_comm.barrier()
        else:
            self._nodes_comm.barrier()

    def all_reduce(self, input, mode="sum", comm_world="worker"):
        if not self._is_initialized:
            warnings.warn(self._err_init)
            return input

        if comm_world not in self._comm_world:
            raise ValueError(self._err_world)

        input = np.array(input)
        input_shape = input.shape
        input_list = input.reshape(-1).tolist()

        self.barrier(comm_world)

        if comm_world == "worker":
            ans = self._worker_comm.all_reduce(input_list, mode)
        elif comm_world == "server":
            ans = self._server_comm.all_reduce(input_list, mode)
        else:
            ans = self._nodes_comm.all_reduce(input_list, mode)

        output = np.array(ans).reshape(input_shape)
        return output

    def all_gather(self, input, comm_world="worker"):
        """
        dummy all gather, do nothing
        Args:
            obj(any): obj to do all gather
        """
        if not self._is_initialized:
            warnings.warn(self._err_init)
            return input

        if comm_world not in self._comm_world:
            raise ValueError(self._err_world)

        if comm_world == "worker":
            output = self._worker_comm.all_gather(input)
        elif comm_world == "server":
            output = self._server_comm.all_gather(input)
        else:
            output = self._nodes_comm.all_gather(input)

        return output
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373


class RoleMakerBase(object):
    """
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    """

    def __init__(self):
        self._worker_endpoints = []
        self._server_endpoints = []
        self._role_is_generated = False
        self._role = None
        self._current_id = -1

374
    def _is_worker(self):
375 376 377 378 379
        """
        return is_worker() of current process
        """
        raise NotImplementedError("Please implement this method in child class")

380
    def _is_server(self):
381 382 383 384 385
        """
        return is_server() of current process
        """
        raise NotImplementedError("Please implement this method in child class")

386
    def _is_first_worker(self):
387 388 389 390 391 392 393 394
        """
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
        """
        raise NotImplementedError("Please implement this method in child class")

395
    def _worker_num(self):
396 397 398 399 400 401 402 403
        """
        Get current total worker number.

        Returns:
            int: worker number
        """
        raise NotImplementedError("Please implement this method in child class")

404
    def _server_num(self):
405 406 407 408 409 410 411 412
        """
        Get current total server number.

        Returns:
            int: server number
        """
        raise NotImplementedError("Please implement this method in child class")

413
    def _worker_index(self):
414 415 416 417 418 419 420 421
        """
        Get current worker id.

        Returns:
            int: node id
        """
        raise NotImplementedError("Please implement this method in child class")

422
    def _server_index(self):
423 424 425 426 427 428 429 430
        """
        Get current server id.

        Returns:
            int: node id
        """
        raise NotImplementedError("Please implement this method in child class")

431
    def _role_id(self):
432 433 434 435 436 437 438 439
        """
        Get current id.

        Returns:
            int: node id
        """
        raise NotImplementedError("Please implement this method in child class")

440
    def _node_num(self):
441 442 443 444 445 446 447
        """
        Get the training node number
        Returns:
            int: node num
        """
        raise NotImplementedError("Please implement this method in child class")

448
    def _get_trainer_endpoints(self):
449 450 451 452 453
        """
        return trainer endpoints
        """
        return self._worker_endpoints

454
    def _get_pserver_endpoints(self):
455 456 457 458 459 460 461 462 463 464
        """
        return pserver endpoints
        """
        return self._server_endpoints

    def to_string(self):
        return "role: {}, current_id: {}, worker_endpoints: {}, server_endpoints: {}".format(
            self._role, self._current_id, self._worker_endpoints,
            self._server_endpoints)

465 466
    def _all_gather(self, input, comm_world="worker"):
        print("warning: RoleMakerBase does not have all gather worker.")
467 468
        return None

469
    def _all_reduce(self, input, mode="sum", comm_world="worker"):
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
        """
        Args:
            input(list/numpy.array): array of one dim
            output(list/numpy.array): array of one dim
            mode(str): "sum" or "min" or "max"
        """
        print("warning: RoleMakerBase does not have all reduce worker.")
        return None

    def _barrier(self, comm_world):
        """
        barrier between trainers if current role is TRAINER
        """
        print("warning: RoleMakerBase does not have barrier worker.")

485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
    #def _is_heter_worker(self):
    #    """
    #    Return is_heter_worker() of current process
    #    """
    #    raise NotImplementedError("Please implement this method in child class")

    #def _heter_worker_num(self):
    #    """
    #    Get current total heter-worker number.
    #
    #    Returns:
    #        int: heter_worker number
    #    """
    #    raise NotImplementedError("Please implement this method in child class")

    #def _get_heter_worker_endpoints(self):
    #    """
    #    Returns:
    #        string: all heter_trainers'endpoints
    #    """
    #    raise NotImplementedError("Please implement this method in child class")

    #def _get_heter_worker_endpoint(self):
    #    """
    #    Returns:
    #        int: corresponding heter_trainer's endpoint
    #    """
    #    raise NotImplementedError("Please implement this method in child class")
513

514 515

class PaddleCloudRoleMaker(RoleMakerBase):
516
    def __init__(self, is_collective=False, **kwargs):
517 518
        super(PaddleCloudRoleMaker, self).__init__()
        self._is_collective = is_collective
519 520 521
        self._non_distributed = False

        self._kwargs = kwargs
522 523
        self._role_is_generated = False

524 525 526 527 528 529 530 531 532 533
        # for heterps  
        self._stage_id = 1
        self._stage_num = 1
        self._next_heter_trainer_endpoints = []
        self._previous_heter_trainer_endpoints = []
        self._heter_trainer_endpoints = []
        self._heter_trainer_device = "CPU"
        self._is_heter_parameter_server_mode = False
        self._stage_trainers = []

534 535
        self._server_endpoints = []
        self._worker_endpoints = []
536

537
        self._gloo = Gloo()  # gloo instance
538 539

    def _barrier(self, comm_world):
540
        self._gloo.barrier(comm_world)
541

542 543
    def _all_gather(self, input, comm_world="worker"):
        return self._gloo.all_gather(input, comm_world)
544

545 546
    def _all_reduce(self, input, mode="sum", comm_world="worker"):
        return self._gloo.all_reduce(input, mode, comm_world)
547

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
    def _heter_device_type(self):
        """
        return the heter device type that current heter worker is using
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._heter_trainer_device

    def _get_stage_id(self):
        """
       return stage id of current heter worker
       """
        if not self._role_is_generated:
            self._generate_role()
        return self._stage_id

    def _get_stage_trainers(self):
        """
       return trainer num of all stages
       """
        if not self._role_is_generated:
            self._generate_role()
        return self._stage_trainers

    def _get_num_stage(self):
        """
       return stage num
       """
        if not self._role_is_generated:
            self._generate_role()
        return self._stage_num

580
    def _is_worker(self):
581 582 583 584
        """
        whether current process is worker
        """
        if not self._role_is_generated:
585
            self._generate_role()
586 587
        return self._role == Role.WORKER

588
    def _is_server(self):
589 590 591 592
        """
        whether current process is server
        """
        if not self._role_is_generated:
593
            self._generate_role()
594 595
        return self._role == Role.SERVER

596
    def _is_first_worker(self):
597 598 599 600
        """
        whether current process is worker of rank 0
        """
        if not self._role_is_generated:
601
            self._generate_role()
602 603
        return self._role == Role.WORKER and self._current_id == 0

604
    def _worker_index(self):
605 606 607 608
        """
        get index of current worker
        """
        if not self._role_is_generated:
609
            self._generate_role()
610 611
        return self._current_id

612
    def _server_index(self):
613 614 615 616
        """
        get index of current server
        """
        if not self._role_is_generated:
617
            self._generate_role()
618 619
        return self._current_id

620
    def _role_id(self):
621 622 623
        """
        get index of current node
        """
624 625
        if not self._role_is_generated:
            self._generate_role()
626
        return self._current_id
627

628
    def _worker_num(self):
629 630 631 632
        """
        retrun the current number of worker
        """
        if not self._role_is_generated:
633
            self._generate_role()
634 635
        return self._trainers_num

636
    def _server_num(self):
637 638 639 640
        """
        return the current number of server
        """
        if not self._role_is_generated:
641
            self._generate_role()
642 643
        return len(self._get_pserver_endpoints(
        )) if self._get_pserver_endpoints() is not None else 0
644

645
    def _node_num(self):
646 647 648 649
        """
        return the training node number
        """
        if not self._role_is_generated:
650 651
            self._generate_role()
        return self._nodes_num
652

653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
    def _get_node_num(self):
        """
        return the training node number
        """
        if not self._role_is_generated:
            self._generate_role()
        return self._nodes_num

    def _get_local_rank(self):
        if not self._role_is_generated:
            self._generate_role()
        return self._local_rank

    def _get_local_device_ids(self):
        if not self._role_is_generated:
            self._generate_role()
        return self._local_device_ids

    def _get_world_device_ids(self):
        if not self._role_is_generated:
            self._generate_role()
        return self._world_device_ids

676
    def _get_trainer_endpoints(self):
677 678 679 680
        """
        get endpoint of all trainers
        """
        if not self._role_is_generated:
681
            self._generate_role()
682 683
        return self._worker_endpoints

684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
    def _get_trainer_endpoint(self):
        if not self._role_is_generated:
            self._generate_role()
        assert self._role == Role.WORKER, "get_trainer_endpoint should be called by trainer"
        return self._cur_endpoint

    def _get_heter_worker_endpoints(self):
        """
        Returns:
            string: all heter_trainers'endpoints
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized"
        return self._heter_trainer_endpoints

    def _get_heter_worker_endpoint(self):
        """
        Returns:
            int: corresponding heter_trainer's endpoint
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._role == Role.HETER_WORKER, "_get_heter_worker_endpoint should be invoked by heter worker"
        return self._cur_endpoint

710
    def _get_pserver_endpoints(self):
711 712 713 714
        """
        get endpoint of all pservers
        """
        if not self._role_is_generated:
715
            self._generate_role()
716 717
        return self._server_endpoints

718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
    def _get_previous_trainers(self):
        """
        invoked by heter worker 
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._role in (
            Role.WORKER, Role.HETER_WORKER
        ), "_get_previous_trainers should be invoked by trainer or heter worker"
        return self._previous_heter_trainer_endpoints

    def _get_next_trainers(self):
        """
        invoked by heter worker 
        """
        if not self._role_is_generated:
            self._generate_role()
        assert self._role in (
            Role.WORKER, Role.HETER_WORKER
        ), "_get_next_trainers should be invoked by trainer or heter worker"
        return self._next_heter_trainer_endpoints

740 741 742 743 744 745
    def _is_non_distributed(self):
        """
        Return True if indispensable environment for fleetrun is not found
        (use python-run to launch fleet-code directly)
        """
        if not self._role_is_generated:
746
            self._generate_role()
747 748
        return self._non_distributed

749 750 751 752 753
    def _heter_worker_num(self):
        """
        get heter worker nums
        """
        if not self._role_is_generated:
754
            self._generate_role()
755 756 757 758 759 760 761
        return self._heter_trainers_num

    def _is_heter_worker(self):
        """
        whether current process is heter worker
        """
        if not self._role_is_generated:
762
            self._generate_role()
763 764
        return self._role == Role.HETER_WORKER

765
    def _ps_env(self):
766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
        # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
        # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
        self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)

        if self._server_endpoints is None:
            # back to non_distributed execution.
            self._server_endpoints = ""
            self._trainers_num = 1
            self._role = Role.WORKER
            self._current_id = 0
            self._nodes_num = 1
            self._heter_trainers_num = 0
            self._heter_trainer_endpoints = None
            self._non_distributed = True
            return
781

782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
        self._server_endpoints = self._server_endpoints.split(",")

        self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
        if self._worker_endpoints != None:
            self._worker_endpoints = self._worker_endpoints.split(",")
        else:
            self._worker_endpoints = []

        trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
        if trainers_num == None:
            raise ValueError(
                "Can not find PADDLE_TRAINERS_NUM, please check your environment."
            )
        trainers_num = int(trainers_num)

        training_role = os.getenv("TRAINING_ROLE", None)
        if training_role == None:
            raise ValueError(
                "Can not find TRAINING_ROLE, please check your environment.")
801

802 803 804 805 806
        if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
            raise ValueError(
                "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
                format(training_role))

807 808 809 810 811 812 813
        # For Heter Parameter Server env setting
        next_heter_trainer_eplist = os.getenv(
            "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST", "")
        previous_heter_trainer_eplist = os.getenv(
            "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST", "")
        all_heter_trainer_eplist = os.getenv(
            "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST", "")
814

815 816
        if all_heter_trainer_eplist != "":
            self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
817
            self._is_heter_parameter_server_mode = True
818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848
            self._heter_trainers_num = len(self._heter_trainer_endpoints)

            if previous_heter_trainer_eplist == "":
                assert training_role in (
                    "TRAINER", "PSERVER"
                ), "training_role should be trainer or pserver"
            else:
                try:
                    self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist.split(
                        ",")
                except:
                    raise ValueError(
                        "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
                    )

            if next_heter_trainer_eplist == "":
                assert training_role in (
                    "HETER_TRAINER", "PSERVER"
                ), "training_role should be heter trainer or pserver"
            else:
                try:
                    self._next_heter_trainer_endpoints = next_heter_trainer_eplist.split(
                        ",")
                except:
                    raise ValueError(
                        "Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
                    )

            #self._is_heter_parameter_server_mode = True
            #heter_trainers_num = len(all_heter_trainer_eplist.split(","))
            #self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
849 850
        else:
            self._is_heter_parameter_server_mode = False
851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867
            self._heter_trainers_num = 0

            #if previous_heter_trainer_eplist == "":
            #    self._is_heter_parameter_server_mode = False
            #    heter_trainers_num = 0
            #else:  ## for the last heter worker
            #    try:
            #        previous_heter_trainer_eplist = os.environ[
            #            "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"].split(",")
            #        self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist
            #    except:
            #        raise ValueError(
            #            "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
            #        )
            #    self._is_heter_parameter_server_mode = True
            #    heter_trainers_num = len(all_heter_trainer_eplist.split(","))
            #    self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
868 869 870 871 872

        if training_role == "TRAINER":
            role = Role.WORKER
            current_id = os.getenv("PADDLE_TRAINER_ID", None)
            if current_id == None:
873
                raise ValueError(
874 875 876
                    "Can not find PADDLE_TRAINER_ID, please check your environment."
                )
            current_id = int(current_id)
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
            if self._is_heter_parameter_server_mode:
                self._stage_id = os.getenv("STAGE_ID", None)
                if self._stage_id == None:
                    raise ValueError(
                        "Can not find STAGE_ID, please check your environment.")
                self._stage_id = int(self._stage_id)
                self._stage_num = os.getenv("STAGE_NUM", None)
                if self._stage_num == None:
                    raise ValueError(
                        "Can not find STAGE_NUM, please check your environment.")
                self._stage_num = int(self._stage_num)
                self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM",
                                                 None)
                if self._stage_trainers == None:
                    raise ValueError(
                        "Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment."
                    )
                self._stage_trainers = eval(self._stage_trainers)
            cur_port = os.getenv("PADDLE_PORT", None)
            if cur_port == None:
                raise ValueError(
                    "Can not find PADDLE_PORT, please check your environment.")
            cur_ip = os.getenv("POD_IP", None)
            if cur_ip == None:
                raise ValueError(
                    "Can not find POD_IP, please check your environment.")
            curr_endpoint = ":".join([cur_ip, cur_port])
            self._cur_endpoint = curr_endpoint
905 906
        elif training_role == "PSERVER":
            role = Role.SERVER
907 908
            cur_port = os.getenv("PADDLE_PORT", None)
            if cur_port == None:
909 910
                raise ValueError(
                    "Can not find PADDLE_PORT, please check your environment.")
911 912
            cur_ip = os.getenv("POD_IP", None)
            if cur_ip == None:
913 914
                raise ValueError(
                    "Can not find POD_IP, please check your environment.")
915 916
            curr_endpoint = ":".join([cur_ip, cur_port])
            self._cur_endpoint = curr_endpoint
917 918 919
            current_id = self._server_endpoints.index(self._cur_endpoint)
        elif training_role == "HETER_TRAINER":
            role = Role.HETER_WORKER
920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
            self._stage_id = os.getenv("STAGE_ID", None)
            if self._stage_id == None:
                raise ValueError(
                    "Can not find STAGE_ID, please check your environment.")
            self._stage_id = int(self._stage_id)
            self._stage_num = os.getenv("STAGE_NUM", None)
            if self._stage_num == None:
                raise ValueError(
                    "Can not find STAGE_NUM, please check your environment.")
            self._stage_num = int(self._stage_num)

            self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM", None)
            if self._stage_trainers == None:
                raise ValueError(
                    "Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment."
                )
            self._stage_trainers = eval(self._stage_trainers)

            self._heter_trainer_device = os.getenv("HETER_DEVICE_TYPE", None)
            if self._heter_trainer_device == None:
                raise ValueError(
                    "Can not find HETER_DEVICE_TYPE, please check your environment."
                )
            assert self._heter_trainer_device in (
                "cpu", "gpu", "xpu"
            ), "HETER_DEVICE_TYPE should be cpu,gpu or xpu"
946 947 948 949 950 951 952 953 954
            cur_port = os.getenv("PADDLE_PORT", None)
            if cur_port == None:
                raise ValueError(
                    "Can not find PADDLE_PORT, please check your environment.")
            cur_ip = os.getenv("POD_IP", None)
            if cur_ip == None:
                raise ValueError(
                    "Can not find POD_IP, please check your environment.")
            curr_endpoint = ":".join([cur_ip, cur_port])
955 956 957
            self._cur_endpoint = curr_endpoint
            current_id = all_heter_trainer_eplist.split(",").index(
                curr_endpoint) + trainers_num
958 959 960 961

        self._trainers_num = trainers_num
        self._role = role
        self._current_id = current_id
962
        self._nodes_num = len(
963
            set([x.split(':')[0] for x in self._worker_endpoints]))
964 965 966 967 968

    def _collective_env(self):
        self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
        assert (self._training_role == "TRAINER")
969
        self._role = Role.WORKER
970 971
        self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
        self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
972 973 974 975 976
        if self._worker_endpoints is None:
            # back to non_distributed execution.
            self._worker_endpoints = "127.0.0.1:6170"
            self._cur_endpoint = self._worker_endpoints
            self._non_distributed = True
977 978
        self._worker_endpoints = self._worker_endpoints.split(",")
        self._trainers_num = len(self._worker_endpoints)
979
        self._nodes_num = len(
980
            set([x.split(':')[0] for x in self._worker_endpoints]))
981 982 983
        self._local_rank = os.getenv("PADDLE_RANK_IN_NODE")
        self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS")
        self._world_device_ids = os.getenv("PADDLE_WORLD_DEVICE_IDS")
984

985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    def _gloo_init(self):
        # PADDLE_WITH_GLOO 1: trainer barrier, 2: all barrier
        use_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
        if use_gloo not in [1, 2]:
            return

        # PADDLE_GLOO_RENDEZVOUS 1: HDFS 2: FILE 3: HTTP
        rendezvous_type = int(os.getenv("PADDLE_GLOO_RENDEZVOUS", "0"))
        prefix = os.getenv("SYS_JOB_ID", "")
        if rendezvous_type not in [
                Gloo.RENDEZVOUS.HDFS, Gloo.RENDEZVOUS.HTTP, Gloo.RENDEZVOUS.FILE
        ]:
            raise ValueError(self._gloo._err_type)

        need_init_all = True if use_gloo == 2 else False

        if rendezvous_type == Gloo.RENDEZVOUS.HDFS:
            dfs_name = os.getenv("PADDLE_GLOO_FS_NAME", "")
            dfs_ugi = os.getenv("PADDLE_GLOO_FS_UGI", "")
            dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "")
            kwargs = {
                "dfs.name": dfs_name,
                "dfs.ugi": dfs_ugi,
                "dfs.path": dfs_path,
                "store.prefix": prefix,
            }
        elif rendezvous_type == Gloo.RENDEZVOUS.HTTP:
1012
            start_http_server = False
1013 1014 1015
            manager = Manager()
            http_server_d = manager.dict()
            http_server_d["running"] = False
1016 1017 1018 1019 1020
            if self._is_collective:
                ep_rank_0 = self._worker_endpoints[0]
                if self._is_first_worker():
                    start_http_server = True
            else:
1021
                ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "")
1022
                if self._is_server() and self._server_index() == 0:
1023 1024
                    start_http_server = True
            ip, port = ep_rank_0.split(':')
1025 1026 1027 1028
            kwargs = {
                "http.host": ip,
                "http.port": port,
                "store.prefix": prefix,
1029
                'start_http_server': start_http_server,
1030
                'http_server_d': http_server_d,
1031
            }
1032
        else:
1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
            dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "")
            kwargs = {
                "dfs.path": dfs_path,
                "store.prefix": prefix,
            }

        if rendezvous_type == Gloo.RENDEZVOUS.HDFS:
            type = "HDFS"
        elif rendezvous_type == Gloo.RENDEZVOUS.HTTP:
            type = "HTTP"
        else:
            type = "FILE"
        print("Gloo init with {}: need_init_all: {}, args: {}".format(
            type, need_init_all, kwargs))

        self._gloo.init(
            rendezvous=rendezvous_type,
            role=self._role,
1051 1052 1053
            role_id=self._role_id(),
            worker_num=self._worker_num(),
            server_num=self._server_num(),
1054 1055
            need_init_all=need_init_all,
            kwargs=kwargs)
1056

1057 1058 1059
        if rendezvous_type == Gloo.RENDEZVOUS.HTTP:
            http_server_d['running'] = False

1060
    def _generate_role(self):
1061 1062 1063 1064 1065 1066 1067 1068 1069
        """
        generate role for role maker
        """
        if not self._role_is_generated:
            if not self._is_collective:
                self._ps_env()
            else:
                self._collective_env()
            self._role_is_generated = True
1070 1071
            if not paddle.fluid.framework.in_dygraph_mode():
                self._gloo_init()
1072 1073 1074 1075 1076 1077


class UserDefinedRoleMaker(PaddleCloudRoleMaker):
    def __init__(self, is_collective=False, init_gloo=False, **kwargs):
        super(UserDefinedRoleMaker, self).__init__(
            is_collective=is_collective, init_gloo=init_gloo, **kwargs)
1078
        self._init_gloo = init_gloo
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096

    def _user_defined_ps_env(self):
        self._server_endpoints = self._kwargs.get("server_endpoints")
        self._worker_endpoints = self._kwargs.get("worker_endpoints", [])
        self._trainers_num = self._kwargs.get("worker_num", 0)

        if self._trainers_num == 0:
            assert (len(self._worker_endpoints) > 0)
            self._trainers_num = len(self._worker_endpoints)

        self._role = self._kwargs.get("role")
        self._current_id = self._kwargs.get("current_id")

        if self._role == Role.WORKER and len(
                self._worker_endpoints) > self._current_id:
            self._cur_endpoint = self._worker_endpoints[self._current_id]
        elif self._role == Role.SERVER:
            self._cur_endpoint = self._server_endpoints[self._current_id]
1097
        self._nodes_num = len(
1098
            set([x.split(':')[0] for x in self._worker_endpoints]))
1099 1100 1101 1102 1103

    def _user_defined_collective_env(self):
        self._worker_endpoints = self._kwargs.get("worker_endpoints")
        self._current_id = self._kwargs.get("current_id")
        self._trainers_num = len(self._worker_endpoints)
1104
        self._training_role = Role.WORKER
1105
        self._nodes_num = len(
1106
            set([x.split(':')[0] for x in self._worker_endpoints]))
1107

1108
    def _generate_role(self):
1109 1110 1111 1112 1113 1114 1115 1116 1117
        """
        generate role for role maker
        """
        if not self._role_is_generated:
            if not self._is_collective:
                self._user_defined_ps_env()
            else:
                self._user_defined_collective_env()
            self._role_is_generated = True