“a1224621b54d3338ebe5db6926a45c0a9033876f”上不存在“src/pass/pass_base.h”
distributed_fused_lamb.py 18.0 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16

W
wuhuachaocoding 已提交
17
import paddle
18
from paddle.fluid import core, unique_name
19 20
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import Variable, name_scope
21
from paddle.fluid.layer_helper import LayerHelper
22
from paddle.nn import ClipGradByGlobalNorm
23
from paddle.optimizer import Optimizer
24 25


26 27 28 29 30 31 32 33
def init_communicator(block, rank, ranks, ring_id):
    eps = os.environ['PADDLE_TRAINER_ENDPOINTS']
    eps = [ep.strip() for ep in eps.split(",") if ep.strip()]
    cur_ep = eps[rank]
    other_eps = [eps[r] for r in ranks if r != rank]

    local_rank = ranks.index(rank)
    comm_var_name = unique_name.generate('comm_id')
34 35 36
    comm_id_var = block.create_var(
        name=comm_var_name, persistable=True, type=core.VarDesc.VarType.RAW
    )
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    if core.is_compiled_with_cuda():
        block.append_op(
            type='c_gen_nccl_id',
            inputs={},
            outputs={'Out': comm_id_var},
            attrs={
                'rank': local_rank,
                'endpoint': cur_ep,
                'other_endpoints': other_eps,
                'ring_id': ring_id,
            },
        )
    elif core.is_compiled_with_xpu():
        block.append_op(
            type='c_gen_bkcl_id',
            inputs={},
            outputs={'Out': comm_id_var},
            attrs={
                'rank': local_rank,
                'endpoint': cur_ep,
                'other_endpoints': other_eps,
                'ring_id': ring_id,
            },
        )
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    elif (
        paddle.distributed.ParallelEnv().device_type
        in paddle.device.get_all_custom_device_type()
    ):
        block.append_op(
            type='c_gen_xccl_id',
            inputs={},
            outputs={'Out': comm_id_var},
            attrs={
                'rank': local_rank,
                'endpoint': cur_ep,
                'other_endpoints': other_eps,
                'ring_id': ring_id,
            },
        )
76 77 78 79 80 81
    block.append_op(
        type='c_comm_init',
        inputs={'X': comm_id_var},
        outputs={},
        attrs={'nranks': len(ranks), 'rank': local_rank, 'ring_id': ring_id},
    )
82
    tmp_var = block.create_var(name=unique_name.generate('tmp'))
83 84 85 86 87 88 89 90 91 92 93 94 95 96
    block.append_op(
        type='fill_constant', outputs={'Out': tmp_var}, attrs={'value': 1}
    )
    block.append_op(
        type='c_allreduce_sum',
        inputs={'X': tmp_var},
        outputs={'Out': tmp_var},
        attrs={'ring_id': ring_id, 'use_calc_stream': True},
    )
    block.append_op(
        type='c_sync_calc_stream',
        inputs={'X': tmp_var},
        outputs={'Out': tmp_var},
    )
97 98 99 100 101
    return ring_id


def broadcast_parameters(block, parameters, ring_id):
    for p in parameters:
102 103 104 105 106 107
        block.append_op(
            type='c_broadcast',
            inputs={'X': p},
            outputs={'Out': p},
            attrs={'ring_id': ring_id, 'use_calc_stream': True},
        )
108 109


110
class DistributedFusedLamb(Optimizer):
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    def __init__(
        self,
        learning_rate=0.001,
        lamb_weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        parameters=None,
        grad_clip=None,
        exclude_from_weight_decay_fn=None,
        clip_after_allreduce=True,
        is_grad_scaled_by_nranks=True,
        alignment=128,
        use_master_param_norm=True,
        gradient_accumulation_steps=1,
        use_master_acc_grad=True,
        nproc_per_node=None,
        use_hierarchical_allreduce=False,
        name=None,
    ):
        assert (
132
            not paddle.in_dynamic_mode()
133
        ), "DistributedFusedLamb does not support dygraph mode"
134
        super().__init__(learning_rate=learning_rate, grad_clip=None, name=name)
135 136 137 138

        self._beta1 = beta1
        self._beta2 = beta2
        self._epsilon = epsilon
139 140 141
        self._weight_decay = (
            lamb_weight_decay if lamb_weight_decay is not None else 0.0
        )
