tensor_parallel_utils.py 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

logger = logging.getLogger(__name__)
formatter = logging.Formatter(
    fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
)
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY
from paddle.fluid import core
from paddle.fluid.framework import Parameter

_supported_optimizer_type = [
    "adam",
    "adamax",
    "adamw",
    "decayed_adagrad",
    "momentum",
    "dgc_momentum",
    "lars_momentum",
    "merged_momentum",
    "lamb",
    "sgd",
]


def tensor_parallel_sync_filter_fn(
    param, pos_emb=True, layer_norm=True, bias=True
):
    """
    Layer fliter function for tensor parallelism transformer.

    In tensor parallelism of transformer like model, there is 4 kind of param
    that are supposed to be the same in all tensor parallel peers:
        * position embedding
        * scale of layer norm
        * bias of layer norm
        * bias of row parallel linear

    set corresponding input args to select specific layers.
    NOTE  adopting the param name pattern for different transformer blocks.
    """
    p_name = param.name
    if pos_emb and p_name.startswith("pos_embedding"):
        return True

    elif layer_norm and p_name.endswith("_layer_norm_bias"):
        return True

    elif layer_norm and p_name.endswith("_layer_norm_scale"):
        return True

    elif bias and ".b_" in p_name and (param.is_distributed is False):
        return True

    else:
        return False


def resolute_tensor_parallel_ring_id(program):
    ops = program.global_block().ops
    ring_id = None

    for op in ops:
        if op.type == "c_identity":
            if ring_id is None:
                ring_id = int(op.attr("ring_id"))
            else:
                assert ring_id == int(
                    op.attr("ring_id")
                ), "Found two different ring_id for Tensor Parallel: ring_id={} and ring_id={}.".format(
                    ring_id, int(op.attr("ring_id"))
                )
    assert ring_id is not None, "Could NOT found ring_id for Tensor Parallel."

    return ring_id


def copy_parameters(block_, params):
    for param in params:
        new_p = Parameter(
            block=block_,
            shape=param.shape,
            dtype=param.dtype,
            type=param.type,
            lod_level=param.lod_level
            if param.type == core.VarDesc.VarType.LOD_TENSOR
            else None,
            stop_gradient=param.stop_gradient,
            trainable=param.trainable,
            optimize_attr=param.optimize_attr,
            regularizer=param.regularizer,
            error_clip=param.error_clip,
            name=param.name,
        )
        assert (
            param.is_distributed is False
        ), "Try to sync Distribted Parameter: {}".format(param)
        new_p.is_distributed = False

    block_.vars[new_p.name] = new_p


def insert_sync_op(
    block, idx, tp_degree, sync_mode, sync_ring_id, src_rank, varname, op_role
):

    if sync_mode == "broadcast":
        block._insert_op_without_sync(
            idx,
            type='c_broadcast',
            inputs={'X': varname},
            outputs={'Out': varname},
            attrs={
                'ring_id': sync_ring_id,
                'root': src_rank,
                'use_calc_stream': True,
                OP_ROLE_KEY: op_role,
            },
        )

    elif sync_mode == "average":
        block._insert_op_without_sync(
            idx,
            type='scale',
            inputs={'X': varname},
            outputs={'Out': varname},
            attrs={'scale': 1.0 / tp_degree, OP_ROLE_KEY: op_role},
        )
        block._insert_op_without_sync(
            idx,
            type='c_allreduce_sum',
            inputs={'X': varname},
            outputs={'Out': varname},
            attrs={
                'ring_id': sync_ring_id,
                'use_calc_stream': True,
                OP_ROLE_KEY: op_role,
            },
        )
    else:
        raise NotImplementedError(
            'Sync mode of [{}] is NOT supported.'.format(sync_mode)
        )


