ps_instance.py 4.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
from .helper import MPIHelper
H
heqiaozhi 已提交
15

H
test  
heqiaozhi 已提交
16

H
heqiaozhi 已提交
17
class PaddlePSInstance(object):
H
test  
heqiaozhi 已提交
18 19 20 21 22 23 24 25 26
    """
        PaddlePSInstance class is used to generate A instance of server or worker 
        Args:
            server_worker_mode: is a value 0 or 1, default is 1
            proc_per_node: process per node, default is 2 
        Examples:
            instance = PaddlePSInstance(1, 2)
    """

H
heqiaozhi 已提交
27
    def __init__(self, server_worker_mode, proc_per_node):
28
        self.dh = MPIHelper()
H
heqiaozhi 已提交
29 30 31 32
        self._rankid = self.dh.get_rank()
        self._server_worker_mode = server_worker_mode
        self._proc_per_node = proc_per_node
        self._nodes = self.dh.get_size()
H
test  
heqiaozhi 已提交
33

H
heqiaozhi 已提交
34 35 36 37
        self._ip = 0
        self._worker_num = self._nodes * self._proc_per_node / 2
        self._server_num = self._nodes * self._proc_per_node / 2
        self._total_server_worker = self._worker_num + self._server_num
H
test  
heqiaozhi 已提交
38
        self._node_type = None  #IDLE=-1, WORKER=1, SERVER=0
H
heqiaozhi 已提交
39 40 41 42 43 44
        self._set_nodetype()
        self._comm = None
        self._split_comm()

    def _set_nodetype(self):
        if self._server_worker_mode == 0:
H
test  
heqiaozhi 已提交
45
            if self._rankid < self._server_num:
H
heqiaozhi 已提交
46 47 48 49 50 51 52 53 54 55 56
                self._node_type = 1
            elif self._rankid < self._total_server_worker:
                self._node_type = 0
            else:
                self._node_type = -1
        elif self._server_worker_mode == 1:
            if self._rankid < self._total_server_worker:
                if 0 == self._rankid % self._proc_per_node % 2:
                    self._node_type = 0
                else:
                    self._node_type = 1
H
test  
heqiaozhi 已提交
57 58
            else:
                self._node_type = -1
H
heqiaozhi 已提交
59 60
        else:
            self._node_type = -1
H
test  
heqiaozhi 已提交
61

H
heqiaozhi 已提交
62 63 64 65 66 67 68
    def _split_comm(self):
        if self.is_server():
            self._comm = self.dh.comm.Split(self._node_type)
        elif self.is_worker():
            self._comm = self.dh.comm.Split(self._node_type)
        pass

69
    def get_worker_id(self):
H
test  
heqiaozhi 已提交
70 71 72
        """
        Return worker index 
        """
H
heqiaozhi 已提交
73 74 75 76 77
        if self._server_worker_mode == 0:
            return self._rankid == self.server_num
        else:
            return self._rankid / self._proc_per_node

78
    def get_server_id(self):
H
test  
heqiaozhi 已提交
79 80 81
        """
        Return server index 
        """
H
heqiaozhi 已提交
82 83 84 85 86 87
        if self._server_worker_mode == 0:
            return self.rank_id
        else:
            return self.rank_id / self._proc_per_node

    def is_worker(self):
H
test  
heqiaozhi 已提交
88 89 90
        """
        Return instance is worker or not
        """
H
heqiaozhi 已提交
91 92 93
        return self._node_type == 1

    def is_server(self):
H
test  
heqiaozhi 已提交
94 95 96
        """
        Return instance is server or not
        """
H
heqiaozhi 已提交
97 98 99
        return self._node_type == 0

    def is_first_worker(self):
H
test  
heqiaozhi 已提交
100 101 102
        """
        Return instance is first worker or not
        """
103
        return self.is_worker() and 0 == self.get_worker_id()
H
heqiaozhi 已提交
104 105

    def set_ip(self, ip):
H
test  
heqiaozhi 已提交
106 107 108
        """
            set server ip
        """
H
heqiaozhi 已提交
109 110 111
        self._ip = ip

    def gather_ips(self):
H
test  
heqiaozhi 已提交
112
        """
T
tianshuo78520a 已提交
113
        Return all servers and workers ip through mpi allgather 
H
test  
heqiaozhi 已提交
114
        """
H
heqiaozhi 已提交
115 116 117 118
        self._ips = self.dh.comm.allgather(self._ip)
        return self._ips

    def get_node_cnt(self):
H
test  
heqiaozhi 已提交
119 120 121
        """
        Return node cnt
        """
H
heqiaozhi 已提交
122 123
        return self._nodes

124 125 126 127 128 129 130 131 132 133 134 135
    def get_worker_num(self):
        """
        Return worker num
        """
        return self._worker_num

    def get_server_num(self):
        """
        Return server num
        """
        return self._server_num

H
heqiaozhi 已提交
136
    def barrier_all(self):
H
test  
heqiaozhi 已提交
137 138 139
        """
        barrier workers and servers
        """
H
heqiaozhi 已提交
140 141 142
        self.dh.comm.barrier()

    def barrier_worker(self):
H
test  
heqiaozhi 已提交
143 144 145
        """
        barrier workers
        """
H
heqiaozhi 已提交
146 147 148 149 150
        if self.is_worker():
            self._comm.barrier()
        pass

    def finalize(self):
H
test  
heqiaozhi 已提交
151 152 153
        """
        MPI finalize
        """
H
heqiaozhi 已提交
154
        self.dh.finalize()
H
heqiaozhi 已提交
155 156 157 158 159 160
        pass


if __name__ == "__main__":
    instance = PaddlePSInstance(1, 1, 2, 50)
    instance.barrier_all()