142 143 144 145 146 147 148 149 150 151 152 153 154 155
        if grad_clip is not None:
            assert isinstance(
                grad_clip, ClipGradByGlobalNorm
            ), "Only ClipGradByGlobalNorm is supported in DistributedFusedLamb"
            max_global_grad_norm = grad_clip.clip_norm
        else:
            max_global_grad_norm = -1.0
        self._max_global_grad_norm = max_global_grad_norm
        self._alignment = alignment if alignment is not None else -1
        self._clip_after_allreduce = clip_after_allreduce
        self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks
        self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
        self._scale = None
        self._use_master_param_norm = use_master_param_norm
156
        self._gradient_accumulation_steps = gradient_accumulation_steps
157
        self._use_master_acc_grad = use_master_acc_grad
158
        self._nproc_per_node = nproc_per_node
159
        self._use_hierarchical_allreduce = use_hierarchical_allreduce
160 161
        assert self._gradient_accumulation_steps >= 1

162 163 164 165 166 167 168
        self.helper = LayerHelper('distributed_fused_lamb')
        self._supports_check_nan_inf = True  # very import flag for AMP

        main_block = self.helper.main_program.global_block()
        self._found_inf = main_block.create_var(
            name=unique_name.generate('found_inf'),
            shape=[1],
169 170
            dtype=core.VarDesc.VarType.BOOL,
        )
171
        self._step = None
172

173 174 175 176
        if self._gradient_accumulation_steps > 1:
            self._stop_update = main_block.create_var(
                name=unique_name.generate('stop_update'),
                shape=[1],
177 178
                dtype=core.VarDesc.VarType.BOOL,
            )
179 180 181
        else:
            self._stop_update = None

182 183
        self._param_to_master_param = {}

184 185 186
    def _get_stop_update_var(self):
        return self._stop_update if self._stop_update is not None else False

187 188 189 190 191 192 193 194
    def _set_step(self, step):
        self._step = step

    def _get_or_create_step(self):
        if self._step is None:
            self._step = self._create_persistable_var('step', dtype='int64')
        return self._step

195 196 197 198 199 200 201 202
    def _set_scale(self, scale):
        assert scale is not None
        if not isinstance(scale, Variable):
            scale = self._create_scale_from_constant(scale)
        self._scale = scale

    def _create_scale_from_constant(self, value):
        name = unique_name.generate('global_scale')
203
        return paddle.static.create_global_var(
204 205 206 207 208 209
            name=name,
            shape=[1],
            dtype='float32',
            value=float(value),
            persistable=True,
        )
210 211 212 213 214 215 216 217 218 219

    def _get_or_create_scale(self):
        if self._scale is None:
            self._scale = self._create_scale_from_constant(1.0)
        return self._scale

    def _create_persistable_var(self, name=None, shape=[-1], dtype='float32'):
        startup_block = self.helper.startup_program.global_block()
        if name is not None:
            name = unique_name.generate(name)
220 221 222 223 224 225 226
        startup_var = startup_block.create_var(
            name=name,
            shape=shape,
            dtype=dtype,
            persistable=True,
            stop_gradient=True,
        )
227
        main_block = self.helper.main_program.global_block()
228 229 230 231 232 233 234
        main_var = main_block.create_var(
            name=startup_var.name,
            shape=startup_var.shape,
            dtype=startup_var.dtype,
            persistable=True,
            stop_gradient=True,
        )
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        return main_var

    def _get_parameter(self, name, scope=None):
        if scope is None:
            scope = global_scope()

        master_param = self._param_to_master_param.get(name)
        assert master_param is not None

        master_param_t = scope.find_var(master_param).get_tensor()
        assert master_param_t._dtype() == core.VarDesc.VarType.FP32

        param_t = scope.find_var(name).get_tensor()
        if param_t._dtype() == core.VarDesc.VarType.FP32:
            assert param_t._ptr() == master_param_t._ptr()
            return param_t, None
        else:
            assert param_t._dtype() == core.VarDesc.VarType.FP16
            assert param_t.shape() == master_param_t.shape()
            return param_t, master_param_t

    def apply_optimize(self, params_grads):
        self.apply_gradients(params_grads)

    def apply_gradients(self, params_grads):
        flattened = []
        for p, g in params_grads:
            flattened.extend([p, g])
        with flattened[0].block.program._optimized_guard(flattened), name_scope(
264 265
            "optimizer"
        ):
266 267 268 269
            self._apply_gradients_impl(params_grads)

    def _apply_gradients_impl(self, params_grads):
        for p, g in params_grads:
270 271 272
            assert (
                g.type == core.VarDesc.VarType.LOD_TENSOR
            ), "Only support dense gradient"
273 274 275 276
            g.persistable = True  # the gradient must be persistable for fusion

        fp32_fused_param = self._create_persistable_var('fp32_fused_param')
        fp32_fused_grad = self._create_persistable_var('fp32_fused_grad')
277 278 279 280 281 282
        fp16_fused_param = self._create_persistable_var(
            'fp16_fused_param', dtype='float16'
        )
        fp16_fused_grad = self._create_persistable_var(
            'fp16_fused_grad', dtype='float16'
        )
283 284 285 286 287 288 289 290 291 292 293 294 295

        master_params = []
        for p, g in params_grads:
            master_p = self._create_persistable_var('master_weight')
            self._param_to_master_param[p.name] = master_p.name
            master_params.append(master_p)

        moment1 = self._create_persistable_var('moment1')
        moment1.is_distributed = True
        moment2 = self._create_persistable_var('moment2')
        moment2.is_distributed = True
        beta1pow = self._create_persistable_var('beta1pow')
        beta2pow = self._create_persistable_var('beta2pow')
296

297 298 299
        param_info = self._create_persistable_var('param_info', dtype='int32')
        param_info.is_distributed = True

300 301 302
        fused_offsets = self._create_persistable_var(
            'fused_offsets', dtype='int32'
        )
303 304

        fp32_partial_fused_offsets = self._create_persistable_var(
305 306
            'fp32_partial_fused_offsets', dtype='int32'
        )
307
        fp32_partial_fused_offsets.is_distributed = True
308

309
        fp16_partial_fused_offsets = self._create_persistable_var(
310 311
            'fp16_partial_fused_offsets', dtype='int32'
        )
312 313
        fp16_partial_fused_offsets.is_distributed = True

314 315 316
        param_order = self._create_persistable_var('param_order', dtype='int32')
        param_order.is_distributed = True

317 318 319 320 321
        if self._gradient_accumulation_steps > 1:
            fp32_acc_fused_grad = [
                self._create_persistable_var('fp32_acc_fused_grad')
            ]
            fp16_acc_fused_grad = [
322 323 324
                self._create_persistable_var(
                    'fp16_acc_fused_grad', dtype='float16'
                )
325 326 327 328 329 330 331
            ]
            acc_step = [self._create_persistable_var('acc_step', dtype='int64')]
        else:
            fp32_acc_fused_grad = []
            fp16_acc_fused_grad = []
            acc_step = []

332 333
        step = self._get_or_create_step()

W
wuhuachaocoding 已提交
334 335
        rank = paddle.distributed.get_rank()
        nranks = paddle.distributed.get_world_size()
336 337 338 339
        if self._nproc_per_node is None:
            nproc_per_node = nranks
        else:
            nproc_per_node = self._nproc_per_node
340 341 342
        assert (
            nranks % nproc_per_node == 0
        ), "nranks should be exactly divided by nproc_per_node"
343

344
        shard_inside_node = nranks > nproc_per_node
345 346 347 348 349 350
        local_rank = rank % nproc_per_node
        node_id = int(rank / nproc_per_node)
        node_num = int(nranks / nproc_per_node)
        ring_ids = []
        startup_block = self.helper.startup_program.global_block()
        if nranks > 1:
351 352 353
            ring_id = init_communicator(
                startup_block, rank, list(range(nranks)), 0
            )
354 355
            ring_ids.append(ring_id)

356
        use_hierarchical_allreduce = False
357 358
        if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node:
            local_group_ranks = list(
359 360 361 362 363
                range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node)
            )
            ring_id = init_communicator(
                startup_block, rank, local_group_ranks, 1
            )
364 365
            ring_ids.append(ring_id)

366 367 368
            if self._use_hierarchical_allreduce and nranks > nproc_per_node:
                use_hierarchical_allreduce = True
                outer_group_ranks = list(
369 370 371 372 373
                    range(rank % nproc_per_node, nranks, nproc_per_node)
                )
                ring_id = init_communicator(
                    startup_block, rank, outer_group_ranks, ring_ids[-1] + 1
                )
374 375
                ring_ids.append(ring_id)

376 377 378 379
        scale = self._get_or_create_scale()

        params = [p for p, _ in params_grads]
        grads = [g for _, g in params_grads]
380
        apply_weight_decay = [1] * len(params)