def insert_synchronization(
    block,
    params_to_sync,
    tp_degree,
    sync_ring_id,
    sync_param,
    sync_grad,
    sync_moment,
    sync_mode,
    src_rank,
):

    unsync_param_names = [p.name for p in params_to_sync]

    for idx, op in reversed(list(enumerate(block.ops))):

        if op.type in _supported_optimizer_type:
            assert "Param" in op.input_names
            assert len(op.input("Param")) == 1
            param_name = op.input("Param")[0]
            op_role = op.attr(OP_ROLE_KEY)

            if param_name in unsync_param_names:

                unsync_param_names.remove(param_name)

                # Param sync after opt
                if sync_param:
                    assert (
                        "ParamOut" in op.output_names
                        and op.output("ParamOut")[0] == param_name
                    )
                    insert_sync_op(
                        block,
                        idx + 1,
                        tp_degree,
                        sync_mode,
                        sync_ring_id,
                        src_rank,
                        param_name,
                        op_role,
                    )

                    if (
                        "MasterParamOut" in op.output_names
                        and len(op.output("MasterParamOut")) == 1
                    ):
                        sync_var = op.output("MasterParamOut")[0]
                        insert_sync_op(
                            block,
                            idx + 1,
                            tp_degree,
                            sync_mode,
                            sync_ring_id,
                            src_rank,
                            sync_var,
                            op_role,
                        )

                # Moment sync after opt
                if sync_moment:
                    if (
                        "Moment1Out" in op.output_names
                        and len(op.output("Moment1Out")) == 1
                    ):
                        sync_var = op.output("Moment1Out")[0]
                        insert_sync_op(
                            block,
                            idx + 1,
                            tp_degree,
                            sync_mode,
                            sync_ring_id,
                            src_rank,
                            sync_var,
                            op_role,
                        )

                    if (
                        "Moment2Out" in op.output_names
                        and len(op.output("Moment2Out")) == 1
                    ):
                        sync_var = op.output("Moment2Out")[0]
                        insert_sync_op(
                            block,
                            idx + 1,
                            tp_degree,
                            sync_mode,
                            sync_ring_id,
                            src_rank,
                            sync_var,
                            op_role,
                        )

                # Grad sync before opt
                if sync_grad:
                    assert (
                        "Grad" in op.input_names and len(op.input("Grad")) == 1
                    )
                    sync_var = op.input("Grad")[0]
                    insert_sync_op(
                        block,
                        idx,
                        tp_degree,
                        sync_mode,
                        sync_ring_id,
                        src_rank,
                        sync_var,
                        op_role,
                    )

    assert (
        len(unsync_param_names) == 0
    ), "The following param is unsync by some error: {}".format(
        unsync_param_names
    )


def add_extra_synchronization(
    program,
    params_filter_fn=tensor_parallel_sync_filter_fn,
    tp_degree=8,
    sync_mode="broadcast",
    sync_param=True,
    sync_grad=False,
    sync_moment=False,
    src_rank=0,
    sync_ring_id=None,
):
    """
    Inplace add extra synchronization for input program.

    program(Paddle.Program): distributed train program.

    params_filter_fn(callable): function to filter out parameter for synchronization.

    sync_mode(string): select from
        "broadcast": parameter is sync by broadcasted from 'src_rank' to all other ranks.
        "average": paramter is sync by average amonge all ranks

    src_rank(int): the src used in broadcast sync_mode.

    sync_param(bool): extra synchronize parameters.

    sync_grad(bool): extra synchronize gradients.

    sync_grad(bool): extra synchronize optimizer momentum.

    sync_ring_id(int): communicator id use for synchronization, if it is None, use the ring_id of tensor parallel.
    """

    logger.info("Constructing Extra Parameter Synchronization.")
    logger.info(
        "Tensor Parallel Degree: {}, Synchronization mode: {}".format(
            tp_degree, sync_mode
        )
    )

    # adopt for pipeline opt
    if program._pipeline_opt is not None:
        assert (
            program._pipeline_opt['section_program'] is not None
        ), "Pipeline is enable but section_program is None"
        program = program._pipeline_opt['section_program']

    # step1: collect the param that need to be sync
    params_to_sync = []
    # TODO support multiple blocks with different parameter.
    all_params = program.global_block().all_parameters()
    for param in all_params:
        if params_filter_fn(param):
            params_to_sync.append(param)
    logger.info(
        "The following param are goning to be synchronization everytime the optimizer update phase of the program is runned: "
    )
    logger.info([p.name for p in params_to_sync])

    # step2: resolute synchronization communicator group (ring_id)
    if sync_ring_id is None:
        sync_ring_id = resolute_tensor_parallel_ring_id(program)

    # step3: insert synchronization
    # TODO support gradient merge with different update block
    block = program.global_block()
    insert_synchronization(
        block,
        params_to_sync,
        tp_degree,
        sync_ring_id,
        sync_param,
        sync_grad,
        sync_moment,
        sync_mode,
        src_rank,
    )