distributed_fused_lamb.py 17.6 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16

W
wuhuachaocoding 已提交
17
import paddle
18
from paddle.fluid import core, framework, unique_name
19 20
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import Variable, name_scope
21
from paddle.fluid.layer_helper import LayerHelper
22
from paddle.fluid.optimizer import Optimizer
23
from paddle.nn import ClipGradByGlobalNorm
24 25


26 27 28 29 30 31 32 33
def init_communicator(block, rank, ranks, ring_id):
    eps = os.environ['PADDLE_TRAINER_ENDPOINTS']
    eps = [ep.strip() for ep in eps.split(",") if ep.strip()]
    cur_ep = eps[rank]
    other_eps = [eps[r] for r in ranks if r != rank]

    local_rank = ranks.index(rank)
    comm_var_name = unique_name.generate('comm_id')
34 35 36
    comm_id_var = block.create_var(
        name=comm_var_name, persistable=True, type=core.VarDesc.VarType.RAW
    )
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    if core.is_compiled_with_cuda():
        block.append_op(
            type='c_gen_nccl_id',
            inputs={},
            outputs={'Out': comm_id_var},
            attrs={
                'rank': local_rank,
                'endpoint': cur_ep,
                'other_endpoints': other_eps,
                'ring_id': ring_id,
            },
        )
    elif core.is_compiled_with_xpu():
        block.append_op(
            type='c_gen_bkcl_id',
            inputs={},
            outputs={'Out': comm_id_var},
            attrs={
                'rank': local_rank,
                'endpoint': cur_ep,
                'other_endpoints': other_eps,
                'ring_id': ring_id,
            },
        )
61 62 63 64 65 66
    block.append_op(
        type='c_comm_init',
        inputs={'X': comm_id_var},
        outputs={},
        attrs={'nranks': len(ranks), 'rank': local_rank, 'ring_id': ring_id},
    )
67
    tmp_var = block.create_var(name=unique_name.generate('tmp'))
68 69 70 71 72 73 74 75 76 77 78 79 80 81
    block.append_op(
        type='fill_constant', outputs={'Out': tmp_var}, attrs={'value': 1}
    )
    block.append_op(
        type='c_allreduce_sum',
        inputs={'X': tmp_var},
        outputs={'Out': tmp_var},
        attrs={'ring_id': ring_id, 'use_calc_stream': True},
    )
    block.append_op(
        type='c_sync_calc_stream',
        inputs={'X': tmp_var},
        outputs={'Out': tmp_var},
    )
82 83 84 85 86
    return ring_id


def broadcast_parameters(block, parameters, ring_id):
    for p in parameters:
87 88 89 90 91 92
        block.append_op(
            type='c_broadcast',
            inputs={'X': p},
            outputs={'Out': p},
            attrs={'ring_id': ring_id, 'use_calc_stream': True},
        )
93 94


95
class DistributedFusedLamb(Optimizer):
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    def __init__(
        self,
        learning_rate=0.001,
        lamb_weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        parameters=None,
        grad_clip=None,
        exclude_from_weight_decay_fn=None,
        clip_after_allreduce=True,
        is_grad_scaled_by_nranks=True,
        alignment=128,
        use_master_param_norm=True,
        gradient_accumulation_steps=1,
        use_master_acc_grad=True,
        nproc_per_node=None,
        use_hierarchical_allreduce=False,
        name=None,
    ):
        assert (
            not framework._non_static_mode()
118
        ), "DistributedFusedLamb does not support dygraph mode"
119
        super().__init__(learning_rate=learning_rate, grad_clip=None, name=name)
120 121 122 123

        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
124 125 126
        self._weight_decay = (
            lamb_weight_decay if lamb_weight_decay is not None else 0.0
        )
127 128 129 130 131 132 133 134 135 136 137 138 139 140
        if grad_clip is not None:
            assert isinstance(
                grad_clip, ClipGradByGlobalNorm
            ), "Only ClipGradByGlobalNorm is supported in DistributedFusedLamb"
            max_global_grad_norm = grad_clip.clip_norm
        else:
            max_global_grad_norm = -1.0
        self._max_global_grad_norm = max_global_grad_norm
        self._alignment = alignment if alignment is not None else -1
        self._clip_after_allreduce = clip_after_allreduce
        self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks
        self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
        self._scale = None
        self._use_master_param_norm = use_master_param_norm
