hybrid_parallel_optimizer.py 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function
import sys
Z
zhangchunle 已提交
17
import paddle
18
from paddle.optimizer import Optimizer
19
from paddle.fluid.clip import ClipGradByGlobalNorm
J
JZ-LIANG 已提交
20
from ...utils.hybrid_parallel_util import fused_allreduce_gradients, sharding_reduce_gradients
21 22 23 24
from ...base.topology import ParallelMode
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import framework
from paddle.fluid.framework import Variable
25
from ...utils.log_util import logger
Z
zhangchunle 已提交
26 27
from paddle.fluid import core
from paddle.fluid import layers
28

29 30
__all__ = []

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

class HybridParallelClipGrad:
    def __init__(self, clip, hcg):
        self._clip = clip
        self._hcg = hcg

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)
            sum_square_list.append(sum_square)

        # all parameters have been filterd out
        if len(sum_square_list) == 0:
            return params_grads

        global_norm_var = layers.concat(sum_square_list)
        global_norm_var = layers.reduce_sum(global_norm_var)
        # add all reduce to get global norm in world size
        paddle.distributed.all_reduce(global_norm_var,
                                      self._hcg.get_check_parallel_group())
        global_norm_var = layers.sqrt(global_norm_var)

        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
        clip_var = layers.elementwise_div(
            x=max_global_norm,
            y=layers.elementwise_max(
                x=global_norm_var, y=max_global_norm))
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                params_and_grads.append((p, g))
                continue
            new_grad = layers.elementwise_mul(x=g, y=clip_var)
            params_and_grads.append((p, new_grad))

        return params_and_grads

    def __getattr__(self, item):
        return getattr(self._clip, item)

    def __call__(self, params_grads):
        return self._clip(params_grads)
87 88 89


class HybridParallelOptimizer:
90
    # adapter wrapper for optimizer
91 92 93 94
    def __init__(self, optimizer, hcg, strategy):
        self._inner_opt = optimizer
        self._strategy = strategy
        self._hcg = hcg
95 96 97 98

        self._use_dp_mode = (
            self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL)

99 100
        self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)

J
JZ-LIANG 已提交
101 102 103
        self._sharding_enable = (
            self._hcg.get_sharding_parallel_world_size() > 1)

104
        if isinstance(self._inner_opt._grad_clip,
105
                      ClipGradByGlobalNorm) and not self._use_dp_mode:
106
            logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
107 108 109 110
                  "optmizer'grad clip will be changed.")
            self._inner_opt._grad_clip = HybridParallelClipGrad(
                self._inner_opt._grad_clip, hcg)

111 112 113
    @imperative_base.no_grad
    @framework.dygraph_only
    def step(self):
J
JZ-LIANG 已提交
114 115 116 117 118
        # Here should use global parameter list 
        if self._sharding_enable:
            sharding_reduce_gradients(
                list(self._inner_opt._parameter_list), self._hcg)

119
        if not self._use_dp_mode and self._need_dp:
120 121 122 123 124 125 126 127 128 129 130 131
            fused_allreduce_gradients(
                list(self._inner_opt._parameter_list), self._hcg)
        self._inner_opt.step()

    @imperative_base.no_grad
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameters=None,
                 no_grad_set=None):

        parameter_list = parameters if parameters \
J
JZ-LIANG 已提交
132 133 134 135 136 137
            else self._inner_opt._parameter_list

        # Here should use global parameter list 
        if self._sharding_enable:
            sharding_reduce_gradients(
                list(self._inner_opt._parameter_list), self._hcg)
138

139
        if not self._use_dp_mode and self._need_dp:
140 141
            fused_allreduce_gradients(list(parameter_list), self._hcg)

J
JZ-LIANG 已提交
142
        return self._inner_opt.minimize(loss, startup_program, parameter_list,
143 144 145 146
                                        no_grad_set)

    def __getattr__(self, item):
        return getattr(self._inner_opt, item)