auto_parallel_data_parallel_optimization.py 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import paddle
from paddle.fluid.framework import default_main_program
19
from paddle.distributed.fleet.meta_optimizers.common import OpRole
20 21 22 23 24 25 26 27 28 29
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group
from .pass_base import PassBase, PassType, register_pass

# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
    'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum',
    'merge_momentum'
]

30 31 32
# a heuristic number
__max_stream_num_allow__ = 16

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
    """
    Apply Optimizations that specialized for data parallelism in Auto Parallel.
    1. prune grad scaling 
    2. overlap comm and calc
    3. fuse allreduce
    """

    def __init__(self):
        super(DataParallelOptimizationPass, self).__init__()
        # NOTE not use depence on loss and param_grads
        self.set_attr("dist_context", None)
        self.set_attr("global_rank", -1)
        # {grad1: group1, grad2: group1, grad3: group2}
        # record the order for fuse grad data memory
        self._grad_name_to_group_map = OrderedDict()
        # {group1:[grad1, grad2] , group2:[grad3]}
        self._group_to_grad_name_map = OrderedDict()
        self._support_rescale_grad = False

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
        if (not isinstance(self.get_attr("global_rank"),
                           int)) or self.get_attr("global_rank") < 0:
            return False

        return True

    def _check_conflict(self, other_pass):
        return True

    def _type(self):
        return PassType.COMM_OPT

    def _apply_single_impl(self, main_program, startup_program, context):

        self.dist_context = self.get_attr("dist_context")
        self.global_rank = int(self.get_attr("global_rank"))

        with paddle.static.program_guard(main_program, startup_program):
            self._analyze_program()
            self._prune_grad_scaling()
78
            self._calc_comm_overlap()
79 80 81 82 83 84 85 86 87 88 89 90 91 92
            self._fuse_allreduce()

    def _prune_grad_scaling(self):

        if not self._could_be_prune():
            return

        if self._all_dp_groups_same_degree():
            self._scale_backward_initial_grad()
        else:
            self._update_opt_rescale_grad()

        self._remove_grad_scaling()

93 94 95 96 97
    def _calc_comm_overlap(self):
        if not self._could_be_overlap():
            return
        self._calc_overlap_comms()
        self._update_wait_comms()
98 99 100 101 102 103

    def _fuse_allreduce(self):
        pass

    def _analyze_program(self):
        """
104
        build two maps
105 106 107 108 109 110 111 112 113
        {param_grad_name: data_parallel_group}
        {pdata_parallel_group: aram_grad_name}
        """

        block = default_main_program().global_block()
        ops = block.ops
        scaled_grads = []

        for op in ops:
114 115
            grad_name = op.output_arg_names[0]

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
            if is_data_parallel_reduce_op(op):
                if grad_name in self._grad_name_to_group_map:
                    continue
                assert op.has_attr(
                    "ring_id"
                ), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
                group = ring_id_to_process_group(op.attr("ring_id"))

                assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format(
                    grad_name, str(op))

                self._grad_name_to_group_map[grad_name] = group

                if group not in self._group_to_grad_name_map:
                    self._group_to_grad_name_map[group] = [grad_name]
                else:
                    self._group_to_grad_name_map[group].append(grad_name)

            elif is_data_parallel_scale_op(op):
                scaled_grads.append(grad_name)

            # TODO support multiple optimizers in on network in future.
            # here we assume that the optimizer is unique in network.
            elif is_optimize_op(
                    op) and op.type in __rescale_grad_supported_opts__:
                self._support_rescale_grad = True

        not_synchronized_grads = []
        for grad_name in scaled_grads:
            if grad_name not in self._grad_name_to_group_map:
                not_synchronized_grads.append(grad_name)
        assert len(
            not_synchronized_grads
        ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
            not_synchronized_grads)

    def _could_be_prune(self):

J
JZ-LIANG 已提交
154 155
        return self.dist_context._gradient_scale and (
            self._support_rescale_grad or self._all_dp_groups_same_degree())
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216

    def _all_dp_groups_same_degree(self):
        return len(
            set([
                len(group.ranks)
                for group in self._group_to_grad_name_map.keys()
            ])) == 1

    def _scale_backward_initial_grad(self):

        block = default_main_program().global_block()
        dp_degree = len(list(self._group_to_grad_name_map.keys())[0].ranks)

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                assert op.type == 'fill_constant', \
                    "loss_grad_op must be fill_constant op, " \
                    "but this op is {}".format(op.type)
                assert op.has_attr('value')
                loss_scale = float(op.attr('value'))
                loss_scale = loss_scale / dp_degree
                op._set_attr('value', loss_scale)
                break

    def _remove_grad_scaling(self):
        block = default_main_program().global_block()

        for op_idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_scale_op(op):
                block._remove_op(op_idx, False)

        block._sync_with_cpp()

    def _update_opt_rescale_grad(self):

        block = default_main_program().global_block()
        scaled_grads = set()

        for idx, op in reversed(list(enumerate(block.ops))):
            if is_optimize_op(
                    op) and op.type in __rescale_grad_supported_opts__:
                assert op.has_attr(
                    'rescale_grad'
                ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
                    str(op))
                assert len(
                    op.input("Grad")
                ) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format(
                    str(op))

                grad_name = op.input("Grad")[0]
                dp_degree = len(
                    list(self._grad_name_to_group_map[grad_name].ranks))
                scaled_grads.add(grad_name)

                rescale_grad = float(op.attr('rescale_grad')) / dp_degree
                op._set_attr('rescale_grad', rescale_grad)

        assert scaled_grads == set(self._grad_name_to_group_map.keys(
        )), "Unexception: gradients [{}] are unscaled.".format(
            set(self._grad_name_to_group_map.keys()) - scaled_grads)
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279

    def _could_be_overlap(self):
        # NOTE current different nccl comm will use different cuda stream
        # so if there too many dp group there will be too many stream need to be
        # created and sync.
        # revise here when framework support custom stream in static mode.
        num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
        if num_dp_comm_stream > __max_stream_num_allow__:
            return False

        return True

    def _calc_overlap_comms(self):
        # TODO support InterpreterCore executor for overlap.
        # InterpreterCore has a different logic for overlapping
        # which is different from use_calc_stream
        block = default_main_program().global_block()
        ops = block.ops

        # comm wait calc to finish
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_data_parallel_reduce_op(op):
                assert op.has_attr('use_calc_stream')
                assert op.has_attr('ring_id')

                op._set_attr('use_calc_stream', False)
                ring_id = op.attr("ring_id")

                block._insert_op_without_sync(idx,
                                              type='c_wait_compute',
                                              inputs={'X': []},
                                              outputs={'Out': []},
                                              attrs={
                                                  'op_role': OpRole.Backward,
                                                  'ring_id': ring_id
                                              })

        block._sync_with_cpp()

    def _update_wait_comms(self):

        block = default_main_program().global_block()
        ops = block.ops

        # update wait comm to finish
        first_optimize_op_idx = -1
        for idx, op in enumerate(ops):
            if is_optimize_op(op):
                first_optimize_op_idx = idx
                break

        assert first_optimize_op_idx > -1, "Unexception: not found optimizer op in program"

        for group in self._group_to_grad_name_map.keys():
            ring_id = group.id
            block._insert_op_without_sync(first_optimize_op_idx,
                                          type='c_wait_comm',
                                          inputs={'X': []},
                                          outputs={'Out': []},
                                          attrs={
                                              'op_role': OpRole.Backward,
                                              'ring_id': ring_id
                                          })