141
        self._gradient_accumulation_steps = gradient_accumulation_steps
142
        self._use_master_acc_grad = use_master_acc_grad
143
        self._nproc_per_node = nproc_per_node
144
        self._use_hierarchical_allreduce = use_hierarchical_allreduce
145 146
        assert self._gradient_accumulation_steps >= 1

147 148 149 150 151 152 153
        self.helper = LayerHelper('distributed_fused_lamb')
        self._supports_check_nan_inf = True  # very import flag for AMP

        main_block = self.helper.main_program.global_block()
        self._found_inf = main_block.create_var(
            name=unique_name.generate('found_inf'),
            shape=[1],
154 155
            dtype=core.VarDesc.VarType.BOOL,
        )
156
        self._step = None
157

158 159 160 161
        if self._gradient_accumulation_steps > 1:
            self._stop_update = main_block.create_var(
                name=unique_name.generate('stop_update'),
                shape=[1],
162 163
                dtype=core.VarDesc.VarType.BOOL,
            )
164 165 166
        else:
            self._stop_update = None

167 168
        self._param_to_master_param = {}

169 170 171
    def _get_stop_update_var(self):
        return self._stop_update if self._stop_update is not None else False

172 173 174 175 176 177 178 179
    def _set_step(self, step):
        self._step = step

    def _get_or_create_step(self):
        if self._step is None:
            self._step = self._create_persistable_var('step', dtype='int64')
        return self._step

180 181 182 183 184 185 186 187
    def _set_scale(self, scale):
        assert scale is not None
        if not isinstance(scale, Variable):
            scale = self._create_scale_from_constant(scale)
        self._scale = scale

    def _create_scale_from_constant(self, value):
        name = unique_name.generate('global_scale')
188
        return paddle.static.create_global_var(
189 190 191 192 193 194
            name=name,
            shape=[1],
            dtype='float32',
            value=float(value),
            persistable=True,
        )
195 196 197 198 199 200 201 202 203 204

    def _get_or_create_scale(self):
        if self._scale is None:
            self._scale = self._create_scale_from_constant(1.0)
        return self._scale

    def _create_persistable_var(self, name=None, shape=[-1], dtype='float32'):
        startup_block = self.helper.startup_program.global_block()
        if name is not None:
            name = unique_name.generate(name)
205 206 207 208 209 210 211
        startup_var = startup_block.create_var(
            name=name,
            shape=shape,
            dtype=dtype,
            persistable=True,
            stop_gradient=True,
        )
212
        main_block = self.helper.main_program.global_block()
213 214 215 216 217 218 219
        main_var = main_block.create_var(
            name=startup_var.name,
            shape=startup_var.shape,
            dtype=startup_var.dtype,
            persistable=True,
            stop_gradient=True,
        )
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
        return main_var

    def _get_parameter(self, name, scope=None):
        if scope is None:
            scope = global_scope()

        master_param = self._param_to_master_param.get(name)
        assert master_param is not None

        master_param_t = scope.find_var(master_param).get_tensor()
        assert master_param_t._dtype() == core.VarDesc.VarType.FP32

        param_t = scope.find_var(name).get_tensor()
        if param_t._dtype() == core.VarDesc.VarType.FP32:
            assert param_t._ptr() == master_param_t._ptr()
            return param_t, None
        else:
            assert param_t._dtype() == core.VarDesc.VarType.FP16
            assert param_t.shape() == master_param_t.shape()
            return param_t, master_param_t

    def apply_optimize(self, params_grads):
        self.apply_gradients(params_grads)

    def apply_gradients(self, params_grads):
        flattened = []
        for p, g in params_grads:
            flattened.extend([p, g])
        with flattened[0].block.program._optimized_guard(flattened), name_scope(
249 250
            "optimizer"
        ):
251 252 253 254
            self._apply_gradients_impl(params_grads)

    def _apply_gradients_impl(self, params_grads):
        for p, g in params_grads:
255 256 257
            assert (
                g.type == core.VarDesc.VarType.LOD_TENSOR
            ), "Only support dense gradient"
258 259 260 261
            g.persistable = True  # the gradient must be persistable for fusion

        fp32_fused_param = self._create_persistable_var('fp32_fused_param')
        fp32_fused_grad = self._create_persistable_var('fp32_fused_grad')
