model.py 7.0 KB
Newer Older
W
wuhuachaocoding 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import os
import numpy as np
from .base import topology as tp
from .base.topology import ParallelMode
from .meta_parallel import TensorParallel, model_parallel_random_seed
from .meta_parallel import PipelineParallel, ShardingParallel
from paddle.fluid import core
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
from paddle.distributed import fleet


class _RecomputeModelWrapper(paddle.nn.Layer):

    def __init__(self, model, segments=2, preserve_rng_state=True):
        super(_RecomputeModelWrapper, self).__init__()
        assert isinstance(model, paddle.nn.Sequential), (
            "The model passed to RecomputeModelWrapper must be of type "
            "paddle.nn.Sequential.")
        self._model = model
        self._segments = segments
        self._preserve_rng_state = preserve_rng_state
        self._layers = list(model.children())
        self._segment_size = len(self._layers) // segments

    def _run_func(self, begin, end):

        def do_run(input):
            for i in range(begin, end):
                input = self._layers[i](input)
            return input

        return do_run

    def _checkpoint(self, func, *args, **kwargs):
        return LegacyRecomputeFunction.apply(func, self._preserve_rng_state,
                                             *args)

    def forward(self, input):
        end = 0
        for begin in range(0, self._segment_size * (self._segments - 1),
                           self._segment_size):
            end = begin + self._segment_size
            input = self._checkpoint(self._run_func(begin, end), input)
        return self._run_func(end, len(self._layers))(input)


_grad_scalar = None


def distributed_model(model):
    """
    Return distributed data parallel model (Only work in dygraph mode)

    Args:
        model (Layer): the user-defind model which inherits Layer.

    Returns:
        distributed data parallel model which inherits Layer.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.nn as nn
            from paddle.distributed import fleet

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)

                def forward(self, x):
                    return self._linear2(self._linear1(x))

            # 1. initialize fleet environment
            fleet.init(is_collective=True)

            # 2. create layer & optimizer
            layer = LinearNet()
            loss_fn = nn.MSELoss()
            adam = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=layer.parameters())

            # 3. get data_parallel model using fleet
            adam = fleet.distributed_optimizer(adam)
            dp_layer = fleet.distributed_model(layer)

            # 4. run layer
            inputs = paddle.randn([10, 10], 'float32')
            outputs = dp_layer(inputs)
            labels = paddle.randn([10, 1], 'float32')
            loss = loss_fn(outputs, labels)

            print("loss:", loss.numpy())

            loss.backward()

            adam.step()
            adam.clear_grad()


    """
    fleet_env = fleet.fleet

    assert model is not None, "model should not be None"
    if fleet_env.worker_num() <= 1:
        return model

    amp_enable = False
    recompute_enable = False
    strategy = fleet_env._user_defined_strategy
    if strategy.amp == True:
        amp_enable = True
        amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
        if amp_level.upper() == "O2":
            model = paddle.amp.decorate(models=model,
                                        optimizers=None,
                                        level="O2",
                                        master_weight=None,
                                        save_dtype=None)
        init_loss_scaling = strategy.amp_configs['init_loss_scaling']
        incr_ratio = strategy.amp_configs['incr_ratio']
        decr_ratio = strategy.amp_configs['decr_ratio']
        incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
        decr_every_n_nan_or_inf = strategy.amp_configs[
            'decr_every_n_nan_or_inf']
        use_dynamic_loss_scaling = strategy.amp_configs[
            'use_dynamic_loss_scaling']

        global _grad_scalar
        _grad_scalar = paddle.amp.GradScaler(
            init_loss_scaling=init_loss_scaling,
            incr_ratio=incr_ratio,
            decr_ratio=decr_ratio,
            incr_every_n_steps=incr_every_n_steps,
            decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)

    if strategy.recompute == True:
        recompute_enable = True
        model = _RecomputeModelWrapper(model)

    if strategy.heter_ccl_mode == True:
        distributed_model = paddle.DataParallel(
            model,
            comm_buffer_size=strategy.fuse_grad_size_in_MB,
            last_comm_buffer_size=strategy.last_comm_group_size_MB,
            find_unused_parameters=strategy.find_unused_parameters)
        return distributed_model

    if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
        model = ShardingParallel(model, fleet_env._hcg, strategy=strategy)
    elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:

        # NOTE (JZ-LIANG) init parameters broadcast within sharding group
        # normally it should be done inside DataParallel
        if fleet_env.sharding_degree > 1:
            from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
            assert fleet_env.sharding_degree == fleet_env._hcg.get_sharding_parallel_world_size(
            )
            broadcast_sharding_parameters(model, fleet_env._hcg)
        model = paddle.DataParallel(
            model,
            comm_buffer_size=strategy.fuse_grad_size_in_MB,
            last_comm_buffer_size=strategy.last_comm_group_size_MB,
            find_unused_parameters=strategy.find_unused_parameters)
    elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
        model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
    elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
        model = PipelineParallel(model, fleet_env._hcg, strategy=strategy)

    return model