common.py 1.8 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed import fleet


def init_parallel_env(mode, global_batch_size, seed=1024):
    '''
20 21
    Args:
        mode:(str) DP1-MP1-PP1-SH1-O1
R
Roc 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
    '''

    def parse_mode(mode):
        assert "DP" == mode[:2]
        assert "-MP" in mode
        assert "-PP" in mode
        assert "-SH" in mode
        assert "-O" in mode
        modes = mode.split("-")
        DP = int(modes[0][2:])
        MP = int(modes[1][2:])
        PP = int(modes[2][2:])
        SH = int(modes[3][2:])
        Ostage = int(modes[4][1:])
        return DP, MP, PP, SH, Ostage

    DP, MP, PP, SH, Ostage = parse_mode(mode)

    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": DP,
        "mp_degree": MP,
        "pp_degree": PP,
45
        "sharding_degree": SH,
R
Roc 已提交
46 47 48 49 50 51 52
    }

    accumulate_steps = 1

    if PP > 1:
        strategy.pipeline_configs = {
            "accumulate_steps": accumulate_steps,
53
            "micro_batch_size": global_batch_size // DP // accumulate_steps,
R
Roc 已提交
54 55 56 57 58 59 60
        }

    # set control in tensor parallel
    strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
    fleet.init(is_collective=True, strategy=strategy)

    return fleet.get_hybrid_communicate_group()