262 263 264 265 266 267
        fp16_fused_param = self._create_persistable_var(
            'fp16_fused_param', dtype='float16'
        )
        fp16_fused_grad = self._create_persistable_var(
            'fp16_fused_grad', dtype='float16'
        )
268 269 270 271 272 273 274 275 276 277 278 279 280

        master_params = []
        for p, g in params_grads:
            master_p = self._create_persistable_var('master_weight')
            self._param_to_master_param[p.name] = master_p.name
            master_params.append(master_p)

        moment1 = self._create_persistable_var('moment1')
        moment1.is_distributed = True
        moment2 = self._create_persistable_var('moment2')
        moment2.is_distributed = True
        beta1pow = self._create_persistable_var('beta1pow')
        beta2pow = self._create_persistable_var('beta2pow')
281

282 283 284
        param_info = self._create_persistable_var('param_info', dtype='int32')
        param_info.is_distributed = True

285 286 287
        fused_offsets = self._create_persistable_var(
            'fused_offsets', dtype='int32'
        )
288 289

        fp32_partial_fused_offsets = self._create_persistable_var(
290 291
            'fp32_partial_fused_offsets', dtype='int32'
        )
292
        fp32_partial_fused_offsets.is_distributed = True
293

294
        fp16_partial_fused_offsets = self._create_persistable_var(
295 296
            'fp16_partial_fused_offsets', dtype='int32'
        )
297 298
        fp16_partial_fused_offsets.is_distributed = True

299 300 301
        param_order = self._create_persistable_var('param_order', dtype='int32')
        param_order.is_distributed = True

302 303 304 305 306
        if self._gradient_accumulation_steps > 1:
            fp32_acc_fused_grad = [
                self._create_persistable_var('fp32_acc_fused_grad')
            ]
            fp16_acc_fused_grad = [
307 308 309
                self._create_persistable_var(
                    'fp16_acc_fused_grad', dtype='float16'
                )
310 311 312 313 314 315 316
            ]
            acc_step = [self._create_persistable_var('acc_step', dtype='int64')]
        else:
            fp32_acc_fused_grad = []
            fp16_acc_fused_grad = []
            acc_step = []

317 318
        step = self._get_or_create_step()

W
wuhuachaocoding 已提交
319 320
        rank = paddle.distributed.get_rank()
        nranks = paddle.distributed.get_world_size()
321 322 323 324
        if self._nproc_per_node is None:
            nproc_per_node = nranks
        else:
            nproc_per_node = self._nproc_per_node
325 326 327
        assert (
            nranks % nproc_per_node == 0
        ), "nranks should be exactly divided by nproc_per_node"
328

329
        shard_inside_node = nranks > nproc_per_node
330 331 332 333 334 335
        local_rank = rank % nproc_per_node
        node_id = int(rank / nproc_per_node)
        node_num = int(nranks / nproc_per_node)
        ring_ids = []
        startup_block = self.helper.startup_program.global_block()
        if nranks > 1:
336 337 338
            ring_id = init_communicator(
                startup_block, rank, list(range(nranks)), 0
            )
339 340
            ring_ids.append(ring_id)

341
        use_hierarchical_allreduce = False
342 343
        if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node:
            local_group_ranks = list(
344 345 346 347 348
                range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node)
            )
            ring_id = init_communicator(
                startup_block, rank, local_group_ranks, 1
            )
349 350
            ring_ids.append(ring_id)

351 352 353
            if self._use_hierarchical_allreduce and nranks > nproc_per_node:
                use_hierarchical_allreduce = True
                outer_group_ranks = list(
354 355 356 357 358
                    range(rank % nproc_per_node, nranks, nproc_per_node)
                )
                ring_id = init_communicator(
                    startup_block, rank, outer_group_ranks, ring_ids[-1] + 1
                )
359 360
                ring_ids.append(ring_id)

361 362 363 364
        scale = self._get_or_create_scale()

        params = [p for p, _ in params_grads]
        grads = [g for _, g in params_grads]
365
        apply_weight_decay = [1] * len(params)
366 367 368
        if self._exclude_from_weight_decay_fn is not None:
            for i, p in enumerate(params):
                if self._exclude_from_weight_decay_fn(p):