381 382 383
        if self._exclude_from_weight_decay_fn is not None:
            for i, p in enumerate(params):
                if self._exclude_from_weight_decay_fn(p):
384
                    apply_weight_decay[i] = 0
385 386

        for g in grads:
387 388 389 390 391 392 393
            startup_block.create_var(
                name=g.name,
                type=g.type,
                dtype=g.dtype,
                persistable=g.persistable,
                shape=g.shape,
            )
394

395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
        if nranks > 1:
            broadcast_parameters(startup_block, params, ring_ids[0])

        startup_block.append_op(
            type='distributed_fused_lamb_init',
            inputs={
                'Param': params,
                'Grad': grads,
            },
            outputs={
                'FP32FusedParam': [fp32_fused_param],
                'FP32FusedGrad': [fp32_fused_grad],
                'FP16FusedParam': [fp16_fused_param],
                'FP16FusedGrad': [fp16_fused_grad],
                'Moment1': [moment1],
                'Moment2': [moment2],
                'Beta1Pow': [beta1pow],
                'Beta2Pow': [beta2pow],
                'GlobalScale': [scale],
                'ParamInfo': [param_info],
                'ParamOut': params,
                'MasterParamOut': master_params,
                'GradOut': grads,
                'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
                'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
                'FusedParamOffsets': [fused_offsets],
                'ParamOrder': [param_order],
                'Step': [step],
            },
            attrs={
                'alignment': self._alignment,
                'rank': local_rank if shard_inside_node else rank,
                'nranks': nproc_per_node if shard_inside_node else nranks,
                'apply_weight_decay': apply_weight_decay,
                'moment1': 0.0,
                'moment2': 0.0,
                'beta1': self._beta1,
                'beta2': self._beta2,
433 434
            },
        )
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467

        main_block = self.helper.main_program.global_block()
        self._create_global_learning_rate()
        lr = None
        for p_g in params_grads:
            if lr is None:
                lr = self._create_param_lr(p_g)
            else:
                new_lr = self._create_param_lr(p_g)
                assert id(lr) == id(
                    new_lr
                ), "The learning rate for each parameter should be the same"
        assert lr is not None

        lamb_op = main_block.append_op(
            type='distributed_fused_lamb',
            inputs={
                'FP32FusedParam': [fp32_fused_param],
                'FP32FusedGrad': [fp32_fused_grad],
                'FP16FusedParam': [fp16_fused_param],
                'FP16FusedGrad': [fp16_fused_grad],
                'LearningRate': [lr],
                'Moment1': [moment1],
                'Moment2': [moment2],
                'Beta1Pow': [beta1pow],
                'Beta2Pow': [beta2pow],
                'GlobalScale': [scale],
                'ParamInfo': [param_info],
                'Param': params,
                'Grad': grads,
                'FusedParamOffsets': [fused_offsets],
                'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
                'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
468
                'ParamOrder': [param_order],
469 470 471 472 473 474 475 476
            },
            outputs={
                'FP32FusedParamOut': [fp32_fused_param],
                'FP16FusedParamOut': [fp16_fused_param],
                'Moment1Out': [moment1],
                'Moment2Out': [moment2],
                'Beta1PowOut': [beta1pow],
                'Beta2PowOut': [beta2pow],
477 478
                'ParamOut': params,
                'GradOut': grads,
479
                'FoundInf': [self._found_inf],
480 481 482 483 484 485
                'FP32AccFusedGrad': fp32_acc_fused_grad,
                'FP16AccFusedGrad': fp16_acc_fused_grad,
                'AccStep': acc_step,
                'StopUpdate': self._stop_update
                if self._stop_update is not None
                else [],
486
                'Step': [step],
487 488
            },
            attrs={
489
                'weight_decay': self._weight_decay,
490 491 492 493 494 495
                'beta1': self._beta1,
                'beta2': self._beta2,
                'epsilon': self._epsilon,
                'max_global_grad_norm': self._max_global_grad_norm,
                'clip_after_allreduce': self._clip_after_allreduce,
                'rank': rank,
496
                'nranks': nranks,
497
                'ring_ids': ring_ids,
498 499
                'use_master_param_norm': self._use_master_param_norm,
                'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
500
                'acc_steps': self._gradient_accumulation_steps,
501
                'use_master_acc_grad': self._use_master_acc_grad,
502
                'use_hierarchical_allreduce': use_hierarchical_allreduce,
503 504
            },
        )
505
        return [lamb_op]