collective.py 13.3 KB
Newer Older
1
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import sys
import math
from functools import reduce

import collections
import six
import logging

import numpy as np

from .. import core, unique_name
from ..framework import Program, default_main_program, default_startup_program
from .details import wait_server_ready

__all__ = ['GradAllReduce', 'LocalSGD']

OpRole = core.op_proto_and_checker_maker.OpRole


class Collective(object):
    '''
    '''

40 41
    def __init__(self, nrings):
        self.nrings = nrings
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        self.endpoints = None
        self.current_endpoint = None
        self.nranks = None
        self.rank = None
        self.startup_program = None
        self.main_program = None
        op_maker = core.op_proto_and_checker_maker
        self.op_role_key = op_maker.kOpRoleAttrName()
        self.op_role_var_key = op_maker.kOpRoleVarAttrName()

    def transpile(self, startup_program, main_program, rank, endpoints,
                  current_endpoint, wait_port):
        # in case of '127.0.0.1:6700,127.0.0.1:6701,...'
        if isinstance(endpoints, str):
            endpoints = endpoints.split(',')

        self.startup_program = startup_program
        if startup_program is None:
            self.startup_program = default_startup_program()

        self.main_program = main_program
        if main_program is None:
            self.main_program = default_main_program()

        self.nranks = len(endpoints)
H
hutuxian 已提交
67
        if self.nranks == 1 and self.mode != "single_process_multi_thread":
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            raise ValueError('the number of endpoints must > 1')

        if rank < 0:
            raise ValueError('rank must >= 0')
        self.rank = rank

        if current_endpoint not in endpoints:
            raise ValueError('current endpoint %s is not in %s',
                             current_endpoint, str(endpoints))

        self.endpoints = endpoints
        self.current_endpoint = current_endpoint

        self.wait_port = wait_port

        self.startup_program._origin_program = self.startup_program.clone()
        self._transpile_startup_program()

        self.main_program._origin_program = self.main_program.clone()
        self._transpile_main_program()

    def _transpile_main_program(self):
        raise NotImplementedError('call the inherited method of subclasses')

    def _transpile_startup_program(self):
93 94 95 96
        for ring_id in range(self.nrings):
            self._init_communicator(self.startup_program, self.current_endpoint,
                                    self.endpoints, self.rank, ring_id,
                                    self.wait_port)
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        self._broadcast_params()

    def _init_communicator(self, program, current_endpoint, endpoints, rank,
                           ring_id, wait_port):
        nranks = len(endpoints)
        other_endpoints = endpoints[:]
        other_endpoints.remove(current_endpoint)
        if rank == 0 and wait_port:
            wait_server_ready(other_endpoints)

        block = program.global_block()
        nccl_id_var = block.create_var(
            name=unique_name.generate('nccl_id'),
            persistable=True,
            type=core.VarDesc.VarType.RAW)
        block.append_op(
            type='c_gen_nccl_id',
            inputs={},
            outputs={'Out': nccl_id_var},
            attrs={
                'rank': rank,
                'endpoint': current_endpoint,
                'other_endpoints': other_endpoints,
120
                self.op_role_key: OpRole.Forward
121 122 123 124 125 126 127 128 129
            })
        block.append_op(
            type='c_comm_init',
            inputs={'X': nccl_id_var},
            outputs={},
            attrs={
                'nranks': nranks,
                'rank': rank,
                'ring_id': ring_id,
130
                self.op_role_key: OpRole.Forward
131 132 133 134
            })

    def _broadcast_params(self):
        block = self.startup_program.global_block()
135 136
        ring_id = -1
        for param in block.iter_parameters():
137 138 139
            if param.is_distributed:
                continue

140
            ring_id = (ring_id + 1) % self.nrings
141 142
            block.append_op(
                type='c_broadcast',
143 144
                inputs={'X': param},
                outputs={'Out': param},
145
                attrs={
146
                    'ring_id': ring_id,
147
                    'root': 0,
148
                    self.op_role_key: OpRole.Forward
149
                })
150 151 152 153 154 155 156 157

        for ring_id in range(self.nrings):
            block.append_op(
                type='c_sync_comm_stream',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={'ring_id': ring_id,
                       self.op_role_key: OpRole.Forward})
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

    def _is_loss_grad_op(self, op):
        if self.op_role_key not in op.attr_names:
            return False
        op_role = int(op.all_attrs()[self.op_role_key])
        return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)

    def _is_backward_op(self, op):
        return self.op_role_key in op.attr_names and \
                int(op.all_attrs()[self.op_role_key]) & int(OpRole.Backward)

    def _is_update_op(self, op):
        return 'Param' in op.input_names and 'Grad' in op.input_names and \
                "LearningRate" in op.input_names

    def _is_optimizer_op(self, op):
        return self.op_role_key in op.attr_names and \
                int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize)


class GradAllReduce(Collective):
    '''
    '''

182 183
    def __init__(self, nrings=2):
        Collective.__init__(self, nrings)
H
hutuxian 已提交
184
        self.mode = "grad_allreduce"
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205

    def _transpile_main_program(self):
        self._insert_scale_loss_grad_ops()
        self._insert_allreduce_ops()

    def _insert_scale_loss_grad_ops(self):
        '''
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        '''
        block = self.main_program.global_block()
        for idx, op in reversed(list(enumerate(block.ops))):
            if self._is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / self.nranks,
206
                        self.op_role_key: OpRole.Backward
207 208 209 210
                    })

    def _insert_allreduce_ops(self):
        block = self.main_program.global_block()
211 212
        ring_id = -1
        grad = None