369
                    apply_weight_decay[i] = 0
370 371

        for g in grads:
372 373 374 375 376 377 378
            startup_block.create_var(
                name=g.name,
                type=g.type,
                dtype=g.dtype,
                persistable=g.persistable,
                shape=g.shape,
            )
379

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
        if nranks > 1:
            broadcast_parameters(startup_block, params, ring_ids[0])

        startup_block.append_op(
            type='distributed_fused_lamb_init',
            inputs={
                'Param': params,
                'Grad': grads,
            },
            outputs={
                'FP32FusedParam': [fp32_fused_param],
                'FP32FusedGrad': [fp32_fused_grad],
                'FP16FusedParam': [fp16_fused_param],
                'FP16FusedGrad': [fp16_fused_grad],
                'Moment1': [moment1],
                'Moment2': [moment2],
                'Beta1Pow': [beta1pow],
                'Beta2Pow': [beta2pow],
                'GlobalScale': [scale],
                'ParamInfo': [param_info],
                'ParamOut': params,
                'MasterParamOut': master_params,
                'GradOut': grads,
                'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
                'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
                'FusedParamOffsets': [fused_offsets],
                'ParamOrder': [param_order],
                'Step': [step],
            },
            attrs={
                'alignment': self._alignment,
                'rank': local_rank if shard_inside_node else rank,
                'nranks': nproc_per_node if shard_inside_node else nranks,
                'apply_weight_decay': apply_weight_decay,
                'moment1': 0.0,
                'moment2': 0.0,
                'beta1': self._beta1,
                'beta2': self._beta2,
418 419
            },
        )
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452

        main_block = self.helper.main_program.global_block()
        self._create_global_learning_rate()
        lr = None
        for p_g in params_grads:
            if lr is None:
                lr = self._create_param_lr(p_g)
            else:
                new_lr = self._create_param_lr(p_g)
                assert id(lr) == id(
                    new_lr
                ), "The learning rate for each parameter should be the same"
        assert lr is not None

        lamb_op = main_block.append_op(
            type='distributed_fused_lamb',
            inputs={
                'FP32FusedParam': [fp32_fused_param],
                'FP32FusedGrad': [fp32_fused_grad],
                'FP16FusedParam': [fp16_fused_param],
                'FP16FusedGrad': [fp16_fused_grad],
                'LearningRate': [lr],
                'Moment1': [moment1],
                'Moment2': [moment2],
                'Beta1Pow': [beta1pow],
                'Beta2Pow': [beta2pow],
                'GlobalScale': [scale],
                'ParamInfo': [param_info],
                'Param': params,
                'Grad': grads,
                'FusedParamOffsets': [fused_offsets],
                'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
                'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
453
                'ParamOrder': [param_order],
454 455 456 457 458 459 460 461
            },
            outputs={
                'FP32FusedParamOut': [fp32_fused_param],
                'FP16FusedParamOut': [fp16_fused_param],
                'Moment1Out': [moment1],
                'Moment2Out': [moment2],
                'Beta1PowOut': [beta1pow],
                'Beta2PowOut': [beta2pow],
462 463
                'ParamOut': params,
                'GradOut': grads,
464
                'FoundInf': [self._found_inf],
465 466 467 468 469 470
                'FP32AccFusedGrad': fp32_acc_fused_grad,
                'FP16AccFusedGrad': fp16_acc_fused_grad,
                'AccStep': acc_step,
                'StopUpdate': self._stop_update
                if self._stop_update is not None
                else [],
471
                'Step': [step],
472 473
            },
            attrs={
474
                'weight_decay': self._weight_decay,
475 476 477 478 479 480
                'beta1': self._beta1,
                'beta2': self._beta2,
                'epsilon': self._epsilon,
                'max_global_grad_norm': self._max_global_grad_norm,
                'clip_after_allreduce': self._clip_after_allreduce,
                'rank': rank,
481 482
                'nranks': nranks,
                'ring_id': ring_ids,
483 484
                'use_master_param_norm': self._use_master_param_norm,
                'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
485
                'acc_steps': self._gradient_accumulation_steps,
486
                'use_master_acc_grad': self._use_master_acc_grad,
487
                'use_hierarchical_allreduce': use_hierarchical_allreduce,
488 489
            },
        )
490
        return [lamb_op]