random.py 4.8 KB
Newer Older
J
JZ-LIANG 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import paddle

from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis

_logger = get_logger(logging.INFO)

_rng_name_to_seed = {}
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42

# use Prime number as offset to avoid confict
_mesh_offset = 173
_dim_offsets = [11, 23, 37, 73]


def is_enable_auto_rand_ctrl():
    global _enable_random_control
    return _enable_random_control


def enable_auto_rand_ctrl():
    global _enable_random_control
    _enable_random_control = True


def parallel_manual_seed(seed):
    """Enable auto parallel random control.
    Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
        * Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.

        * Consistency:  Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent.

    For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1].
    Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent;
    Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent.

    This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()).

    This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization.

    Examples:
        # seed relative to training step
        auto_parallel_random_seed((step + 13) * 257)
        ...
        engine.prepare()
    """

    enable_auto_rand_ctrl()
    global _basic_seed
    _basic_seed = seed


def determinate_rng(rank, dims_mapping, process_mesh):

    # TODO(JZ-LIANG) Support Mesh with any high rank
    # use a string to unique integer hashing algorithm for seed computation.
    # instead of using offsets to coodinate seed across devices.
    if len(process_mesh.shape) > 4:
        raise NotImplementedError(
            "Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format(
                str(process_mesh)
            )
        )
    global _basic_seed
    seed_ = _basic_seed

    # FIXME
    # unique_id = process_mesh.unique_id
    unique_id = retrive_unique_id_for_process_mesh(
        process_mesh.shape, process_mesh.process_ids
    )
    sharding_expr = f'mesh:{unique_id}'
    seed_ += _mesh_offset * (unique_id + 1)

    for i in range(len(process_mesh.shape)):
        if i not in dims_mapping:
            relative_idx = -1
        else:
            relative_idx = _get_idx_in_axis(
                process_mesh.process_ids,
                process_mesh.shape,
                i,
                rank,
            )

        sharding_expr += f"_dim{i}:{relative_idx}"
        seed_ += _dim_offsets[i] * (relative_idx + 1)

    global _rng_name_to_seed
    if sharding_expr in _rng_name_to_seed:
        assert _rng_name_to_seed[sharding_expr] == seed_
    else:
        assert (
            seed_ not in _rng_name_to_seed.values()
        ), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format(
            seed_, sharding_expr, _rng_name_to_seed
        )
        _rng_name_to_seed[sharding_expr] = seed_

    return sharding_expr


def init_auto_parallel_rng():

    if not is_enable_auto_rand_ctrl():
        return

    global _rng_name_to_seed
    # NOTE init rng maybe call multiple times, avoid init same rng twice
    global _inited_rng_name_to_seed

    for rng_name, seed in _rng_name_to_seed.items():
        if rng_name in _inited_rng_name_to_seed:
            assert _inited_rng_name_to_seed[rng_name] == seed
        else:
            _logger.info(
                f"Init Auto Parallel RNG: {rng_name}, with seed {seed}"
            )
            paddle.framework.random.set_random_seed_generator(rng_name, seed)
            _inited_rng_name_to_seed[rng_name] = seed