213 214 215 216 217 218 219 220 221
        for idx, op in reversed(list(enumerate(block.ops))):
            if self._is_backward_op(op) and \
                    self.op_role_var_key in op.attr_names:
                op_role_var = op.all_attrs()[self.op_role_var_key]

                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0

222
                offset = idx
223
                for i in range(0, len(op_role_var), 2):
224 225
                    param = block.vars[op_role_var[i]]
                    grad = block.vars[op_role_var[i + 1]]
226 227 228
                    if param.is_distributed:
                        continue

229 230 231 232 233 234 235 236 237 238 239 240 241
                    if offset == idx:
                        offset += 1
                        block._insert_op(
                            offset,
                            type='c_sync_calc_stream',
                            inputs={'X': grad},
                            outputs={'Out': grad},
                            attrs={self.op_role_key: OpRole.Backward})
                        offset += 1

                    # As we search ops reversedly, we should insert c_allreduce_sum
                    # op in the same way to keep the ring_id alternate
                    ring_id = (ring_id + 1) % self.nrings
242
                    block._insert_op(
243 244 245 246
                        offset,
                        type='c_allreduce_sum',
                        inputs={'X': grad},
                        outputs={'Out': grad},
247
                        attrs={
248 249
                            'ring_id': ring_id,
                            self.op_role_key: OpRole.Backward
250
                        })
251 252 253

        if grad is None:
            return
254 255 256

        for idx, op in enumerate(block.ops):
            if self._is_optimizer_op(op):
257 258 259 260 261 262 263 264 265 266
                for ring_id in range(self.nrings):
                    block._insert_op(
                        idx + ring_id,
                        type='c_sync_comm_stream',
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            self.op_role_key: OpRole.Backward
                        })
267 268 269 270 271 272 273
                break


class LocalSGD(Collective):
    '''
    '''

274 275
    def __init__(self, nrings=2):
        Collective.__init__(self, nrings)
276
        self.snapshot_key = '@SNAPSHOT'
H
hutuxian 已提交
277
        self.mode = "local_sgd"
278 279 280 281 282

    def _transpile_startup_program(self):
        Collective._transpile_startup_program(self)

        block = self.startup_program.global_block()
283
        non_dist_params = []
284
        for param in block.iter_parameters():
285 286
            if not param.is_distributed:
                non_dist_params.append(param)
287

288
        for param in non_dist_params:
289 290 291 292 293 294 295 296 297
            snapshot = block.create_var(
                name=self.snapshot_name(param.name),
                shape=param.shape,
                persistable=True,
                stop_gradient=True)
            block.append_op(
                type='assign',
                inputs={'X': [param]},
                outputs={'Out': [snapshot]},
298
                attrs={self.op_role_key: OpRole.Forward})
299 300 301 302 303 304 305

    def snapshot_name(self, param_name):
        return param_name + self.snapshot_key

    def _transpile_main_program(self):
        block = self.main_program.global_block()
        ordered_param_snapshot = []
306
        ring_id = -1
307 308 309
        for idx, op in reversed(list(enumerate(block.ops))):
            if self._is_update_op(op):
                param = block.vars[op.input('Param')[0]]
310 311 312
                if param.is_distributed:
                    continue

313 314 315 316 317 318 319 320 321 322 323 324
                snapshot = block.create_var(
                    name=self.snapshot_name(param.name),
                    shape=param.shape,
                    persistable=True,
                    stop_gradient=True)

                block._insert_op(
                    idx + 1,
                    type='elementwise_sub',
                    inputs={'X': [snapshot],
                            'Y': [param]},
                    outputs={'Out': [param]},
325
                    attrs={self.op_role_key: OpRole.Optimize})
326 327 328 329 330
                block._insert_op(
                    idx + 2,
                    type='c_sync_calc_stream',
                    inputs={'X': param},
                    outputs={'Out': param},
331 332
                    attrs={self.op_role_key: OpRole.Optimize})
                ring_id = (ring_id + 1) % self.nrings
333 334
                block._insert_op(
                    idx + 3,
335
                    type='c_allreduce_sum',
336 337 338
                    inputs={'X': [param]},
                    outputs={'Out': [param]},
                    attrs={
339 340
                        'ring_id': ring_id,
                        self.op_role_key: OpRole.Optimize
341 342 343 344
                    })

                ordered_param_snapshot.append((param, snapshot))

345 346 347 348 349 350 351
        for ring_id in range(self.nrings):
            block.append_op(
                type='c_sync_comm_stream',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={'ring_id': ring_id,
                       self.op_role_key: OpRole.Optimize})
352 353 354 355 356 357 358 359 360 361

        for param_snapshot in reversed(ordered_param_snapshot):
            param = param_snapshot[0]
            snapshot = param_snapshot[1]
            block.append_op(
                type='scale',
                inputs={'X': [param]},
                outputs={'Out': [param]},
                attrs={
                    'scale': 1.0 / self.nranks,
362
                    self.op_role_key: OpRole.Optimize
363 364 365 366 367 368
                })
            block.append_op(
                type='elementwise_sub',
                inputs={'X': [snapshot],
                        'Y': [param]},
                outputs={'Out': [param]},
369
                attrs={self.op_role_key: OpRole.Optimize})
370 371 372 373
            block.append_op(
                type='assign',
                inputs={'X': [param]},
                outputs={'Out': [snapshot]},
374
                attrs={self.op_role_key: OpRole.Optimize})
H
hutuxian 已提交
375 376 377 378 379 380 381


class SingleProcessMultiThread(GradAllReduce):
    '''
    '''

    def __init__(self):
H
hutuxian 已提交
382
        GradAllReduce.__init__(self, 1)
H
hutuxian 已提交
383 384 385 386 387
        self.mode = "single_process_multi_thread"

    def _transpile_startup_program(self):
        block = self.startup_program.global_block()
        block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})